[VLM] Separate out profiling-related logic (#11746)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
2a622d704a
commit
996357e480
@ -586,9 +586,10 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid):
|
||||
)
|
||||
|
||||
processor = processor_factory(ctx, cache=None)
|
||||
profiler = processor.profiling_info
|
||||
|
||||
mock_supported_mm_limits = MagicMock(return_value={"image": num_supported})
|
||||
processor.get_supported_mm_limits = mock_supported_mm_limits
|
||||
profiler.get_supported_mm_limits = mock_supported_mm_limits
|
||||
|
||||
if is_valid:
|
||||
exc_ctx = nullcontext()
|
||||
@ -596,7 +597,7 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid):
|
||||
exc_ctx = pytest.raises(ValueError, match="this model only supports")
|
||||
|
||||
with exc_ctx:
|
||||
processor._get_and_validate_dummy_mm_counts()
|
||||
profiler.get_mm_limits()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
|
||||
@ -723,7 +724,7 @@ def _test_processing_cache_correctness(
|
||||
}
|
||||
|
||||
mm_counts = {k: len(vs) for k, vs in mm_data.items()}
|
||||
prompt = baseline_processor._get_dummy_processor_inputs(
|
||||
prompt = baseline_processor.profiling_info.get_dummy_processor_inputs(
|
||||
model_config.max_model_len,
|
||||
mm_counts,
|
||||
).prompt_text
|
||||
|
@ -24,8 +24,9 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
|
||||
NestedTensors)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
MultiModalDataItems, ProcessorInputs,
|
||||
MultiModalDataItems, ProcessingMixin,
|
||||
PromptReplacement)
|
||||
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs.aria import (AriaMoELMConfig,
|
||||
AriaVisionConfig)
|
||||
@ -444,54 +445,33 @@ def build_mm_projector(config: PretrainedConfig):
|
||||
)
|
||||
|
||||
|
||||
class AriaMultiModalProcessor(BaseMultiModalProcessor):
|
||||
class AriaProcessingMixin(ProcessingMixin):
|
||||
|
||||
def _get_hf_config(self):
|
||||
return self.ctx.get_hf_config()
|
||||
|
||||
def _get_vision_config(self) -> AriaVisionConfig:
|
||||
return self._get_hf_config().vision_config
|
||||
|
||||
def _get_num_image_tokens(self) -> int:
|
||||
hf_config = self._get_hf_config()
|
||||
return max(hf_config.projector_patch_to_query_dict.values())
|
||||
|
||||
|
||||
class AriaProfilingInfo(AriaProcessingMixin, BaseProfilingInfo):
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None}
|
||||
|
||||
def _get_num_image_tokens(self) -> int:
|
||||
hf_config = self.ctx.get_hf_config()
|
||||
return max(hf_config.projector_patch_to_query_dict.values())
|
||||
|
||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||
return {"image": self._get_num_image_tokens()}
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
pixel_mask=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
hf_config = self.ctx.get_hf_config()
|
||||
image_token_id = hf_config.image_token_index
|
||||
|
||||
num_image_tokens = self._get_num_image_tokens()
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=[image_token_id],
|
||||
replacement=[image_token_id] * num_image_tokens,
|
||||
)
|
||||
]
|
||||
|
||||
def _get_dummy_processor_inputs(
|
||||
def get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
hf_config = self.ctx.get_hf_config()
|
||||
vision_config: AriaVisionConfig = hf_config.vision_config
|
||||
vision_config = self._get_vision_config()
|
||||
|
||||
max_image_size = vision_config.image_size
|
||||
num_images = mm_counts.get("image", 0)
|
||||
@ -512,6 +492,41 @@ class AriaMultiModalProcessor(BaseMultiModalProcessor):
|
||||
)
|
||||
|
||||
|
||||
class AriaMultiModalProcessor(AriaProcessingMixin, BaseMultiModalProcessor):
|
||||
|
||||
def _get_profiling_info(self) -> BaseProfilingInfo:
|
||||
return AriaProfilingInfo(self.ctx)
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
pixel_mask=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
hf_config = self._get_hf_config()
|
||||
image_token_id = hf_config.image_token_index
|
||||
|
||||
num_image_tokens = self._get_num_image_tokens()
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=[image_token_id],
|
||||
replacement=[image_token_id] * num_image_tokens,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(AriaMultiModalProcessor)
|
||||
class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
"""
|
||||
|
@ -4,8 +4,8 @@ from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import (BatchFeature, Blip2Config, Blip2Processor,
|
||||
Blip2QFormerConfig, apply_chunking_to_forward)
|
||||
from transformers import (BatchFeature, Blip2Config, Blip2QFormerConfig,
|
||||
apply_chunking_to_forward)
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
@ -18,8 +18,9 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalInputsV2, MultiModalKwargs,
|
||||
NestedTensors, PlaceholderRange)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
MultiModalDataItems, ProcessorInputs,
|
||||
MultiModalDataItems, ProcessingMixin,
|
||||
PromptReplacement)
|
||||
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .blip import BlipVisionModel
|
||||
@ -396,20 +397,52 @@ class Blip2QFormerModel(nn.Module):
|
||||
return sequence_output
|
||||
|
||||
|
||||
class Blip2MultiModalProcessor(BaseMultiModalProcessor):
|
||||
class Blip2ProcessingMixin(ProcessingMixin):
|
||||
|
||||
def _get_hf_config(self):
|
||||
return self.ctx.get_hf_config(Blip2Config)
|
||||
|
||||
def _get_num_image_tokens(self) -> int:
|
||||
hf_config = self._get_hf_config()
|
||||
return hf_config.num_query_tokens
|
||||
|
||||
|
||||
class Blip2ProfilingInfo(Blip2ProcessingMixin, BaseProfilingInfo):
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": 1}
|
||||
|
||||
def _get_num_image_tokens(self) -> int:
|
||||
hf_config = self.ctx.get_hf_config(Blip2Config)
|
||||
return hf_config.num_query_tokens
|
||||
|
||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||
return {"image": self._get_num_image_tokens()}
|
||||
|
||||
def _get_hf_processor(self) -> Blip2Processor:
|
||||
return self.ctx.get_hf_processor(Blip2Processor)
|
||||
def get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
hf_config = self._get_hf_config()
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
max_image_size = vision_config.image_size
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
mm_data = {
|
||||
"image":
|
||||
self._get_dummy_images(width=max_image_size,
|
||||
height=max_image_size,
|
||||
num_images=num_images)
|
||||
}
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt_text="",
|
||||
mm_data=mm_data,
|
||||
)
|
||||
|
||||
|
||||
class Blip2MultiModalProcessor(Blip2ProcessingMixin, BaseMultiModalProcessor):
|
||||
|
||||
def _get_profiling_info(self) -> BaseProfilingInfo:
|
||||
return Blip2ProfilingInfo(self.ctx)
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
@ -427,13 +460,13 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor):
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
max_image_tokens = self._get_num_image_tokens()
|
||||
num_image_tokens = self._get_num_image_tokens()
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target="</s>",
|
||||
replacement="<image>" * max_image_tokens + "</s>",
|
||||
replacement="<image>" * num_image_tokens + "</s>",
|
||||
)
|
||||
]
|
||||
|
||||
@ -457,29 +490,6 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
return result
|
||||
|
||||
def _get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
hf_config = self.ctx.get_hf_config(Blip2Config)
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
max_image_size = vision_config.image_size
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
mm_data = {
|
||||
"image":
|
||||
self._get_dummy_images(width=max_image_size,
|
||||
height=max_image_size,
|
||||
num_images=num_images)
|
||||
}
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt_text="",
|
||||
mm_data=mm_data,
|
||||
)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(Blip2MultiModalProcessor)
|
||||
class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
@ -31,8 +31,9 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalInputsV2, MultiModalKwargs,
|
||||
NestedTensors, PlaceholderRange)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
MultiModalDataItems, ProcessorInputs,
|
||||
MultiModalDataItems, ProcessingMixin,
|
||||
PromptReplacement)
|
||||
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import print_warning_once
|
||||
|
||||
@ -48,54 +49,33 @@ class ChameleonImagePixelInputs(TypedDict):
|
||||
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
|
||||
|
||||
|
||||
class ChameleonMultiModalProcessor(BaseMultiModalProcessor):
|
||||
class ChameleonProcessingMixin(ProcessingMixin):
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": 1}
|
||||
def _get_hf_config(self):
|
||||
return self.ctx.get_hf_config(ChameleonConfig)
|
||||
|
||||
def _get_hf_processor(self):
|
||||
return self.ctx.get_hf_processor(ChameleonProcessor)
|
||||
|
||||
def _get_num_image_tokens(self) -> int:
|
||||
processor = self._get_hf_processor()
|
||||
return processor.image_seq_length
|
||||
|
||||
|
||||
class ChameleonProfilingInfo(ChameleonProcessingMixin, BaseProfilingInfo):
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": 1}
|
||||
|
||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||
return {"image": self._get_num_image_tokens()}
|
||||
|
||||
def _get_hf_processor(self) -> ChameleonProcessor:
|
||||
return self.ctx.get_hf_processor(ChameleonProcessor)
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(pixel_values=MultiModalFieldConfig.batched("image"))
|
||||
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
processor = self._get_hf_processor()
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target="<image>",
|
||||
replacement="".join([
|
||||
processor.image_start_token,
|
||||
processor.image_token * self._get_num_image_tokens(),
|
||||
processor.image_end_token,
|
||||
]),
|
||||
)
|
||||
]
|
||||
|
||||
def _get_dummy_processor_inputs(
|
||||
def get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
config = self.ctx.get_hf_config(ChameleonConfig)
|
||||
config = self._get_hf_config()
|
||||
|
||||
width = height = config.vq_config.resolution
|
||||
num_images = mm_counts.get("image", 0)
|
||||
@ -112,6 +92,40 @@ class ChameleonMultiModalProcessor(BaseMultiModalProcessor):
|
||||
mm_data=mm_data,
|
||||
)
|
||||
|
||||
|
||||
class ChameleonMultiModalProcessor(ChameleonProcessingMixin,
|
||||
BaseMultiModalProcessor):
|
||||
|
||||
def _get_profiling_info(self) -> BaseProfilingInfo:
|
||||
return ChameleonProfilingInfo(self.ctx)
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(pixel_values=MultiModalFieldConfig.batched("image"))
|
||||
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
processor = self._get_hf_processor(**hf_processor_mm_kwargs)
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target="<image>",
|
||||
replacement="".join([
|
||||
processor.image_start_token,
|
||||
processor.image_token * self._get_num_image_tokens(),
|
||||
processor.image_end_token,
|
||||
]),
|
||||
)
|
||||
]
|
||||
|
||||
def apply(
|
||||
self,
|
||||
prompt_text: str,
|
||||
|
@ -35,8 +35,9 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
NestedTensors, PlaceholderRange)
|
||||
from vllm.multimodal.parse import ImageProcessorItems, ImageSize
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
MultiModalDataItems, ProcessorInputs,
|
||||
MultiModalDataItems, ProcessingMixin,
|
||||
PromptReplacement)
|
||||
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
@ -63,18 +64,16 @@ class FuyuImagePatchInputs(TypedDict):
|
||||
"""
|
||||
|
||||
|
||||
class FuyuMultiModalProcessor(BaseMultiModalProcessor):
|
||||
class FuyuProcessingMixin(ProcessingMixin):
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": 1}
|
||||
def _get_hf_config(self):
|
||||
return self.ctx.get_hf_config(FuyuConfig)
|
||||
|
||||
def _get_image_target_size(self) -> ImageSize:
|
||||
processor = self._get_hf_processor()
|
||||
image_processor: FuyuImageProcessor = processor.image_processor
|
||||
def _get_hf_processor(self):
|
||||
return self.ctx.get_hf_processor(FuyuProcessor)
|
||||
|
||||
target_size = image_processor.size
|
||||
return ImageSize(width=target_size["width"],
|
||||
height=target_size["height"])
|
||||
def _get_image_processor(self) -> FuyuImageProcessor:
|
||||
return self._get_hf_processor().image_processor
|
||||
|
||||
def _get_image_feature_grid_size(
|
||||
self,
|
||||
@ -82,7 +81,9 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor):
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
) -> tuple[int, int]:
|
||||
target_width, target_height = self._get_image_target_size()
|
||||
image_processor = self._get_image_processor()
|
||||
target_width = image_processor.size["width"]
|
||||
target_height = image_processor.size["height"]
|
||||
|
||||
if not (image_width <= target_width and image_height <= target_height):
|
||||
height_scale_factor = target_height / image_height
|
||||
@ -96,8 +97,14 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor):
|
||||
nrows = math.ceil(image_height / 30)
|
||||
return ncols, nrows
|
||||
|
||||
|
||||
class FuyuProfilingInfo(FuyuProcessingMixin, BaseProfilingInfo):
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": 1}
|
||||
|
||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||
target_width, target_height = self._get_image_target_size()
|
||||
target_width, target_height = self._get_image_size_with_most_features()
|
||||
|
||||
max_ncols, max_nrows = self._get_image_feature_grid_size(
|
||||
image_width=target_width,
|
||||
@ -107,8 +114,36 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
return {"image": max_image_tokens}
|
||||
|
||||
def _get_hf_processor(self) -> FuyuProcessor:
|
||||
return self.ctx.get_hf_processor(FuyuProcessor)
|
||||
def _get_image_size_with_most_features(self) -> ImageSize:
|
||||
image_processor = self._get_image_processor()
|
||||
return ImageSize(width=image_processor.size["width"],
|
||||
height=image_processor.size["height"])
|
||||
|
||||
def get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
target_width, target_height = self._get_image_size_with_most_features()
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
mm_data = {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images)
|
||||
}
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt_text="",
|
||||
mm_data=mm_data,
|
||||
)
|
||||
|
||||
|
||||
class FuyuMultiModalProcessor(FuyuProcessingMixin, BaseMultiModalProcessor):
|
||||
|
||||
def _get_profiling_info(self) -> BaseProfilingInfo:
|
||||
return FuyuProfilingInfo(self.ctx)
|
||||
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
@ -161,7 +196,7 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor):
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
hf_config = self.ctx.get_hf_config(FuyuConfig)
|
||||
hf_config = self._get_hf_config()
|
||||
bos_token_id = hf_config.bos_token_id
|
||||
|
||||
tokenizer = self._get_tokenizer()
|
||||
@ -208,26 +243,6 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
return result
|
||||
|
||||
def _get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
target_width, target_height = self._get_image_target_size()
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
mm_data = {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images)
|
||||
}
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt_text="",
|
||||
mm_data=mm_data,
|
||||
)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(FuyuMultiModalProcessor)
|
||||
class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
@ -1,4 +1,4 @@
|
||||
from abc import abstractmethod
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import cached_property
|
||||
from typing import (Final, Iterable, List, Literal, Mapping, Optional,
|
||||
Protocol, Set, Tuple, TypedDict, Union)
|
||||
@ -13,6 +13,7 @@ from transformers.models.pixtral import PixtralProcessor
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs import InputProcessingContext
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
@ -25,9 +26,10 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
NestedTensors)
|
||||
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
||||
ImageSize)
|
||||
from vllm.multimodal.processing import (InputProcessingContext,
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
MultiModalDataItems, ProcessingCache,
|
||||
ProcessorInputs, PromptReplacement)
|
||||
ProcessingMixin, PromptReplacement)
|
||||
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .clip import CLIPVisionModel
|
||||
@ -37,7 +39,7 @@ from .pixtral import (PixtralHFVisionModel,
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
from .vision import BaseVisionLanguageMultiModalProcessor
|
||||
from .vision import get_vision_encoder_info
|
||||
|
||||
|
||||
class LlavaImagePixelInputs(TypedDict):
|
||||
@ -94,30 +96,42 @@ class LlavaMultiModalProjector(nn.Module):
|
||||
|
||||
class LlavaLikeConfig(Protocol):
|
||||
vision_config: Final[PretrainedConfig]
|
||||
image_token_index: Final[int]
|
||||
vision_feature_select_strategy: Final[str]
|
||||
vision_feature_layer: Final[Union[int, List[int]]]
|
||||
vision_feature_layer: Final[Union[int, list[int]]]
|
||||
|
||||
|
||||
class BaseLlavaMultiModalProcessor(BaseVisionLanguageMultiModalProcessor):
|
||||
class LlavaLikeProcessor(Protocol):
|
||||
image_token: Final[str]
|
||||
|
||||
|
||||
class BaseLlavaProcessingMixin(ProcessingMixin, ABC):
|
||||
|
||||
def _get_hf_config(self) -> LlavaLikeConfig:
|
||||
return self.ctx.get_hf_config(LlavaConfig)
|
||||
|
||||
def _get_vision_encoder_info(self):
|
||||
return get_vision_encoder_info(self._get_hf_config())
|
||||
|
||||
@abstractmethod
|
||||
def _get_hf_config(self) -> LlavaLikeConfig:
|
||||
def _get_hf_processor(self) -> LlavaLikeProcessor:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None}
|
||||
|
||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||
return {"image": self._get_max_image_tokens()}
|
||||
|
||||
def _get_mm_fields_config(
|
||||
def _get_num_image_tokens(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
) -> int:
|
||||
hf_config = self._get_hf_config()
|
||||
vision_encoder_info = self._get_vision_encoder_info()
|
||||
|
||||
return self._apply_feature_select_strategy(
|
||||
hf_config.vision_feature_select_strategy,
|
||||
vision_encoder_info.get_num_image_tokens(
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
),
|
||||
)
|
||||
|
||||
def _apply_feature_select_strategy(
|
||||
@ -133,31 +147,38 @@ class BaseLlavaMultiModalProcessor(BaseVisionLanguageMultiModalProcessor):
|
||||
msg = f"Unexpected feature select strategy: {strategy!r}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
def _get_max_image_tokens(self) -> int:
|
||||
hf_config = self._get_hf_config()
|
||||
|
||||
return self._apply_feature_select_strategy(
|
||||
hf_config.vision_feature_select_strategy,
|
||||
self._vision_encoder_info.get_max_image_tokens(),
|
||||
class BaseLlavaProfilingInfo(BaseLlavaProcessingMixin, BaseProfilingInfo):
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None}
|
||||
|
||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||
return {"image": self._get_max_image_tokens()}
|
||||
|
||||
def _get_image_size_with_most_features(self) -> ImageSize:
|
||||
vision_encoder_info = self._get_vision_encoder_info()
|
||||
width = height = vision_encoder_info.get_image_size()
|
||||
return ImageSize(width=width, height=height)
|
||||
|
||||
def _get_max_image_tokens(self) -> int:
|
||||
target_width, target_height = self._get_image_size_with_most_features()
|
||||
|
||||
return self._get_num_image_tokens(
|
||||
image_width=target_width,
|
||||
image_height=target_height,
|
||||
)
|
||||
|
||||
def _get_dummy_image_size(self) -> ImageSize:
|
||||
image_size = self._vision_encoder_info.get_image_size()
|
||||
return ImageSize(image_size, image_size)
|
||||
|
||||
@abstractmethod
|
||||
def _get_image_token(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_dummy_processor_inputs(
|
||||
def get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
image_token = self._get_image_token()
|
||||
target_width, target_height = self._get_dummy_image_size()
|
||||
processor = self._get_hf_processor()
|
||||
image_token = processor.image_token
|
||||
target_width, target_height = self._get_image_size_with_most_features()
|
||||
|
||||
mm_data = {
|
||||
"image":
|
||||
@ -172,32 +193,32 @@ class BaseLlavaMultiModalProcessor(BaseVisionLanguageMultiModalProcessor):
|
||||
)
|
||||
|
||||
|
||||
class LlavaMultiModalProcessor(BaseLlavaMultiModalProcessor):
|
||||
class LlavaProcessingMixin(BaseLlavaProcessingMixin):
|
||||
|
||||
def _get_hf_config(self) -> LlavaConfig:
|
||||
return self.ctx.get_hf_config(LlavaConfig)
|
||||
|
||||
def _get_hf_processor(self) -> LlavaProcessor:
|
||||
def _get_hf_processor(self):
|
||||
return self.ctx.get_hf_processor(LlavaProcessor)
|
||||
|
||||
def _get_image_token(self) -> str:
|
||||
return self._get_hf_processor().image_token
|
||||
|
||||
def _get_num_image_tokens(
|
||||
class LlavaProfilingInfo(LlavaProcessingMixin, BaseLlavaProfilingInfo):
|
||||
pass
|
||||
|
||||
|
||||
class BaseLlavaMultiModalProcessor(LlavaProcessingMixin,
|
||||
BaseMultiModalProcessor):
|
||||
|
||||
# Copied from BaseMultiModalProcessor
|
||||
@abstractmethod
|
||||
def _get_profiling_info(self) -> BaseProfilingInfo:
|
||||
raise NotImplementedError
|
||||
|
||||
# Copied from BaseMultiModalProcessor
|
||||
@abstractmethod
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
) -> int:
|
||||
hf_config = self._get_hf_config()
|
||||
|
||||
return self._apply_feature_select_strategy(
|
||||
hf_config.vision_feature_select_strategy,
|
||||
self._vision_encoder_info.get_num_image_tokens(
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
),
|
||||
)
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
@ -232,16 +253,37 @@ class LlavaMultiModalProcessor(BaseLlavaMultiModalProcessor):
|
||||
]
|
||||
|
||||
|
||||
class PixtralHFMultiModalProcessor(BaseLlavaMultiModalProcessor):
|
||||
class LlavaMultiModalProcessor(BaseLlavaMultiModalProcessor):
|
||||
|
||||
def _get_hf_config(self) -> LlavaConfig:
|
||||
return self.ctx.get_hf_config(LlavaConfig)
|
||||
def _get_profiling_info(self) -> BaseProfilingInfo:
|
||||
return LlavaProfilingInfo(self.ctx)
|
||||
|
||||
def _get_hf_processor(self) -> PixtralProcessor:
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
|
||||
class PixtralHFProcessingMixin(BaseLlavaProcessingMixin):
|
||||
|
||||
def _get_hf_processor(self):
|
||||
return self.ctx.get_hf_processor(PixtralProcessor)
|
||||
|
||||
def _get_image_token(self) -> str:
|
||||
return self._get_hf_processor().image_token
|
||||
|
||||
class PixtralHFProfilingInfo(PixtralHFProcessingMixin, BaseLlavaProfilingInfo):
|
||||
pass
|
||||
|
||||
|
||||
class PixtralHFMultiModalProcessor(PixtralHFProcessingMixin,
|
||||
BaseMultiModalProcessor):
|
||||
|
||||
def _get_profiling_info(self) -> BaseProfilingInfo:
|
||||
return PixtralHFProfilingInfo(self.ctx)
|
||||
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
@ -270,6 +312,16 @@ class PixtralHFMultiModalProcessor(BaseLlavaMultiModalProcessor):
|
||||
|
||||
return processed_outputs
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
@ -316,7 +368,7 @@ def _build_llava_or_pixtral_hf_processor(
|
||||
*,
|
||||
cache: Optional[ProcessingCache] = None,
|
||||
enable_sanity_checks: bool = True,
|
||||
) -> BaseLlavaMultiModalProcessor:
|
||||
) -> BaseMultiModalProcessor:
|
||||
hf_config = ctx.get_hf_config(LlavaConfig)
|
||||
|
||||
if isinstance(hf_config.vision_config, PixtralVisionConfig):
|
||||
@ -663,16 +715,13 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
class MantisMultiModalProcessor(LlavaMultiModalProcessor):
|
||||
|
||||
def _get_hf_processor(self):
|
||||
return self.ctx.get_hf_processor(LlavaProcessor)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
prompt_text: str,
|
||||
mm_data: MultiModalDataDict,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> MultiModalInputsV2:
|
||||
hf_config = self.ctx.get_hf_config(LlavaConfig)
|
||||
hf_config = self._get_hf_config()
|
||||
image_token_id = hf_config.image_token_index
|
||||
|
||||
# Assume that it doesn't depend on the image size
|
||||
|
@ -1,6 +1,6 @@
|
||||
from functools import cached_property
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
|
||||
TypedDict, Union)
|
||||
from typing import (Final, Iterable, List, Literal, Mapping, Optional,
|
||||
Protocol, Set, Tuple, TypedDict, Union)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -17,12 +17,14 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import MultiModalFieldConfig, NestedTensors
|
||||
from vllm.multimodal.parse import ImageSize
|
||||
from vllm.multimodal.profiling import BaseProfilingInfo
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .clip import CLIPVisionModel
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .llava import (LlavaMultiModalProcessor, LlavaMultiModalProjector,
|
||||
init_vision_tower_for_llava)
|
||||
from .llava import (BaseLlavaMultiModalProcessor, BaseLlavaProcessingMixin,
|
||||
BaseLlavaProfilingInfo, LlavaLikeConfig,
|
||||
LlavaMultiModalProjector, init_vision_tower_for_llava)
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (AutoWeightsLoader, embed_multimodal, flatten_bn,
|
||||
init_vllm_registered_model, maybe_prefix)
|
||||
@ -60,36 +62,18 @@ LlavaNextImageInputs = Union[LlavaNextImagePixelInputs,
|
||||
LlavaNextImageEmbeddingInputs]
|
||||
|
||||
|
||||
class LlavaNextMultiModalProcessor(LlavaMultiModalProcessor):
|
||||
class LlavaNextLikeConfig(LlavaLikeConfig, Protocol):
|
||||
image_grid_pinpoints: Final[list[list[int]]]
|
||||
|
||||
def _get_hf_config(self) -> LlavaNextConfig:
|
||||
|
||||
class LlavaNextProcessingMixin(BaseLlavaProcessingMixin):
|
||||
|
||||
def _get_hf_config(self) -> LlavaNextLikeConfig:
|
||||
return self.ctx.get_hf_config(LlavaNextConfig)
|
||||
|
||||
def _get_hf_processor(self) -> LlavaNextProcessor:
|
||||
def _get_hf_processor(self):
|
||||
return self.ctx.get_hf_processor(LlavaNextProcessor)
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
image_sizes=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_image_token(self) -> str:
|
||||
return self._get_hf_processor().image_token
|
||||
|
||||
def _get_max_image_tokens(self) -> int:
|
||||
largest_feature_size, _ = self._get_pinpoint_with_most_features()
|
||||
return largest_feature_size
|
||||
|
||||
def _get_dummy_image_size(self) -> ImageSize:
|
||||
_, pinpoint = self._get_pinpoint_with_most_features()
|
||||
return pinpoint
|
||||
|
||||
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L106
|
||||
def _get_num_image_tokens(
|
||||
self,
|
||||
@ -98,7 +82,7 @@ class LlavaNextMultiModalProcessor(LlavaMultiModalProcessor):
|
||||
image_height: int,
|
||||
) -> int:
|
||||
hf_config = self._get_hf_config()
|
||||
vision_encoder_info = self._vision_encoder_info
|
||||
vision_encoder_info = self._get_vision_encoder_info()
|
||||
|
||||
base_feature_size = self._apply_feature_select_strategy(
|
||||
hf_config.vision_feature_select_strategy,
|
||||
@ -140,7 +124,7 @@ class LlavaNextMultiModalProcessor(LlavaMultiModalProcessor):
|
||||
current_height = npatches * num_patch_height
|
||||
current_width = npatches * num_patch_width
|
||||
|
||||
# NOTE: HF resizes based on float32
|
||||
# NOTE: Use float32 to remain consistent with HF output
|
||||
original_aspect_ratio = np.array(original_width / original_height,
|
||||
dtype=np.float32)
|
||||
current_aspect_ratio = np.array(current_width / current_height,
|
||||
@ -164,11 +148,10 @@ class LlavaNextMultiModalProcessor(LlavaMultiModalProcessor):
|
||||
|
||||
return (unpadded_features, newline_features)
|
||||
|
||||
def _get_pinpoint_with_most_features(self) -> tuple[int, ImageSize]:
|
||||
"""
|
||||
Get the grid pinpoint with the most features and
|
||||
the corresponding feature size.
|
||||
"""
|
||||
|
||||
class LlavaNextProfilingInfo(LlavaNextProcessingMixin, BaseLlavaProfilingInfo):
|
||||
|
||||
def _get_image_size_with_most_features(self) -> ImageSize:
|
||||
hf_config = self._get_hf_config()
|
||||
|
||||
largest_feature_size, largest_feature_pinpoint = 0, None
|
||||
@ -183,7 +166,25 @@ class LlavaNextMultiModalProcessor(LlavaMultiModalProcessor):
|
||||
if largest_feature_size == 0 or largest_feature_pinpoint is None:
|
||||
raise ValueError("Cannot have a largest feature size of 0!")
|
||||
|
||||
return largest_feature_size, largest_feature_pinpoint
|
||||
return largest_feature_pinpoint
|
||||
|
||||
|
||||
class LlavaNextMultiModalProcessor(LlavaNextProcessingMixin,
|
||||
BaseLlavaMultiModalProcessor):
|
||||
|
||||
def _get_profiling_info(self) -> BaseProfilingInfo:
|
||||
return LlavaNextProfilingInfo(self.ctx)
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
image_sizes=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(LlavaNextMultiModalProcessor)
|
||||
|
@ -15,11 +15,14 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.models.clip import CLIPVisionModel
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
|
||||
from vllm.multimodal.parse import (ImageSize, MultiModalDataItems,
|
||||
VideoEmbeddingItems, VideoProcessorItems)
|
||||
from vllm.multimodal.processing import (MultiModalFieldConfig, ProcessorInputs,
|
||||
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
|
||||
NestedTensors)
|
||||
from vllm.multimodal.parse import (ImageSize, VideoEmbeddingItems,
|
||||
VideoProcessorItems)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
MultiModalDataItems, ProcessingMixin,
|
||||
PromptReplacement)
|
||||
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
@ -28,7 +31,7 @@ from .llava import init_vision_tower_for_llava
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
from .vision import BaseVisionLanguageMultiModalProcessor
|
||||
from .vision import get_vision_encoder_info
|
||||
|
||||
|
||||
class LlavaNextVideoPixelInputs(TypedDict):
|
||||
@ -44,30 +47,17 @@ class LlavaNextVideoPixelInputs(TypedDict):
|
||||
"""
|
||||
|
||||
|
||||
class LlavaNextVideoMultiModalProcessor(BaseVisionLanguageMultiModalProcessor):
|
||||
class LlavaNextVideoProcessingMixin(ProcessingMixin):
|
||||
|
||||
def _get_hf_config(self) -> LlavaNextVideoConfig:
|
||||
def _get_hf_config(self):
|
||||
return self.ctx.get_hf_config(LlavaNextVideoConfig)
|
||||
|
||||
def _get_hf_processor(self) -> LlavaNextVideoProcessor:
|
||||
def _get_vision_encoder_info(self):
|
||||
return get_vision_encoder_info(self._get_hf_config())
|
||||
|
||||
def _get_hf_processor(self):
|
||||
return self.ctx.get_hf_processor(LlavaNextVideoProcessor)
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"video": 1}
|
||||
|
||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||
num_frames = self._get_dummy_num_frames(seq_len)
|
||||
max_video_tokens = self._get_max_video_tokens(num_frames)
|
||||
|
||||
return {"video": max_video_tokens}
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(pixel_values_videos=MultiModalFieldConfig.batched("video"))
|
||||
|
||||
def _get_num_frame_tokens(
|
||||
self,
|
||||
*,
|
||||
@ -77,7 +67,8 @@ class LlavaNextVideoMultiModalProcessor(BaseVisionLanguageMultiModalProcessor):
|
||||
hf_config = self._get_hf_config()
|
||||
spatial_pool_stride = hf_config.spatial_pool_stride
|
||||
|
||||
patch_grid_length = self._vision_encoder_info.get_patch_grid_length()
|
||||
vision_encoder_info = self._get_vision_encoder_info()
|
||||
patch_grid_length = vision_encoder_info.get_patch_grid_length()
|
||||
pooled_grid_length = math.ceil(patch_grid_length / spatial_pool_stride)
|
||||
|
||||
return pooled_grid_length * pooled_grid_length
|
||||
@ -96,18 +87,43 @@ class LlavaNextVideoMultiModalProcessor(BaseVisionLanguageMultiModalProcessor):
|
||||
|
||||
return num_frame_tokens * num_frames
|
||||
|
||||
def _get_max_video_tokens(self, num_frames: int) -> int:
|
||||
return self._get_num_video_tokens(image_width=999999,
|
||||
image_height=999999,
|
||||
num_frames=num_frames)
|
||||
|
||||
class LlavaNextVideoProfilingInfo(LlavaNextVideoProcessingMixin,
|
||||
BaseProfilingInfo):
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"video": 1}
|
||||
|
||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||
target_width, target_height = self._get_image_size_with_most_features()
|
||||
|
||||
max_video_tokens = self._get_num_video_tokens(
|
||||
image_width=target_width,
|
||||
image_height=target_height,
|
||||
num_frames=self._get_dummy_num_frames(seq_len),
|
||||
)
|
||||
|
||||
return {"video": max_video_tokens}
|
||||
|
||||
def _get_image_size_with_most_features(self) -> ImageSize:
|
||||
vision_encoder_info = self._get_vision_encoder_info()
|
||||
width = height = vision_encoder_info.get_image_size()
|
||||
return ImageSize(width=width, height=height)
|
||||
|
||||
def _get_max_video_frames(self, max_tokens: int) -> int:
|
||||
target_width, target_height = self._get_image_size_with_most_features()
|
||||
|
||||
num_frames = 0
|
||||
|
||||
while True:
|
||||
next_num_frames = num_frames + 1
|
||||
next_max_tokens = self._get_num_video_tokens(
|
||||
image_width=target_width,
|
||||
image_height=target_height,
|
||||
num_frames=next_num_frames,
|
||||
)
|
||||
|
||||
if self._get_max_video_tokens(next_num_frames) > max_tokens:
|
||||
if next_max_tokens > max_tokens:
|
||||
break
|
||||
|
||||
num_frames = next_num_frames
|
||||
@ -122,12 +138,45 @@ class LlavaNextVideoMultiModalProcessor(BaseVisionLanguageMultiModalProcessor):
|
||||
|
||||
return max(max_total_frames // max(max_videos, 1), 1)
|
||||
|
||||
def _get_dummy_image_size(self) -> ImageSize:
|
||||
image_size = self._vision_encoder_info.get_image_size()
|
||||
return ImageSize(image_size, image_size)
|
||||
def get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
num_videos = mm_counts.get("video", 0)
|
||||
|
||||
def _get_video_token(self) -> str:
|
||||
return self._get_hf_processor().video_token
|
||||
processor = self._get_hf_processor()
|
||||
video_token = processor.video_token
|
||||
target_width, target_height = self._get_image_size_with_most_features()
|
||||
|
||||
mm_data = {
|
||||
"video":
|
||||
self._get_dummy_videos(
|
||||
width=target_width,
|
||||
height=target_height,
|
||||
num_frames=self._get_dummy_num_frames(seq_len),
|
||||
num_videos=num_videos,
|
||||
)
|
||||
}
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt_text=video_token * num_videos,
|
||||
mm_data=mm_data,
|
||||
)
|
||||
|
||||
|
||||
class LlavaNextVideoMultiModalProcessor(LlavaNextVideoProcessingMixin,
|
||||
BaseMultiModalProcessor):
|
||||
|
||||
def _get_profiling_info(self) -> BaseProfilingInfo:
|
||||
return LlavaNextVideoProfilingInfo(self.ctx)
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(pixel_values_videos=MultiModalFieldConfig.batched("video"))
|
||||
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
@ -162,36 +211,11 @@ class LlavaNextVideoMultiModalProcessor(BaseVisionLanguageMultiModalProcessor):
|
||||
),
|
||||
]
|
||||
|
||||
def _get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
num_videos = mm_counts.get("video", 0)
|
||||
|
||||
video_token = self._get_video_token()
|
||||
target_width, target_height = self._get_dummy_image_size()
|
||||
|
||||
mm_data = {
|
||||
"video":
|
||||
self._get_dummy_videos(
|
||||
width=target_width,
|
||||
height=target_height,
|
||||
num_frames=self._get_dummy_num_frames(seq_len),
|
||||
num_videos=num_videos,
|
||||
)
|
||||
}
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt_text=video_token * num_videos,
|
||||
mm_data=mm_data,
|
||||
)
|
||||
|
||||
|
||||
# adopted from transformers modeling_llava_next_video.py
|
||||
class LlavaNextVideoPooler(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(self, config: LlavaNextVideoConfig):
|
||||
super().__init__()
|
||||
|
||||
mode = config.spatial_pool_mode
|
||||
@ -209,7 +233,7 @@ class LlavaNextVideoPooler(nn.Module):
|
||||
raise ValueError(
|
||||
f"Unknown pooling mode: {mode}. Expected [`average`, `max`]")
|
||||
|
||||
def forward(self, image_features):
|
||||
def forward(self, image_features: torch.Tensor):
|
||||
ori_width = int(
|
||||
math.sqrt(image_features.shape[1] * self.image_size //
|
||||
self.image_size))
|
||||
|
@ -1,7 +1,7 @@
|
||||
import math
|
||||
from functools import cached_property
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
|
||||
TypedDict, Union)
|
||||
from typing import (Final, Iterable, List, Literal, Mapping, Optional,
|
||||
Protocol, Set, Tuple, TypedDict, Union)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -21,15 +21,16 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
|
||||
from vllm.multimodal.parse import (MultiModalDataItems, VideoEmbeddingItems,
|
||||
VideoProcessorItems)
|
||||
from vllm.multimodal.processing import (MultiModalFieldConfig, ProcessorInputs,
|
||||
PromptReplacement)
|
||||
from vllm.multimodal.processing import MultiModalFieldConfig, PromptReplacement
|
||||
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .clip import CLIPVisionModel
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .llava import init_vision_tower_for_llava
|
||||
from .llava_next import LlavaNextMultiModalProcessor
|
||||
from .llava import BaseLlavaProfilingInfo, init_vision_tower_for_llava
|
||||
from .llava_next import (LlavaNextLikeConfig, LlavaNextMultiModalProcessor,
|
||||
LlavaNextProcessingMixin)
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
@ -82,40 +83,18 @@ LlavaOnevisionMultiInputs = Union[LlavaOnevisionImageInputs,
|
||||
LlavaOnevisionVideoPixelInputs]
|
||||
|
||||
|
||||
class LlavaOnevisionMultiModalProcessor(LlavaNextMultiModalProcessor):
|
||||
class LlavaOnevisionLikeConfig(LlavaNextLikeConfig, Protocol):
|
||||
video_token_index: Final[int]
|
||||
|
||||
def _get_hf_config(self) -> LlavaOnevisionConfig:
|
||||
|
||||
class LlavaOnevisionProcessingMixin(LlavaNextProcessingMixin):
|
||||
|
||||
def _get_hf_config(self) -> LlavaOnevisionLikeConfig:
|
||||
return self.ctx.get_hf_config(LlavaOnevisionConfig)
|
||||
|
||||
def _get_hf_processor(self) -> LlavaOnevisionProcessor:
|
||||
def _get_hf_processor(self):
|
||||
return self.ctx.get_hf_processor(LlavaOnevisionProcessor)
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None, "video": None}
|
||||
|
||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||
max_image_tokens = self._get_max_image_tokens()
|
||||
|
||||
num_frames = self._get_dummy_num_frames(seq_len)
|
||||
max_video_tokens = self._get_max_video_tokens(num_frames)
|
||||
|
||||
return {
|
||||
"image": max_image_tokens,
|
||||
"video": max_video_tokens,
|
||||
}
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
image_sizes=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
pixel_values_videos=MultiModalFieldConfig.batched("video"),
|
||||
)
|
||||
|
||||
def _get_num_unpadded_features(
|
||||
self,
|
||||
*,
|
||||
@ -128,7 +107,7 @@ class LlavaOnevisionMultiModalProcessor(LlavaNextMultiModalProcessor):
|
||||
current_height = npatches * num_patch_height
|
||||
current_width = npatches * num_patch_width
|
||||
|
||||
# NOTE: HF resizes based on float32
|
||||
# NOTE: Use float32 to remain consistent with HF output
|
||||
original_aspect_ratio = np.array(original_width / original_height,
|
||||
dtype=np.float32)
|
||||
current_aspect_ratio = np.array(current_width / current_height,
|
||||
@ -167,7 +146,8 @@ class LlavaOnevisionMultiModalProcessor(LlavaNextMultiModalProcessor):
|
||||
hf_config = self._get_hf_config()
|
||||
spatial_pool_stride = getattr(hf_config, "spatial_pool_stride", 2)
|
||||
|
||||
patch_grid_length = self._vision_encoder_info.get_patch_grid_length()
|
||||
vision_encoder_info = self._get_vision_encoder_info()
|
||||
patch_grid_length = vision_encoder_info.get_patch_grid_length()
|
||||
pooled_grid_length = math.ceil(patch_grid_length / spatial_pool_stride)
|
||||
|
||||
return pooled_grid_length * pooled_grid_length
|
||||
@ -186,18 +166,33 @@ class LlavaOnevisionMultiModalProcessor(LlavaNextMultiModalProcessor):
|
||||
|
||||
return num_frame_tokens * num_frames + 1 # Newline token
|
||||
|
||||
def _get_max_video_tokens(self, num_frames: int) -> int:
|
||||
return self._get_num_video_tokens(image_width=999999,
|
||||
image_height=999999,
|
||||
num_frames=num_frames)
|
||||
|
||||
class LlavaOnevisionProfilingInfo(LlavaOnevisionProcessingMixin,
|
||||
BaseLlavaProfilingInfo):
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None, "video": None}
|
||||
|
||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||
return {
|
||||
"image": self._get_max_image_tokens(),
|
||||
"video": self._get_max_video_tokens(seq_len),
|
||||
}
|
||||
|
||||
def _get_max_video_frames(self, max_tokens: int) -> int:
|
||||
target_width, target_height = self._get_image_size_with_most_features()
|
||||
|
||||
num_frames = 0
|
||||
|
||||
while True:
|
||||
next_num_frames = num_frames + 1
|
||||
next_max_tokens = self._get_num_video_tokens(
|
||||
image_width=target_width,
|
||||
image_height=target_height,
|
||||
num_frames=next_num_frames,
|
||||
)
|
||||
|
||||
if self._get_max_video_tokens(next_num_frames) > max_tokens:
|
||||
if next_max_tokens > max_tokens:
|
||||
break
|
||||
|
||||
num_frames = next_num_frames
|
||||
@ -215,8 +210,65 @@ class LlavaOnevisionMultiModalProcessor(LlavaNextMultiModalProcessor):
|
||||
|
||||
return max(max_total_frames // max(max_videos, 1), 1)
|
||||
|
||||
def _get_video_token(self) -> str:
|
||||
return self._get_hf_processor().video_token
|
||||
def _get_max_video_tokens(self, seq_len: int) -> int:
|
||||
target_width, target_height = self._get_image_size_with_most_features()
|
||||
|
||||
return self._get_num_video_tokens(
|
||||
image_width=target_width,
|
||||
image_height=target_height,
|
||||
num_frames=self._get_dummy_num_frames(seq_len),
|
||||
)
|
||||
|
||||
def get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
num_videos = mm_counts.get("video", 0)
|
||||
|
||||
processor = self._get_hf_processor()
|
||||
image_token = processor.image_token
|
||||
video_token = processor.video_token
|
||||
target_width, target_height = self._get_image_size_with_most_features()
|
||||
|
||||
mm_data = {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images),
|
||||
"video":
|
||||
self._get_dummy_videos(
|
||||
width=target_width,
|
||||
height=target_height,
|
||||
num_frames=self._get_dummy_num_frames(seq_len),
|
||||
num_videos=num_videos,
|
||||
)
|
||||
}
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt_text=image_token * num_images + video_token * num_videos,
|
||||
mm_data=mm_data,
|
||||
)
|
||||
|
||||
|
||||
class LlavaOnevisionMultiModalProcessor(LlavaOnevisionProcessingMixin,
|
||||
LlavaNextMultiModalProcessor):
|
||||
|
||||
def _get_profiling_info(self) -> BaseProfilingInfo:
|
||||
return LlavaOnevisionProfilingInfo(self.ctx)
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
image_sizes=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
pixel_values_videos=MultiModalFieldConfig.batched("video"),
|
||||
)
|
||||
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
@ -235,7 +287,8 @@ class LlavaOnevisionMultiModalProcessor(LlavaNextMultiModalProcessor):
|
||||
mm_kwargs=mm_kwargs,
|
||||
)
|
||||
|
||||
video_token = self._get_video_token()
|
||||
processor = self._get_hf_processor()
|
||||
video_token = processor.video_token
|
||||
|
||||
# LLaVA-OneVision processor doesn't support multiple videos
|
||||
# with different sizes when converting back to tensors
|
||||
@ -303,37 +356,6 @@ class LlavaOnevisionMultiModalProcessor(LlavaNextMultiModalProcessor):
|
||||
),
|
||||
]
|
||||
|
||||
def _get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
num_videos = mm_counts.get("video", 0)
|
||||
|
||||
image_token = self._get_image_token()
|
||||
video_token = self._get_video_token()
|
||||
target_width, target_height = self._get_dummy_image_size()
|
||||
|
||||
mm_data = {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images),
|
||||
"video":
|
||||
self._get_dummy_videos(
|
||||
width=target_width,
|
||||
height=target_height,
|
||||
num_frames=self._get_dummy_num_frames(seq_len),
|
||||
num_videos=num_videos,
|
||||
)
|
||||
}
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt_text=image_token * num_images + video_token * num_videos,
|
||||
mm_data=mm_data,
|
||||
)
|
||||
|
||||
|
||||
class LlavaOnevisionMultiModalProjector(nn.Module):
|
||||
|
||||
|
@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from functools import cached_property
|
||||
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
|
||||
from typing import Any, List, Literal, Optional, Set, Tuple, TypedDict, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -28,22 +28,23 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.models.clip import CLIPVisionModel
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalInputsV2, MultiModalKwargs,
|
||||
NestedTensors, PlaceholderRange)
|
||||
from vllm.multimodal.parse import ImageEmbeddingItems, ImageProcessorItems
|
||||
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
||||
ImageSize)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
MultiModalDataItems, ProcessorInputs,
|
||||
MultiModalDataItems, ProcessingMixin,
|
||||
PromptReplacement,
|
||||
_BoundPromptReplacement,
|
||||
_PlaceholderInfo)
|
||||
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .clip import dummy_image_for_clip
|
||||
from .clip import CLIPVisionModel
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
init_vllm_registered_model, maybe_prefix,
|
||||
@ -54,10 +55,6 @@ logger = init_logger(__name__)
|
||||
# Cannot find the following 2 numbers from hf config.
|
||||
_IMAGE_TOKEN_ID = 32044
|
||||
|
||||
# Result in the max possible feature size (h:w = 16:1)
|
||||
MAX_IMAGE_FEATURE_SIZE_HEIGHT = 8000
|
||||
MAX_IMAGE_FEATURE_SIZE_WIDTH = 50
|
||||
|
||||
CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(dropout=0.0,
|
||||
hidden_act="quick_gelu",
|
||||
hidden_size=1024,
|
||||
@ -305,10 +302,17 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
|
||||
return image_features_hd_newline
|
||||
|
||||
|
||||
class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
|
||||
class Phi3VProcessingMixin(ProcessingMixin):
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None}
|
||||
def _get_hf_processor(
|
||||
self,
|
||||
*,
|
||||
num_crops: Optional[int] = None,
|
||||
) -> ProcessorMixin:
|
||||
if num_crops is not None:
|
||||
return self.ctx.get_hf_processor(num_crops=num_crops)
|
||||
|
||||
return self.ctx.get_hf_processor()
|
||||
|
||||
def _get_num_image_tokens(
|
||||
self,
|
||||
@ -323,23 +327,55 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
|
||||
height=image_height,
|
||||
)
|
||||
|
||||
|
||||
class Phi3VProfilingInfo(Phi3VProcessingMixin, BaseProfilingInfo):
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None}
|
||||
|
||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||
target_width, target_height = self._get_image_size_with_most_features()
|
||||
|
||||
max_image_tokens = self._get_num_image_tokens(
|
||||
image_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
|
||||
image_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
|
||||
image_width=target_width,
|
||||
image_height=target_height,
|
||||
)
|
||||
|
||||
return {"image": max_image_tokens}
|
||||
|
||||
def _get_hf_processor(
|
||||
self,
|
||||
*,
|
||||
num_crops: Optional[int] = None,
|
||||
) -> ProcessorMixin:
|
||||
if num_crops is not None:
|
||||
return self.ctx.get_hf_processor(num_crops=num_crops)
|
||||
def _get_image_size_with_most_features(self) -> ImageSize:
|
||||
# Result in the max possible feature size (h:w = 16:1)
|
||||
return ImageSize(height=8000, width=50)
|
||||
|
||||
return self.ctx.get_hf_processor()
|
||||
def get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
target_width, target_height = self._get_image_size_with_most_features()
|
||||
|
||||
mm_data = {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images)
|
||||
}
|
||||
|
||||
hf_processor = self._get_hf_processor()
|
||||
image_tokens: list[str] = hf_processor.img_tokens # type: ignore
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt_text="".join(image_tokens[:num_images]),
|
||||
mm_data=mm_data,
|
||||
)
|
||||
|
||||
|
||||
class Phi3VMultiModalProcessor(Phi3VProcessingMixin, BaseMultiModalProcessor):
|
||||
|
||||
def _get_profiling_info(self) -> BaseProfilingInfo:
|
||||
return Phi3VProfilingInfo(self.ctx)
|
||||
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
@ -377,10 +413,10 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
hf_processor_mm_kwargs: Mapping[str, Any],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
hf_processor = self._get_hf_processor()
|
||||
hf_processor = self._get_hf_processor(**hf_processor_mm_kwargs)
|
||||
image_tokens: list[str] = hf_processor.img_tokens # type: ignore
|
||||
|
||||
tokenizer = self._get_tokenizer()
|
||||
@ -442,28 +478,6 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
return token_ids, text, placeholders
|
||||
|
||||
def _get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
data = dummy_image_for_clip(
|
||||
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
|
||||
num_images,
|
||||
image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
|
||||
image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
|
||||
)
|
||||
|
||||
hf_processor = self._get_hf_processor()
|
||||
image_tokens: list[str] = hf_processor.img_tokens # type: ignore
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt_text="".join(image_tokens[:num_images]),
|
||||
mm_data=data,
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
prompt_text: str,
|
||||
|
@ -20,8 +20,8 @@
|
||||
# limitations under the License.
|
||||
"""Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
|
||||
from functools import cached_property
|
||||
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
|
||||
Union)
|
||||
from typing import (Any, Iterable, List, Mapping, Optional, Set, Tuple,
|
||||
TypedDict, Union)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -40,8 +40,9 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
|
||||
NestedTensors)
|
||||
from vllm.multimodal.parse import AudioProcessorItems, MultiModalDataParser
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
MultiModalDataItems, ProcessorInputs,
|
||||
MultiModalDataItems, ProcessingMixin,
|
||||
PromptReplacement)
|
||||
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
@ -79,17 +80,10 @@ def _get_feat_extract_output_lengths(input_lengths: torch.Tensor):
|
||||
return feat_lengths, output_lengths
|
||||
|
||||
|
||||
class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
|
||||
class Qwen2AudioProcessingMixin(ProcessingMixin):
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"audio": None}
|
||||
|
||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||
hf_config = self.ctx.get_hf_config(Qwen2AudioConfig)
|
||||
max_source_positions = hf_config.audio_config.max_source_positions
|
||||
max_output_lengths = (max_source_positions - 2) // 2 + 1
|
||||
|
||||
return {"audio": max_output_lengths}
|
||||
def _get_hf_config(self):
|
||||
return self.ctx.get_hf_config(Qwen2AudioConfig)
|
||||
|
||||
def _get_hf_processor(
|
||||
self,
|
||||
@ -99,8 +93,57 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
|
||||
) -> Qwen2AudioProcessor:
|
||||
return self.ctx.get_hf_processor(Qwen2AudioProcessor)
|
||||
|
||||
def _get_feature_extractor(self) -> WhisperFeatureExtractor:
|
||||
return self._get_hf_processor().feature_extractor # type: ignore
|
||||
def _get_feature_extractor(
|
||||
self,
|
||||
*,
|
||||
# Ignored in initialization
|
||||
sampling_rate: Optional[int] = None,
|
||||
) -> WhisperFeatureExtractor:
|
||||
hf_processor = self._get_hf_processor(sampling_rate=sampling_rate)
|
||||
feature_extractor = hf_processor.feature_extractor # type: ignore
|
||||
assert isinstance(feature_extractor, WhisperFeatureExtractor)
|
||||
return feature_extractor
|
||||
|
||||
|
||||
class Qwen2AudioProfilingInfo(Qwen2AudioProcessingMixin, BaseProfilingInfo):
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"audio": None}
|
||||
|
||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||
hf_config = self._get_hf_config()
|
||||
max_source_positions = hf_config.audio_config.max_source_positions
|
||||
max_output_lengths = (max_source_positions - 2) // 2 + 1
|
||||
|
||||
return {"audio": max_output_lengths}
|
||||
|
||||
def get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
feature_extractor = self._get_feature_extractor()
|
||||
|
||||
sampling_rate = feature_extractor.sampling_rate
|
||||
audio_len = feature_extractor.chunk_length * sampling_rate
|
||||
num_audios = mm_counts.get("audio", 0)
|
||||
|
||||
mm_data = {
|
||||
"audio":
|
||||
self._get_dummy_audios(length=audio_len, num_audios=num_audios)
|
||||
}
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt_text="<|AUDIO|>" * num_audios,
|
||||
mm_data=mm_data,
|
||||
)
|
||||
|
||||
|
||||
class Qwen2AudioMultiModalProcessor(Qwen2AudioProcessingMixin,
|
||||
BaseMultiModalProcessor):
|
||||
|
||||
def _get_profiling_info(self) -> BaseProfilingInfo:
|
||||
return Qwen2AudioProfilingInfo(self.ctx)
|
||||
|
||||
def _get_data_parser(self) -> MultiModalDataParser:
|
||||
feature_extractor = self._get_feature_extractor()
|
||||
@ -110,7 +153,7 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
|
||||
self,
|
||||
prompt: str,
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, Any],
|
||||
) -> BatchFeature:
|
||||
mm_data = dict(mm_data)
|
||||
audios = mm_data.pop("audios", [])
|
||||
@ -118,7 +161,7 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
|
||||
if audios:
|
||||
mm_data["audios"] = audios
|
||||
|
||||
feature_extractor = self._get_feature_extractor()
|
||||
feature_extractor = self._get_feature_extractor(**mm_kwargs)
|
||||
mm_kwargs = dict(
|
||||
**mm_kwargs,
|
||||
sampling_rate=feature_extractor.sampling_rate,
|
||||
@ -151,7 +194,7 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
hf_config = self.ctx.get_hf_config(Qwen2AudioConfig)
|
||||
hf_config = self._get_hf_config()
|
||||
placeholder = hf_config.audio_token_index
|
||||
|
||||
feature_attention_mask = out_mm_kwargs.get("feature_attention_mask")
|
||||
@ -191,27 +234,6 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
|
||||
# tokens than the number of audio items)
|
||||
return True
|
||||
|
||||
def _get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
feature_extractor = self._get_feature_extractor()
|
||||
|
||||
sampling_rate = feature_extractor.sampling_rate
|
||||
audio_len = feature_extractor.chunk_length * sampling_rate
|
||||
num_audios = mm_counts.get("audio", 0)
|
||||
|
||||
mm_data = {
|
||||
"audio":
|
||||
self._get_dummy_audios(length=audio_len, num_audios=num_audios)
|
||||
}
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt_text="<|AUDIO|>" * num_audios,
|
||||
mm_data=mm_data,
|
||||
)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(Qwen2AudioMultiModalProcessor)
|
||||
class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
@ -59,8 +59,9 @@ from vllm.multimodal.inputs import (ImageItem, ModalityData,
|
||||
from vllm.multimodal.parse import (ImageSize, ModalityDataItems,
|
||||
MultiModalDataParser)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
MultiModalDataItems, ProcessorInputs,
|
||||
MultiModalDataItems, ProcessingMixin,
|
||||
PromptReplacement)
|
||||
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
|
||||
from vllm.platforms import _Backend
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.config import uses_mrope
|
||||
@ -708,10 +709,44 @@ class Qwen2MultiModalDataParser(MultiModalDataParser):
|
||||
return super()._parse_video_data(data)
|
||||
|
||||
|
||||
class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
|
||||
class Qwen2VLProcessingMixin(ProcessingMixin):
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None, "video": None}
|
||||
def _get_hf_config(self):
|
||||
return self.ctx.get_hf_config(Qwen2VLConfig)
|
||||
|
||||
def _get_hf_processor(
|
||||
self,
|
||||
*,
|
||||
min_pixels: Optional[int] = None,
|
||||
max_pixels: Optional[int] = None,
|
||||
) -> Qwen2VLProcessor:
|
||||
hf_processor = self.ctx.get_hf_processor(Qwen2VLProcessor)
|
||||
image_processor = hf_processor.image_processor # type: ignore
|
||||
assert isinstance(image_processor, Qwen2VLImageProcessor)
|
||||
|
||||
if min_pixels:
|
||||
image_processor.min_pixels = min_pixels
|
||||
if max_pixels:
|
||||
image_processor.max_pixels = max_pixels
|
||||
if max_pixels or min_pixels:
|
||||
image_processor.size = {
|
||||
"min_pixels": image_processor.min_pixels,
|
||||
"max_pixels": image_processor.max_pixels,
|
||||
}
|
||||
|
||||
return hf_processor
|
||||
|
||||
def _get_image_processor(
|
||||
self,
|
||||
*,
|
||||
min_pixels: Optional[int] = None,
|
||||
max_pixels: Optional[int] = None,
|
||||
):
|
||||
hf_processor = self._get_hf_processor(min_pixels=min_pixels,
|
||||
max_pixels=max_pixels)
|
||||
image_processor = hf_processor.image_processor # type: ignore
|
||||
assert isinstance(image_processor, Qwen2VLImageProcessor)
|
||||
return image_processor
|
||||
|
||||
def _get_vision_info(
|
||||
self,
|
||||
@ -721,14 +756,13 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
|
||||
num_frames: int = 1,
|
||||
do_resize: bool = True,
|
||||
) -> tuple[ImageSize, int]:
|
||||
hf_config = self.ctx.get_hf_config(Qwen2VLConfig)
|
||||
hf_config = self._get_hf_config()
|
||||
vision_config = hf_config.vision_config
|
||||
patch_size = vision_config.patch_size
|
||||
merge_size = vision_config.spatial_merge_size
|
||||
temporal_patch_size = vision_config.temporal_patch_size
|
||||
|
||||
hf_processor = self._get_hf_processor()
|
||||
image_processor = self._get_image_processor(hf_processor)
|
||||
image_processor = self._get_image_processor()
|
||||
|
||||
if do_resize:
|
||||
resized_height, resized_width = smart_resize(
|
||||
@ -753,7 +787,45 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
return preprocessed_size, num_vision_tokens
|
||||
|
||||
def _get_dummy_image_size(self) -> ImageSize:
|
||||
def _get_num_image_tokens(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
) -> int:
|
||||
_, num_image_tokens = self._get_vision_info(
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
)
|
||||
return num_image_tokens
|
||||
|
||||
def _get_num_video_tokens(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
num_frames: int,
|
||||
) -> int:
|
||||
_, num_video_tokens = self._get_vision_info(
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
num_frames=num_frames,
|
||||
)
|
||||
return num_video_tokens
|
||||
|
||||
|
||||
class Qwen2VLProfilingInfo(Qwen2VLProcessingMixin, BaseProfilingInfo):
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None, "video": None}
|
||||
|
||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||
return {
|
||||
"image": self._get_max_image_tokens(),
|
||||
"video": self._get_max_video_tokens(seq_len),
|
||||
}
|
||||
|
||||
def _get_image_size_with_most_features(self) -> ImageSize:
|
||||
max_image_size, _ = self._get_vision_info(
|
||||
image_width=9999999,
|
||||
image_height=9999999,
|
||||
@ -761,27 +833,27 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
|
||||
return max_image_size
|
||||
|
||||
def _get_max_image_tokens(self) -> int:
|
||||
_, max_image_tokens = self._get_vision_info(
|
||||
image_width=9999999,
|
||||
image_height=9999999,
|
||||
)
|
||||
return max_image_tokens
|
||||
target_width, target_height = self._get_image_size_with_most_features()
|
||||
|
||||
def _get_max_video_tokens(self, num_frames: int) -> int:
|
||||
_, max_video_tokens = self._get_vision_info(
|
||||
image_width=9999999,
|
||||
image_height=9999999,
|
||||
num_frames=num_frames,
|
||||
return self._get_num_image_tokens(
|
||||
image_width=target_width,
|
||||
image_height=target_height,
|
||||
)
|
||||
return max_video_tokens
|
||||
|
||||
def _get_max_video_frames(self, max_tokens: int) -> int:
|
||||
target_width, target_height = self._get_image_size_with_most_features()
|
||||
|
||||
num_frames = 0
|
||||
|
||||
while True:
|
||||
next_num_frames = num_frames + 1
|
||||
next_max_tokens = self._get_num_video_tokens(
|
||||
image_width=target_width,
|
||||
image_height=target_height,
|
||||
num_frames=next_num_frames,
|
||||
)
|
||||
|
||||
if self._get_max_video_tokens(next_num_frames) > max_tokens:
|
||||
if next_max_tokens > max_tokens:
|
||||
break
|
||||
|
||||
num_frames = next_num_frames
|
||||
@ -797,56 +869,73 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
|
||||
max_total_frames = self._get_max_video_frames(seq_len -
|
||||
max_image_tokens)
|
||||
|
||||
return max(max_total_frames // max(max_videos, 1), 1)
|
||||
num_frames = max(max_total_frames // max(max_videos, 1), 1)
|
||||
|
||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||
max_image_tokens = self._get_max_image_tokens()
|
||||
# Temporary workaround for https://github.com/huggingface/transformers/issues/35412
|
||||
if num_frames > 1 and num_frames % 2 == 1:
|
||||
num_frames += 1
|
||||
|
||||
num_frames = self._get_dummy_num_frames(seq_len)
|
||||
max_video_tokens = self._get_max_video_tokens(num_frames)
|
||||
return num_frames
|
||||
|
||||
return {
|
||||
"image": max_image_tokens,
|
||||
"video": max_video_tokens,
|
||||
def _get_max_video_tokens(self, seq_len: int) -> int:
|
||||
target_width, target_height = self._get_image_size_with_most_features()
|
||||
|
||||
return self._get_num_video_tokens(
|
||||
image_width=target_width,
|
||||
image_height=target_height,
|
||||
num_frames=self._get_dummy_num_frames(seq_len),
|
||||
)
|
||||
|
||||
def get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
num_videos = mm_counts.get("video", 0)
|
||||
|
||||
hf_processor = self._get_hf_processor()
|
||||
image_token: str = hf_processor.image_token
|
||||
video_token: str = hf_processor.video_token
|
||||
target_width, target_height = self._get_image_size_with_most_features()
|
||||
|
||||
mm_data = {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images),
|
||||
"video":
|
||||
self._get_dummy_videos(
|
||||
width=target_width,
|
||||
height=target_height,
|
||||
num_frames=self._get_dummy_num_frames(seq_len),
|
||||
num_videos=num_videos,
|
||||
)
|
||||
}
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt_text=image_token * num_images + video_token * num_videos,
|
||||
mm_data=mm_data,
|
||||
)
|
||||
|
||||
|
||||
class Qwen2VLMultiModalProcessor(Qwen2VLProcessingMixin,
|
||||
BaseMultiModalProcessor):
|
||||
|
||||
def _get_profiling_info(self) -> BaseProfilingInfo:
|
||||
return Qwen2VLProfilingInfo(self.ctx)
|
||||
|
||||
def _get_data_parser(self) -> MultiModalDataParser:
|
||||
return Qwen2MultiModalDataParser()
|
||||
|
||||
def _get_image_processor(self, hf_processor: Qwen2VLProcessor):
|
||||
image_processor = hf_processor.image_processor # type: ignore
|
||||
assert isinstance(image_processor, Qwen2VLImageProcessor)
|
||||
return image_processor
|
||||
|
||||
def _get_hf_processor(
|
||||
self,
|
||||
*,
|
||||
min_pixels: Optional[int] = None,
|
||||
max_pixels: Optional[int] = None,
|
||||
) -> Qwen2VLProcessor:
|
||||
hf_processor = self.ctx.get_hf_processor(Qwen2VLProcessor)
|
||||
image_processor = self._get_image_processor(hf_processor)
|
||||
|
||||
if min_pixels:
|
||||
image_processor.min_pixels = min_pixels
|
||||
if max_pixels:
|
||||
image_processor.max_pixels = max_pixels
|
||||
if max_pixels or min_pixels:
|
||||
image_processor.size = {
|
||||
"min_pixels": image_processor.min_pixels,
|
||||
"max_pixels": image_processor.max_pixels,
|
||||
}
|
||||
|
||||
return hf_processor
|
||||
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
hf_processor_mm_kwargs: Mapping[str, Any],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
hf_processor = self._get_hf_processor()
|
||||
image_processor = self._get_image_processor(hf_processor)
|
||||
hf_processor = self._get_hf_processor(**hf_processor_mm_kwargs)
|
||||
image_processor = self._get_image_processor(**hf_processor_mm_kwargs)
|
||||
|
||||
# NOTE: Only Qwen2VLProcessor in transformers 4.47.0 has
|
||||
# image_token and video_token registered
|
||||
@ -901,38 +990,6 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
|
||||
video_grid_thw=MultiModalFieldConfig.batched("video"),
|
||||
)
|
||||
|
||||
def _get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
num_videos = mm_counts.get("video", 0)
|
||||
|
||||
hf_processor = self._get_hf_processor()
|
||||
image_token: str = hf_processor.image_token
|
||||
video_token: str = hf_processor.video_token
|
||||
target_width, target_height = self._get_dummy_image_size()
|
||||
|
||||
mm_data = {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images),
|
||||
"video":
|
||||
self._get_dummy_videos(
|
||||
width=target_width,
|
||||
height=target_height,
|
||||
num_frames=self._get_dummy_num_frames(seq_len),
|
||||
num_videos=num_videos,
|
||||
)
|
||||
}
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt_text=image_token * num_images + video_token * num_videos,
|
||||
mm_data=mm_data,
|
||||
)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor)
|
||||
class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
@ -3,8 +3,8 @@
|
||||
|
||||
import math
|
||||
from functools import cached_property
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
|
||||
TypedDict, Union)
|
||||
from typing import (Any, Iterable, List, Literal, Mapping, Optional, Set,
|
||||
Tuple, TypedDict, Union)
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
@ -26,8 +26,9 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
|
||||
NestedTensors)
|
||||
from vllm.multimodal.parse import MultiModalDataParser
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
MultiModalDataItems, ProcessorInputs,
|
||||
MultiModalDataItems, ProcessingMixin,
|
||||
PromptReplacement)
|
||||
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
|
||||
|
||||
@ -55,7 +56,30 @@ UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs,
|
||||
UltravoxAudioEmbeddingInputs]
|
||||
|
||||
|
||||
class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
|
||||
class UltravoxProcessingMixin(ProcessingMixin):
|
||||
|
||||
def _get_hf_processor(
|
||||
self,
|
||||
*,
|
||||
# Ignored in initialization
|
||||
sampling_rate: Optional[int] = None,
|
||||
) -> ProcessorMixin:
|
||||
return self.ctx.get_hf_processor()
|
||||
|
||||
def _get_feature_extractor(
|
||||
self,
|
||||
*,
|
||||
# Ignored in initialization
|
||||
sampling_rate: Optional[int] = None,
|
||||
) -> WhisperFeatureExtractor:
|
||||
hf_processor = self._get_hf_processor(sampling_rate=sampling_rate)
|
||||
audio_processor = hf_processor.audio_processor # type: ignore
|
||||
feature_extractor = audio_processor.feature_extractor # type: ignore
|
||||
assert isinstance(feature_extractor, WhisperFeatureExtractor)
|
||||
return feature_extractor
|
||||
|
||||
|
||||
class UltravoxProfilingInfo(UltravoxProcessingMixin, BaseProfilingInfo):
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"audio": None}
|
||||
@ -67,17 +91,33 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
return {"audio": max_audio_tokens}
|
||||
|
||||
def _get_hf_processor(
|
||||
def get_dummy_processor_inputs(
|
||||
self,
|
||||
*,
|
||||
# Ignored in initialization
|
||||
sampling_rate: Optional[int] = None,
|
||||
) -> ProcessorMixin:
|
||||
return self.ctx.get_hf_processor()
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
feature_extractor = self._get_feature_extractor()
|
||||
|
||||
def _get_feature_extractor(self) -> WhisperFeatureExtractor:
|
||||
hf_processor = self._get_hf_processor()
|
||||
return hf_processor.audio_processor.feature_extractor # type: ignore
|
||||
sampling_rate = feature_extractor.sampling_rate
|
||||
audio_len = feature_extractor.chunk_length * sampling_rate
|
||||
num_audios = mm_counts.get("audio", 0)
|
||||
|
||||
mm_data = {
|
||||
"audio":
|
||||
self._get_dummy_audios(length=audio_len, num_audios=num_audios)
|
||||
}
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt_text="<|audio|>" * num_audios,
|
||||
mm_data=mm_data,
|
||||
)
|
||||
|
||||
|
||||
class UltravoxMultiModalProcessor(UltravoxProcessingMixin,
|
||||
BaseMultiModalProcessor):
|
||||
|
||||
def _get_profiling_info(self) -> BaseProfilingInfo:
|
||||
return UltravoxProfilingInfo(self.ctx)
|
||||
|
||||
def _get_data_parser(self) -> MultiModalDataParser:
|
||||
feature_extractor = self._get_feature_extractor()
|
||||
@ -155,10 +195,10 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
hf_processor_mm_kwargs: Mapping[str, Any],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
hf_processor = self._get_hf_processor()
|
||||
hf_processor = self._get_hf_processor(**hf_processor_mm_kwargs)
|
||||
placeholder = hf_processor.audio_token_replacement # type: ignore
|
||||
|
||||
def get_replacement_ultravox(item_idx: int):
|
||||
@ -173,27 +213,6 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
|
||||
)
|
||||
]
|
||||
|
||||
def _get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
feature_extractor = self._get_feature_extractor()
|
||||
|
||||
sampling_rate = feature_extractor.sampling_rate
|
||||
audio_len = feature_extractor.chunk_length * sampling_rate
|
||||
num_audios = mm_counts.get("audio", 0)
|
||||
|
||||
mm_data = {
|
||||
"audio":
|
||||
self._get_dummy_audios(length=audio_len, num_audios=num_audios)
|
||||
}
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt_text="<|audio|>" * num_audios,
|
||||
mm_data=mm_data,
|
||||
)
|
||||
|
||||
|
||||
class StackAudioFrames(nn.Module):
|
||||
"""
|
||||
|
@ -1,12 +1,8 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Final, Generic, Optional, Protocol, TypeVar
|
||||
from typing import Final, Generic, Protocol, TypeVar
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
InputProcessingContext,
|
||||
ProcessingCache)
|
||||
|
||||
_C = TypeVar("_C", bound=PretrainedConfig)
|
||||
|
||||
|
||||
@ -43,12 +39,18 @@ class VisionEncoderInfo(ABC, Generic[_C]):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def vision_encoder_info(vision_config: PretrainedConfig) -> VisionEncoderInfo:
|
||||
class VisionLanguageConfig(Protocol):
|
||||
vision_config: Final[PretrainedConfig]
|
||||
|
||||
|
||||
def get_vision_encoder_info(
|
||||
hf_config: VisionLanguageConfig) -> VisionEncoderInfo:
|
||||
# Avoid circular imports
|
||||
from .clip import CLIPEncoderInfo, CLIPVisionConfig
|
||||
from .pixtral import PixtralHFEncoderInfo, PixtralVisionConfig
|
||||
from .siglip import SiglipEncoderInfo, SiglipVisionConfig
|
||||
|
||||
vision_config = hf_config.vision_config
|
||||
if isinstance(vision_config, CLIPVisionConfig):
|
||||
return CLIPEncoderInfo(vision_config)
|
||||
if isinstance(vision_config, PixtralVisionConfig):
|
||||
@ -58,26 +60,3 @@ def vision_encoder_info(vision_config: PretrainedConfig) -> VisionEncoderInfo:
|
||||
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
class VisionLanguageConfig(Protocol):
|
||||
vision_config: Final[PretrainedConfig]
|
||||
|
||||
|
||||
class BaseVisionLanguageMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
def __init__(self,
|
||||
ctx: InputProcessingContext,
|
||||
*,
|
||||
cache: Optional[ProcessingCache] = None,
|
||||
enable_sanity_checks: bool = True) -> None:
|
||||
super().__init__(ctx,
|
||||
cache=cache,
|
||||
enable_sanity_checks=enable_sanity_checks)
|
||||
|
||||
vision_config = self._get_hf_config().vision_config
|
||||
self._vision_encoder_info = vision_encoder_info(vision_config)
|
||||
|
||||
@abstractmethod
|
||||
def _get_hf_config(self) -> VisionLanguageConfig:
|
||||
raise NotImplementedError
|
||||
|
@ -8,11 +8,10 @@ from functools import lru_cache
|
||||
from typing import Any, NamedTuple, Optional, Protocol, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import torch
|
||||
from blake3 import blake3
|
||||
from PIL import Image
|
||||
from transformers import BatchFeature, ProcessorMixin
|
||||
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
|
||||
|
||||
from vllm.inputs import DummyData, InputProcessingContext
|
||||
from vllm.logger import init_logger
|
||||
@ -24,6 +23,7 @@ from .inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalInputsV2, MultiModalKwargs,
|
||||
MultiModalKwargsItem, PlaceholderRange)
|
||||
from .parse import MultiModalDataItems, MultiModalDataParser
|
||||
from .profiling import BaseProfilingInfo
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -466,14 +466,6 @@ def find_mm_placeholders(
|
||||
return dict(full_groupby_modality(it))
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcessorInputs:
|
||||
"""Keyword arguments to :meth:`BaseMultiModalProcessor`."""
|
||||
prompt_text: str
|
||||
mm_data: MultiModalDataDict
|
||||
hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict)
|
||||
|
||||
|
||||
class ProcessingCache:
|
||||
|
||||
def __init__(self, capacity: int) -> None:
|
||||
@ -585,9 +577,33 @@ class ProcessingCache:
|
||||
self._cache.put(cache_key, output_kwargs)
|
||||
|
||||
|
||||
class BaseMultiModalProcessor(ABC):
|
||||
class ProcessingMixin:
|
||||
"""
|
||||
Contains helper functions to perform processing.
|
||||
|
||||
Not to be confused with :class:`transformers.ProcessorMixin`.
|
||||
"""
|
||||
ctx: InputProcessingContext
|
||||
|
||||
def _get_tokenizer(self) -> AnyTokenizer:
|
||||
return self.ctx.tokenizer
|
||||
|
||||
def _get_hf_config(self) -> PretrainedConfig:
|
||||
return self.ctx.get_hf_config()
|
||||
|
||||
def _get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
|
||||
"""
|
||||
Subclasses can override this method to handle
|
||||
specific kwargs from model config or user inputs.
|
||||
"""
|
||||
return self.ctx.get_hf_processor(**kwargs)
|
||||
|
||||
|
||||
class BaseMultiModalProcessor(ProcessingMixin, ABC):
|
||||
"""
|
||||
Abstract base class to process multi-modal inputs to be used in vLLM.
|
||||
|
||||
Not to be confused with :class:`transformers.ProcessorMixin`.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@ -601,6 +617,9 @@ class BaseMultiModalProcessor(ABC):
|
||||
self.cache = cache
|
||||
self.enable_sanity_checks = enable_sanity_checks
|
||||
|
||||
self.data_parser = self._get_data_parser()
|
||||
self.profiling_info = self._get_profiling_info()
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
prompt: str,
|
||||
@ -609,32 +628,9 @@ class BaseMultiModalProcessor(ABC):
|
||||
) -> MultiModalInputsV2:
|
||||
return self.apply(prompt, mm_data, hf_processor_mm_kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
"""
|
||||
Return the maximum supported number of items for each modality.
|
||||
|
||||
A value of `None` means unlimited number of items.
|
||||
|
||||
Omitting a modality from the returned dictionary means that
|
||||
it is not supported at all.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||
"""
|
||||
Get the maximum possible number of tokens per data item
|
||||
for each modality.
|
||||
|
||||
The dictionary returned by this method should have the same
|
||||
keys as that returned by :meth:`get_supported_mm_limits`.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_data_parser(self) -> MultiModalDataParser:
|
||||
"""
|
||||
Construct a data parser to preprocess multi-modal data items
|
||||
Construct a parser to preprocess multi-modal data items
|
||||
before passing them to :meth:`_get_hf_mm_data`.
|
||||
|
||||
You can support additional modalities by creating a subclass
|
||||
@ -642,15 +638,12 @@ class BaseMultiModalProcessor(ABC):
|
||||
"""
|
||||
return MultiModalDataParser()
|
||||
|
||||
def _get_hf_processor(self) -> ProcessorMixin:
|
||||
def _get_profiling_info(self) -> BaseProfilingInfo:
|
||||
"""
|
||||
Subclasses can add keyword arguments to this method to accept
|
||||
additional kwargs from model config or user inputs.
|
||||
Get the profiling information to find the worst-case memory usage of
|
||||
the model.
|
||||
"""
|
||||
return self.ctx.get_hf_processor()
|
||||
|
||||
def _get_tokenizer(self) -> AnyTokenizer:
|
||||
return self.ctx.tokenizer
|
||||
raise NotImplementedError
|
||||
|
||||
def _to_mm_items(
|
||||
self,
|
||||
@ -660,8 +653,7 @@ class BaseMultiModalProcessor(ABC):
|
||||
Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems`
|
||||
before passing them to :meth:`_get_hf_mm_data`.
|
||||
"""
|
||||
parser = self._get_data_parser()
|
||||
mm_items = parser.parse_mm_data(mm_data)
|
||||
mm_items = self.data_parser.parse_mm_data(mm_data)
|
||||
|
||||
mm_limits = self.ctx.get_mm_config().limit_per_prompt
|
||||
for modality, items in mm_items.items():
|
||||
@ -799,7 +791,7 @@ class BaseMultiModalProcessor(ABC):
|
||||
|
||||
# Some HF processors (e.g. Qwen2-VL) expect corresponding
|
||||
# multi-modal tokens to be in the prompt text
|
||||
dummy_inputs = self._get_dummy_processor_inputs(
|
||||
dummy_inputs = self.profiling_info.get_dummy_processor_inputs(
|
||||
self.ctx.model_config.max_model_len,
|
||||
mm_missing_counts,
|
||||
)
|
||||
@ -1133,73 +1125,14 @@ class BaseMultiModalProcessor(ABC):
|
||||
mm_placeholders=mm_placeholder_ranges,
|
||||
)
|
||||
|
||||
def _get_dummy_audios(
|
||||
self,
|
||||
*,
|
||||
length: int,
|
||||
num_audios: int,
|
||||
) -> list[npt.NDArray]:
|
||||
audio = np.zeros((length, ))
|
||||
return [audio] * num_audios
|
||||
|
||||
def _get_dummy_images(
|
||||
self,
|
||||
*,
|
||||
width: int,
|
||||
height: int,
|
||||
num_images: int,
|
||||
) -> list[Image.Image]:
|
||||
image = Image.new("RGB", (width, height), color=0)
|
||||
return [image] * num_images
|
||||
|
||||
def _get_dummy_videos(
|
||||
self,
|
||||
*,
|
||||
width: int,
|
||||
height: int,
|
||||
num_frames: int,
|
||||
num_videos: int,
|
||||
) -> list[npt.NDArray]:
|
||||
video = np.zeros((num_frames, width, height, 3))
|
||||
return [video] * num_videos
|
||||
|
||||
@abstractmethod
|
||||
def _get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
"""
|
||||
Build the multi-modal portion of the input which, after processing,
|
||||
results in `mm_max_tokens` in :meth:`get_dummy_data`.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_and_validate_dummy_mm_counts(self) -> Mapping[str, int]:
|
||||
mm_limit_per_prompt = self.ctx.get_mm_config().limit_per_prompt
|
||||
supported_mm_limits = self.get_supported_mm_limits()
|
||||
|
||||
mm_limits = {
|
||||
modality: mm_limit_per_prompt.get(modality, 1)
|
||||
for modality in supported_mm_limits
|
||||
}
|
||||
|
||||
for modality, supported_limit in supported_mm_limits.items():
|
||||
limit = mm_limits[modality]
|
||||
if supported_limit is not None and supported_limit < limit:
|
||||
raise ValueError(
|
||||
f"You set {modality}={limit} (or defaulted to 1) in "
|
||||
f"`--limit-mm-per-prompt`, but this model only supports "
|
||||
f"at most {supported_limit} {modality} items.")
|
||||
|
||||
return mm_limits
|
||||
|
||||
def _get_dummy_mm_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> MultiModalInputsV2:
|
||||
processor_inputs = self._get_dummy_processor_inputs(seq_len, mm_counts)
|
||||
profiling = self.profiling_info
|
||||
processor_inputs = profiling.get_dummy_processor_inputs(
|
||||
seq_len, mm_counts)
|
||||
|
||||
return self.apply(
|
||||
prompt_text=processor_inputs.prompt_text,
|
||||
@ -1211,8 +1144,9 @@ class BaseMultiModalProcessor(ABC):
|
||||
# Avoid circular import
|
||||
from vllm.sequence import SequenceData
|
||||
|
||||
mm_counts = self._get_and_validate_dummy_mm_counts()
|
||||
mm_max_tokens_per_item = self.get_mm_max_tokens_per_item(seq_len)
|
||||
profiling = self.profiling_info
|
||||
mm_counts = profiling.get_mm_limits()
|
||||
mm_max_tokens_per_item = profiling.get_mm_max_tokens_per_item(seq_len)
|
||||
if mm_counts.keys() != mm_max_tokens_per_item.keys():
|
||||
raise AssertionError(
|
||||
"The keys returned by `get_supported_mm_limits`"
|
||||
|
121
vllm/multimodal/profiling.py
Normal file
121
vllm/multimodal/profiling.py
Normal file
@ -0,0 +1,121 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
from PIL import Image
|
||||
|
||||
from vllm.inputs import InputProcessingContext
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .inputs import MultiModalDataDict
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcessorInputs:
|
||||
"""Keyword arguments to :meth:`BaseMultiModalProcessor`."""
|
||||
prompt_text: str
|
||||
mm_data: MultiModalDataDict
|
||||
hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict)
|
||||
|
||||
|
||||
class BaseProfilingInfo(ABC):
|
||||
"""
|
||||
Abstract base class that provides the information necessary to profile
|
||||
multi-modal models.
|
||||
"""
|
||||
|
||||
def __init__(self, ctx: InputProcessingContext) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.ctx = ctx
|
||||
|
||||
@abstractmethod
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
"""
|
||||
Return the maximum supported number of items for each modality.
|
||||
|
||||
A value of `None` means unlimited number of items.
|
||||
|
||||
Omitting a modality from the returned dictionary means that
|
||||
it is not supported at all.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||
"""
|
||||
Get the maximum possible number of tokens per data item
|
||||
for each modality.
|
||||
|
||||
The dictionary returned by this method should have the same
|
||||
keys as that returned by :meth:`get_supported_mm_limits`.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
"""
|
||||
Build the multi-modal portion of the input which, after processing,
|
||||
results in `mm_max_tokens` in :meth:`get_mm_max_tokens_per_item`.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_dummy_audios(
|
||||
self,
|
||||
*,
|
||||
length: int,
|
||||
num_audios: int,
|
||||
) -> list[npt.NDArray]:
|
||||
audio = np.zeros((length, ))
|
||||
return [audio] * num_audios
|
||||
|
||||
def _get_dummy_images(
|
||||
self,
|
||||
*,
|
||||
width: int,
|
||||
height: int,
|
||||
num_images: int,
|
||||
) -> list[Image.Image]:
|
||||
image = Image.new("RGB", (width, height), color=0)
|
||||
return [image] * num_images
|
||||
|
||||
def _get_dummy_videos(
|
||||
self,
|
||||
*,
|
||||
width: int,
|
||||
height: int,
|
||||
num_frames: int,
|
||||
num_videos: int,
|
||||
) -> list[npt.NDArray]:
|
||||
video = np.zeros((num_frames, width, height, 3))
|
||||
return [video] * num_videos
|
||||
|
||||
def get_mm_limits(self) -> Mapping[str, int]:
|
||||
mm_config = self.ctx.get_mm_config()
|
||||
mm_limit_per_prompt = mm_config.limit_per_prompt
|
||||
|
||||
supported_mm_limits = self.get_supported_mm_limits()
|
||||
|
||||
mm_limits = {
|
||||
modality: mm_limit_per_prompt.get(modality, 1)
|
||||
for modality in supported_mm_limits
|
||||
}
|
||||
|
||||
for modality, supported_limit in supported_mm_limits.items():
|
||||
limit = mm_limits[modality]
|
||||
if supported_limit is not None and supported_limit < limit:
|
||||
raise ValueError(
|
||||
f"You set {modality}={limit} (or defaulted to 1) in "
|
||||
f"`--limit-mm-per-prompt`, but this model only supports "
|
||||
f"at most {supported_limit} {modality} items.")
|
||||
|
||||
return mm_limits
|
@ -224,7 +224,7 @@ class MultiModalRegistry:
|
||||
tokenizer = cached_get_tokenizer(model_config.tokenizer)
|
||||
processor = self.create_processor(model_config, tokenizer)
|
||||
seq_len = model_config.max_model_len
|
||||
return processor.get_mm_max_tokens_per_item(seq_len)
|
||||
return processor.profiling_info.get_mm_max_tokens_per_item(seq_len)
|
||||
|
||||
return {
|
||||
key: plugin.get_max_multimodal_tokens(model_config)
|
||||
|
Loading…
x
Reference in New Issue
Block a user