diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 483f552b..83c1b9c8 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -7,6 +7,8 @@ vLLM supports a variety of generative Transformer models in `HuggingFace Transfo The following is the list of model architectures that are currently supported by vLLM. Alongside each architecture, we include some popular models that use it. +---- + Decoder-only Language Models ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. list-table:: @@ -186,6 +188,10 @@ Vision Language Models - Models - Example HuggingFace Models - :ref:`LoRA ` + * - :code:`Blip2ForConditionalGeneration` + - BLIP-2 + - :code:`Salesforce/blip2-opt-2.7b`, :code:`Salesforce/blip2-opt-6.7b`, etc. + - * - :code:`ChameleonForConditionalGeneration` - Chameleon - :code:`facebook/chameleon-7b` etc. @@ -215,6 +221,8 @@ Vision Language Models - :code:`openbmb/MiniCPM-V-2`, :code:`openbmb/MiniCPM-Llama3-V-2_5`, etc. - +---- + If your model uses one of the above model architectures, you can seamlessly run your model with vLLM. Otherwise, please refer to :ref:`Adding a New Model ` and :ref:`Enabling Multimodal Inputs ` for instructions on how to implement support for your model. diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index 8a636533..04ba1a96 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -106,6 +106,16 @@ def run_minicpmv(question): return llm, prompt +# BLIP-2 +def run_blip2(question): + + # BLIP-2 prompt format is inaccurate on HuggingFace model repository. + # See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa + prompt = f"Question: {question} Answer:" + llm = LLM(model="Salesforce/blip2-opt-2.7b") + return llm, prompt + + model_example_map = { "llava": run_llava, "llava-next": run_llava_next, @@ -114,6 +124,7 @@ model_example_map = { "paligemma": run_paligemma, "chameleon": run_chameleon, "minicpmv": run_minicpmv, + "blip-2": run_blip2, } diff --git a/examples/template_blip2.jinja b/examples/template_blip2.jinja new file mode 100644 index 00000000..fd41a7f7 --- /dev/null +++ b/examples/template_blip2.jinja @@ -0,0 +1,11 @@ +{%- for message in messages -%} + {%- if message['role'] == 'user' -%} + {{- 'Question: ' + message['content'] + ' ' -}} + {%- elif message['role'] == 'assistant' -%} + {{- 'Answer: ' + message['content'] + ' ' -}} + {%- endif -%} +{%- endfor -%} + +{%- if add_generation_prompt -%} + {{- 'Answer:' -}} +{% endif %} diff --git a/tests/models/test_blip2.py b/tests/models/test_blip2.py new file mode 100644 index 00000000..26afd57a --- /dev/null +++ b/tests/models/test_blip2.py @@ -0,0 +1,102 @@ +from typing import List, Optional, Tuple + +import pytest +from transformers import AutoTokenizer + +from vllm.multimodal.utils import rescale_image_size +from vllm.sequence import SampleLogprobs + +from ..conftest import IMAGE_ASSETS +from .utils import check_logprobs_close + +pytestmark = pytest.mark.vlm + +HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ + "stop_sign": + "Question: What's the content of the image? Answer:", + "cherry_blossom": + "Question: What is the season? Answer:", +}) + + +def vllm_to_hf_output(vllm_output: Tuple[List[int], str, + Optional[SampleLogprobs]], + model: str): + """Sanitize vllm output to be comparable with hf output.""" + _, output_str, out_logprobs = vllm_output + + hf_output_str = output_str + "\n" + + tokenizer = AutoTokenizer.from_pretrained(model) + hf_output_ids = tokenizer.encode(hf_output_str) + assert hf_output_ids[0] == tokenizer.bos_token_id + hf_output_ids = hf_output_ids[1:] + + return hf_output_ids, hf_output_str, out_logprobs + + +@pytest.mark.parametrize("model", ["Salesforce/blip2-opt-2.7b"]) +@pytest.mark.parametrize( + "size_factors", + [ + # No image + [], + # Single-scale + [1.0], + # Single-scale, batched + [1.0, 1.0, 1.0], + # Multi-scale + [0.25, 0.5, 1.0], + ], +) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, + dtype: str, max_tokens: int, num_logprobs: int) -> None: + """Inference result should be the same between hf and vllm. + + All the image fixtures for the test is under tests/images. + For huggingface runner, we provide the PIL images as input. + For vllm runner, we provide MultiModalData objects and corresponding + vision language config as input. + Note, the text input is also adjusted to abide by vllm contract. + The text output is sanitized to be able to compare with hf. + """ + images = [asset.pil_image for asset in image_assets] + + inputs_per_image = [( + [prompt for _ in size_factors], + [rescale_image_size(image, factor) for factor in size_factors], + ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] + + # max_model_len should be greater than image_feature_size + with vllm_runner(model, dtype=dtype, enforce_eager=True) as vllm_model: + vllm_outputs_per_image = [ + vllm_model.generate_greedy_logprobs(prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images) + for prompts, images in inputs_per_image + ] + + with hf_runner(model, dtype=dtype, is_vision_model=True) as hf_model: + hf_outputs_per_image = [ + hf_model.generate_greedy_logprobs_limit(prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images) + for prompts, images in inputs_per_image + ] + + for hf_outputs, vllm_outputs in zip(hf_outputs_per_image, + vllm_outputs_per_image): + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=[ + vllm_to_hf_output(vllm_output, model) + for vllm_output in vllm_outputs + ], + name_0="hf", + name_1="vllm", + ) diff --git a/tests/models/test_fuyu.py b/tests/models/test_fuyu.py index 25f63a3d..7d0f3be5 100644 --- a/tests/models/test_fuyu.py +++ b/tests/models/test_fuyu.py @@ -77,8 +77,8 @@ def run_test( vllm_model.generate_greedy_logprobs(prompts, max_tokens, num_logprobs=num_logprobs, - images=vllm_images) - for prompts, vllm_images in inputs_per_image + images=images) + for prompts, images in inputs_per_image ] with hf_runner(model, dtype=dtype) as hf_model: @@ -89,9 +89,9 @@ def run_test( hf_model.generate_greedy_logprobs_limit(prompts, max_tokens, num_logprobs=num_logprobs, - images=hf_images, + images=images, eos_token_id=eos_token_id) - for prompts, hf_images in inputs_per_image + for prompts, images in inputs_per_image ] for hf_outputs, vllm_outputs in zip(hf_outputs_per_image, diff --git a/tests/models/test_minicpmv.py b/tests/models/test_minicpmv.py index 9124fa7a..c57f0f8c 100644 --- a/tests/models/test_minicpmv.py +++ b/tests/models/test_minicpmv.py @@ -88,9 +88,9 @@ def run_test( vllm_model.generate_greedy_logprobs(prompts, max_tokens, num_logprobs=num_logprobs, - images=vllm_images, + images=images, stop_token_ids=stop_token_ids) - for prompts, vllm_images in inputs_per_image + for prompts, images in inputs_per_image ] with hf_runner(model, dtype=dtype) as hf_model, torch.no_grad(): @@ -114,9 +114,9 @@ def run_test( hf_model.generate_greedy_logprobs_limit(prompts, max_tokens, num_logprobs=num_logprobs, - images=hf_images, + images=images, tokenizer=tokenizer) - for prompts, hf_images in inputs_per_image + for prompts, images in inputs_per_image ] for hf_outputs, vllm_outputs in zip(hf_outputs_per_image, diff --git a/tests/models/test_phi3v.py b/tests/models/test_phi3v.py index 9da25ab8..35ffe4ef 100644 --- a/tests/models/test_phi3v.py +++ b/tests/models/test_phi3v.py @@ -101,8 +101,8 @@ def run_test( vllm_model.generate_greedy_logprobs(prompts, max_tokens, num_logprobs=num_logprobs, - images=vllm_images) - for prompts, vllm_images in inputs_per_image + images=images) + for prompts, images in inputs_per_image ] # use eager mode for hf runner, since phi3_v didn't work with flash_attn @@ -114,9 +114,9 @@ def run_test( hf_model.generate_greedy_logprobs_limit(prompts, max_tokens, num_logprobs=num_logprobs, - images=hf_images, + images=images, eos_token_id=eos_token_id) - for prompts, hf_images in inputs_per_image + for prompts, images in inputs_per_image ] for hf_outputs, vllm_outputs in zip(hf_outputs_per_image, diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index ead64c0e..fe04c6db 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -16,6 +16,8 @@ _GENERATION_MODELS = { "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b "BloomForCausalLM": ("bloom", "BloomForCausalLM"), + "Blip2ForConditionalGeneration": + ("blip2", "Blip2ForConditionalGeneration"), "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"), "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"), @@ -56,8 +58,8 @@ _GENERATION_MODELS = { "OPTForCausalLM": ("opt", "OPTForCausalLM"), "OrionForCausalLM": ("orion", "OrionForCausalLM"), "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"), - "PaliGemmaForConditionalGeneration": - ("paligemma", "PaliGemmaForConditionalGeneration"), + "PaliGemmaForConditionalGeneration": ("paligemma", + "PaliGemmaForConditionalGeneration"), "PhiForCausalLM": ("phi", "PhiForCausalLM"), "Phi3ForCausalLM": ("llama", "LlamaForCausalLM"), "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), diff --git a/vllm/model_executor/models/blip.py b/vllm/model_executor/models/blip.py new file mode 100644 index 00000000..0b124d5e --- /dev/null +++ b/vllm/model_executor/models/blip.py @@ -0,0 +1,269 @@ +"""Minimal implementation of BlipVisionModel intended to be only used +within a vision language model.""" +from typing import Optional, Union + +import torch +import torch.nn as nn +from PIL import Image +from transformers import Blip2VisionConfig, BlipVisionConfig +from transformers.models.blip.modeling_blip import BlipAttention + +from vllm.config import ModelConfig +from vllm.inputs import LLMInputs +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.multimodal.image import (cached_get_tokenizer, + repeat_and_pad_image_tokens) +from vllm.sequence import SequenceData + + +def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int: + assert image_size % patch_size == 0 + return image_size // patch_size + + +def get_blip_num_patches(*, image_size: int, patch_size: int) -> int: + grid_length = get_blip_patch_grid_length(image_size=image_size, + patch_size=patch_size) + return grid_length * grid_length + + +def get_blip_image_feature_size( + hf_config: Union[BlipVisionConfig, Blip2VisionConfig], ) -> int: + return get_blip_num_patches(image_size=hf_config.image_size, + patch_size=hf_config.patch_size) + + +def get_max_blip_image_tokens( + hf_config: Union[BlipVisionConfig, Blip2VisionConfig], ) -> int: + return get_blip_image_feature_size(hf_config) + + +def dummy_seq_data_for_blip( + hf_config: Union[BlipVisionConfig, Blip2VisionConfig], + seq_len: int, + *, + image_token_id: int, + image_feature_size_override: Optional[int] = None, +): + if image_feature_size_override is None: + image_feature_size = get_blip_image_feature_size(hf_config) + else: + image_feature_size = image_feature_size_override + + token_ids = [image_token_id] * image_feature_size + token_ids += [0] * (seq_len - image_feature_size) + return SequenceData(token_ids) + + +def dummy_image_for_blip( + hf_config: Union[BlipVisionConfig, Blip2VisionConfig], + *, + image_width_override: Optional[int] = None, + image_height_override: Optional[int] = None, +): + width = height = hf_config.image_size + if image_width_override is not None: + width = image_width_override + if image_height_override is not None: + height = image_height_override + + image = Image.new("RGB", (width, height), color=0) + return {"image": image} + + +def input_processor_for_blip( + model_config: ModelConfig, + hf_config: Union[BlipVisionConfig, Blip2VisionConfig], + llm_inputs: LLMInputs, + *, + image_token_id: int, + image_feature_size_override: Optional[int] = None, +): + multi_modal_data = llm_inputs.get("multi_modal_data") + if multi_modal_data is None or "image" not in multi_modal_data: + return llm_inputs + + tokenizer = cached_get_tokenizer(model_config.tokenizer) + + if image_feature_size_override is None: + image_feature_size = get_blip_image_feature_size(hf_config) + else: + image_feature_size = image_feature_size_override + + new_prompt, new_token_ids = repeat_and_pad_image_tokens( + tokenizer, + llm_inputs.get("prompt"), + llm_inputs["prompt_token_ids"], + image_token_id=image_token_id, + repeat_count=image_feature_size, + ) + + # NOTE: Create a defensive copy of the original inputs + return LLMInputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data) + + +# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/blip/modeling_blip.py#L164 # noqa +class BlipVisionEmbeddings(nn.Module): + + def __init__(self, config: BlipVisionConfig): + super().__init__() + + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim)) + + self.patch_embedding = nn.Conv2d( + in_channels=3, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + ) + + self.num_patches = get_blip_num_patches(image_size=self.image_size, + patch_size=self.patch_size) + self.num_positions = self.num_patches + 1 + + self.position_embedding = nn.Parameter( + torch.randn(1, self.num_positions, self.embed_dim)) + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + batch_size = pixel_values.shape[0] + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values.to( + dtype=target_dtype)) # shape = [*, width, grid, grid] + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + + class_embeds = self.class_embedding.expand(batch_size, 1, -1) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + + position_embeds = self.position_embedding.to(target_dtype) + embeddings = embeddings + position_embeds[:, :embeddings.size(1), :] + + return embeddings + + +class BlipMLP(nn.Module): + + def __init__(self, + config: BlipVisionConfig, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + + self.config = config + + self.activation_fn = get_act_fn(config.hidden_act) + self.fc1 = ColumnParallelLinear(config.hidden_size, + config.intermediate_size, + bias=True, + quant_config=quant_config) + self.fc2 = RowParallelLinear(config.intermediate_size, + config.hidden_size, + bias=True, + quant_config=quant_config) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states, _ = self.fc2(hidden_states) + + return hidden_states + + +class BlipEncoderLayer(nn.Module): + + def __init__(self, + config: BlipVisionConfig, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + + self.self_attn = BlipAttention(config) + self.layer_norm1 = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.mlp = BlipMLP(config, quant_config=quant_config) + self.layer_norm2 = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, _ = self.self_attn(hidden_states=hidden_states) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class BlipEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self + attention layers. Each layer is a [`BlipEncoderLayer`]. + + Args: + config: BlipConfig + """ + + def __init__(self, + config: BlipVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + num_hidden_layers_override: Optional[int] = None): + super().__init__() + + self.config = config + + if num_hidden_layers_override is None: + num_hidden_layers = config.num_hidden_layers + else: + num_hidden_layers = num_hidden_layers_override + + self.layers = nn.ModuleList([ + BlipEncoderLayer(config=config, quant_config=quant_config) + for _ in range(num_hidden_layers) + ]) + + def forward(self, inputs_embeds: torch.Tensor): + hidden_states = inputs_embeds + for encoder_layer in self.layers: + hidden_states = encoder_layer(hidden_states) + + return hidden_states + + +class BlipVisionModel(nn.Module): + config_class = BlipVisionConfig + main_input_name = "pixel_values" + + def __init__(self, + config: BlipVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + num_hidden_layers_override: Optional[int] = None): + super().__init__() + + self.config = config + + self.embeddings = BlipVisionEmbeddings(config) + self.encoder = BlipEncoder( + config=config, + quant_config=quant_config, + num_hidden_layers_override=num_hidden_layers_override, + ) + self.post_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + hidden_states = self.embeddings(pixel_values) + hidden_states = self.encoder(inputs_embeds=hidden_states) + + return self.post_layernorm(hidden_states) diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py new file mode 100644 index 00000000..e00e6c08 --- /dev/null +++ b/vllm/model_executor/models/blip2.py @@ -0,0 +1,669 @@ +from typing import Iterable, List, Literal, Optional, Tuple, TypedDict + +import torch +import torch.nn as nn +from transformers import (Blip2Config, Blip2QFormerConfig, Blip2VisionConfig, + apply_chunking_to_forward) + +from vllm.attention import AttentionMetadata +from vllm.config import CacheConfig, MultiModalConfig +from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.opt import OPTModel +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData + +from .blip import (BlipVisionModel, dummy_image_for_blip, + get_max_blip_image_tokens) +from .interfaces import SupportsVision +from .utils import merge_vision_embeddings + +_KEYS_TO_MODIFY_MAPPING = { + "language_model.lm_head": "lm_head", + "language_model.model": "language_model", +} + + +class Blip2QFormerMultiHeadAttention(nn.Module): + + def __init__( + self, + config: Blip2QFormerConfig, + *, + quant_config: Optional[QuantizationConfig], + cache_config: Optional[CacheConfig], + is_cross_attention: bool = False, + ) -> None: + super().__init__() + + self.config = config + + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of " + f"the number of attention heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = (config.hidden_size // + config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.scaling = self.attention_head_size**-0.5 + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + if is_cross_attention: + kv_hidden_size = config.encoder_hidden_size + else: + kv_hidden_size = config.hidden_size + self.key = nn.Linear(kv_hidden_size, self.all_head_size) + self.value = nn.Linear(kv_hidden_size, self.all_head_size) + + self.position_embedding_type = getattr(config, + "position_embedding_type", + "absolute") + if self.position_embedding_type != "absolute": + raise NotImplementedError("Unsupported position_embedding_type: " + f"{self.position_embedding_type}") + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + x = x.view(*x.size()[:-1], self.num_attention_heads, + self.attention_head_size) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + ): + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention: + key_layer = self.transpose_for_scores( + self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores( + self.value(encoder_hidden_states)) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + mixed_query_layer = self.query(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + attention_scores = torch.matmul(query_layer, + key_layer.transpose(-1, -2)) + attention_probs = torch.softmax(attention_scores * self.scaling, + dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs_dropped = self.dropout(attention_probs) + + context_layer = torch.matmul(attention_probs_dropped, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + context_layer = context_layer.view(*context_layer.size()[:-2], + self.all_head_size) + + return context_layer + + +class Blip2QFormerSelfOutput(nn.Module): + + def __init__(self, config: Blip2QFormerConfig) -> None: + super().__init__() + + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward( + self, + hidden_states: torch.Tensor, + input_tensor: torch.Tensor, + ) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class Blip2QFormerAttention(nn.Module): + + def __init__( + self, + config: Blip2QFormerConfig, + *, + quant_config: Optional[QuantizationConfig], + cache_config: Optional[CacheConfig], + is_cross_attention: bool = False, + ) -> None: + super().__init__() + + self.attention = Blip2QFormerMultiHeadAttention( + config, + quant_config=quant_config, + cache_config=cache_config, + is_cross_attention=is_cross_attention, + ) + + self.output = Blip2QFormerSelfOutput(config) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + ) -> Tuple[torch.Tensor]: + self_output = self.attention( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + ) + attention_output = self.output(self_output, hidden_states) + + return attention_output + + +class Blip2QFormerIntermediate(nn.Module): + + def __init__(self, config: Blip2QFormerConfig) -> None: + super().__init__() + + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + self.intermediate_act_fn = get_act_fn(config.hidden_act) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class Blip2QFormerOutput(nn.Module): + + def __init__(self, config: Blip2QFormerConfig) -> None: + super().__init__() + + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward( + self, + hidden_states: torch.Tensor, + input_tensor: torch.Tensor, + ) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class Blip2QFormerLayer(nn.Module): + + def __init__( + self, + config: Blip2QFormerConfig, + *, + quant_config: Optional[QuantizationConfig], + cache_config: Optional[CacheConfig], + layer_idx: int, + ) -> None: + super().__init__() + + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = Blip2QFormerAttention(config, + quant_config=quant_config, + cache_config=cache_config) + + self.layer_idx = layer_idx + + if layer_idx % config.cross_attention_frequency == 0: + self.crossattention = Blip2QFormerAttention( + config, + quant_config=quant_config, + cache_config=cache_config, + is_cross_attention=True) + self.has_cross_attention = True + else: + self.has_cross_attention = False + + self.intermediate_query = Blip2QFormerIntermediate(config) + self.output_query = Blip2QFormerOutput(config) + + def forward( + self, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor, + query_length: int, + ): + attention_output = self.attention(hidden_states) + + if query_length > 0: + query_attention_output = attention_output[:, :query_length, :] + + if self.has_cross_attention: + query_attention_output = self.crossattention( + query_attention_output, + encoder_hidden_states=encoder_hidden_states, + ) + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk_query, + self.chunk_size_feed_forward, + self.seq_len_dim, + query_attention_output, + ) + + if attention_output.shape[1] > query_length: + layer_output_text = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output[:, query_length:, :], + ) + layer_output = torch.cat([layer_output, layer_output_text], + dim=1) + else: + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output, + ) + + return layer_output + + def feed_forward_chunk(self, + attention_output: torch.Tensor) -> torch.Tensor: + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + def feed_forward_chunk_query( + self, attention_output: torch.Tensor) -> torch.Tensor: + intermediate_output = self.intermediate_query(attention_output) + layer_output = self.output_query(intermediate_output, attention_output) + return layer_output + + +class Blip2QFormerEncoder(nn.Module): + + def __init__( + self, + config: Blip2QFormerConfig, + *, + quant_config: Optional[QuantizationConfig], + cache_config: Optional[CacheConfig], + ) -> None: + super().__init__() + + self.config = config + + self.layer = nn.ModuleList([ + Blip2QFormerLayer(config, + quant_config=quant_config, + cache_config=cache_config, + layer_idx=layer_idx) + for layer_idx in range(config.num_hidden_layers) + ]) + + def forward( + self, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor, + query_length: int, + ) -> torch.Tensor: + for i in range(self.config.num_hidden_layers): + layer_module = self.layer[i] + + hidden_states = layer_module( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + query_length=query_length, + ) + + return hidden_states + + +# Adapted from https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/blip_2/modeling_blip_2.py#L1025 +class Blip2QFormerModel(nn.Module): + + def __init__( + self, + config: Blip2QFormerConfig, + *, + quant_config: Optional[QuantizationConfig], + cache_config: Optional[CacheConfig], + ) -> None: + super().__init__() + + self.config = config + + self.layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + self.encoder = Blip2QFormerEncoder(config, + quant_config=quant_config, + cache_config=cache_config) + + def forward( + self, + query_embeds: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor, + ) -> torch.Tensor: + query_length = query_embeds.shape[1] + + embedding_output = self.layernorm(query_embeds) + embedding_output = self.dropout(embedding_output) + + sequence_output = self.encoder( + embedding_output, + encoder_hidden_states=encoder_hidden_states, + query_length=query_length, + ) + + return sequence_output + + +class Blip2ImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: torch.Tensor + """Shape: (batch_size, num_channels, height, width)""" + + +Blip2ImageInputs = Blip2ImagePixelInputs + +# We use this internally as placeholders since there is no image token +# defined on the HuggingFace repo +BLIP2_IMAGE_TOKEN = "" +BLIP2_IMAGE_TOKEN_ID = 50265 + + +def get_blip2_image_feature_size(hf_config: Blip2Config) -> int: + return hf_config.num_query_tokens + + +def get_max_blip2_image_tokens(ctx: InputContext): + hf_config = ctx.get_hf_config(Blip2Config) + vision_config = hf_config.vision_config + + if isinstance(vision_config, Blip2VisionConfig): + return get_max_blip_image_tokens(vision_config) + + msg = f"Unsupported vision config: {type(vision_config)}" + raise NotImplementedError(msg) + + +def dummy_data_for_blip2(ctx: InputContext, seq_len: int): + hf_config = ctx.get_hf_config(Blip2Config) + vision_config = hf_config.vision_config + + image_feature_size = get_blip2_image_feature_size(hf_config) + token_ids = [BLIP2_IMAGE_TOKEN_ID] * image_feature_size + token_ids += [0] * (seq_len - image_feature_size) + seq_data = SequenceData(token_ids) + + if isinstance(vision_config, Blip2VisionConfig): + mm_data = dummy_image_for_blip(vision_config) + + return seq_data, mm_data + + msg = f"Unsupported vision config: {type(vision_config)}" + raise NotImplementedError(msg) + + +def input_processor_for_blip2(ctx: InputContext, llm_inputs: LLMInputs): + multi_modal_data = llm_inputs.get("multi_modal_data") + if multi_modal_data is None or "image" not in multi_modal_data: + return llm_inputs + + hf_config = ctx.get_hf_config(Blip2Config) + image_feature_size = get_blip2_image_feature_size(hf_config) + + # The original model places image tokens at the front + # https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/blip_2/modeling_blip_2.py#L1514 + new_token_ids = [BLIP2_IMAGE_TOKEN_ID] * image_feature_size + new_token_ids += llm_inputs["prompt_token_ids"] + + new_prompt = llm_inputs.get("prompt") + if new_prompt is not None: + new_prompt = BLIP2_IMAGE_TOKEN * image_feature_size + new_prompt + + return LLMInputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data) + + +@MULTIMODAL_REGISTRY.register_image_input_mapper() +@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_blip2_image_tokens) +@INPUT_REGISTRY.register_dummy_data(dummy_data_for_blip2) +@INPUT_REGISTRY.register_input_processor(input_processor_for_blip2) +class Blip2ForConditionalGeneration(nn.Module, SupportsVision): + + def __init__(self, + config: Blip2Config, + multimodal_config: MultiModalConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None) -> None: + + super().__init__() + + self.config = config + self.multimodal_config = multimodal_config + + # TODO: Optionally initializes this for supporting embeddings. + self.vision_model = BlipVisionModel(config.vision_config) + + self.query_tokens = nn.Parameter( + torch.zeros(1, config.num_query_tokens, + config.qformer_config.hidden_size)) + + self.qformer = Blip2QFormerModel(config.qformer_config, + cache_config=cache_config, + quant_config=quant_config) + + self.language_projection = nn.Linear( + config.qformer_config.hidden_size, + config.text_config.hidden_size, + bias=True, + ) + + self.quant_config = quant_config + + self.language_model = OPTModel(config.text_config, cache_config, + quant_config) + + self.unpadded_vocab_size = config.text_config.vocab_size + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size) + self.sampler = Sampler() + + def get_lm_head(self): + return self.language_model.decoder.embed_tokens + + def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: + h = w = self.config.vision_config.image_size + expected_dims = (3, h, w) + actual_dims = tuple(data.shape[1:]) + + if actual_dims != expected_dims: + expected_expr = ("batch_size", *map(str, expected_dims)) + raise ValueError( + f"The expected shape of pixel values is {expected_expr}. " + f"You supplied {tuple(data.shape)}.") + + return data + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[Blip2ImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + + if pixel_values is None: + return None + + if not isinstance(pixel_values, torch.Tensor): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + return Blip2ImagePixelInputs( + type="pixel_values", + data=self._validate_pixel_values(pixel_values), + ) + + def _image_pixels_to_features(self, vision_model: BlipVisionModel, + pixel_values: torch.Tensor) -> torch.Tensor: + + # NOTE: we skip the step to select the vision feature layer since + # this is already done inside the vision tower + image_features = vision_model(pixel_values) + + return image_features + + def _process_image_pixels(self, + inputs: Blip2ImagePixelInputs) -> torch.Tensor: + assert self.vision_model is not None + + pixel_values = inputs["data"] + + return self._image_pixels_to_features(self.vision_model, pixel_values) + + def _process_image_input(self, + image_input: Blip2ImageInputs) -> torch.Tensor: + assert self.vision_model is not None + image_features = self._process_image_pixels(image_input) + + query_tokens = self.query_tokens.expand(image_features.shape[0], -1, + -1) + query_output = self.qformer( + query_embeds=query_tokens, + encoder_hidden_states=image_features, + ) + + return self.language_projection(query_output) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + **kwargs: object, + ) -> SamplerOutput: + """Run forward pass for BLIP-2. + + One key thing to understand is the `input_ids` already accounts for the + positions of the to-be-inserted image embeddings. + + Concretely, consider a text prompt: + `"Question: What's the content of the image? Answer:"`. + + Tokenizer outputs: + `[2, 45641, 35, 653, 18, 5, 1383, 9, 5, 2274, 116, 31652, 35]`. + + To reserve space in KV cache, we have to insert placeholder tokens + before they are inputted to the model, so the input processor prepends + dummy tokens (denoted as `50265`), resulting in: + `[50265, ..., 50265, 2, 45641, 35, ..., 31652, 35]`. + + We insert 32 tokens since it corresponds to the number of query + embeddings outputted by the Q-Former and inputted to the language model. + + This way, the `positions` and `attn_metadata` are consistent + with the `input_ids`. + + Args: + input_ids: Flattened (concatenated) input_ids corresponding to a + batch. + pixel_values: The pixels in each input image. + + See also: + :class:`Blip2ImageInputs` + """ + image_input = self._parse_and_validate_image_input(**kwargs) + + if image_input is not None: + vision_embeddings = self._process_image_input(image_input) + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + + inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds, + vision_embeddings, + BLIP2_IMAGE_TOKEN_ID) + + input_ids = None + else: + inputs_embeds = None + + hidden_states = self.language_model(input_ids, + positions, + kv_caches, + attn_metadata, + inputs_embeds=inputs_embeds) + + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.get_lm_head(), hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + # only doing this for language model part for now. + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + + for name, loaded_weight in weights: + if "lm_head.weight" in name: + continue + if "rotary_emb.inv_freq" in name: + continue + for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items(): + if key_to_modify in name: + name = name.replace(key_to_modify, new_key) + use_default_weight_loading = False + if "vision" in name: + if self.vision_model is not None: + # We only do sharding for language model and + # not vision model for now. + use_default_weight_loading = True + else: + for (param_name, weight_name, + shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + param = params_dict[name.replace(weight_name, param_name)] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + use_default_weight_loading = True + if use_default_weight_loading: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index edc16710..a05090cd 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -237,14 +237,19 @@ class OPTDecoder(nn.Module): for _ in range(config.num_hidden_layers) ]) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - inputs_embeds = self.embed_tokens(input_ids) + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings(input_ids) pos_embeds = self.embed_positions(positions) if self.project_in is not None: inputs_embeds, _ = self.project_in(inputs_embeds) @@ -272,14 +277,22 @@ class OPTModel(nn.Module): super().__init__() self.decoder = OPTDecoder(config, cache_config, quant_config) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.decoder.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - return self.decoder(input_ids, positions, kv_caches, attn_metadata) + return self.decoder(input_ids, + positions, + kv_caches, + attn_metadata, + inputs_embeds=inputs_embeds) class OPTForCausalLM(nn.Module): diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index 0d435bd6..5abd0ad6 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -1,8 +1,9 @@ import sys from abc import ABC, abstractmethod from collections import UserDict, defaultdict -from typing import (Any, Callable, Dict, List, Optional, Type, TypedDict, - TypeVar, Union, cast) +from typing import Any, Callable, Dict, List, Optional +from typing import Sequence as GenericSequence +from typing import Type, TypedDict, TypeVar, Union, cast import torch import torch.types @@ -15,13 +16,13 @@ from vllm.logger import init_logger logger = init_logger(__name__) -NestedTensors = Union[List[torch.Tensor], torch.Tensor] +NestedTensors = Union[GenericSequence[torch.Tensor], torch.Tensor] """ Use a list instead of a tensor if the dimensions of each element do not match. Currently only supports up to singly nested list of tensors. """ -BatchedTensors = Union[List[NestedTensors], NestedTensors] +BatchedTensors = Union[GenericSequence[NestedTensors], NestedTensors] """ If each input tensor in the batch has the same size, this is a single batched tensor; otherwise, this is a list of :class:`NestedTensors` with one element @@ -53,7 +54,7 @@ class MultiModalInputs(_MultiModalInputsBase): # may be list rather than tensors if isinstance(tensors[0], list): return [[t.to(device=device) for t in tensor[0]] - for tensor in tensors] + for tensor in cast(List[List[torch.Tensor]], tensors)] tensors_ = cast(List[torch.Tensor], tensors)