Enforce valid max_num_batched_tokens when disable_chunked_mm_input=True (#16447)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
f7030df3be
commit
aa3b3d76e0
@ -322,6 +322,15 @@ def test_no_mm_input_chunking():
|
|||||||
assert len(output.finished_req_ids) == 0
|
assert len(output.finished_req_ids) == 0
|
||||||
assert output.num_scheduled_tokens[requests[0].request_id] == 800
|
assert output.num_scheduled_tokens[requests[0].request_id] == 800
|
||||||
|
|
||||||
|
# Test that we fail if we disable chunked mm input and use too small
|
||||||
|
# of a max_num_batched_tokens for the mm input.
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_ = create_scheduler(
|
||||||
|
model="llava-hf/llava-1.5-7b-hf",
|
||||||
|
max_num_batched_tokens=100,
|
||||||
|
disable_chunked_mm_input=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("enable_prefix_caching", [True, False])
|
@pytest.mark.parametrize("enable_prefix_caching", [True, False])
|
||||||
def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
|
def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
|
||||||
|
@ -1030,7 +1030,7 @@ class EngineArgs:
|
|||||||
action=StoreBoolean,
|
action=StoreBoolean,
|
||||||
default=EngineArgs.disable_chunked_mm_input,
|
default=EngineArgs.disable_chunked_mm_input,
|
||||||
nargs="?",
|
nargs="?",
|
||||||
const="False",
|
const="True",
|
||||||
help="Disable multimodal input chunking attention for V1. "
|
help="Disable multimodal input chunking attention for V1. "
|
||||||
"If set to true and chunked prefill is enabled, we do not want to"
|
"If set to true and chunked prefill is enabled, we do not want to"
|
||||||
" partially schedule a multimodal item. This ensures that if a "
|
" partially schedule a multimodal item. This ensures that if a "
|
||||||
|
@ -133,6 +133,14 @@ def _compute_encoder_budget_multimodal(
|
|||||||
_, max_tokens_per_mm_item = max(max_tokens_by_modality_dict.items(),
|
_, max_tokens_per_mm_item = max(max_tokens_by_modality_dict.items(),
|
||||||
key=lambda item: item[1])
|
key=lambda item: item[1])
|
||||||
|
|
||||||
|
if (scheduler_config.disable_chunked_mm_input and max_tokens_per_mm_item
|
||||||
|
> scheduler_config.max_num_batched_tokens):
|
||||||
|
raise ValueError(
|
||||||
|
"Chunked MM input disabled but max_tokens_per_mm_item "
|
||||||
|
f"({max_tokens_per_mm_item}) is larger than max_num_batched_tokens"
|
||||||
|
f" ({scheduler_config.max_num_batched_tokens}). Please increase "
|
||||||
|
"max_num_batched_tokens.")
|
||||||
|
|
||||||
encoder_compute_budget = max(scheduler_config.max_num_encoder_input_tokens,
|
encoder_compute_budget = max(scheduler_config.max_num_encoder_input_tokens,
|
||||||
max_tokens_per_mm_item)
|
max_tokens_per_mm_item)
|
||||||
encoder_cache_size = max(scheduler_config.encoder_cache_size,
|
encoder_cache_size = max(scheduler_config.encoder_cache_size,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user