[Model] Initial support for BLIP-2 (#5920)

Co-authored-by: ywang96 <ywang@roblox.com>
This commit is contained in:
Cyrus Leung 2024-07-27 19:53:07 +08:00 committed by GitHub
parent ecb33a28cb
commit 1ad86acf17
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 1107 additions and 21 deletions

View File

@ -7,6 +7,8 @@ vLLM supports a variety of generative Transformer models in `HuggingFace Transfo
The following is the list of model architectures that are currently supported by vLLM. The following is the list of model architectures that are currently supported by vLLM.
Alongside each architecture, we include some popular models that use it. Alongside each architecture, we include some popular models that use it.
----
Decoder-only Language Models Decoder-only Language Models
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. list-table:: .. list-table::
@ -186,6 +188,10 @@ Vision Language Models
- Models - Models
- Example HuggingFace Models - Example HuggingFace Models
- :ref:`LoRA <lora>` - :ref:`LoRA <lora>`
* - :code:`Blip2ForConditionalGeneration`
- BLIP-2
- :code:`Salesforce/blip2-opt-2.7b`, :code:`Salesforce/blip2-opt-6.7b`, etc.
-
* - :code:`ChameleonForConditionalGeneration` * - :code:`ChameleonForConditionalGeneration`
- Chameleon - Chameleon
- :code:`facebook/chameleon-7b` etc. - :code:`facebook/chameleon-7b` etc.
@ -215,6 +221,8 @@ Vision Language Models
- :code:`openbmb/MiniCPM-V-2`, :code:`openbmb/MiniCPM-Llama3-V-2_5`, etc. - :code:`openbmb/MiniCPM-V-2`, :code:`openbmb/MiniCPM-Llama3-V-2_5`, etc.
- -
----
If your model uses one of the above model architectures, you can seamlessly run your model with vLLM. If your model uses one of the above model architectures, you can seamlessly run your model with vLLM.
Otherwise, please refer to :ref:`Adding a New Model <adding_a_new_model>` and :ref:`Enabling Multimodal Inputs <enabling_multimodal_inputs>` Otherwise, please refer to :ref:`Adding a New Model <adding_a_new_model>` and :ref:`Enabling Multimodal Inputs <enabling_multimodal_inputs>`
for instructions on how to implement support for your model. for instructions on how to implement support for your model.

View File

@ -106,6 +106,16 @@ def run_minicpmv(question):
return llm, prompt return llm, prompt
# BLIP-2
def run_blip2(question):
# BLIP-2 prompt format is inaccurate on HuggingFace model repository.
# See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa
prompt = f"Question: {question} Answer:"
llm = LLM(model="Salesforce/blip2-opt-2.7b")
return llm, prompt
model_example_map = { model_example_map = {
"llava": run_llava, "llava": run_llava,
"llava-next": run_llava_next, "llava-next": run_llava_next,
@ -114,6 +124,7 @@ model_example_map = {
"paligemma": run_paligemma, "paligemma": run_paligemma,
"chameleon": run_chameleon, "chameleon": run_chameleon,
"minicpmv": run_minicpmv, "minicpmv": run_minicpmv,
"blip-2": run_blip2,
} }

View File

@ -0,0 +1,11 @@
{%- for message in messages -%}
{%- if message['role'] == 'user' -%}
{{- 'Question: ' + message['content'] + ' ' -}}
{%- elif message['role'] == 'assistant' -%}
{{- 'Answer: ' + message['content'] + ' ' -}}
{%- endif -%}
{%- endfor -%}
{%- if add_generation_prompt -%}
{{- 'Answer:' -}}
{% endif %}

102
tests/models/test_blip2.py Normal file
View File

@ -0,0 +1,102 @@
from typing import List, Optional, Tuple
import pytest
from transformers import AutoTokenizer
from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
from ..conftest import IMAGE_ASSETS
from .utils import check_logprobs_close
pytestmark = pytest.mark.vlm
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign":
"Question: What's the content of the image? Answer:",
"cherry_blossom":
"Question: What is the season? Answer:",
})
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
Optional[SampleLogprobs]],
model: str):
"""Sanitize vllm output to be comparable with hf output."""
_, output_str, out_logprobs = vllm_output
hf_output_str = output_str + "\n"
tokenizer = AutoTokenizer.from_pretrained(model)
hf_output_ids = tokenizer.encode(hf_output_str)
assert hf_output_ids[0] == tokenizer.bos_token_id
hf_output_ids = hf_output_ids[1:]
return hf_output_ids, hf_output_str, out_logprobs
@pytest.mark.parametrize("model", ["Salesforce/blip2-opt-2.7b"])
@pytest.mark.parametrize(
"size_factors",
[
# No image
[],
# Single-scale
[1.0],
# Single-scale, batched
[1.0, 1.0, 1.0],
# Multi-scale
[0.25, 0.5, 1.0],
],
)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
dtype: str, max_tokens: int, num_logprobs: int) -> None:
"""Inference result should be the same between hf and vllm.
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalData objects and corresponding
vision language config as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
images = [asset.pil_image for asset in image_assets]
inputs_per_image = [(
[prompt for _ in size_factors],
[rescale_image_size(image, factor) for factor in size_factors],
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
# max_model_len should be greater than image_feature_size
with vllm_runner(model, dtype=dtype, enforce_eager=True) as vllm_model:
vllm_outputs_per_image = [
vllm_model.generate_greedy_logprobs(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images)
for prompts, images in inputs_per_image
]
with hf_runner(model, dtype=dtype, is_vision_model=True) as hf_model:
hf_outputs_per_image = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images)
for prompts, images in inputs_per_image
]
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
vllm_outputs_per_image):
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=[
vllm_to_hf_output(vllm_output, model)
for vllm_output in vllm_outputs
],
name_0="hf",
name_1="vllm",
)

View File

@ -77,8 +77,8 @@ def run_test(
vllm_model.generate_greedy_logprobs(prompts, vllm_model.generate_greedy_logprobs(prompts,
max_tokens, max_tokens,
num_logprobs=num_logprobs, num_logprobs=num_logprobs,
images=vllm_images) images=images)
for prompts, vllm_images in inputs_per_image for prompts, images in inputs_per_image
] ]
with hf_runner(model, dtype=dtype) as hf_model: with hf_runner(model, dtype=dtype) as hf_model:
@ -89,9 +89,9 @@ def run_test(
hf_model.generate_greedy_logprobs_limit(prompts, hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens, max_tokens,
num_logprobs=num_logprobs, num_logprobs=num_logprobs,
images=hf_images, images=images,
eos_token_id=eos_token_id) eos_token_id=eos_token_id)
for prompts, hf_images in inputs_per_image for prompts, images in inputs_per_image
] ]
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image, for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,

View File

@ -88,9 +88,9 @@ def run_test(
vllm_model.generate_greedy_logprobs(prompts, vllm_model.generate_greedy_logprobs(prompts,
max_tokens, max_tokens,
num_logprobs=num_logprobs, num_logprobs=num_logprobs,
images=vllm_images, images=images,
stop_token_ids=stop_token_ids) stop_token_ids=stop_token_ids)
for prompts, vllm_images in inputs_per_image for prompts, images in inputs_per_image
] ]
with hf_runner(model, dtype=dtype) as hf_model, torch.no_grad(): with hf_runner(model, dtype=dtype) as hf_model, torch.no_grad():
@ -114,9 +114,9 @@ def run_test(
hf_model.generate_greedy_logprobs_limit(prompts, hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens, max_tokens,
num_logprobs=num_logprobs, num_logprobs=num_logprobs,
images=hf_images, images=images,
tokenizer=tokenizer) tokenizer=tokenizer)
for prompts, hf_images in inputs_per_image for prompts, images in inputs_per_image
] ]
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image, for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,

View File

@ -101,8 +101,8 @@ def run_test(
vllm_model.generate_greedy_logprobs(prompts, vllm_model.generate_greedy_logprobs(prompts,
max_tokens, max_tokens,
num_logprobs=num_logprobs, num_logprobs=num_logprobs,
images=vllm_images) images=images)
for prompts, vllm_images in inputs_per_image for prompts, images in inputs_per_image
] ]
# use eager mode for hf runner, since phi3_v didn't work with flash_attn # use eager mode for hf runner, since phi3_v didn't work with flash_attn
@ -114,9 +114,9 @@ def run_test(
hf_model.generate_greedy_logprobs_limit(prompts, hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens, max_tokens,
num_logprobs=num_logprobs, num_logprobs=num_logprobs,
images=hf_images, images=images,
eos_token_id=eos_token_id) eos_token_id=eos_token_id)
for prompts, hf_images in inputs_per_image for prompts, images in inputs_per_image
] ]
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image, for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,

View File

@ -16,6 +16,8 @@ _GENERATION_MODELS = {
"BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b
"BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b
"BloomForCausalLM": ("bloom", "BloomForCausalLM"), "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
"Blip2ForConditionalGeneration":
("blip2", "Blip2ForConditionalGeneration"),
"ChameleonForConditionalGeneration": "ChameleonForConditionalGeneration":
("chameleon", "ChameleonForConditionalGeneration"), ("chameleon", "ChameleonForConditionalGeneration"),
"ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"), "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
@ -56,8 +58,8 @@ _GENERATION_MODELS = {
"OPTForCausalLM": ("opt", "OPTForCausalLM"), "OPTForCausalLM": ("opt", "OPTForCausalLM"),
"OrionForCausalLM": ("orion", "OrionForCausalLM"), "OrionForCausalLM": ("orion", "OrionForCausalLM"),
"PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"), "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
"PaliGemmaForConditionalGeneration": "PaliGemmaForConditionalGeneration": ("paligemma",
("paligemma", "PaliGemmaForConditionalGeneration"), "PaliGemmaForConditionalGeneration"),
"PhiForCausalLM": ("phi", "PhiForCausalLM"), "PhiForCausalLM": ("phi", "PhiForCausalLM"),
"Phi3ForCausalLM": ("llama", "LlamaForCausalLM"), "Phi3ForCausalLM": ("llama", "LlamaForCausalLM"),
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),

View File

@ -0,0 +1,269 @@
"""Minimal implementation of BlipVisionModel intended to be only used
within a vision language model."""
from typing import Optional, Union
import torch
import torch.nn as nn
from PIL import Image
from transformers import Blip2VisionConfig, BlipVisionConfig
from transformers.models.blip.modeling_blip import BlipAttention
from vllm.config import ModelConfig
from vllm.inputs import LLMInputs
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal.image import (cached_get_tokenizer,
repeat_and_pad_image_tokens)
from vllm.sequence import SequenceData
def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
assert image_size % patch_size == 0
return image_size // patch_size
def get_blip_num_patches(*, image_size: int, patch_size: int) -> int:
grid_length = get_blip_patch_grid_length(image_size=image_size,
patch_size=patch_size)
return grid_length * grid_length
def get_blip_image_feature_size(
hf_config: Union[BlipVisionConfig, Blip2VisionConfig], ) -> int:
return get_blip_num_patches(image_size=hf_config.image_size,
patch_size=hf_config.patch_size)
def get_max_blip_image_tokens(
hf_config: Union[BlipVisionConfig, Blip2VisionConfig], ) -> int:
return get_blip_image_feature_size(hf_config)
def dummy_seq_data_for_blip(
hf_config: Union[BlipVisionConfig, Blip2VisionConfig],
seq_len: int,
*,
image_token_id: int,
image_feature_size_override: Optional[int] = None,
):
if image_feature_size_override is None:
image_feature_size = get_blip_image_feature_size(hf_config)
else:
image_feature_size = image_feature_size_override
token_ids = [image_token_id] * image_feature_size
token_ids += [0] * (seq_len - image_feature_size)
return SequenceData(token_ids)
def dummy_image_for_blip(
hf_config: Union[BlipVisionConfig, Blip2VisionConfig],
*,
image_width_override: Optional[int] = None,
image_height_override: Optional[int] = None,
):
width = height = hf_config.image_size
if image_width_override is not None:
width = image_width_override
if image_height_override is not None:
height = image_height_override
image = Image.new("RGB", (width, height), color=0)
return {"image": image}
def input_processor_for_blip(
model_config: ModelConfig,
hf_config: Union[BlipVisionConfig, Blip2VisionConfig],
llm_inputs: LLMInputs,
*,
image_token_id: int,
image_feature_size_override: Optional[int] = None,
):
multi_modal_data = llm_inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return llm_inputs
tokenizer = cached_get_tokenizer(model_config.tokenizer)
if image_feature_size_override is None:
image_feature_size = get_blip_image_feature_size(hf_config)
else:
image_feature_size = image_feature_size_override
new_prompt, new_token_ids = repeat_and_pad_image_tokens(
tokenizer,
llm_inputs.get("prompt"),
llm_inputs["prompt_token_ids"],
image_token_id=image_token_id,
repeat_count=image_feature_size,
)
# NOTE: Create a defensive copy of the original inputs
return LLMInputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data)
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/blip/modeling_blip.py#L164 # noqa
class BlipVisionEmbeddings(nn.Module):
def __init__(self, config: BlipVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim))
self.patch_embedding = nn.Conv2d(
in_channels=3,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
)
self.num_patches = get_blip_num_patches(image_size=self.image_size,
patch_size=self.patch_size)
self.num_positions = self.num_patches + 1
self.position_embedding = nn.Parameter(
torch.randn(1, self.num_positions, self.embed_dim))
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values.to(
dtype=target_dtype)) # shape = [*, width, grid, grid]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
position_embeds = self.position_embedding.to(target_dtype)
embeddings = embeddings + position_embeds[:, :embeddings.size(1), :]
return embeddings
class BlipMLP(nn.Module):
def __init__(self,
config: BlipVisionConfig,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.config = config
self.activation_fn = get_act_fn(config.hidden_act)
self.fc1 = ColumnParallelLinear(config.hidden_size,
config.intermediate_size,
bias=True,
quant_config=quant_config)
self.fc2 = RowParallelLinear(config.intermediate_size,
config.hidden_size,
bias=True,
quant_config=quant_config)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states, _ = self.fc2(hidden_states)
return hidden_states
class BlipEncoderLayer(nn.Module):
def __init__(self,
config: BlipVisionConfig,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.self_attn = BlipAttention(config)
self.layer_norm1 = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.mlp = BlipMLP(config, quant_config=quant_config)
self.layer_norm2 = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states, _ = self.self_attn(hidden_states=hidden_states)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class BlipEncoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self
attention layers. Each layer is a [`BlipEncoderLayer`].
Args:
config: BlipConfig
"""
def __init__(self,
config: BlipVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
num_hidden_layers_override: Optional[int] = None):
super().__init__()
self.config = config
if num_hidden_layers_override is None:
num_hidden_layers = config.num_hidden_layers
else:
num_hidden_layers = num_hidden_layers_override
self.layers = nn.ModuleList([
BlipEncoderLayer(config=config, quant_config=quant_config)
for _ in range(num_hidden_layers)
])
def forward(self, inputs_embeds: torch.Tensor):
hidden_states = inputs_embeds
for encoder_layer in self.layers:
hidden_states = encoder_layer(hidden_states)
return hidden_states
class BlipVisionModel(nn.Module):
config_class = BlipVisionConfig
main_input_name = "pixel_values"
def __init__(self,
config: BlipVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
num_hidden_layers_override: Optional[int] = None):
super().__init__()
self.config = config
self.embeddings = BlipVisionEmbeddings(config)
self.encoder = BlipEncoder(
config=config,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override,
)
self.post_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
hidden_states = self.embeddings(pixel_values)
hidden_states = self.encoder(inputs_embeds=hidden_states)
return self.post_layernorm(hidden_states)

View File

@ -0,0 +1,669 @@
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
import torch
import torch.nn as nn
from transformers import (Blip2Config, Blip2QFormerConfig, Blip2VisionConfig,
apply_chunking_to_forward)
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.opt import OPTModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData
from .blip import (BlipVisionModel, dummy_image_for_blip,
get_max_blip_image_tokens)
from .interfaces import SupportsVision
from .utils import merge_vision_embeddings
_KEYS_TO_MODIFY_MAPPING = {
"language_model.lm_head": "lm_head",
"language_model.model": "language_model",
}
class Blip2QFormerMultiHeadAttention(nn.Module):
def __init__(
self,
config: Blip2QFormerConfig,
*,
quant_config: Optional[QuantizationConfig],
cache_config: Optional[CacheConfig],
is_cross_attention: bool = False,
) -> None:
super().__init__()
self.config = config
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
f"The hidden size ({config.hidden_size}) is not a multiple of "
f"the number of attention heads ({config.num_attention_heads})"
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = (config.hidden_size //
config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.scaling = self.attention_head_size**-0.5
self.query = nn.Linear(config.hidden_size, self.all_head_size)
if is_cross_attention:
kv_hidden_size = config.encoder_hidden_size
else:
kv_hidden_size = config.hidden_size
self.key = nn.Linear(kv_hidden_size, self.all_head_size)
self.value = nn.Linear(kv_hidden_size, self.all_head_size)
self.position_embedding_type = getattr(config,
"position_embedding_type",
"absolute")
if self.position_embedding_type != "absolute":
raise NotImplementedError("Unsupported position_embedding_type: "
f"{self.position_embedding_type}")
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x):
x = x.view(*x.size()[:-1], self.num_attention_heads,
self.attention_head_size)
return x.permute(0, 2, 1, 3)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
):
is_cross_attention = encoder_hidden_states is not None
if is_cross_attention:
key_layer = self.transpose_for_scores(
self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(
self.value(encoder_hidden_states))
else:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
mixed_query_layer = self.query(hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer)
attention_scores = torch.matmul(query_layer,
key_layer.transpose(-1, -2))
attention_probs = torch.softmax(attention_scores * self.scaling,
dim=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs_dropped = self.dropout(attention_probs)
context_layer = torch.matmul(attention_probs_dropped, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
context_layer = context_layer.view(*context_layer.size()[:-2],
self.all_head_size)
return context_layer
class Blip2QFormerSelfOutput(nn.Module):
def __init__(self, config: Blip2QFormerConfig) -> None:
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(
self,
hidden_states: torch.Tensor,
input_tensor: torch.Tensor,
) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class Blip2QFormerAttention(nn.Module):
def __init__(
self,
config: Blip2QFormerConfig,
*,
quant_config: Optional[QuantizationConfig],
cache_config: Optional[CacheConfig],
is_cross_attention: bool = False,
) -> None:
super().__init__()
self.attention = Blip2QFormerMultiHeadAttention(
config,
quant_config=quant_config,
cache_config=cache_config,
is_cross_attention=is_cross_attention,
)
self.output = Blip2QFormerSelfOutput(config)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
) -> Tuple[torch.Tensor]:
self_output = self.attention(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
)
attention_output = self.output(self_output, hidden_states)
return attention_output
class Blip2QFormerIntermediate(nn.Module):
def __init__(self, config: Blip2QFormerConfig) -> None:
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
self.intermediate_act_fn = get_act_fn(config.hidden_act)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class Blip2QFormerOutput(nn.Module):
def __init__(self, config: Blip2QFormerConfig) -> None:
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(
self,
hidden_states: torch.Tensor,
input_tensor: torch.Tensor,
) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class Blip2QFormerLayer(nn.Module):
def __init__(
self,
config: Blip2QFormerConfig,
*,
quant_config: Optional[QuantizationConfig],
cache_config: Optional[CacheConfig],
layer_idx: int,
) -> None:
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = Blip2QFormerAttention(config,
quant_config=quant_config,
cache_config=cache_config)
self.layer_idx = layer_idx
if layer_idx % config.cross_attention_frequency == 0:
self.crossattention = Blip2QFormerAttention(
config,
quant_config=quant_config,
cache_config=cache_config,
is_cross_attention=True)
self.has_cross_attention = True
else:
self.has_cross_attention = False
self.intermediate_query = Blip2QFormerIntermediate(config)
self.output_query = Blip2QFormerOutput(config)
def forward(
self,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor,
query_length: int,
):
attention_output = self.attention(hidden_states)
if query_length > 0:
query_attention_output = attention_output[:, :query_length, :]
if self.has_cross_attention:
query_attention_output = self.crossattention(
query_attention_output,
encoder_hidden_states=encoder_hidden_states,
)
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk_query,
self.chunk_size_feed_forward,
self.seq_len_dim,
query_attention_output,
)
if attention_output.shape[1] > query_length:
layer_output_text = apply_chunking_to_forward(
self.feed_forward_chunk,
self.chunk_size_feed_forward,
self.seq_len_dim,
attention_output[:, query_length:, :],
)
layer_output = torch.cat([layer_output, layer_output_text],
dim=1)
else:
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk,
self.chunk_size_feed_forward,
self.seq_len_dim,
attention_output,
)
return layer_output
def feed_forward_chunk(self,
attention_output: torch.Tensor) -> torch.Tensor:
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output
def feed_forward_chunk_query(
self, attention_output: torch.Tensor) -> torch.Tensor:
intermediate_output = self.intermediate_query(attention_output)
layer_output = self.output_query(intermediate_output, attention_output)
return layer_output
class Blip2QFormerEncoder(nn.Module):
def __init__(
self,
config: Blip2QFormerConfig,
*,
quant_config: Optional[QuantizationConfig],
cache_config: Optional[CacheConfig],
) -> None:
super().__init__()
self.config = config
self.layer = nn.ModuleList([
Blip2QFormerLayer(config,
quant_config=quant_config,
cache_config=cache_config,
layer_idx=layer_idx)
for layer_idx in range(config.num_hidden_layers)
])
def forward(
self,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor,
query_length: int,
) -> torch.Tensor:
for i in range(self.config.num_hidden_layers):
layer_module = self.layer[i]
hidden_states = layer_module(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
query_length=query_length,
)
return hidden_states
# Adapted from https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/blip_2/modeling_blip_2.py#L1025
class Blip2QFormerModel(nn.Module):
def __init__(
self,
config: Blip2QFormerConfig,
*,
quant_config: Optional[QuantizationConfig],
cache_config: Optional[CacheConfig],
) -> None:
super().__init__()
self.config = config
self.layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.encoder = Blip2QFormerEncoder(config,
quant_config=quant_config,
cache_config=cache_config)
def forward(
self,
query_embeds: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor,
) -> torch.Tensor:
query_length = query_embeds.shape[1]
embedding_output = self.layernorm(query_embeds)
embedding_output = self.dropout(embedding_output)
sequence_output = self.encoder(
embedding_output,
encoder_hidden_states=encoder_hidden_states,
query_length=query_length,
)
return sequence_output
class Blip2ImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: (batch_size, num_channels, height, width)"""
Blip2ImageInputs = Blip2ImagePixelInputs
# We use this internally as placeholders since there is no image token
# defined on the HuggingFace repo
BLIP2_IMAGE_TOKEN = "<image>"
BLIP2_IMAGE_TOKEN_ID = 50265
def get_blip2_image_feature_size(hf_config: Blip2Config) -> int:
return hf_config.num_query_tokens
def get_max_blip2_image_tokens(ctx: InputContext):
hf_config = ctx.get_hf_config(Blip2Config)
vision_config = hf_config.vision_config
if isinstance(vision_config, Blip2VisionConfig):
return get_max_blip_image_tokens(vision_config)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
def dummy_data_for_blip2(ctx: InputContext, seq_len: int):
hf_config = ctx.get_hf_config(Blip2Config)
vision_config = hf_config.vision_config
image_feature_size = get_blip2_image_feature_size(hf_config)
token_ids = [BLIP2_IMAGE_TOKEN_ID] * image_feature_size
token_ids += [0] * (seq_len - image_feature_size)
seq_data = SequenceData(token_ids)
if isinstance(vision_config, Blip2VisionConfig):
mm_data = dummy_image_for_blip(vision_config)
return seq_data, mm_data
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
def input_processor_for_blip2(ctx: InputContext, llm_inputs: LLMInputs):
multi_modal_data = llm_inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return llm_inputs
hf_config = ctx.get_hf_config(Blip2Config)
image_feature_size = get_blip2_image_feature_size(hf_config)
# The original model places image tokens at the front
# https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/blip_2/modeling_blip_2.py#L1514
new_token_ids = [BLIP2_IMAGE_TOKEN_ID] * image_feature_size
new_token_ids += llm_inputs["prompt_token_ids"]
new_prompt = llm_inputs.get("prompt")
if new_prompt is not None:
new_prompt = BLIP2_IMAGE_TOKEN * image_feature_size + new_prompt
return LLMInputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data)
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_blip2_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_blip2)
@INPUT_REGISTRY.register_input_processor(input_processor_for_blip2)
class Blip2ForConditionalGeneration(nn.Module, SupportsVision):
def __init__(self,
config: Blip2Config,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__()
self.config = config
self.multimodal_config = multimodal_config
# TODO: Optionally initializes this for supporting embeddings.
self.vision_model = BlipVisionModel(config.vision_config)
self.query_tokens = nn.Parameter(
torch.zeros(1, config.num_query_tokens,
config.qformer_config.hidden_size))
self.qformer = Blip2QFormerModel(config.qformer_config,
cache_config=cache_config,
quant_config=quant_config)
self.language_projection = nn.Linear(
config.qformer_config.hidden_size,
config.text_config.hidden_size,
bias=True,
)
self.quant_config = quant_config
self.language_model = OPTModel(config.text_config, cache_config,
quant_config)
self.unpadded_vocab_size = config.text_config.vocab_size
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size)
self.sampler = Sampler()
def get_lm_head(self):
return self.language_model.decoder.embed_tokens
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
actual_dims = tuple(data.shape[1:])
if actual_dims != expected_dims:
expected_expr = ("batch_size", *map(str, expected_dims))
raise ValueError(
f"The expected shape of pixel values is {expected_expr}. "
f"You supplied {tuple(data.shape)}.")
return data
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[Blip2ImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
if pixel_values is None:
return None
if not isinstance(pixel_values, torch.Tensor):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
return Blip2ImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
)
def _image_pixels_to_features(self, vision_model: BlipVisionModel,
pixel_values: torch.Tensor) -> torch.Tensor:
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
image_features = vision_model(pixel_values)
return image_features
def _process_image_pixels(self,
inputs: Blip2ImagePixelInputs) -> torch.Tensor:
assert self.vision_model is not None
pixel_values = inputs["data"]
return self._image_pixels_to_features(self.vision_model, pixel_values)
def _process_image_input(self,
image_input: Blip2ImageInputs) -> torch.Tensor:
assert self.vision_model is not None
image_features = self._process_image_pixels(image_input)
query_tokens = self.query_tokens.expand(image_features.shape[0], -1,
-1)
query_output = self.qformer(
query_embeds=query_tokens,
encoder_hidden_states=image_features,
)
return self.language_projection(query_output)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object,
) -> SamplerOutput:
"""Run forward pass for BLIP-2.
One key thing to understand is the `input_ids` already accounts for the
positions of the to-be-inserted image embeddings.
Concretely, consider a text prompt:
`"Question: What's the content of the image? Answer:"`.
Tokenizer outputs:
`[2, 45641, 35, 653, 18, 5, 1383, 9, 5, 2274, 116, 31652, 35]`.
To reserve space in KV cache, we have to insert placeholder tokens
before they are inputted to the model, so the input processor prepends
dummy tokens (denoted as `50265`), resulting in:
`[50265, ..., 50265, 2, 45641, 35, ..., 31652, 35]`.
We insert 32 tokens since it corresponds to the number of query
embeddings outputted by the Q-Former and inputted to the language model.
This way, the `positions` and `attn_metadata` are consistent
with the `input_ids`.
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
pixel_values: The pixels in each input image.
See also:
:class:`Blip2ImageInputs`
"""
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None:
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds,
vision_embeddings,
BLIP2_IMAGE_TOKEN_ID)
input_ids = None
else:
inputs_embeds = None
hidden_states = self.language_model(input_ids,
positions,
kv_caches,
attn_metadata,
inputs_embeds=inputs_embeds)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.get_lm_head(), hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# only doing this for language model part for now.
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "lm_head.weight" in name:
continue
if "rotary_emb.inv_freq" in name:
continue
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in name:
name = name.replace(key_to_modify, new_key)
use_default_weight_loading = False
if "vision" in name:
if self.vision_model is not None:
# We only do sharding for language model and
# not vision model for now.
use_default_weight_loading = True
else:
for (param_name, weight_name,
shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
param = params_dict[name.replace(weight_name, param_name)]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
use_default_weight_loading = True
if use_default_weight_loading:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@ -237,14 +237,19 @@ class OPTDecoder(nn.Module):
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.embed_tokens(input_ids) if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings(input_ids)
pos_embeds = self.embed_positions(positions) pos_embeds = self.embed_positions(positions)
if self.project_in is not None: if self.project_in is not None:
inputs_embeds, _ = self.project_in(inputs_embeds) inputs_embeds, _ = self.project_in(inputs_embeds)
@ -272,14 +277,22 @@ class OPTModel(nn.Module):
super().__init__() super().__init__()
self.decoder = OPTDecoder(config, cache_config, quant_config) self.decoder = OPTDecoder(config, cache_config, quant_config)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.decoder.get_input_embeddings(input_ids)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
return self.decoder(input_ids, positions, kv_caches, attn_metadata) return self.decoder(input_ids,
positions,
kv_caches,
attn_metadata,
inputs_embeds=inputs_embeds)
class OPTForCausalLM(nn.Module): class OPTForCausalLM(nn.Module):

View File

@ -1,8 +1,9 @@
import sys import sys
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import UserDict, defaultdict from collections import UserDict, defaultdict
from typing import (Any, Callable, Dict, List, Optional, Type, TypedDict, from typing import Any, Callable, Dict, List, Optional
TypeVar, Union, cast) from typing import Sequence as GenericSequence
from typing import Type, TypedDict, TypeVar, Union, cast
import torch import torch
import torch.types import torch.types
@ -15,13 +16,13 @@ from vllm.logger import init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
NestedTensors = Union[List[torch.Tensor], torch.Tensor] NestedTensors = Union[GenericSequence[torch.Tensor], torch.Tensor]
""" """
Use a list instead of a tensor if the dimensions of each element do not match. Use a list instead of a tensor if the dimensions of each element do not match.
Currently only supports up to singly nested list of tensors. Currently only supports up to singly nested list of tensors.
""" """
BatchedTensors = Union[List[NestedTensors], NestedTensors] BatchedTensors = Union[GenericSequence[NestedTensors], NestedTensors]
""" """
If each input tensor in the batch has the same size, this is a single batched If each input tensor in the batch has the same size, this is a single batched
tensor; otherwise, this is a list of :class:`NestedTensors` with one element tensor; otherwise, this is a list of :class:`NestedTensors` with one element
@ -53,7 +54,7 @@ class MultiModalInputs(_MultiModalInputsBase):
# may be list rather than tensors # may be list rather than tensors
if isinstance(tensors[0], list): if isinstance(tensors[0], list):
return [[t.to(device=device) for t in tensor[0]] return [[t.to(device=device) for t in tensor[0]]
for tensor in tensors] for tensor in cast(List[List[torch.Tensor]], tensors)]
tensors_ = cast(List[torch.Tensor], tensors) tensors_ = cast(List[torch.Tensor], tensors)