[Bugfix] Fix profiling OOM and decouple encoder multimodal profiling (#14361)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py 2025-03-09 00:52:34 +08:00 committed by GitHub
parent db84f5eb3b
commit 609ef61fea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 59 additions and 33 deletions

View File

@ -873,7 +873,7 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid):
exc_ctx = pytest.raises(ValueError, match="this model only supports")
with exc_ctx:
profiler.get_dummy_data(model_config.max_model_len)
profiler.get_decoder_dummy_data(model_config.max_model_len)
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])

View File

@ -335,8 +335,10 @@ class InputRegistry:
tokenizer,
disable_cache=True)
profiler = MultiModalProfiler(processor)
dummy_data = profiler.get_dummy_data(
seq_len, is_encoder_data=is_encoder_data)
dummy_data_factory = (profiler.get_encoder_dummy_data
if is_encoder_data else
profiler.get_decoder_dummy_data)
dummy_data = dummy_data_factory(seq_len)
else:
model_cls, _ = get_model_architecture(model_config)
if is_encoder_data:

View File

@ -3,7 +3,7 @@
from abc import ABC, abstractmethod
from collections.abc import Mapping
from dataclasses import dataclass, field
from typing import Generic, TypeVar
from typing import Generic, TypeVar, cast
import numpy as np
import numpy.typing as npt
@ -13,7 +13,8 @@ import vllm.envs as envs
from vllm.inputs import DummyData
from vllm.logger import init_logger
from .inputs import MultiModalDataDict, MultiModalInputs
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
MultiModalInputs)
from .processing import BaseMultiModalProcessor, BaseProcessingInfo
logger = init_logger(__name__)
@ -142,14 +143,10 @@ class MultiModalProfiler(Generic[_I]):
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
)
def get_dummy_data(
def get_and_validate_mm_inputs(
self,
seq_len: int,
is_encoder_data: bool = False,
) -> DummyData:
# Avoid circular import
from vllm.sequence import SequenceData
) -> tuple[MultiModalInputs, Mapping[str, int]]:
mm_counts = self.get_mm_limits()
info = self.processing_info
@ -165,11 +162,6 @@ class MultiModalProfiler(Generic[_I]):
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
placeholders_by_modality = mm_inputs["mm_placeholders"]
# For encoder-decoder models, use encoder prompt token ids instead of
# decoder prompt to construct dummy seq_data for encoder profiling.
prompt_token_ids = (
mm_inputs["prompt_token_ids"] if not is_encoder_data else
mm_inputs["encoder_prompt_token_ids"]) # type: ignore
total_placeholders_by_modality = {
modality: sum(item["length"] for item in placeholders)
@ -185,28 +177,60 @@ class MultiModalProfiler(Generic[_I]):
f"{total_placeholders_by_modality} placeholder tokens, which "
f"is not the expected {expected_placeholders_by_modality} "
"tokens.")
return mm_inputs, total_placeholders_by_modality
def get_encoder_dummy_data(
self,
seq_len: int,
) -> DummyData:
# Avoid circular import
from vllm.sequence import SequenceData
mm_inputs, _ = self.get_and_validate_mm_inputs(seq_len)
mm_inputs = cast(MultiModalEncDecInputs, mm_inputs)
# For encoder-decoder models, use encoder prompt token ids instead of
# decoder prompt to construct dummy seq_data for encoder profiling.
encoder_prompt_token_ids = mm_inputs["encoder_prompt_token_ids"]
total_len = len(encoder_prompt_token_ids)
num_tokens_to_pad = max(total_len, seq_len) - total_len
encoder_prompt_token_ids.extend([0] * num_tokens_to_pad)
return DummyData(
seq_data=SequenceData.from_seqs(encoder_prompt_token_ids),
multi_modal_data=None,
multi_modal_placeholders=None,
)
def get_decoder_dummy_data(
self,
seq_len: int,
) -> DummyData:
# Avoid circular import
from vllm.sequence import SequenceData
(mm_inputs, total_placeholders_by_modality
) = self.get_and_validate_mm_inputs(seq_len)
prompt_token_ids = mm_inputs["prompt_token_ids"]
total_len = len(prompt_token_ids)
# V0 does not support chunked prefill.
if (total_len > seq_len and not envs.VLLM_USE_V1) or is_encoder_data:
if total_len > seq_len and not is_encoder_data:
logger.warning(
"The context length (%d) of the model is too short "
"to hold the multi-modal embeddings in the worst case "
"(%d tokens in total, out of which %s are reserved for "
"multi-modal embeddings). This may cause certain "
"multi-modal inputs to fail during inference, even when "
"the input text is short. To avoid this, you should "
"increase `max_model_len`, reduce `max_num_seqs`, "
"and/or reduce `mm_counts`.", seq_len, total_len,
total_placeholders_by_modality)
num_tokens_to_pad = max(total_len, seq_len) - total_len
prompt_token_ids.extend([0] * num_tokens_to_pad)
if total_len > seq_len and not envs.VLLM_USE_V1:
logger.warning(
"The context length (%d) of the model is too short "
"to hold the multi-modal embeddings in the worst case "
"(%d tokens in total, out of which %s are reserved for "
"multi-modal embeddings). This may cause certain "
"multi-modal inputs to fail during inference, even when "
"the input text is short. To avoid this, you should "
"increase `max_model_len`, reduce `max_num_seqs`, "
"and/or reduce `mm_counts`.", seq_len, total_len,
total_placeholders_by_modality)
return DummyData(
seq_data=SequenceData.from_seqs(prompt_token_ids),
seq_data=SequenceData.from_prompt_token_counts((0, seq_len)),
multi_modal_data=None,
multi_modal_placeholders=None,
)
@ -216,5 +240,5 @@ class MultiModalProfiler(Generic[_I]):
return DummyData(
seq_data=SequenceData.from_seqs(prompt_token_ids),
multi_modal_data=mm_inputs["mm_kwargs"],
multi_modal_placeholders=placeholders_by_modality,
multi_modal_placeholders=mm_inputs["mm_placeholders"],
)