[Bugfix] Fix profiling OOM and decouple encoder multimodal profiling (#14361)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
db84f5eb3b
commit
609ef61fea
@ -873,7 +873,7 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid):
|
||||
exc_ctx = pytest.raises(ValueError, match="this model only supports")
|
||||
|
||||
with exc_ctx:
|
||||
profiler.get_dummy_data(model_config.max_model_len)
|
||||
profiler.get_decoder_dummy_data(model_config.max_model_len)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
|
||||
|
@ -335,8 +335,10 @@ class InputRegistry:
|
||||
tokenizer,
|
||||
disable_cache=True)
|
||||
profiler = MultiModalProfiler(processor)
|
||||
dummy_data = profiler.get_dummy_data(
|
||||
seq_len, is_encoder_data=is_encoder_data)
|
||||
dummy_data_factory = (profiler.get_encoder_dummy_data
|
||||
if is_encoder_data else
|
||||
profiler.get_decoder_dummy_data)
|
||||
dummy_data = dummy_data_factory(seq_len)
|
||||
else:
|
||||
model_cls, _ = get_model_architecture(model_config)
|
||||
if is_encoder_data:
|
||||
|
@ -3,7 +3,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Generic, TypeVar
|
||||
from typing import Generic, TypeVar, cast
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
@ -13,7 +13,8 @@ import vllm.envs as envs
|
||||
from vllm.inputs import DummyData
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .inputs import MultiModalDataDict, MultiModalInputs
|
||||
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
|
||||
MultiModalInputs)
|
||||
from .processing import BaseMultiModalProcessor, BaseProcessingInfo
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -142,14 +143,10 @@ class MultiModalProfiler(Generic[_I]):
|
||||
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
|
||||
)
|
||||
|
||||
def get_dummy_data(
|
||||
def get_and_validate_mm_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
is_encoder_data: bool = False,
|
||||
) -> DummyData:
|
||||
# Avoid circular import
|
||||
from vllm.sequence import SequenceData
|
||||
|
||||
) -> tuple[MultiModalInputs, Mapping[str, int]]:
|
||||
mm_counts = self.get_mm_limits()
|
||||
|
||||
info = self.processing_info
|
||||
@ -165,11 +162,6 @@ class MultiModalProfiler(Generic[_I]):
|
||||
|
||||
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
|
||||
placeholders_by_modality = mm_inputs["mm_placeholders"]
|
||||
# For encoder-decoder models, use encoder prompt token ids instead of
|
||||
# decoder prompt to construct dummy seq_data for encoder profiling.
|
||||
prompt_token_ids = (
|
||||
mm_inputs["prompt_token_ids"] if not is_encoder_data else
|
||||
mm_inputs["encoder_prompt_token_ids"]) # type: ignore
|
||||
|
||||
total_placeholders_by_modality = {
|
||||
modality: sum(item["length"] for item in placeholders)
|
||||
@ -185,28 +177,60 @@ class MultiModalProfiler(Generic[_I]):
|
||||
f"{total_placeholders_by_modality} placeholder tokens, which "
|
||||
f"is not the expected {expected_placeholders_by_modality} "
|
||||
"tokens.")
|
||||
return mm_inputs, total_placeholders_by_modality
|
||||
|
||||
def get_encoder_dummy_data(
|
||||
self,
|
||||
seq_len: int,
|
||||
) -> DummyData:
|
||||
# Avoid circular import
|
||||
from vllm.sequence import SequenceData
|
||||
|
||||
mm_inputs, _ = self.get_and_validate_mm_inputs(seq_len)
|
||||
mm_inputs = cast(MultiModalEncDecInputs, mm_inputs)
|
||||
|
||||
# For encoder-decoder models, use encoder prompt token ids instead of
|
||||
# decoder prompt to construct dummy seq_data for encoder profiling.
|
||||
encoder_prompt_token_ids = mm_inputs["encoder_prompt_token_ids"]
|
||||
|
||||
total_len = len(encoder_prompt_token_ids)
|
||||
num_tokens_to_pad = max(total_len, seq_len) - total_len
|
||||
encoder_prompt_token_ids.extend([0] * num_tokens_to_pad)
|
||||
|
||||
return DummyData(
|
||||
seq_data=SequenceData.from_seqs(encoder_prompt_token_ids),
|
||||
multi_modal_data=None,
|
||||
multi_modal_placeholders=None,
|
||||
)
|
||||
|
||||
def get_decoder_dummy_data(
|
||||
self,
|
||||
seq_len: int,
|
||||
) -> DummyData:
|
||||
# Avoid circular import
|
||||
from vllm.sequence import SequenceData
|
||||
|
||||
(mm_inputs, total_placeholders_by_modality
|
||||
) = self.get_and_validate_mm_inputs(seq_len)
|
||||
|
||||
prompt_token_ids = mm_inputs["prompt_token_ids"]
|
||||
total_len = len(prompt_token_ids)
|
||||
|
||||
# V0 does not support chunked prefill.
|
||||
if (total_len > seq_len and not envs.VLLM_USE_V1) or is_encoder_data:
|
||||
if total_len > seq_len and not is_encoder_data:
|
||||
logger.warning(
|
||||
"The context length (%d) of the model is too short "
|
||||
"to hold the multi-modal embeddings in the worst case "
|
||||
"(%d tokens in total, out of which %s are reserved for "
|
||||
"multi-modal embeddings). This may cause certain "
|
||||
"multi-modal inputs to fail during inference, even when "
|
||||
"the input text is short. To avoid this, you should "
|
||||
"increase `max_model_len`, reduce `max_num_seqs`, "
|
||||
"and/or reduce `mm_counts`.", seq_len, total_len,
|
||||
total_placeholders_by_modality)
|
||||
|
||||
num_tokens_to_pad = max(total_len, seq_len) - total_len
|
||||
prompt_token_ids.extend([0] * num_tokens_to_pad)
|
||||
if total_len > seq_len and not envs.VLLM_USE_V1:
|
||||
logger.warning(
|
||||
"The context length (%d) of the model is too short "
|
||||
"to hold the multi-modal embeddings in the worst case "
|
||||
"(%d tokens in total, out of which %s are reserved for "
|
||||
"multi-modal embeddings). This may cause certain "
|
||||
"multi-modal inputs to fail during inference, even when "
|
||||
"the input text is short. To avoid this, you should "
|
||||
"increase `max_model_len`, reduce `max_num_seqs`, "
|
||||
"and/or reduce `mm_counts`.", seq_len, total_len,
|
||||
total_placeholders_by_modality)
|
||||
|
||||
return DummyData(
|
||||
seq_data=SequenceData.from_seqs(prompt_token_ids),
|
||||
seq_data=SequenceData.from_prompt_token_counts((0, seq_len)),
|
||||
multi_modal_data=None,
|
||||
multi_modal_placeholders=None,
|
||||
)
|
||||
@ -216,5 +240,5 @@ class MultiModalProfiler(Generic[_I]):
|
||||
return DummyData(
|
||||
seq_data=SequenceData.from_seqs(prompt_token_ids),
|
||||
multi_modal_data=mm_inputs["mm_kwargs"],
|
||||
multi_modal_placeholders=placeholders_by_modality,
|
||||
multi_modal_placeholders=mm_inputs["mm_placeholders"],
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user