[V1][VLM] V1 support for selected single-image models. (#11632)

Signed-off-by: Roger Wang <ywang@roblox.com>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: Isotr0py <2037008807@qq.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Roger Wang 2024-12-31 13:17:22 -08:00 committed by GitHub
parent 8c3230d8c1
commit e7c7c5e822
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 575 additions and 621 deletions

View File

@ -570,28 +570,28 @@ See [this page](#generative-models) for more information on how to use generativ
- `rhymes-ai/Aria` - `rhymes-ai/Aria`
- -
- ✅︎ - ✅︎
- - ✅︎
* - `Blip2ForConditionalGeneration` * - `Blip2ForConditionalGeneration`
- BLIP-2 - BLIP-2
- T + I<sup>E</sup> - T + I<sup>E</sup>
- `Salesforce/blip2-opt-2.7b`, `Salesforce/blip2-opt-6.7b`, etc. - `Salesforce/blip2-opt-2.7b`, `Salesforce/blip2-opt-6.7b`, etc.
- -
- ✅︎ - ✅︎
- - ✅︎
* - `ChameleonForConditionalGeneration` * - `ChameleonForConditionalGeneration`
- Chameleon - Chameleon
- T + I - T + I
- `facebook/chameleon-7b` etc. - `facebook/chameleon-7b` etc.
- -
- ✅︎ - ✅︎
- - ✅︎
* - `FuyuForCausalLM` * - `FuyuForCausalLM`
- Fuyu - Fuyu
- T + I - T + I
- `adept/fuyu-8b` etc. - `adept/fuyu-8b` etc.
- -
- ✅︎ - ✅︎
- - ✅︎
* - `ChatGLMModel` * - `ChatGLMModel`
- GLM-4V - GLM-4V
- T + I - T + I
@ -633,7 +633,7 @@ See [this page](#generative-models) for more information on how to use generativ
- `llava-hf/llava-v1.6-mistral-7b-hf`, `llava-hf/llava-v1.6-vicuna-7b-hf`, etc. - `llava-hf/llava-v1.6-mistral-7b-hf`, `llava-hf/llava-v1.6-vicuna-7b-hf`, etc.
- -
- ✅︎ - ✅︎
- - ✅︎
* - `LlavaNextVideoForConditionalGeneration` * - `LlavaNextVideoForConditionalGeneration`
- LLaVA-NeXT-Video - LLaVA-NeXT-Video
- T + V - T + V

View File

@ -24,10 +24,13 @@ def run_aria(question: str, modality: str):
assert modality == "image" assert modality == "image"
model_name = "rhymes-ai/Aria" model_name = "rhymes-ai/Aria"
# NOTE: Need L40 (or equivalent) to avoid OOM
llm = LLM(model=model_name, llm = LLM(model=model_name,
tokenizer_mode="slow", tokenizer_mode="slow",
trust_remote_code=True,
dtype="bfloat16", dtype="bfloat16",
max_model_len=4096,
max_num_seqs=2,
trust_remote_code=True,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache) disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
prompt = (f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>\n{question}" prompt = (f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>\n{question}"
@ -57,6 +60,7 @@ def run_chameleon(question: str, modality: str):
prompt = f"{question}<image>" prompt = f"{question}<image>"
llm = LLM(model="facebook/chameleon-7b", llm = LLM(model="facebook/chameleon-7b",
max_model_len=4096, max_model_len=4096,
max_num_seqs=2,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache) disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
stop_token_ids = None stop_token_ids = None
return llm, prompt, stop_token_ids return llm, prompt, stop_token_ids
@ -257,7 +261,7 @@ def run_minicpmv(question: str, modality: str):
# 2.5 # 2.5
# model_name = "openbmb/MiniCPM-Llama3-V-2_5" # model_name = "openbmb/MiniCPM-Llama3-V-2_5"
#2.6 # 2.6
model_name = "openbmb/MiniCPM-V-2_6" model_name = "openbmb/MiniCPM-V-2_6"
tokenizer = AutoTokenizer.from_pretrained(model_name, tokenizer = AutoTokenizer.from_pretrained(model_name,
trust_remote_code=True) trust_remote_code=True)
@ -430,9 +434,11 @@ def run_pixtral_hf(question: str, modality: str):
model_name = "mistral-community/pixtral-12b" model_name = "mistral-community/pixtral-12b"
# NOTE: Need L40 (or equivalent) to avoid OOM
llm = LLM( llm = LLM(
model=model_name, model=model_name,
max_model_len=8192, max_model_len=8192,
max_num_seqs=2,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
) )

View File

@ -140,10 +140,7 @@ VLM_TEST_SETTINGS = {
"aria": VLMTestInfo( "aria": VLMTestInfo(
models=["rhymes-ai/Aria"], models=["rhymes-ai/Aria"],
tokenizer_mode="slow", tokenizer_mode="slow",
test_type=( test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
VLMTestType.IMAGE,
VLMTestType.MULTI_IMAGE,
),
dtype="bfloat16", dtype="bfloat16",
prompt_formatter=lambda img_prompt: f"<|im_start|>user\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n ", # noqa: E501 prompt_formatter=lambda img_prompt: f"<|im_start|>user\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n ", # noqa: E501
img_idx_to_prompt=lambda idx: "<fim_prefix><|img|><fim_suffix>\n", img_idx_to_prompt=lambda idx: "<fim_prefix><|img|><fim_suffix>\n",
@ -179,6 +176,7 @@ VLM_TEST_SETTINGS = {
test_type=VLMTestType.IMAGE, test_type=VLMTestType.IMAGE,
prompt_formatter=lambda img_prompt: f"USER: {img_prompt}\nASSISTANT:", prompt_formatter=lambda img_prompt: f"USER: {img_prompt}\nASSISTANT:",
max_model_len=4096, max_model_len=4096,
max_num_seqs=2,
auto_cls=AutoModelForVision2Seq, auto_cls=AutoModelForVision2Seq,
postprocess_inputs=model_utils.cast_dtype_post_processor( postprocess_inputs=model_utils.cast_dtype_post_processor(
"pixel_values" "pixel_values"
@ -201,7 +199,6 @@ VLM_TEST_SETTINGS = {
vllm_output_post_proc=model_utils.fuyu_vllm_to_hf_output, vllm_output_post_proc=model_utils.fuyu_vllm_to_hf_output,
num_logprobs=10, num_logprobs=10,
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)], image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
marks=[large_gpu_mark(min_gb=48)],
), ),
"glm4": VLMTestInfo( "glm4": VLMTestInfo(
models=["THUDM/glm-4v-9b"], models=["THUDM/glm-4v-9b"],

View File

@ -528,7 +528,7 @@ def _rand_audio(
def _test_processing_cache_correctness( def _test_processing_cache_correctness(
model_id: str, model_id: str,
modalities: set[str], modalities: dict[str, bool],
hit_rate: float, hit_rate: float,
num_batches: int, num_batches: int,
simplify_rate: float, simplify_rate: float,
@ -583,9 +583,8 @@ def _test_processing_cache_correctness(
partial(_rand_audio, rng, min_len=256, max_len=512, sr=16000), partial(_rand_audio, rng, min_len=256, max_len=512, sr=16000),
} }
input_max_count = { input_max_count = {
"image": 3, modality: 3 if supports_multi else 1
"video": 3, for modality, supports_multi in modalities.items()
"audio": 3,
} }
for batch_idx in range(num_batches): for batch_idx in range(num_batches):
@ -624,12 +623,16 @@ def _test_processing_cache_correctness(
# yapf: disable # yapf: disable
@pytest.mark.parametrize(("model_id", "modalities"), [ @pytest.mark.parametrize(("model_id", "modalities"), [
("llava-hf/llava-1.5-7b-hf", {"image"}), ("rhymes-ai/Aria", {"image": True}),
("TIGER-Lab/Mantis-8B-siglip-llama3", {"image"}), ("Salesforce/blip2-opt-2.7b", {"image": False}),
("mistral-community/pixtral-12b", {"image"}), ("facebook/chameleon-7b", {"image": True}),
("Qwen/Qwen2-VL-2B-Instruct", {"image", "video"}), ("adept/fuyu-8b", {"image": False}),
("Qwen/Qwen2-Audio-7B-Instruct", {"audio"}), ("llava-hf/llava-1.5-7b-hf", {"image": True}),
("fixie-ai/ultravox-v0_3", {"audio"}), ("TIGER-Lab/Mantis-8B-siglip-llama3", {"image": True}),
("mistral-community/pixtral-12b", {"image": True}),
("Qwen/Qwen2-VL-2B-Instruct", {"image": True, "video": True}),
("Qwen/Qwen2-Audio-7B-Instruct", {"audio": True}),
("fixie-ai/ultravox-v0_3", {"audio": True}),
]) ])
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0]) @pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
@pytest.mark.parametrize("num_batches", [32]) @pytest.mark.parametrize("num_batches", [32])
@ -637,7 +640,7 @@ def _test_processing_cache_correctness(
# yapf: enable # yapf: enable
def test_processing_cache_correctness( def test_processing_cache_correctness(
model_id: str, model_id: str,
modalities: set[str], modalities: dict[str, bool],
hit_rate: float, hit_rate: float,
num_batches: int, num_batches: int,
simplify_rate: float, simplify_rate: float,
@ -653,7 +656,7 @@ def test_processing_cache_correctness(
# yapf: disable # yapf: disable
@pytest.mark.parametrize(("model_id", "modalities"), [ @pytest.mark.parametrize(("model_id", "modalities"), [
("microsoft/Phi-3-vision-128k-instruct", {"image"}), ("microsoft/Phi-3-vision-128k-instruct", {"image": True}),
]) ])
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0]) @pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
@pytest.mark.parametrize("num_batches", [32]) @pytest.mark.parametrize("num_batches", [32])
@ -661,7 +664,7 @@ def test_processing_cache_correctness(
# yapf: enable # yapf: enable
def test_processing_cache_correctness_phi3v( def test_processing_cache_correctness_phi3v(
model_id: str, model_id: str,
modalities: set[str], modalities: dict[str, bool],
hit_rate: float, hit_rate: float,
num_batches: int, num_batches: int,
simplify_rate: float, simplify_rate: float,

View File

@ -1,15 +1,15 @@
import math from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
from typing import Iterable, List, Optional, Set, Tuple, TypedDict, Union Union)
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn.init import trunc_normal_ from torch.nn.init import trunc_normal_
from transformers import LlamaConfig from transformers import BatchFeature, PretrainedConfig
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, QuantizationConfig, VllmConfig from vllm.config import CacheConfig, QuantizationConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_rank from vllm.distributed import get_tensor_model_parallel_rank
from vllm.inputs import INPUT_REGISTRY, token_inputs from vllm.inputs import InputContext
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@ -17,30 +17,27 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
get_compressed_tensors_cache_scale) get_compressed_tensors_cache_scale)
from vllm.model_executor.layers.sampler import (Sampler, SamplerOutput, from vllm.model_executor.layers.sampler import (SamplerOutput,
SamplingMetadata) SamplingMetadata, get_sampler)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name) default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.idefics2_vision_model import (
Idefics2VisionTransformer)
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.models.llama import (LlamaDecoderLayer, LlamaMLP,
LlamaModel)
from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper,
is_pp_missing_parameter,
maybe_prefix,
merge_multimodal_embeddings)
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import cached_get_image_processor from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors NestedTensors)
from vllm.multimodal.utils import (cached_get_tokenizer, from vllm.multimodal.processing import (BaseMultiModalProcessor,
repeat_and_pad_placeholder_tokens) MultiModalDataItems, ProcessorInputs,
PromptReplacement)
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.aria import (AriaMoELMConfig, from vllm.transformers_utils.configs.aria import (AriaMoELMConfig,
AriaVisionConfig) AriaVisionConfig)
from .utils import flatten_bn from .idefics2_vision_model import Idefics2VisionTransformer
from .interfaces import SupportsMultiModal
from .llama import LlamaDecoderLayer, LlamaMLP, LlamaModel
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
is_pp_missing_parameter, maybe_prefix,
merge_multimodal_embeddings)
class AriaImagePixelInputs(TypedDict): class AriaImagePixelInputs(TypedDict):
@ -251,7 +248,7 @@ class AriaProjector(nn.Module):
class AriaFusedMoE(FusedMoE): class AriaFusedMoE(FusedMoE):
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
shard_id: str) -> Set[str]: shard_id: str) -> None:
# Override the weight_loader to handle the expert weights in the Aria # Override the weight_loader to handle the expert weights in the Aria
# model, which are already packed with experts, and merge the gate and # model, which are already packed with experts, and merge the gate and
# up weights for each expert. # up weights for each expert.
@ -346,7 +343,7 @@ class MoEDecoderLayer(LlamaDecoderLayer):
def __init__( def __init__(
self, self,
config: LlamaConfig, config: AriaMoELMConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
@ -434,7 +431,7 @@ class AriaMoELMModel(LlamaModel):
return loaded_params return loaded_params
def build_mm_projector(config): def build_mm_projector(config: PretrainedConfig):
return AriaProjector( return AriaProjector(
patch_to_query_dict=config.projector_patch_to_query_dict, patch_to_query_dict=config.projector_patch_to_query_dict,
embed_dim=config.vision_config.hidden_size, embed_dim=config.vision_config.hidden_size,
@ -445,75 +442,70 @@ def build_mm_projector(config):
) )
def get_max_multimodal_tokens(ctx): def get_max_aria_image_tokens(ctx: InputContext):
return max(ctx.model_config.hf_config.image_size2tokens.values()) hf_config = ctx.get_hf_config()
return max(hf_config.projector_patch_to_query_dict.values())
def input_mapper_for_aria(ctx, data): class AriaMultiModalProcessor(BaseMultiModalProcessor):
return MultiModalKwargs(data)
def _get_mm_fields_config(
def input_processor(ctx, llm_inputs): self,
multi_modal_data = llm_inputs.get("multi_modal_data") hf_inputs: BatchFeature,
# if it is pure text input, use it as is hf_processor_mm_kwargs: Mapping[str, object],
if multi_modal_data is None or "image" not in multi_modal_data: ) -> Mapping[str, MultiModalFieldConfig]:
return llm_inputs return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
model_config = ctx.model_config pixel_mask=MultiModalFieldConfig.batched("image"),
tokenizer = cached_get_tokenizer(model_config.tokenizer)
image_processor = cached_get_image_processor(
model_config.model, trust_remote_code=model_config.trust_remote_code)
hf_config = model_config.hf_config
# prepare image tokens, the max_image_size is used to determine the number
# of patch_size for every image
max_image_size = multi_modal_data.pop("max_image_size", 980)
_split_image = multi_modal_data.pop("split_image", False)
assert isinstance(max_image_size,
(int, float)), "max_image_size should be float or int"
images = (multi_modal_data["image"] if isinstance(
multi_modal_data["image"], list) else [multi_modal_data["image"]])
image_inputs = image_processor.preprocess(images,
max_image_size=max_image_size,
split_image=_split_image,
return_tensors="pt").data
image_inputs['pixel_values'] = image_inputs['pixel_values'].to(
ctx.model_config.dtype)
num_crops = image_inputs.pop("num_crops")
prompt_token_ids = llm_inputs["prompt_token_ids"]
if num_crops.sum().item() > 0:
_, prompt_token_ids, _ = repeat_and_pad_placeholder_tokens(
tokenizer,
None,
prompt_token_ids,
placeholder_token_id=hf_config.image_token_index,
repeat_count=num_crops,
) )
repeat_count = [hf_config.image_size2tokens[max_image_size] def _get_prompt_replacements(
] * sum(num_crops).item() self,
new_prompt, new_token_ids, _ = repeat_and_pad_placeholder_tokens( mm_items: MultiModalDataItems,
tokenizer, hf_processor_mm_kwargs: Mapping[str, object],
None, out_mm_kwargs: MultiModalKwargs,
prompt_token_ids, ) -> list[PromptReplacement]:
placeholder_token_id=hf_config.image_token_index, hf_config = self.ctx.get_hf_config()
repeat_count=repeat_count, image_token_id = hf_config.image_token_index
)
return token_inputs( max_image_tokens = get_max_aria_image_tokens(self.ctx)
prompt_token_ids=new_token_ids,
prompt=new_prompt, return [
multi_modal_data={"image": image_inputs}, PromptReplacement(
) modality="image",
target=[image_token_id],
replacement=[image_token_id] * max_image_tokens,
)
]
def _get_dummy_mm_inputs(
self,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
hf_config = self.ctx.get_hf_config()
vision_config: AriaVisionConfig = hf_config.vision_config
max_image_size = vision_config.image_size
num_images = mm_counts.get("image", 0)
mm_data = {
"image":
self._get_dummy_images(width=max_image_size,
height=max_image_size,
num_images=num_images)
}
hf_processor = self._get_hf_processor()
image_token: str = hf_processor.image_token # type: ignore
return ProcessorInputs(
prompt_text=image_token * num_images,
mm_data=mm_data,
)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_multimodal_tokens) @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_aria_image_tokens)
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_aria) @MULTIMODAL_REGISTRY.register_processor(AriaMultiModalProcessor)
@INPUT_REGISTRY.register_input_processor(input_processor)
class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
""" """
Aria model for conditional generation tasks. Aria model for conditional generation tasks.
@ -540,12 +532,6 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
# prepare the image_size to tokens mapping for the image preprocess, see
# input_processor
config.image_size2tokens = {
int(math.sqrt(k) * config.vision_config.patch_size): v
for k, v in config.projector_patch_to_query_dict.items()
}
self.config = config self.config = config
self.vision_tower = AriaVisionModel(config.vision_config) self.vision_tower = AriaVisionModel(config.vision_config)
self.multi_modal_projector = build_mm_projector(config) self.multi_modal_projector = build_mm_projector(config)
@ -566,7 +552,7 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
logit_scale = getattr(config, "logit_scale", 1.0) logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
self.vocab_size, logit_scale) self.vocab_size, logit_scale)
self.sampler = Sampler() self.sampler = get_sampler()
def _validate_image_sizes( def _validate_image_sizes(
self, images: List[torch.Tensor]) -> List[torch.Tensor]: self, images: List[torch.Tensor]) -> List[torch.Tensor]:
@ -588,7 +574,12 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
pixel_values = self._validate_image_sizes(pixel_values) pixel_values = self._validate_image_sizes(pixel_values)
pixel_values = flatten_bn(pixel_values, concat=True) pixel_values = flatten_bn(pixel_values, concat=True)
if pixel_mask is not None: if pixel_mask is not None:
if not isinstance(pixel_mask, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel mask. "
f"Got type: {type(pixel_mask)}")
pixel_mask = flatten_bn(pixel_mask, concat=True) pixel_mask = flatten_bn(pixel_mask, concat=True)
return AriaImagePixelInputs( return AriaImagePixelInputs(

View File

@ -4,22 +4,16 @@ from typing import Iterable, Optional, Set, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from PIL import Image
from transformers import Blip2VisionConfig, BlipVisionConfig from transformers import Blip2VisionConfig, BlipVisionConfig
from vllm.attention.layer import MultiHeadAttention from vllm.attention.layer import MultiHeadAttention
from vllm.config import ModelConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.inputs import DecoderOnlyInputs, token_inputs
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens)
from vllm.sequence import SequenceData
def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int: def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
@ -33,92 +27,6 @@ def get_blip_num_patches(*, image_size: int, patch_size: int) -> int:
return grid_length * grid_length return grid_length * grid_length
def get_blip_image_feature_size(
hf_config: Union[BlipVisionConfig, Blip2VisionConfig]) -> int:
return get_blip_num_patches(image_size=hf_config.image_size,
patch_size=hf_config.patch_size)
def get_max_blip_image_tokens(
hf_config: Union[BlipVisionConfig, Blip2VisionConfig]) -> int:
return get_blip_image_feature_size(hf_config)
def dummy_seq_data_for_blip(
hf_config: Union[BlipVisionConfig, Blip2VisionConfig],
seq_len: int,
num_images: int,
*,
image_token_id: int,
image_feature_size_override: Optional[int] = None,
):
if image_feature_size_override is None:
image_feature_size = get_blip_image_feature_size(hf_config)
else:
image_feature_size = image_feature_size_override
return SequenceData.from_prompt_token_counts(
(image_token_id, image_feature_size * num_images),
(0, seq_len - image_feature_size * num_images),
)
def dummy_image_for_blip(
hf_config: Union[BlipVisionConfig, Blip2VisionConfig],
num_images: int,
*,
image_width_override: Optional[int] = None,
image_height_override: Optional[int] = None,
):
width = height = hf_config.image_size
if image_width_override is not None:
width = image_width_override
if image_height_override is not None:
height = image_height_override
image = Image.new("RGB", (width, height), color=0)
return {"image": image if num_images == 1 else [image] * num_images}
def input_processor_for_blip(
model_config: ModelConfig,
hf_config: Union[BlipVisionConfig, Blip2VisionConfig],
inputs: DecoderOnlyInputs,
*,
image_token_id: int,
image_feature_size_override: Optional[int] = None,
):
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return inputs
if "multi_modal_placeholders" in inputs and "image" in inputs[
"multi_modal_placeholders"]:
# The inputs already have placeholders.
return inputs
tokenizer = cached_get_tokenizer(model_config.tokenizer)
if image_feature_size_override is None:
image_feature_size = get_blip_image_feature_size(hf_config)
else:
image_feature_size = image_feature_size_override
new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
tokenizer,
inputs.get("prompt"),
inputs["prompt_token_ids"],
placeholder_token_id=image_token_id,
repeat_count=image_feature_size,
)
# NOTE: Create a defensive copy of the original inputs
return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data,
multi_modal_placeholders={"image": ranges})
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/blip/modeling_blip.py#L164 # noqa # Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/blip/modeling_blip.py#L164 # noqa
class BlipVisionEmbeddings(nn.Module): class BlipVisionEmbeddings(nn.Module):

View File

@ -4,32 +4,33 @@ from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
import torch import torch
import torch.nn as nn import torch.nn as nn
from transformers import (Blip2Config, Blip2QFormerConfig, Blip2VisionConfig, from transformers import (BatchFeature, Blip2Config, Blip2Processor,
apply_chunking_to_forward) Blip2QFormerConfig, apply_chunking_to_forward)
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, from vllm.inputs import InputContext
InputContext, token_inputs)
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
from vllm.multimodal.utils import consecutive_placeholder_ranges MultiModalInputsV2, MultiModalKwargs,
from vllm.sequence import IntermediateTensors, SequenceData NestedTensors, PlaceholderRange)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
from vllm.sequence import IntermediateTensors
from .blip import (BlipVisionModel, dummy_image_for_blip, from .blip import BlipVisionModel
get_max_blip_image_tokens)
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, init_vllm_registered_model, from .utils import (AutoWeightsLoader, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings) maybe_prefix, merge_multimodal_embeddings)
# We use this internally as placeholders since there is no image token # We use this internally as placeholders since there is no image token
# defined on the HuggingFace repo # defined on the HuggingFace repo
BLIP2_IMAGE_TOKEN = "<image>" _IMAGE_TOKEN_ID = 50265
BLIP2_IMAGE_TOKEN_ID = 50265
class Blip2ImagePixelInputs(TypedDict): class Blip2ImagePixelInputs(TypedDict):
@ -396,92 +397,87 @@ class Blip2QFormerModel(nn.Module):
return sequence_output return sequence_output
def get_blip2_image_feature_size(hf_config: Blip2Config) -> int: def get_max_blip2_image_tokens(ctx: InputContext):
hf_config = ctx.get_hf_config(Blip2Config)
return hf_config.num_query_tokens return hf_config.num_query_tokens
def get_max_blip2_image_tokens(ctx: InputContext): class Blip2MultiModalProcessor(BaseMultiModalProcessor):
hf_config = ctx.get_hf_config(Blip2Config)
vision_config = hf_config.vision_config
if isinstance(vision_config, Blip2VisionConfig): def _get_hf_processor(self) -> Blip2Processor:
return get_max_blip_image_tokens(vision_config) return self.ctx.get_hf_processor(Blip2Processor)
msg = f"Unsupported vision config: {type(vision_config)}" def _get_mm_fields_config(
raise NotImplementedError(msg) self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
max_image_tokens = get_max_blip2_image_tokens(self.ctx)
return [
PromptReplacement(
modality="image",
target="</s>",
replacement="<image>" * max_image_tokens + "</s>",
)
]
def apply(
self,
prompt_text: str,
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputsV2:
result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs)
# Only <image> tokens should be considered as placeholders,
# so we ignore the trailing bos_token
result["mm_placeholders"] = {
modality: [
PlaceholderRange(offset=p["offset"], length=p["length"] - 1)
for p in ps
]
for modality, ps in result["mm_placeholders"].items()
}
return result
def _get_dummy_mm_inputs(
self,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
hf_config = self.ctx.get_hf_config(Blip2Config)
vision_config = hf_config.vision_config
max_image_size = vision_config.image_size
num_images = mm_counts.get("image", 0)
mm_data = {
"image":
self._get_dummy_images(width=max_image_size,
height=max_image_size,
num_images=num_images)
}
return ProcessorInputs(
prompt_text="",
mm_data=mm_data,
)
def dummy_seq_data_for_blip2(
hf_config: Blip2Config,
seq_len: int,
num_images: int,
*,
image_token_id: int,
image_feature_size_override: Optional[int] = None,
):
if image_feature_size_override is None:
image_feature_size = get_blip2_image_feature_size(hf_config)
else:
image_feature_size = image_feature_size_override
return SequenceData.from_prompt_token_counts(
(image_token_id, image_feature_size * num_images),
(0, seq_len - image_feature_size * num_images),
), {
"image":
consecutive_placeholder_ranges(num_items=num_images,
item_size=image_feature_size)
}
def dummy_data_for_blip2(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
hf_config = ctx.get_hf_config(Blip2Config)
vision_config = hf_config.vision_config
num_images = mm_counts["image"]
seq_data, ranges = dummy_seq_data_for_blip2(
hf_config,
seq_len,
num_images,
image_token_id=BLIP2_IMAGE_TOKEN_ID,
)
if isinstance(vision_config, Blip2VisionConfig):
mm_data = dummy_image_for_blip(vision_config, num_images)
return DummyData(seq_data, mm_data, ranges)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
def input_processor_for_blip2(ctx: InputContext, inputs: DecoderOnlyInputs):
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return inputs
hf_config = ctx.get_hf_config(Blip2Config)
image_feature_size = get_blip2_image_feature_size(hf_config)
# The original model places image tokens at the front
# https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/blip_2/modeling_blip_2.py#L1514
new_token_ids = [BLIP2_IMAGE_TOKEN_ID] * image_feature_size
new_token_ids += inputs["prompt_token_ids"]
new_prompt = inputs.get("prompt")
if new_prompt is not None:
new_prompt = BLIP2_IMAGE_TOKEN * image_feature_size + new_prompt
return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data)
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_blip2_image_tokens) @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_blip2_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_blip2) @MULTIMODAL_REGISTRY.register_processor(Blip2MultiModalProcessor)
@INPUT_REGISTRY.register_input_processor(input_processor_for_blip2)
class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
@ -627,7 +623,7 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
if multimodal_embeddings is not None: if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings, input_ids, inputs_embeds, multimodal_embeddings,
BLIP2_IMAGE_TOKEN_ID) _IMAGE_TOKEN_ID)
return inputs_embeds return inputs_embeds
def forward( def forward(

View File

@ -3,16 +3,15 @@ from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set,
Tuple, TypedDict, Union) Tuple, TypedDict, Union)
import torch import torch
import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from PIL import Image from transformers import (BatchFeature, ChameleonConfig, ChameleonProcessor,
from torch import nn ChameleonVQVAEConfig)
from transformers import ChameleonConfig, ChameleonVQVAEConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, from vllm.inputs import InputContext
InputContext, token_inputs)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
@ -29,11 +28,13 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
from vllm.multimodal.utils import (cached_get_tokenizer, MultiModalInputsV2, MultiModalKwargs,
consecutive_placeholder_ranges, NestedTensors, PlaceholderRange)
repeat_and_pad_placeholder_tokens) from vllm.multimodal.processing import (BaseMultiModalProcessor,
from vllm.sequence import IntermediateTensors, SequenceData MultiModalDataItems, ProcessorInputs,
PromptReplacement)
from vllm.sequence import IntermediateTensors
from vllm.utils import print_warning_once from vllm.utils import print_warning_once
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
@ -45,10 +46,6 @@ from .utils import (is_pp_missing_parameter,
# and processor files, so we hardcode them in the model file for now. # and processor files, so we hardcode them in the model file for now.
CHAMELEON_CROP_SIZE_HEIGHT = CHAMELEON_CROP_SIZE_WIDTH = 512 CHAMELEON_CROP_SIZE_HEIGHT = CHAMELEON_CROP_SIZE_WIDTH = 512
CHAMELEON_IMAGE_SEQ_LENGTH = 1024 CHAMELEON_IMAGE_SEQ_LENGTH = 1024
CHAMELEON_IMAGE_TOKEN_ID = 8711
CHAMELEON_IMAGE_START_TOKEN_ID = 8197
CHAMELEON_IMAGE_END_TOKEN_ID = 8196
CHAMELEON_SEP_TOKEN_ID = 8710
class ChameleonImagePixelInputs(TypedDict): class ChameleonImagePixelInputs(TypedDict):
@ -61,99 +58,75 @@ def get_max_chameleon_image_tokens(ctx: InputContext):
return CHAMELEON_IMAGE_SEQ_LENGTH return CHAMELEON_IMAGE_SEQ_LENGTH
def dummy_seq_data_for_chameleon( class ChameleonMultiModalProcessor(BaseMultiModalProcessor):
seq_len: int,
num_images: int,
*,
image_token_id: int,
image_feature_size_override: Optional[int] = None,
):
if image_feature_size_override is None:
image_feature_size = CHAMELEON_IMAGE_SEQ_LENGTH
else:
image_feature_size = image_feature_size_override
return SequenceData.from_prompt_token_counts( def _get_hf_processor(self) -> ChameleonProcessor:
(image_token_id, image_feature_size * num_images), return self.ctx.get_hf_processor(ChameleonProcessor)
(0, seq_len - image_feature_size * num_images),
), {
"image":
consecutive_placeholder_ranges(num_items=num_images,
item_size=image_feature_size)
}
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(pixel_values=MultiModalFieldConfig.batched("image"))
def dummy_image_for_chameleon( def _get_prompt_replacements(
num_images: int, self,
*, mm_items: MultiModalDataItems,
image_width_override: Optional[int] = None, hf_processor_mm_kwargs: Mapping[str, object],
image_height_override: Optional[int] = None, out_mm_kwargs: MultiModalKwargs,
): ) -> list[PromptReplacement]:
width = CHAMELEON_CROP_SIZE_WIDTH processor = self._get_hf_processor()
height = CHAMELEON_CROP_SIZE_HEIGHT
if image_width_override is not None:
width = image_width_override
if image_height_override is not None:
height = image_height_override
image = Image.new("RGB", (width, height), color=0) return [
return {"image": image if num_images == 1 else [image] * num_images} PromptReplacement(
modality="image",
target="<image>",
replacement="".join([
processor.image_start_token,
processor.image_token * CHAMELEON_IMAGE_SEQ_LENGTH,
processor.image_end_token,
]),
)
]
def _get_dummy_mm_inputs(
self,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
num_images = mm_counts.get("image", 0)
def dummy_data_for_chameleon(ctx: InputContext, seq_len: int, mm_data = {
mm_counts: Mapping[str, int]): "image":
num_images = mm_counts["image"] self._get_dummy_images(width=CHAMELEON_CROP_SIZE_WIDTH,
height=CHAMELEON_CROP_SIZE_HEIGHT,
num_images=num_images)
}
seq_data, ranges = dummy_seq_data_for_chameleon( return ProcessorInputs(
seq_len, prompt_text="<image>" * num_images,
num_images, mm_data=mm_data,
image_token_id=CHAMELEON_IMAGE_TOKEN_ID, )
)
mm_data = dummy_image_for_chameleon(num_images) def apply(
return DummyData(seq_data, mm_data, ranges) self,
prompt_text: str,
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputsV2:
result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs)
# Only <image> tokens should be considered as placeholders,
# so we ignore the image_start_token and image_end_token
result["mm_placeholders"] = {
modality: [
PlaceholderRange(offset=p["offset"] + 1,
length=p["length"] - 2) for p in ps
]
for modality, ps in result["mm_placeholders"].items()
}
def input_processor_for_chameleon(ctx: InputContext, return result
inputs: DecoderOnlyInputs):
"""
Processing input prompt to insert required tokens for image placeholder.
See https://github.com/huggingface/transformers/blob/0fdea8607d7e01eb0e38a1ebeb7feee30a22f0cf/src/transformers/models/chameleon/processing_chameleon.py#L58
""" # noqa
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return inputs
if "multi_modal_placeholders" in inputs and "image" in inputs[
"multi_modal_placeholders"]:
# The inputs already have placeholders.
return inputs
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(model_config.tokenizer)
new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
tokenizer,
inputs.get("prompt"),
inputs["prompt_token_ids"],
placeholder_token_id=CHAMELEON_IMAGE_TOKEN_ID,
repeat_count=CHAMELEON_IMAGE_SEQ_LENGTH,
pad_token_left=CHAMELEON_IMAGE_START_TOKEN_ID,
pad_token_right=CHAMELEON_IMAGE_END_TOKEN_ID,
)
# Appending sep token for chat mode to follow default processor
# behavior
if new_prompt is not None:
new_prompt += tokenizer.sep_token
new_token_ids += [CHAMELEON_SEP_TOKEN_ID]
# NOTE: Create a defensive copy of the original inputs
return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data)
class ChameleonLayerNorm(nn.LayerNorm): class ChameleonLayerNorm(nn.LayerNorm):
@ -736,7 +709,7 @@ class ChameleonVQVAEEncoder(nn.Module):
for i_level in range(self.num_resolutions): for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks): for i_block in range(self.num_res_blocks):
hidden_state = self.down[i_level].block[i_block]( hidden_state = self.down[i_level].block[i_block](
hidden_states[-1], ) hidden_states[-1])
if len(self.down[i_level].attn) > 0: if len(self.down[i_level].attn) > 0:
hidden_state = self.down[i_level].attn[i_block]( hidden_state = self.down[i_level].attn[i_block](
hidden_state) hidden_state)
@ -925,10 +898,8 @@ class ChameleonModel(nn.Module):
return hidden_states return hidden_states
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_chameleon_image_tokens) @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_chameleon_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_chameleon) @MULTIMODAL_REGISTRY.register_processor(ChameleonMultiModalProcessor)
@INPUT_REGISTRY.register_input_processor(input_processor_for_chameleon)
class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP): SupportsPP):

View File

@ -15,32 +15,30 @@
# limitations under the License. # limitations under the License.
""" PyTorch Fuyu model.""" """ PyTorch Fuyu model."""
import math import math
from array import array
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict) TypedDict)
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.utils.checkpoint from transformers import (BatchFeature, FuyuConfig, FuyuImageProcessor,
from PIL import Image FuyuProcessor)
from transformers import FuyuImageProcessor
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, from vllm.inputs import InputContext
InputContext, token_inputs)
from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.models.persimmon import PersimmonForCausalLM from vllm.model_executor.models.persimmon import PersimmonForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import cached_get_image_processor from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
from vllm.multimodal.inputs import NestedTensors MultiModalInputsV2, MultiModalKwargs,
from vllm.multimodal.utils import (cached_get_tokenizer, NestedTensors, PlaceholderRange)
consecutive_placeholder_ranges) from vllm.multimodal.parse import ImageProcessorItems
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, from vllm.multimodal.processing import (BaseMultiModalProcessor,
SequenceData) MultiModalDataItems, ProcessorInputs,
from vllm.utils import is_list_of PromptReplacement)
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix, from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
@ -54,178 +52,193 @@ MAX_IMAGE_FEATURE_SIZE_HEIGHT = 1080
MAX_IMAGE_FEATURE_SIZE_WIDTH = 1920 MAX_IMAGE_FEATURE_SIZE_WIDTH = 1920
class FuyuImagePixelInputs(TypedDict): class FuyuImagePatchInputs(TypedDict):
type: Literal["pixel_values"] type: Literal["image_patches"]
data: torch.Tensor data: torch.Tensor
""" """
Shape: Shape:
(batch_size, num_patches, patch_size_x * patch_size_y * num_channels) `(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)`
"""
patches_per_image: List[int]
"""
List of number of total patches for each image in the batch.
This is used to restore the first two dimensions of `data`.
""" """
def _calculate_num_image_tokens( def _get_fuyu_num_image_tokens(
height: int, image_height: int,
width: int, image_width: int,
) -> Tuple[int, int]: ) -> Tuple[int, int]:
""" """
calculate number of image tokens needed for a given image size Calculate the number of image tokens needed for a given image size.
The expected Fuyu image prompts is in format:
The expected Fuyu image prompts can be expressed as:
.. code-block::
(image_token * ncols + newline_token) * nrows (image_token * ncols + newline_token) * nrows
args:
image_size: Tuple[int, int] - (width, height) of the image Args:
returns: image_size: Tuple[int, int] - `(width, height)` of the image
ncols: int - number of image tokens in x direction
nrows: int - number of image tokens in y direction Returns:
ncols: int - number of image tokens in `x` direction
nrows: int - number of image tokens in `y` direction
""" """
ncol = math.ceil(width / 30) ncols = math.ceil(image_width / 30)
nrow = math.ceil(height / 30) nrows = math.ceil(image_height / 30)
return ncol, nrow return ncols, nrows
def get_max_fuyu_image_feature_size():
return _calculate_num_image_tokens(
height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
)
def get_max_fuyu_image_tokens(ctx: InputContext): def get_max_fuyu_image_tokens(ctx: InputContext):
ncol, nrow = get_max_fuyu_image_feature_size() ncols, nrows = _get_fuyu_num_image_tokens(
return (ncol + 1) * nrow image_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
image_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
def dummy_seq_data_for_fuyu(ctx: InputContext, seq_len: int, num_images: int):
ncol, nrow = get_max_fuyu_image_feature_size()
image_feature_size = get_max_fuyu_image_tokens(ctx)
image_token_ids = (
array(VLLM_TOKEN_ID_ARRAY_TYPE, [_IMAGE_TOKEN_ID]) * ncol +
array(VLLM_TOKEN_ID_ARRAY_TYPE, [_NEWLINE_TOKEN_ID])) * nrow
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, image_token_ids) * num_images
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - image_feature_size * num_images)
return SequenceData(token_ids), {
"image":
consecutive_placeholder_ranges(num_items=num_images,
item_size=image_feature_size)
}
def dummy_image_for_fuyu(
num_images: int,
*,
image_width: int,
image_height: int,
):
image = Image.new("RGB", (image_width, image_height), color=0)
return {"image": image if num_images == 1 else [image] * num_images}
def dummy_data_for_fuyu(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
num_images = mm_counts["image"]
seq_data, ranges = dummy_seq_data_for_fuyu(ctx, seq_len, num_images)
mm_data = dummy_image_for_fuyu(num_images,
image_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
image_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT)
return DummyData(seq_data, mm_data, ranges)
def _fuyu_image_preprocess(image_processor: FuyuImageProcessor,
data: List[Image.Image]):
image_encoding = image_processor.preprocess(data, return_tensors="pt")
batch_images = torch.stack([img[0] for img in image_encoding["images"]
]).unsqueeze(1)
image_unpadded_heights = torch.tensor(
image_encoding["image_unpadded_heights"])
image_unpadded_widths = torch.tensor(
image_encoding["image_unpadded_widths"])
batch_size = len(image_encoding["images"])
image_present = torch.ones(batch_size, 1, 1)
model_image_input = image_processor.preprocess_with_tokenizer_info(
image_input=batch_images,
image_present=image_present,
image_unpadded_h=image_unpadded_heights,
image_unpadded_w=image_unpadded_widths,
image_placeholder_id=_IMAGE_TOKEN_ID,
image_newline_id=_NEWLINE_TOKEN_ID,
variable_sized=True,
) )
return model_image_input
return (ncols + 1) * nrows
def input_processor_for_fuyu(ctx: InputContext, inputs: DecoderOnlyInputs): class FuyuMultiModalProcessor(BaseMultiModalProcessor):
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return inputs
model_config = ctx.model_config def _get_hf_processor(self) -> FuyuProcessor:
image_data = multi_modal_data["image"] return self.ctx.get_hf_processor(FuyuProcessor)
new_multi_modal_data = {}
image_list = image_data if isinstance(image_data, list) else [image_data]
# process image data def _call_hf_processor(
if is_list_of(image_list, Image.Image): self,
# Fuyu's image_processor can also finish token padding prompt: str,
image_processor: FuyuImageProcessor = cached_get_image_processor( mm_data: Mapping[str, object],
model_config.model) mm_kwargs: Mapping[str, object],
) -> BatchFeature:
model_image_input = _fuyu_image_preprocess(image_processor, image_data) if not mm_data:
image_patches = torch.cat([ # Avoid warning from HF logger for text-only input
image_patch[0] # Input_ids format: bos_token_id + prompt_token_ids + boa_token_id
for image_patch in model_image_input["image_patches"] # Tokenizer won't add boa_token_id by default, we add it manually.
]) tokenizer = self._get_tokenizer()
new_multi_modal_data["image"] = image_patches boa_token_id: int = tokenizer.vocab["<0x04>"] # type: ignore
prompt_ids = tokenizer.encode(prompt) + [boa_token_id]
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
elif is_list_of(image_list, torch.Tensor): processed_outputs = super()._call_hf_processor(
raise NotImplementedError("Embeddings input is not supported yet") prompt=prompt,
else: mm_data=mm_data,
raise TypeError(f"Invalid image type: {type(image_data)}") mm_kwargs=mm_kwargs,
)
# process prompts image_patches = processed_outputs.get("image_patches")
prompt = inputs.get("prompt") if image_patches is not None:
prompt_token_ids = inputs["prompt_token_ids"] images = mm_data["images"]
tokenizer = cached_get_tokenizer(model_config.model) assert isinstance(images, list)
# dim0 is batch_size, dim1 is subseq_size which will always be 1
image_input_ids: List[List[
torch.Tensor]] = model_image_input["image_input_ids"]
image_input_ids = image_input_ids[0][0].tolist()
bos_token = tokenizer.encode("<s>", add_special_tokens=False)[1:]
boa_token = tokenizer.encode("\x04", add_special_tokens=False)[1:]
new_prompt = prompt + "\x04" # Original output: (1, num_images, Pn, Px * Py * C)
new_prompt_token_ids = image_input_ids + bos_token + prompt_token_ids[ # New output: (num_images, Pn, Px * Py * C)
1:] + boa_token assert (isinstance(image_patches, list)
and len(image_patches) == 1)
assert (isinstance(image_patches[0], torch.Tensor)
and len(image_patches[0]) == len(images))
return token_inputs(prompt=new_prompt, processed_outputs["image_patches"] = image_patches[0]
prompt_token_ids=new_prompt_token_ids,
multi_modal_data=new_multi_modal_data) return processed_outputs
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(image_patches=MultiModalFieldConfig.batched("image"))
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_config = self.ctx.get_hf_config(FuyuConfig)
bos_token_id = hf_config.bos_token_id
tokenizer = self._get_tokenizer()
eot_token_id = tokenizer.bos_token_id
assert isinstance(eot_token_id, int)
hf_processor = self._get_hf_processor()
image_processor: FuyuImageProcessor = hf_processor.image_processor
target_size = image_processor.size
target_height, target_width = (target_size["height"],
target_size["width"])
def get_replacement_fuyu(item_idx: int):
images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx)
width, height = image_size.width, image_size.height
if not (width <= target_width and height <= target_height):
height_scale_factor = target_height / height
width_scale_factor = target_width / width
optimal_scale_factor = min(height_scale_factor,
width_scale_factor)
height = int(height * optimal_scale_factor)
width = int(width * optimal_scale_factor)
ncols, nrows = _get_fuyu_num_image_tokens(
image_width=width,
image_height=height,
)
return (([_IMAGE_TOKEN_ID] * ncols + [_NEWLINE_TOKEN_ID]) * nrows +
[bos_token_id])
return [
PromptReplacement(
modality="image",
target=[eot_token_id],
replacement=get_replacement_fuyu,
)
]
def apply(
self,
prompt_text: str,
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputsV2:
result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs)
# Only |SPEAKER| (image) tokens should be considered as placeholders,
# so we ignore the trailing bos_token_id
result["mm_placeholders"] = {
modality: [
PlaceholderRange(offset=p["offset"], length=p["length"] - 1)
for p in ps
]
for modality, ps in result["mm_placeholders"].items()
}
return result
def _get_dummy_mm_inputs(
self,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
num_images = mm_counts.get("image", 0)
mm_data = {
"image":
self._get_dummy_images(width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
num_images=num_images)
}
return ProcessorInputs(
prompt_text="",
mm_data=mm_data,
)
def input_mapper_for_fuyu(ctx: InputContext, data: object):
model_config = ctx.model_config
data_list = data if isinstance(data, list) else [data]
if is_list_of(data_list, Image.Image):
# Fuyu's image_processor can also finish token padding
image_processor: FuyuImageProcessor = cached_get_image_processor(
model_config.model)
model_image_input = _fuyu_image_preprocess(image_processor, data_list)
data = torch.stack([
image_patch[0]
for image_patch in model_image_input["image_patches"]
])
# image has been processed with prompt in input processor
return MultiModalKwargs({"pixel_values": data})
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_fuyu)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_fuyu_image_tokens) @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_fuyu_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_fuyu) @MULTIMODAL_REGISTRY.register_processor(FuyuMultiModalProcessor)
@INPUT_REGISTRY.register_input_processor(input_processor_for_fuyu)
class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
@ -280,28 +293,32 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return data.to(self.vision_embed_tokens.weight.dtype) return data.to(self.vision_embed_tokens.weight.dtype)
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[FuyuImagePixelInputs]: self, **kwargs: object) -> Optional[FuyuImagePatchInputs]:
pixel_values = kwargs.pop("pixel_values", None) image_patches = kwargs.pop("image_patches", None)
if image_patches is not None:
if pixel_values is not None: if not isinstance(image_patches, (torch.Tensor, list)):
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of image patches. " raise ValueError("Incorrect type of image patches. "
f"Got type: {type(pixel_values)}") f"Got type: {type(image_patches)}")
return FuyuImagePixelInputs( image_patches_flat = flatten_bn(image_patches)
type="pixel_values",
return FuyuImagePatchInputs(
type="image_patches",
data=self._validate_pixel_values( data=self._validate_pixel_values(
flatten_bn(pixel_values, concat=True)), flatten_bn(image_patches_flat, concat=True)),
patches_per_image=[x.size(0) for x in image_patches_flat],
) )
return None return None
def _process_image_input( def _process_image_input(
self, image_input: FuyuImagePixelInputs) -> torch.Tensor: self, image_input: FuyuImagePatchInputs) -> NestedTensors:
image_patches = image_input["data"]
patches_per_image = image_input["patches_per_image"]
assert self.vision_embed_tokens is not None assert self.vision_embed_tokens is not None
vision_embeddings, _ = self.vision_embed_tokens(image_input["data"]) vision_embeddings, _ = self.vision_embed_tokens(image_patches)
return vision_embeddings return vision_embeddings.split(patches_per_image, dim=0)
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)

View File

@ -69,7 +69,8 @@ class Idefics2VisionEmbeddings(nn.Module):
patch_attention_mask: torch.BoolTensor, patch_attention_mask: torch.BoolTensor,
tgt_sizes: Optional[torch.IntTensor] = None) -> torch.Tensor: tgt_sizes: Optional[torch.IntTensor] = None) -> torch.Tensor:
batch_size, _, max_im_h, max_im_w = pixel_values.shape batch_size, _, max_im_h, max_im_w = pixel_values.shape
patch_embeds = self.patch_embedding(pixel_values) target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values.to(target_dtype))
embeddings = patch_embeds.flatten(2).transpose(1, 2) embeddings = patch_embeds.flatten(2).transpose(1, 2)
max_nb_patches_h, max_nb_patches_w = ( max_nb_patches_h, max_nb_patches_w = (
max_im_h // self.patch_size, max_im_h // self.patch_size,
@ -309,7 +310,8 @@ class Idefics2VisionTransformer(nn.Module):
hidden_states = self.embeddings( hidden_states = self.embeddings(
pixel_values=pixel_values, pixel_values=pixel_values,
patch_attention_mask=patch_attention_mask, patch_attention_mask=patch_attention_mask,
tgt_sizes=tgt_sizes) tgt_sizes=tgt_sizes,
)
encoder_outputs = self.encoder(hidden_states) encoder_outputs = self.encoder(hidden_states)
last_hidden_state = self.post_layernorm(encoder_outputs) last_hidden_state = self.post_layernorm(encoder_outputs)
return last_hidden_state return last_hidden_state

View File

@ -144,8 +144,8 @@ class LlavaMultiModalProcessor(BaseMultiModalProcessor):
# Original output: (1, num_images, C, H, W) # Original output: (1, num_images, C, H, W)
# New output: (num_images, C, H, W) # New output: (num_images, C, H, W)
assert (isinstance(pixel_values, list) assert (isinstance(pixel_values, list)
and len(pixel_values) == 1 and len(pixel_values) == 1)
and isinstance(pixel_values[0], list) assert (isinstance(pixel_values[0], list)
and len(pixel_values[0]) == len(images)) and len(pixel_values[0]) == len(images))
processed_outputs["pixel_values"] = pixel_values[0] processed_outputs["pixel_values"] = pixel_values[0]

View File

@ -528,10 +528,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
stacked_image_features = self._image_pixels_to_features( stacked_image_features = self._image_pixels_to_features(
self.vision_tower, stacked_pixel_values) self.vision_tower, stacked_pixel_values)
return [ return torch.split(self.multi_modal_projector(stacked_image_features),
self.multi_modal_projector(image_features) for image_features in num_patches_per_batch)
torch.split(stacked_image_features, num_patches_per_batch)
]
def _process_image_input( def _process_image_input(
self, self,

View File

@ -1,8 +1,8 @@
import math
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
from functools import cached_property from functools import cached_property
from typing import Iterable, List, Mapping, Optional, Set, Tuple, Union from typing import Iterable, List, Mapping, Optional, Set, Tuple, Union
import numpy
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
@ -306,7 +306,7 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
images: Optional[Union[List[List[torch.Tensor]], List[torch.Tensor], images: Optional[Union[List[List[torch.Tensor]], List[torch.Tensor],
torch.Tensor]] = None, torch.Tensor]] = None,
image_tokens: Optional[torch.Tensor] = None, image_tokens: Optional[torch.Tensor] = None,
) -> Optional[List[torch.Tensor]]: ) -> Tuple[Optional[List[torch.Tensor]], Optional[torch.Tensor]]:
if images is None: if images is None:
return None, None return None, None
@ -604,11 +604,11 @@ class VisionTransformer(nn.Module):
return self.args.image_size // self.args.patch_size return self.args.image_size // self.args.patch_size
@property @property
def device(self) -> torch.device: def device(self) -> torch.types.Device:
return next(self.parameters()).device return next(self.parameters()).device
@property @property
def dtype(self) -> torch.device: def dtype(self) -> torch.dtype:
return next(self.parameters()).dtype return next(self.parameters()).dtype
@property @property
@ -741,8 +741,8 @@ def get_pixtral_hf_image_feature_size(hf_config: PixtralVisionConfig,
ratio = max(image_width / max_width, image_height / max_height) ratio = max(image_width / max_width, image_height / max_height)
if ratio > 1: if ratio > 1:
image_width = int(numpy.ceil(image_width / ratio)) image_width = int(math.ceil(image_width / ratio))
image_height = int(numpy.ceil(image_height / ratio)) image_height = int(math.ceil(image_height / ratio))
num_height_tokens, num_width_tokens = _get_pixtral_hf_num_image_tokens( num_height_tokens, num_width_tokens = _get_pixtral_hf_num_image_tokens(
(image_height, image_width), (image_height, image_width),

View File

@ -23,7 +23,6 @@ from functools import cached_property
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict, from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
Union) Union)
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from transformers import BatchFeature from transformers import BatchFeature
@ -177,16 +176,19 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> ProcessorInputs: ) -> ProcessorInputs:
feature_extractor = self._get_feature_extractor() feature_extractor = self._get_feature_extractor()
sampling_rate = feature_extractor.sampling_rate sampling_rate = feature_extractor.sampling_rate
audio_len = feature_extractor.chunk_length * sampling_rate audio_len = feature_extractor.chunk_length * sampling_rate
num_audios = mm_counts.get("audio", 0)
audio_count = mm_counts.get("audio", 0) mm_data = {
audio = np.zeros(audio_len) "audio":
data = {"audio": [audio] * audio_count} self._get_dummy_audios(length=audio_len, num_audios=num_audios)
}
return ProcessorInputs( return ProcessorInputs(
prompt_text="<|AUDIO|>" * audio_count, prompt_text="<|AUDIO|>" * num_audios,
mm_data=data, mm_data=mm_data,
) )

View File

@ -29,7 +29,6 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange, repeat from einops import rearrange, repeat
from PIL import Image
from transformers import BatchFeature from transformers import BatchFeature
from transformers.models.qwen2_vl import (Qwen2VLImageProcessor, from transformers.models.qwen2_vl import (Qwen2VLImageProcessor,
Qwen2VLProcessor) Qwen2VLProcessor)
@ -882,12 +881,10 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
self, self,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> ProcessorInputs: ) -> ProcessorInputs:
num_images = mm_counts.get("image", 0)
hf_processor = self._get_hf_processor() hf_processor = self._get_hf_processor()
image_token: str = hf_processor.image_token
image_processor = _get_image_processor(hf_processor) image_processor = _get_image_processor(hf_processor)
data = {} image_token: str = hf_processor.image_token
resized_height, resized_width = smart_resize( resized_height, resized_width = smart_resize(
height=9999999, height=9999999,
width=9999999, width=9999999,
@ -895,14 +892,18 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
min_pixels=image_processor.min_pixels, min_pixels=image_processor.min_pixels,
max_pixels=image_processor.max_pixels, max_pixels=image_processor.max_pixels,
) )
num_images = mm_counts.get("image", 0)
dummy_image = Image.new("RGB", (resized_width, resized_height), mm_data = {
color=0) "image":
data["image"] = [dummy_image] * num_images self._get_dummy_images(width=resized_width,
height=resized_height,
num_images=num_images)
}
return ProcessorInputs( return ProcessorInputs(
prompt_text=image_token * num_images, prompt_text=image_token * num_images,
mm_data=data, mm_data=mm_data,
) )

View File

@ -188,16 +188,19 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> ProcessorInputs: ) -> ProcessorInputs:
feature_extractor = self._get_feature_extractor() feature_extractor = self._get_feature_extractor()
sampling_rate = feature_extractor.sampling_rate sampling_rate = feature_extractor.sampling_rate
audio_len = feature_extractor.chunk_length * sampling_rate audio_len = feature_extractor.chunk_length * sampling_rate
num_audios = mm_counts.get("audio", 0)
audio_count = mm_counts.get("audio", 0) mm_data = {
audio = np.zeros(audio_len) "audio":
data = {"audio": [audio] * audio_count} self._get_dummy_audios(length=audio_len, num_audios=num_audios)
}
return ProcessorInputs( return ProcessorInputs(
prompt_text="<|audio|>" * audio_count, prompt_text="<|audio|>" * num_audios,
mm_data=data, mm_data=mm_data,
) )

View File

@ -1,15 +1,17 @@
import pickle import pickle
import re import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import defaultdict
from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import lru_cache from functools import lru_cache
from typing import Any, NamedTuple, Optional, Protocol, TypeVar, Union from typing import Any, NamedTuple, Optional, Protocol, TypeVar, Union
import numpy as np import numpy as np
import numpy.typing as npt
import torch import torch
from blake3 import blake3 from blake3 import blake3
from PIL.Image import Image from PIL import Image
from transformers import BatchFeature, ProcessorMixin from transformers import BatchFeature, ProcessorMixin
from vllm.inputs import DummyData, InputProcessingContext from vllm.inputs import DummyData, InputProcessingContext
@ -353,13 +355,13 @@ def _replace_matches(
) -> list[_S]: ) -> list[_S]:
out_seqs = list[_S]() out_seqs = list[_S]()
prev_end_idx = 0 prev_end_idx = 0
next_idx_by_modality = {modality: 0 for modality in mm_item_counts} next_idx_by_modality = defaultdict[str, int](lambda: 0)
for match in _resolve_matches(prompt, matches): for match in _resolve_matches(prompt, matches):
modality = match.modality modality = match.modality
item_idx = next_idx_by_modality[modality] item_idx = next_idx_by_modality[modality]
if item_idx >= mm_item_counts[modality]: if item_idx >= mm_item_counts.get(modality, 0):
continue continue
start_idx = match.start_idx start_idx = match.start_idx
@ -513,7 +515,7 @@ class ProcessingCache:
return obj.encode("utf-8") return obj.encode("utf-8")
if isinstance(obj, bytes): if isinstance(obj, bytes):
return obj return obj
if isinstance(obj, Image): if isinstance(obj, Image.Image):
return obj.tobytes() return obj.tobytes()
# Convertible to NumPy arrays # Convertible to NumPy arrays
@ -673,10 +675,14 @@ class BaseMultiModalProcessor(ABC):
Given the original multi-modal items for this modality Given the original multi-modal items for this modality
and HF-processed data, output the replacements to perform. and HF-processed data, output the replacements to perform.
Note: Notes:
Even when the HF processor already performs replacement for us, - You should not assume that HF processor always performs prompt
we still use this replacement information to determine replacement: in :meth:`_apply_hf_processor_missing`, this method
the placeholder token positions for each multi-modal item. is called on text-only and multimodal-only inputs separately,
instead of passing them in the same call.
- The replacement information returned by this method is also used
to determine the placeholder token positions for each multi-modal
item.
""" """
raise NotImplementedError raise NotImplementedError
@ -710,6 +716,10 @@ class BaseMultiModalProcessor(ABC):
mm_data: Mapping[str, object], mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object], mm_kwargs: Mapping[str, object],
) -> BatchFeature: ) -> BatchFeature:
"""
Call the HF processor on the prompt text and
associated multi-modal data.
"""
return self.ctx.call_hf_processor( return self.ctx.call_hf_processor(
self._get_hf_processor(**mm_kwargs), self._get_hf_processor(**mm_kwargs),
dict(text=prompt, **mm_data), dict(text=prompt, **mm_data),
@ -723,7 +733,8 @@ class BaseMultiModalProcessor(ABC):
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
) -> tuple[list[int], MultiModalKwargs]: ) -> tuple[list[int], MultiModalKwargs]:
""" """
Apply the HF processor on the full prompt text and multi-modal data. Wrapper of :meth:`_call_hf_processor` that applies
additional pre-processing and post-processing.
""" """
processor_data, passthrough_data = self._get_hf_mm_data(mm_items) processor_data, passthrough_data = self._get_hf_mm_data(mm_items)
@ -754,10 +765,11 @@ class BaseMultiModalProcessor(ABC):
Apply the HF processor on the full prompt text, but only on the Apply the HF processor on the full prompt text, but only on the
multi-modal data that are missing from the cache. multi-modal data that are missing from the cache.
Note: We pass prompt text and multi-modal data into the HF processor Note:
in separate calls to avoid HF prompt replacement being done for We pass prompt text and multi-modal data into the HF processor
cached items; instead, we rely on our own prompt replacement logic in separate calls to avoid HF prompt replacement being done for
for the full text. cached items; instead, we rely on our own prompt replacement logic
(:meth:`_get_prompt_replacements`) for the full text.
""" """
mm_missing_counts = mm_missing_data_items.get_all_counts() mm_missing_counts = mm_missing_data_items.get_all_counts()
@ -1010,6 +1022,36 @@ class BaseMultiModalProcessor(ABC):
mm_placeholders=mm_placeholders, mm_placeholders=mm_placeholders,
) )
def _get_dummy_audios(
self,
*,
length: int,
num_audios: int,
) -> list[npt.NDArray]:
audio = np.zeros((length, ))
return [audio] * num_audios
def _get_dummy_images(
self,
*,
width: int,
height: int,
num_images: int,
) -> list[Image.Image]:
image = Image.new("RGB", (width, height), color=0)
return [image] * num_images
def _get_dummy_videos(
self,
*,
width: int,
height: int,
num_frames: int,
num_videos: int,
) -> list[npt.NDArray]:
video = np.zeros((num_frames, width, height, 3))
return [video] * num_videos
@abstractmethod @abstractmethod
def _get_dummy_mm_inputs( def _get_dummy_mm_inputs(
self, self,

View File

@ -400,15 +400,19 @@ def repeat_and_pad_placeholder_tokens(
placeholder_token_idx = 0 placeholder_token_idx = 0
for i, token in enumerate(prompt_token_ids): for i, token in enumerate(prompt_token_ids):
if token == placeholder_token_id: if token == placeholder_token_id:
curr_repeat_count = repeat_count[placeholder_token_idx]
replacement_ids = repeat_and_pad_token( replacement_ids = repeat_and_pad_token(
placeholder_token_id, placeholder_token_id,
repeat_count=repeat_count[placeholder_token_idx], repeat_count=curr_repeat_count,
pad_token_left=pad_token_left, pad_token_left=pad_token_left,
pad_token_right=pad_token_right, pad_token_right=pad_token_right,
) )
offset = len(new_token_ids)
if pad_token_left is not None:
offset += 1
placeholder_ranges.append({ placeholder_ranges.append({
"offset": len(new_token_ids), "offset": offset,
"length": len(replacement_ids) "length": curr_repeat_count,
}) })
new_token_ids.extend(replacement_ids) new_token_ids.extend(replacement_ids)
placeholder_token_idx += 1 placeholder_token_idx += 1

View File

@ -647,10 +647,23 @@ class GPUModelRunner:
self.mm_registry.get_max_tokens_per_item_by_modality( self.mm_registry.get_max_tokens_per_item_by_modality(
self.model_config).values()) self.model_config).values())
max_num_mm_items = min( max_num_mm_items_encoder_budget = min(
self.max_num_encoder_input_tokens, self.max_num_encoder_input_tokens,
self.encoder_cache_size) // max_tokens_per_mm_item self.encoder_cache_size) // max_tokens_per_mm_item
max_mm_items_per_req = max(
self.mm_registry.get_mm_limits_per_prompt(
self.model_config).values())
# NOTE: We do not consider max_num_batched_tokens on purpose
# because the multimodal embeddings can be generated in advance
# and chunked prefilled.
max_num_mm_items_decoder_budget = self.max_num_reqs * \
max_mm_items_per_req
max_num_mm_items = min(max_num_mm_items_encoder_budget,
max_num_mm_items_decoder_budget)
# Dummy data definition in V0 may contain multiple multimodal items # Dummy data definition in V0 may contain multiple multimodal items
# (e.g, multiple images) for a single request, therefore here we # (e.g, multiple images) for a single request, therefore here we
# always replicate first item by max_num_mm_items times since in V1 # always replicate first item by max_num_mm_items times since in V1