[V1] Override mm_counts
for dummy data creation (#15703)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
7fd8c0f85c
commit
803d5c35f3
@ -385,7 +385,7 @@ VLM_TEST_SETTINGS = {
|
||||
),
|
||||
"minicpmo_26": VLMTestInfo(
|
||||
models=["openbmb/MiniCPM-o-2_6"],
|
||||
test_type=(VLMTestType.IMAGE),
|
||||
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
|
||||
prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501
|
||||
img_idx_to_prompt=lambda idx: "(<image>./</image>)\n",
|
||||
max_model_len=4096,
|
||||
@ -394,21 +394,9 @@ VLM_TEST_SETTINGS = {
|
||||
hf_output_post_proc=model_utils.minicpmv_trunc_hf_output,
|
||||
patch_hf_runner=model_utils.minicpmo_26_patch_hf_runner,
|
||||
),
|
||||
"minicpmo_26_multi_image": VLMTestInfo(
|
||||
models=["openbmb/MiniCPM-o-2_6"],
|
||||
test_type=(VLMTestType.MULTI_IMAGE),
|
||||
prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501
|
||||
img_idx_to_prompt=lambda idx: "(<image>./</image>)\n",
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
get_stop_token_ids=lambda tok: tok.convert_tokens_to_ids(['<|im_end|>', '<|endoftext|>']), # noqa: E501
|
||||
hf_output_post_proc=model_utils.minicpmv_trunc_hf_output,
|
||||
patch_hf_runner=model_utils.minicpmo_26_patch_hf_runner,
|
||||
marks=[large_gpu_mark(min_gb=32)],
|
||||
),
|
||||
"minicpmv_26": VLMTestInfo(
|
||||
models=["openbmb/MiniCPM-V-2_6"],
|
||||
test_type=(VLMTestType.IMAGE),
|
||||
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
|
||||
prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501
|
||||
img_idx_to_prompt=lambda idx: "(<image>./</image>)\n",
|
||||
max_model_len=4096,
|
||||
@ -417,18 +405,6 @@ VLM_TEST_SETTINGS = {
|
||||
hf_output_post_proc=model_utils.minicpmv_trunc_hf_output,
|
||||
patch_hf_runner=model_utils.minicpmv_26_patch_hf_runner,
|
||||
),
|
||||
"minicpmv_26_multi_image": VLMTestInfo(
|
||||
models=["openbmb/MiniCPM-V-2_6"],
|
||||
test_type=(VLMTestType.MULTI_IMAGE),
|
||||
prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501
|
||||
img_idx_to_prompt=lambda idx: "(<image>./</image>)\n",
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
get_stop_token_ids=lambda tok: tok.convert_tokens_to_ids(['<|im_end|>', '<|endoftext|>']), # noqa: E501
|
||||
hf_output_post_proc=model_utils.minicpmv_trunc_hf_output,
|
||||
patch_hf_runner=model_utils.minicpmv_26_patch_hf_runner,
|
||||
marks=[large_gpu_mark(min_gb=32)],
|
||||
),
|
||||
"molmo": VLMTestInfo(
|
||||
models=["allenai/Molmo-7B-D-0924"],
|
||||
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
|
||||
|
@ -71,7 +71,8 @@ class LlavaNextVideoProcessingInfo(BaseProcessingInfo):
|
||||
max_video_tokens = self.get_num_video_tokens(
|
||||
image_width=target_width,
|
||||
image_height=target_height,
|
||||
num_frames=self.get_num_frames_with_most_features(seq_len),
|
||||
num_frames=self.get_num_frames_with_most_features(
|
||||
seq_len, mm_counts),
|
||||
)
|
||||
|
||||
return {"video": max_video_tokens}
|
||||
@ -130,9 +131,12 @@ class LlavaNextVideoProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
return num_frames
|
||||
|
||||
def get_num_frames_with_most_features(self, seq_len: int) -> int:
|
||||
mm_config = self.ctx.get_mm_config()
|
||||
max_videos = mm_config.get_limit_per_prompt("video")
|
||||
def get_num_frames_with_most_features(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> int:
|
||||
max_videos = mm_counts.get("video", 0)
|
||||
|
||||
max_total_frames = self._get_max_video_frames(seq_len)
|
||||
|
||||
@ -155,7 +159,7 @@ class LlavaNextVideoDummyInputsBuilder(
|
||||
target_width, target_height = \
|
||||
self.info.get_image_size_with_most_features()
|
||||
target_num_frames = \
|
||||
self.info.get_num_frames_with_most_features(seq_len)
|
||||
self.info.get_num_frames_with_most_features(seq_len, mm_counts)
|
||||
|
||||
mm_data = {
|
||||
"video":
|
||||
|
@ -108,7 +108,7 @@ class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo):
|
||||
) -> Mapping[str, int]:
|
||||
return {
|
||||
"image": self.get_max_image_tokens(),
|
||||
"video": self.get_max_video_tokens(seq_len),
|
||||
"video": self.get_max_video_tokens(seq_len, mm_counts),
|
||||
}
|
||||
|
||||
# Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L86
|
||||
@ -202,10 +202,13 @@ class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo):
|
||||
|
||||
return num_frames
|
||||
|
||||
def get_num_frames_with_most_features(self, seq_len: int) -> int:
|
||||
mm_config = self.ctx.get_mm_config()
|
||||
max_images = mm_config.get_limit_per_prompt("image")
|
||||
max_videos = mm_config.get_limit_per_prompt("video")
|
||||
def get_num_frames_with_most_features(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> int:
|
||||
max_images = mm_counts.get("image", 0)
|
||||
max_videos = mm_counts.get("video", 0)
|
||||
|
||||
max_image_tokens = self.get_max_image_tokens() * max_images
|
||||
max_total_frames = self._get_max_video_frames(seq_len -
|
||||
@ -215,13 +218,18 @@ class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo):
|
||||
|
||||
return max(max_frames_per_video, 1)
|
||||
|
||||
def get_max_video_tokens(self, seq_len: int) -> int:
|
||||
def get_max_video_tokens(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> int:
|
||||
target_width, target_height = self.get_image_size_with_most_features()
|
||||
|
||||
return self.get_num_video_tokens(
|
||||
image_width=target_width,
|
||||
image_height=target_height,
|
||||
num_frames=self.get_num_frames_with_most_features(seq_len),
|
||||
num_frames=self.get_num_frames_with_most_features(
|
||||
seq_len, mm_counts),
|
||||
)
|
||||
|
||||
|
||||
@ -243,7 +251,8 @@ class LlavaOnevisionDummyInputsBuilder(
|
||||
target_width, target_height = \
|
||||
self.info.get_image_size_with_most_features()
|
||||
target_num_frames = \
|
||||
self.info.get_num_frames_with_most_features(seq_len)
|
||||
self.info.get_num_frames_with_most_features(seq_len,
|
||||
mm_counts)
|
||||
|
||||
mm_data = {
|
||||
"image":
|
||||
|
@ -43,7 +43,8 @@ from vllm.multimodal.parse import (AudioItem, AudioProcessorItems,
|
||||
from vllm.multimodal.processing import PromptReplacement, PromptUpdate
|
||||
from vllm.multimodal.profiling import ProcessorInputs
|
||||
|
||||
from .minicpmv import (MiniCPMV2_6, MiniCPMVDummyInputsBuilder,
|
||||
from .minicpmv import (_MAX_FRAMES_PER_VIDEO, MiniCPMV2_6,
|
||||
MiniCPMVDummyInputsBuilder,
|
||||
MiniCPMVMultiModalDataParser,
|
||||
MiniCPMVMultiModalProcessor, MiniCPMVProcessingInfo,
|
||||
_minicpmv_field_config)
|
||||
@ -203,8 +204,8 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
|
||||
return 30
|
||||
|
||||
def get_max_audio_tokens(self) -> int:
|
||||
return self.get_max_audio_tokens_per_chunk(
|
||||
) * self.get_max_audio_chunks_with_most_features()
|
||||
num_chunks = self.get_max_audio_chunks_with_most_features()
|
||||
return self.get_max_audio_tokens_per_chunk() * num_chunks
|
||||
|
||||
def get_audio_len_by_num_chunks(self, num_chunks: int) -> int:
|
||||
sampling_rate = self.get_default_audio_sampling_rate()
|
||||
@ -212,21 +213,24 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
|
||||
num_tokens_per_chunk = self.get_max_audio_tokens_per_chunk() - 2
|
||||
return int(num_chunks * sampling_rate / num_tokens_per_chunk) + 1
|
||||
|
||||
def get_num_frames_with_most_features(self, seq_len: int) -> int:
|
||||
mm_config = self.ctx.get_mm_config()
|
||||
max_images = mm_config.get_limit_per_prompt("image")
|
||||
max_videos = mm_config.get_limit_per_prompt("video")
|
||||
max_audios = mm_config.get_limit_per_prompt("audio")
|
||||
def get_num_frames_with_most_features(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> int:
|
||||
max_images = mm_counts.get("image", 0)
|
||||
max_videos = mm_counts.get("video", 0)
|
||||
max_audios = mm_counts.get("audio", 0)
|
||||
|
||||
max_image_tokens = self.get_max_image_tokens() * max_images
|
||||
max_audio_tokens = self.get_max_audio_tokens() * max_audios
|
||||
max_total_frames = self.get_max_video_frames(seq_len -
|
||||
max_image_tokens -
|
||||
max_audio_tokens)
|
||||
max_frames_per_video = min(max_total_frames // max(max_videos, 1),
|
||||
_MAX_FRAMES_PER_VIDEO)
|
||||
|
||||
num_frames = max(max_total_frames // max(max_videos, 1), 1)
|
||||
|
||||
return num_frames
|
||||
return max(max_frames_per_video, 1)
|
||||
|
||||
|
||||
class MiniCPMODummyInputsBuilder(
|
||||
|
@ -69,6 +69,9 @@ from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
from .vision import scatter_patch_features, select_patch_features
|
||||
|
||||
# For profile run
|
||||
_MAX_FRAMES_PER_VIDEO = 16
|
||||
|
||||
|
||||
class MiniCPMVImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
@ -369,7 +372,8 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
|
||||
) -> Mapping[str, int]:
|
||||
mm_max_tokens = {"image": self.get_max_image_tokens()}
|
||||
if self.get_model_version() == (2, 6):
|
||||
mm_max_tokens["video"] = self.get_max_video_tokens(seq_len)
|
||||
mm_max_tokens["video"] = self.get_max_video_tokens(
|
||||
seq_len, mm_counts)
|
||||
|
||||
return mm_max_tokens
|
||||
|
||||
@ -432,9 +436,14 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
|
||||
use_image_id=False,
|
||||
)
|
||||
|
||||
def get_max_video_tokens(self, seq_len: int) -> int:
|
||||
return self.get_max_video_frame_tokens(
|
||||
) * self.get_num_frames_with_most_features(seq_len)
|
||||
def get_max_video_tokens(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> int:
|
||||
num_frames = self.get_num_frames_with_most_features(seq_len, mm_counts)
|
||||
num_video_tokens_total = self.get_max_video_frame_tokens() * num_frames
|
||||
return num_video_tokens_total
|
||||
|
||||
def get_video_max_slice_num(self) -> int:
|
||||
return 1
|
||||
@ -449,18 +458,21 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
|
||||
num_frames = max_tokens // num_frame_tokens
|
||||
return num_frames
|
||||
|
||||
def get_num_frames_with_most_features(self, seq_len: int) -> int:
|
||||
mm_config = self.ctx.get_mm_config()
|
||||
max_images = mm_config.get_limit_per_prompt("image")
|
||||
max_videos = mm_config.get_limit_per_prompt("video")
|
||||
def get_num_frames_with_most_features(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> int:
|
||||
max_images = mm_counts.get("image", 0)
|
||||
max_videos = mm_counts.get("video", 0)
|
||||
|
||||
max_image_tokens = self.get_max_image_tokens() * max_images
|
||||
max_total_frames = self.get_max_video_frames(seq_len -
|
||||
max_image_tokens)
|
||||
max_frames_per_video = min(max_total_frames // max(max_videos, 1),
|
||||
_MAX_FRAMES_PER_VIDEO)
|
||||
|
||||
num_frames = max(max_total_frames // max(max_videos, 1), 1)
|
||||
|
||||
return num_frames
|
||||
return max(max_frames_per_video, 1)
|
||||
|
||||
|
||||
_I = TypeVar("_I",
|
||||
@ -483,7 +495,7 @@ class MiniCPMVDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
|
||||
video_width, video_height = \
|
||||
self.info.get_video_frame_size_with_most_features()
|
||||
num_video_frames = \
|
||||
self.info.get_num_frames_with_most_features(seq_len)
|
||||
self.info.get_num_frames_with_most_features(seq_len, mm_counts)
|
||||
|
||||
mm_data = {
|
||||
"image":
|
||||
|
@ -806,7 +806,7 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
|
||||
max_pixels: Optional[int] = None,
|
||||
size: Optional[dict[str, int]] = None,
|
||||
**kwargs: object,
|
||||
):
|
||||
) -> Qwen2VLImageProcessor:
|
||||
return cached_image_processor_from_config(
|
||||
self.ctx.model_config,
|
||||
**self._get_image_processor_kwargs(min_pixels=min_pixels,
|
||||
@ -825,7 +825,7 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
|
||||
) -> Mapping[str, int]:
|
||||
return {
|
||||
"image": self.get_max_image_tokens(),
|
||||
"video": self.get_max_video_tokens(seq_len),
|
||||
"video": self.get_max_video_tokens(seq_len, mm_counts),
|
||||
}
|
||||
|
||||
def _get_vision_info(
|
||||
@ -941,10 +941,13 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
return num_frames
|
||||
|
||||
def get_num_frames_with_most_features(self, seq_len: int) -> int:
|
||||
mm_config = self.ctx.get_mm_config()
|
||||
max_images = mm_config.get_limit_per_prompt("image")
|
||||
max_videos = mm_config.get_limit_per_prompt("video")
|
||||
def get_num_frames_with_most_features(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> int:
|
||||
max_images = mm_counts.get("image", 0)
|
||||
max_videos = mm_counts.get("video", 0)
|
||||
|
||||
max_image_tokens = self.get_max_image_tokens() * max_images
|
||||
max_total_frames = self._get_max_video_frames(seq_len -
|
||||
@ -954,13 +957,18 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
return max(max_frames_per_video, 1)
|
||||
|
||||
def get_max_video_tokens(self, seq_len: int) -> int:
|
||||
def get_max_video_tokens(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> int:
|
||||
target_width, target_height = self.get_image_size_with_most_features()
|
||||
|
||||
return self.get_num_video_tokens(
|
||||
image_width=target_width,
|
||||
image_height=target_height,
|
||||
num_frames=self.get_num_frames_with_most_features(seq_len),
|
||||
num_frames=self.get_num_frames_with_most_features(
|
||||
seq_len, mm_counts),
|
||||
image_processor=None,
|
||||
)
|
||||
|
||||
@ -982,7 +990,7 @@ class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]):
|
||||
target_width, target_height = \
|
||||
self.info.get_image_size_with_most_features()
|
||||
target_num_frames = \
|
||||
self.info.get_num_frames_with_most_features(seq_len)
|
||||
self.info.get_num_frames_with_most_features(seq_len, mm_counts)
|
||||
|
||||
mm_data = {
|
||||
"image":
|
||||
|
@ -3,7 +3,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Generic, NamedTuple, TypeVar, cast
|
||||
from typing import Generic, NamedTuple, Optional, TypeVar, cast
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
@ -160,17 +160,19 @@ class MultiModalProfiler(Generic[_I]):
|
||||
def get_and_validate_mm_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Optional[Mapping[str, int]] = None,
|
||||
) -> tuple[MultiModalInputs, Mapping[str, int]]:
|
||||
mm_counts = self.get_mm_limits()
|
||||
if mm_counts is None:
|
||||
mm_counts = self.get_mm_limits()
|
||||
|
||||
info = self.processing_info
|
||||
mm_max_tokens_per_item = info.get_mm_max_tokens_per_item(
|
||||
seq_len, mm_counts)
|
||||
|
||||
if mm_counts.keys() != mm_max_tokens_per_item.keys():
|
||||
if mm_counts.keys() - mm_max_tokens_per_item.keys():
|
||||
raise AssertionError(
|
||||
"The keys returned by `get_supported_mm_limits` "
|
||||
f"({set(mm_counts.keys())}) should be the same as those "
|
||||
f"({set(mm_counts.keys())}) should be a subset of those "
|
||||
"returned by `get_mm_max_tokens_per_item` "
|
||||
f"({set(mm_max_tokens_per_item.keys())})")
|
||||
|
||||
@ -193,8 +195,12 @@ class MultiModalProfiler(Generic[_I]):
|
||||
"tokens.")
|
||||
return mm_inputs, total_placeholders_by_modality
|
||||
|
||||
def get_encoder_dummy_data(self, seq_len: int) -> DummyEncoderData:
|
||||
mm_inputs, _ = self.get_and_validate_mm_inputs(seq_len)
|
||||
def get_encoder_dummy_data(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Optional[Mapping[str, int]] = None,
|
||||
) -> DummyEncoderData:
|
||||
mm_inputs, _ = self.get_and_validate_mm_inputs(seq_len, mm_counts)
|
||||
mm_inputs = cast(MultiModalEncDecInputs, mm_inputs)
|
||||
|
||||
# For encoder-decoder models, use encoder prompt token ids instead of
|
||||
@ -207,9 +213,15 @@ class MultiModalProfiler(Generic[_I]):
|
||||
|
||||
return DummyEncoderData(encoder_prompt_token_ids)
|
||||
|
||||
def get_decoder_dummy_data(self, seq_len: int) -> DummyDecoderData:
|
||||
(mm_inputs, total_placeholders_by_modality
|
||||
) = self.get_and_validate_mm_inputs(seq_len)
|
||||
def get_decoder_dummy_data(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Optional[Mapping[str, int]] = None,
|
||||
) -> DummyDecoderData:
|
||||
(
|
||||
mm_inputs,
|
||||
total_placeholders_by_modality,
|
||||
) = self.get_and_validate_mm_inputs(seq_len, mm_counts)
|
||||
|
||||
prompt_token_ids = mm_inputs["prompt_token_ids"]
|
||||
total_len = len(prompt_token_ids)
|
||||
|
@ -458,6 +458,7 @@ class MultiModalRegistry:
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
seq_len: int,
|
||||
mm_counts: Optional[Mapping[str, int]] = None,
|
||||
) -> DummyDecoderData:
|
||||
"""
|
||||
Create dummy data for profiling the memory usage of a model.
|
||||
@ -466,7 +467,7 @@ class MultiModalRegistry:
|
||||
"""
|
||||
processor = self.create_processor(model_config, disable_cache=True)
|
||||
profiler = MultiModalProfiler(processor)
|
||||
dummy_data = profiler.get_decoder_dummy_data(seq_len)
|
||||
dummy_data = profiler.get_decoder_dummy_data(seq_len, mm_counts)
|
||||
|
||||
# Having more tokens is over-conservative but otherwise fine
|
||||
token_ids = dummy_data.prompt_token_ids
|
||||
@ -481,6 +482,7 @@ class MultiModalRegistry:
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
seq_len: int,
|
||||
mm_counts: Optional[Mapping[str, int]] = None,
|
||||
) -> DummyEncoderData:
|
||||
"""
|
||||
Create dummy data for profiling the memory usage of a model.
|
||||
@ -489,7 +491,7 @@ class MultiModalRegistry:
|
||||
"""
|
||||
processor = self.create_processor(model_config, disable_cache=True)
|
||||
profiler = MultiModalProfiler(processor)
|
||||
dummy_data = profiler.get_encoder_dummy_data(seq_len)
|
||||
dummy_data = profiler.get_encoder_dummy_data(seq_len, mm_counts)
|
||||
|
||||
# Having more tokens is over-conservative but otherwise fine
|
||||
token_ids = dummy_data.prompt_token_ids
|
||||
|
@ -1470,19 +1470,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
encoder_budget, max_num_mm_items, dummy_data_modality)
|
||||
|
||||
# Create dummy batch of multimodal inputs.
|
||||
dummy_request_data = self.mm_registry.get_decoder_dummy_data(
|
||||
dummy_mm_kwargs = self.mm_registry.get_decoder_dummy_data(
|
||||
model_config=self.model_config,
|
||||
seq_len=self.max_num_tokens,
|
||||
)
|
||||
dummy_mm_data = dummy_request_data.multi_modal_data
|
||||
|
||||
# Dummy data definition may contain multiple multimodal items
|
||||
# (e.g, multiple images) for a single request, therefore here we
|
||||
# always replicate first item by max_num_mm_items times since in V1
|
||||
# they are scheduled to be processed separately.
|
||||
dummy_mm_item = dummy_mm_data.get_item(
|
||||
modality=dummy_data_modality, item_index=0)
|
||||
dummy_mm_kwargs = MultiModalKwargs.from_items([dummy_mm_item])
|
||||
mm_counts={
|
||||
dummy_data_modality: 1
|
||||
},
|
||||
).multi_modal_data
|
||||
|
||||
batched_dummy_mm_inputs = MultiModalKwargs.batch(
|
||||
[dummy_mm_kwargs] * max_num_mm_items)
|
||||
|
Loading…
x
Reference in New Issue
Block a user