From 8e5314a46859cdd089a71c0e27ce6dd5fe05b17f Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Tue, 8 Apr 2025 00:24:07 -0600 Subject: [PATCH] [V1] Add `disable_chunked_mm_input` arg to disable partial mm input prefill (#15837) Signed-off-by: mgoin --- tests/v1/core/test_scheduler.py | 45 +++++++++++++++++++++++++++++++++ vllm/config.py | 8 ++++++ vllm/engine/arg_utils.py | 16 ++++++++++++ vllm/v1/core/sched/scheduler.py | 11 ++++++++ 4 files changed, 80 insertions(+) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 21a1cbf5..75c50755 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -24,6 +24,7 @@ def create_scheduler( max_num_batched_tokens: int = 8192, enable_prefix_caching: Optional[bool] = None, long_prefill_token_threshold: int = 0, + disable_chunked_mm_input: bool = False, ) -> Scheduler: '''Create scheduler under test. @@ -43,6 +44,7 @@ def create_scheduler( max_num_batched_tokens=max_num_batched_tokens, max_model_len=max_num_batched_tokens, long_prefill_token_threshold=long_prefill_token_threshold, + disable_chunked_mm_input=disable_chunked_mm_input, ) model_config = ModelConfig( model=model, @@ -278,6 +280,49 @@ def test_schedule_partial_requests(): assert requests[2].request_id not in output.num_scheduled_tokens +def test_no_mm_input_chunking(): + # Disable multimodal input chunking. + scheduler = create_scheduler( + model="llava-hf/llava-1.5-7b-hf", + max_num_batched_tokens=1024, + disable_chunked_mm_input=True, + ) + mm_positions = [[PlaceholderRange(offset=400, length=800)]] + requests = create_requests(num_requests=1, + num_tokens=1200, + mm_positions=mm_positions) + for request in requests: + scheduler.add_request(request) + + output = scheduler.schedule() + assert len(output.scheduled_new_reqs) == 1 + assert len(output.scheduled_cached_reqs) == 0 + assert len(output.finished_req_ids) == 0 + # We want to only see the 400 text tokens at the start scheduled + assert output.num_scheduled_tokens[requests[0].request_id] == 400 + + req_to_index = { + request.request_id: i + for i, request in enumerate(requests) + } + model_runner_output = ModelRunnerOutput( + req_ids=[request.request_id for request in requests], + req_id_to_index=req_to_index, + sampled_token_ids=[[] for _ in range(len(requests))], + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + ) + scheduler.update_from_output(output, model_runner_output) + + output = scheduler.schedule() + assert len(scheduler.running) == 1 + assert len(output.scheduled_new_reqs) == 0 + assert len(output.scheduled_cached_reqs) == 1 + assert len(output.finished_req_ids) == 0 + assert output.num_scheduled_tokens[requests[0].request_id] == 800 + + @pytest.mark.parametrize("enable_prefix_caching", [True, False]) def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool): """Test scheduling behavior with concurrent partial requests. diff --git a/vllm/config.py b/vllm/config.py index c232f0f5..439e27b1 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1721,6 +1721,14 @@ class SchedulerConfig: chunked_prefill_enabled: bool = field(init=False) + # If set to true and chunked prefill is enabled, we do not want to + # partially schedule a multimodal item. Only used in V1 + # This ensures that if a request has a mixed prompt + # (like text tokens TTTT followed by image tokens IIIIIIIIII) where only + # some image tokens can be scheduled (like TTTTIIIII, leaving IIIII), + # it will be scheduled as TTTT in one step and IIIIIIIIII in the next. + disable_chunked_mm_input: bool = False + # scheduler class or path. "vllm.core.scheduler.Scheduler" (default) # or "mod.custom_class". scheduler_cls: Union[str, type[object]] = "vllm.core.scheduler.Scheduler" diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 6d9f89fa..0c81e3ed 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -179,6 +179,7 @@ class EngineArgs: scheduler_delay_factor: float = 0.0 enable_chunked_prefill: Optional[bool] = None + disable_chunked_mm_input: bool = False guided_decoding_backend: str = 'xgrammar' logits_processor_pattern: Optional[str] = None @@ -1017,6 +1018,20 @@ class EngineArgs: "Note that even if this is set to False, cascade attention will be " "only used when the heuristic tells that it's beneficial.") + parser.add_argument( + "--disable-chunked-mm-input", + action=StoreBoolean, + default=EngineArgs.disable_chunked_mm_input, + nargs="?", + const="False", + help="Disable multimodal input chunking attention for V1. " + "If set to true and chunked prefill is enabled, we do not want to" + " partially schedule a multimodal item. This ensures that if a " + "request has a mixed prompt (like text tokens TTTT followed by " + "image tokens IIIIIIIIII) where only some image tokens can be " + "scheduled (like TTTTIIIII, leaving IIIII), it will be scheduled " + "as TTTT in one step and IIIIIIIIII in the next.") + return parser @classmethod @@ -1261,6 +1276,7 @@ class EngineArgs: num_lookahead_slots=num_lookahead_slots, delay_factor=self.scheduler_delay_factor, enable_chunked_prefill=self.enable_chunked_prefill, + disable_chunked_mm_input=self.disable_chunked_mm_input, is_multimodal_model=model_config.is_multimodal_model, preemption_mode=self.preemption_mode, num_scheduler_steps=self.num_scheduler_steps, diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index b3905987..488d32cb 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -522,6 +522,17 @@ class Scheduler(SchedulerInterface): if self.encoder_cache_manager.has_cache(request, i): # The encoder input is already computed and cached. continue + + # If no encoder input chunking is allowed, we do not want to + # partially schedule a multimodal item. If the scheduled range would + # only cover part of the mm input, roll back to before the mm item. + if (self.scheduler_config.disable_chunked_mm_input + and num_computed_tokens < start_pos + and (num_computed_tokens + num_new_tokens) + < (start_pos + num_encoder_tokens)): + num_new_tokens = start_pos - num_computed_tokens + break + if (not self.encoder_cache_manager.can_allocate(request, i) or num_encoder_tokens > encoder_budget): # The encoder cache is full or the encoder budget is exhausted.