[V1] Add disable_chunked_mm_input
arg to disable partial mm input prefill (#15837)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
87918e40c4
commit
8e5314a468
@ -24,6 +24,7 @@ def create_scheduler(
|
|||||||
max_num_batched_tokens: int = 8192,
|
max_num_batched_tokens: int = 8192,
|
||||||
enable_prefix_caching: Optional[bool] = None,
|
enable_prefix_caching: Optional[bool] = None,
|
||||||
long_prefill_token_threshold: int = 0,
|
long_prefill_token_threshold: int = 0,
|
||||||
|
disable_chunked_mm_input: bool = False,
|
||||||
) -> Scheduler:
|
) -> Scheduler:
|
||||||
'''Create scheduler under test.
|
'''Create scheduler under test.
|
||||||
|
|
||||||
@ -43,6 +44,7 @@ def create_scheduler(
|
|||||||
max_num_batched_tokens=max_num_batched_tokens,
|
max_num_batched_tokens=max_num_batched_tokens,
|
||||||
max_model_len=max_num_batched_tokens,
|
max_model_len=max_num_batched_tokens,
|
||||||
long_prefill_token_threshold=long_prefill_token_threshold,
|
long_prefill_token_threshold=long_prefill_token_threshold,
|
||||||
|
disable_chunked_mm_input=disable_chunked_mm_input,
|
||||||
)
|
)
|
||||||
model_config = ModelConfig(
|
model_config = ModelConfig(
|
||||||
model=model,
|
model=model,
|
||||||
@ -278,6 +280,49 @@ def test_schedule_partial_requests():
|
|||||||
assert requests[2].request_id not in output.num_scheduled_tokens
|
assert requests[2].request_id not in output.num_scheduled_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_mm_input_chunking():
|
||||||
|
# Disable multimodal input chunking.
|
||||||
|
scheduler = create_scheduler(
|
||||||
|
model="llava-hf/llava-1.5-7b-hf",
|
||||||
|
max_num_batched_tokens=1024,
|
||||||
|
disable_chunked_mm_input=True,
|
||||||
|
)
|
||||||
|
mm_positions = [[PlaceholderRange(offset=400, length=800)]]
|
||||||
|
requests = create_requests(num_requests=1,
|
||||||
|
num_tokens=1200,
|
||||||
|
mm_positions=mm_positions)
|
||||||
|
for request in requests:
|
||||||
|
scheduler.add_request(request)
|
||||||
|
|
||||||
|
output = scheduler.schedule()
|
||||||
|
assert len(output.scheduled_new_reqs) == 1
|
||||||
|
assert len(output.scheduled_cached_reqs) == 0
|
||||||
|
assert len(output.finished_req_ids) == 0
|
||||||
|
# We want to only see the 400 text tokens at the start scheduled
|
||||||
|
assert output.num_scheduled_tokens[requests[0].request_id] == 400
|
||||||
|
|
||||||
|
req_to_index = {
|
||||||
|
request.request_id: i
|
||||||
|
for i, request in enumerate(requests)
|
||||||
|
}
|
||||||
|
model_runner_output = ModelRunnerOutput(
|
||||||
|
req_ids=[request.request_id for request in requests],
|
||||||
|
req_id_to_index=req_to_index,
|
||||||
|
sampled_token_ids=[[] for _ in range(len(requests))],
|
||||||
|
spec_token_ids=None,
|
||||||
|
logprobs=None,
|
||||||
|
prompt_logprobs_dict={},
|
||||||
|
)
|
||||||
|
scheduler.update_from_output(output, model_runner_output)
|
||||||
|
|
||||||
|
output = scheduler.schedule()
|
||||||
|
assert len(scheduler.running) == 1
|
||||||
|
assert len(output.scheduled_new_reqs) == 0
|
||||||
|
assert len(output.scheduled_cached_reqs) == 1
|
||||||
|
assert len(output.finished_req_ids) == 0
|
||||||
|
assert output.num_scheduled_tokens[requests[0].request_id] == 800
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("enable_prefix_caching", [True, False])
|
@pytest.mark.parametrize("enable_prefix_caching", [True, False])
|
||||||
def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
|
def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
|
||||||
"""Test scheduling behavior with concurrent partial requests.
|
"""Test scheduling behavior with concurrent partial requests.
|
||||||
|
@ -1721,6 +1721,14 @@ class SchedulerConfig:
|
|||||||
|
|
||||||
chunked_prefill_enabled: bool = field(init=False)
|
chunked_prefill_enabled: bool = field(init=False)
|
||||||
|
|
||||||
|
# If set to true and chunked prefill is enabled, we do not want to
|
||||||
|
# partially schedule a multimodal item. Only used in V1
|
||||||
|
# This ensures that if a request has a mixed prompt
|
||||||
|
# (like text tokens TTTT followed by image tokens IIIIIIIIII) where only
|
||||||
|
# some image tokens can be scheduled (like TTTTIIIII, leaving IIIII),
|
||||||
|
# it will be scheduled as TTTT in one step and IIIIIIIIII in the next.
|
||||||
|
disable_chunked_mm_input: bool = False
|
||||||
|
|
||||||
# scheduler class or path. "vllm.core.scheduler.Scheduler" (default)
|
# scheduler class or path. "vllm.core.scheduler.Scheduler" (default)
|
||||||
# or "mod.custom_class".
|
# or "mod.custom_class".
|
||||||
scheduler_cls: Union[str, type[object]] = "vllm.core.scheduler.Scheduler"
|
scheduler_cls: Union[str, type[object]] = "vllm.core.scheduler.Scheduler"
|
||||||
|
@ -179,6 +179,7 @@ class EngineArgs:
|
|||||||
|
|
||||||
scheduler_delay_factor: float = 0.0
|
scheduler_delay_factor: float = 0.0
|
||||||
enable_chunked_prefill: Optional[bool] = None
|
enable_chunked_prefill: Optional[bool] = None
|
||||||
|
disable_chunked_mm_input: bool = False
|
||||||
|
|
||||||
guided_decoding_backend: str = 'xgrammar'
|
guided_decoding_backend: str = 'xgrammar'
|
||||||
logits_processor_pattern: Optional[str] = None
|
logits_processor_pattern: Optional[str] = None
|
||||||
@ -1017,6 +1018,20 @@ class EngineArgs:
|
|||||||
"Note that even if this is set to False, cascade attention will be "
|
"Note that even if this is set to False, cascade attention will be "
|
||||||
"only used when the heuristic tells that it's beneficial.")
|
"only used when the heuristic tells that it's beneficial.")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--disable-chunked-mm-input",
|
||||||
|
action=StoreBoolean,
|
||||||
|
default=EngineArgs.disable_chunked_mm_input,
|
||||||
|
nargs="?",
|
||||||
|
const="False",
|
||||||
|
help="Disable multimodal input chunking attention for V1. "
|
||||||
|
"If set to true and chunked prefill is enabled, we do not want to"
|
||||||
|
" partially schedule a multimodal item. This ensures that if a "
|
||||||
|
"request has a mixed prompt (like text tokens TTTT followed by "
|
||||||
|
"image tokens IIIIIIIIII) where only some image tokens can be "
|
||||||
|
"scheduled (like TTTTIIIII, leaving IIIII), it will be scheduled "
|
||||||
|
"as TTTT in one step and IIIIIIIIII in the next.")
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -1261,6 +1276,7 @@ class EngineArgs:
|
|||||||
num_lookahead_slots=num_lookahead_slots,
|
num_lookahead_slots=num_lookahead_slots,
|
||||||
delay_factor=self.scheduler_delay_factor,
|
delay_factor=self.scheduler_delay_factor,
|
||||||
enable_chunked_prefill=self.enable_chunked_prefill,
|
enable_chunked_prefill=self.enable_chunked_prefill,
|
||||||
|
disable_chunked_mm_input=self.disable_chunked_mm_input,
|
||||||
is_multimodal_model=model_config.is_multimodal_model,
|
is_multimodal_model=model_config.is_multimodal_model,
|
||||||
preemption_mode=self.preemption_mode,
|
preemption_mode=self.preemption_mode,
|
||||||
num_scheduler_steps=self.num_scheduler_steps,
|
num_scheduler_steps=self.num_scheduler_steps,
|
||||||
|
@ -522,6 +522,17 @@ class Scheduler(SchedulerInterface):
|
|||||||
if self.encoder_cache_manager.has_cache(request, i):
|
if self.encoder_cache_manager.has_cache(request, i):
|
||||||
# The encoder input is already computed and cached.
|
# The encoder input is already computed and cached.
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# If no encoder input chunking is allowed, we do not want to
|
||||||
|
# partially schedule a multimodal item. If the scheduled range would
|
||||||
|
# only cover part of the mm input, roll back to before the mm item.
|
||||||
|
if (self.scheduler_config.disable_chunked_mm_input
|
||||||
|
and num_computed_tokens < start_pos
|
||||||
|
and (num_computed_tokens + num_new_tokens)
|
||||||
|
< (start_pos + num_encoder_tokens)):
|
||||||
|
num_new_tokens = start_pos - num_computed_tokens
|
||||||
|
break
|
||||||
|
|
||||||
if (not self.encoder_cache_manager.can_allocate(request, i)
|
if (not self.encoder_cache_manager.can_allocate(request, i)
|
||||||
or num_encoder_tokens > encoder_budget):
|
or num_encoder_tokens > encoder_budget):
|
||||||
# The encoder cache is full or the encoder budget is exhausted.
|
# The encoder cache is full or the encoder budget is exhausted.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user