[Bugfix] Fix profiling OOM and decouple encoder multimodal profiling (#14361)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
db84f5eb3b
commit
609ef61fea
@ -873,7 +873,7 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid):
|
|||||||
exc_ctx = pytest.raises(ValueError, match="this model only supports")
|
exc_ctx = pytest.raises(ValueError, match="this model only supports")
|
||||||
|
|
||||||
with exc_ctx:
|
with exc_ctx:
|
||||||
profiler.get_dummy_data(model_config.max_model_len)
|
profiler.get_decoder_dummy_data(model_config.max_model_len)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
|
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
|
||||||
|
@ -335,8 +335,10 @@ class InputRegistry:
|
|||||||
tokenizer,
|
tokenizer,
|
||||||
disable_cache=True)
|
disable_cache=True)
|
||||||
profiler = MultiModalProfiler(processor)
|
profiler = MultiModalProfiler(processor)
|
||||||
dummy_data = profiler.get_dummy_data(
|
dummy_data_factory = (profiler.get_encoder_dummy_data
|
||||||
seq_len, is_encoder_data=is_encoder_data)
|
if is_encoder_data else
|
||||||
|
profiler.get_decoder_dummy_data)
|
||||||
|
dummy_data = dummy_data_factory(seq_len)
|
||||||
else:
|
else:
|
||||||
model_cls, _ = get_model_architecture(model_config)
|
model_cls, _ = get_model_architecture(model_config)
|
||||||
if is_encoder_data:
|
if is_encoder_data:
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Generic, TypeVar
|
from typing import Generic, TypeVar, cast
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
@ -13,7 +13,8 @@ import vllm.envs as envs
|
|||||||
from vllm.inputs import DummyData
|
from vllm.inputs import DummyData
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
from .inputs import MultiModalDataDict, MultiModalInputs
|
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
|
||||||
|
MultiModalInputs)
|
||||||
from .processing import BaseMultiModalProcessor, BaseProcessingInfo
|
from .processing import BaseMultiModalProcessor, BaseProcessingInfo
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -142,14 +143,10 @@ class MultiModalProfiler(Generic[_I]):
|
|||||||
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
|
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_dummy_data(
|
def get_and_validate_mm_inputs(
|
||||||
self,
|
self,
|
||||||
seq_len: int,
|
seq_len: int,
|
||||||
is_encoder_data: bool = False,
|
) -> tuple[MultiModalInputs, Mapping[str, int]]:
|
||||||
) -> DummyData:
|
|
||||||
# Avoid circular import
|
|
||||||
from vllm.sequence import SequenceData
|
|
||||||
|
|
||||||
mm_counts = self.get_mm_limits()
|
mm_counts = self.get_mm_limits()
|
||||||
|
|
||||||
info = self.processing_info
|
info = self.processing_info
|
||||||
@ -165,11 +162,6 @@ class MultiModalProfiler(Generic[_I]):
|
|||||||
|
|
||||||
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
|
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
|
||||||
placeholders_by_modality = mm_inputs["mm_placeholders"]
|
placeholders_by_modality = mm_inputs["mm_placeholders"]
|
||||||
# For encoder-decoder models, use encoder prompt token ids instead of
|
|
||||||
# decoder prompt to construct dummy seq_data for encoder profiling.
|
|
||||||
prompt_token_ids = (
|
|
||||||
mm_inputs["prompt_token_ids"] if not is_encoder_data else
|
|
||||||
mm_inputs["encoder_prompt_token_ids"]) # type: ignore
|
|
||||||
|
|
||||||
total_placeholders_by_modality = {
|
total_placeholders_by_modality = {
|
||||||
modality: sum(item["length"] for item in placeholders)
|
modality: sum(item["length"] for item in placeholders)
|
||||||
@ -185,28 +177,60 @@ class MultiModalProfiler(Generic[_I]):
|
|||||||
f"{total_placeholders_by_modality} placeholder tokens, which "
|
f"{total_placeholders_by_modality} placeholder tokens, which "
|
||||||
f"is not the expected {expected_placeholders_by_modality} "
|
f"is not the expected {expected_placeholders_by_modality} "
|
||||||
"tokens.")
|
"tokens.")
|
||||||
|
return mm_inputs, total_placeholders_by_modality
|
||||||
|
|
||||||
|
def get_encoder_dummy_data(
|
||||||
|
self,
|
||||||
|
seq_len: int,
|
||||||
|
) -> DummyData:
|
||||||
|
# Avoid circular import
|
||||||
|
from vllm.sequence import SequenceData
|
||||||
|
|
||||||
|
mm_inputs, _ = self.get_and_validate_mm_inputs(seq_len)
|
||||||
|
mm_inputs = cast(MultiModalEncDecInputs, mm_inputs)
|
||||||
|
|
||||||
|
# For encoder-decoder models, use encoder prompt token ids instead of
|
||||||
|
# decoder prompt to construct dummy seq_data for encoder profiling.
|
||||||
|
encoder_prompt_token_ids = mm_inputs["encoder_prompt_token_ids"]
|
||||||
|
|
||||||
|
total_len = len(encoder_prompt_token_ids)
|
||||||
|
num_tokens_to_pad = max(total_len, seq_len) - total_len
|
||||||
|
encoder_prompt_token_ids.extend([0] * num_tokens_to_pad)
|
||||||
|
|
||||||
|
return DummyData(
|
||||||
|
seq_data=SequenceData.from_seqs(encoder_prompt_token_ids),
|
||||||
|
multi_modal_data=None,
|
||||||
|
multi_modal_placeholders=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_decoder_dummy_data(
|
||||||
|
self,
|
||||||
|
seq_len: int,
|
||||||
|
) -> DummyData:
|
||||||
|
# Avoid circular import
|
||||||
|
from vllm.sequence import SequenceData
|
||||||
|
|
||||||
|
(mm_inputs, total_placeholders_by_modality
|
||||||
|
) = self.get_and_validate_mm_inputs(seq_len)
|
||||||
|
|
||||||
|
prompt_token_ids = mm_inputs["prompt_token_ids"]
|
||||||
total_len = len(prompt_token_ids)
|
total_len = len(prompt_token_ids)
|
||||||
|
|
||||||
# V0 does not support chunked prefill.
|
# V0 does not support chunked prefill.
|
||||||
if (total_len > seq_len and not envs.VLLM_USE_V1) or is_encoder_data:
|
if total_len > seq_len and not envs.VLLM_USE_V1:
|
||||||
if total_len > seq_len and not is_encoder_data:
|
logger.warning(
|
||||||
logger.warning(
|
"The context length (%d) of the model is too short "
|
||||||
"The context length (%d) of the model is too short "
|
"to hold the multi-modal embeddings in the worst case "
|
||||||
"to hold the multi-modal embeddings in the worst case "
|
"(%d tokens in total, out of which %s are reserved for "
|
||||||
"(%d tokens in total, out of which %s are reserved for "
|
"multi-modal embeddings). This may cause certain "
|
||||||
"multi-modal embeddings). This may cause certain "
|
"multi-modal inputs to fail during inference, even when "
|
||||||
"multi-modal inputs to fail during inference, even when "
|
"the input text is short. To avoid this, you should "
|
||||||
"the input text is short. To avoid this, you should "
|
"increase `max_model_len`, reduce `max_num_seqs`, "
|
||||||
"increase `max_model_len`, reduce `max_num_seqs`, "
|
"and/or reduce `mm_counts`.", seq_len, total_len,
|
||||||
"and/or reduce `mm_counts`.", seq_len, total_len,
|
total_placeholders_by_modality)
|
||||||
total_placeholders_by_modality)
|
|
||||||
|
|
||||||
num_tokens_to_pad = max(total_len, seq_len) - total_len
|
|
||||||
prompt_token_ids.extend([0] * num_tokens_to_pad)
|
|
||||||
|
|
||||||
return DummyData(
|
return DummyData(
|
||||||
seq_data=SequenceData.from_seqs(prompt_token_ids),
|
seq_data=SequenceData.from_prompt_token_counts((0, seq_len)),
|
||||||
multi_modal_data=None,
|
multi_modal_data=None,
|
||||||
multi_modal_placeholders=None,
|
multi_modal_placeholders=None,
|
||||||
)
|
)
|
||||||
@ -216,5 +240,5 @@ class MultiModalProfiler(Generic[_I]):
|
|||||||
return DummyData(
|
return DummyData(
|
||||||
seq_data=SequenceData.from_seqs(prompt_token_ids),
|
seq_data=SequenceData.from_seqs(prompt_token_ids),
|
||||||
multi_modal_data=mm_inputs["mm_kwargs"],
|
multi_modal_data=mm_inputs["mm_kwargs"],
|
||||||
multi_modal_placeholders=placeholders_by_modality,
|
multi_modal_placeholders=mm_inputs["mm_placeholders"],
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user