[Model] Initial support for BLIP-2 (#5920)
Co-authored-by: ywang96 <ywang@roblox.com>
This commit is contained in:
parent
ecb33a28cb
commit
1ad86acf17
@ -7,6 +7,8 @@ vLLM supports a variety of generative Transformer models in `HuggingFace Transfo
|
||||
The following is the list of model architectures that are currently supported by vLLM.
|
||||
Alongside each architecture, we include some popular models that use it.
|
||||
|
||||
----
|
||||
|
||||
Decoder-only Language Models
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
.. list-table::
|
||||
@ -186,6 +188,10 @@ Vision Language Models
|
||||
- Models
|
||||
- Example HuggingFace Models
|
||||
- :ref:`LoRA <lora>`
|
||||
* - :code:`Blip2ForConditionalGeneration`
|
||||
- BLIP-2
|
||||
- :code:`Salesforce/blip2-opt-2.7b`, :code:`Salesforce/blip2-opt-6.7b`, etc.
|
||||
-
|
||||
* - :code:`ChameleonForConditionalGeneration`
|
||||
- Chameleon
|
||||
- :code:`facebook/chameleon-7b` etc.
|
||||
@ -215,6 +221,8 @@ Vision Language Models
|
||||
- :code:`openbmb/MiniCPM-V-2`, :code:`openbmb/MiniCPM-Llama3-V-2_5`, etc.
|
||||
-
|
||||
|
||||
----
|
||||
|
||||
If your model uses one of the above model architectures, you can seamlessly run your model with vLLM.
|
||||
Otherwise, please refer to :ref:`Adding a New Model <adding_a_new_model>` and :ref:`Enabling Multimodal Inputs <enabling_multimodal_inputs>`
|
||||
for instructions on how to implement support for your model.
|
||||
|
@ -106,6 +106,16 @@ def run_minicpmv(question):
|
||||
return llm, prompt
|
||||
|
||||
|
||||
# BLIP-2
|
||||
def run_blip2(question):
|
||||
|
||||
# BLIP-2 prompt format is inaccurate on HuggingFace model repository.
|
||||
# See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa
|
||||
prompt = f"Question: {question} Answer:"
|
||||
llm = LLM(model="Salesforce/blip2-opt-2.7b")
|
||||
return llm, prompt
|
||||
|
||||
|
||||
model_example_map = {
|
||||
"llava": run_llava,
|
||||
"llava-next": run_llava_next,
|
||||
@ -114,6 +124,7 @@ model_example_map = {
|
||||
"paligemma": run_paligemma,
|
||||
"chameleon": run_chameleon,
|
||||
"minicpmv": run_minicpmv,
|
||||
"blip-2": run_blip2,
|
||||
}
|
||||
|
||||
|
||||
|
11
examples/template_blip2.jinja
Normal file
11
examples/template_blip2.jinja
Normal file
@ -0,0 +1,11 @@
|
||||
{%- for message in messages -%}
|
||||
{%- if message['role'] == 'user' -%}
|
||||
{{- 'Question: ' + message['content'] + ' ' -}}
|
||||
{%- elif message['role'] == 'assistant' -%}
|
||||
{{- 'Answer: ' + message['content'] + ' ' -}}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
|
||||
{%- if add_generation_prompt -%}
|
||||
{{- 'Answer:' -}}
|
||||
{% endif %}
|
102
tests/models/test_blip2.py
Normal file
102
tests/models/test_blip2.py
Normal file
@ -0,0 +1,102 @@
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.multimodal.utils import rescale_image_size
|
||||
from vllm.sequence import SampleLogprobs
|
||||
|
||||
from ..conftest import IMAGE_ASSETS
|
||||
from .utils import check_logprobs_close
|
||||
|
||||
pytestmark = pytest.mark.vlm
|
||||
|
||||
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
|
||||
"stop_sign":
|
||||
"Question: What's the content of the image? Answer:",
|
||||
"cherry_blossom":
|
||||
"Question: What is the season? Answer:",
|
||||
})
|
||||
|
||||
|
||||
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
|
||||
Optional[SampleLogprobs]],
|
||||
model: str):
|
||||
"""Sanitize vllm output to be comparable with hf output."""
|
||||
_, output_str, out_logprobs = vllm_output
|
||||
|
||||
hf_output_str = output_str + "\n"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model)
|
||||
hf_output_ids = tokenizer.encode(hf_output_str)
|
||||
assert hf_output_ids[0] == tokenizer.bos_token_id
|
||||
hf_output_ids = hf_output_ids[1:]
|
||||
|
||||
return hf_output_ids, hf_output_str, out_logprobs
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["Salesforce/blip2-opt-2.7b"])
|
||||
@pytest.mark.parametrize(
|
||||
"size_factors",
|
||||
[
|
||||
# No image
|
||||
[],
|
||||
# Single-scale
|
||||
[1.0],
|
||||
# Single-scale, batched
|
||||
[1.0, 1.0, 1.0],
|
||||
# Multi-scale
|
||||
[0.25, 0.5, 1.0],
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
|
||||
dtype: str, max_tokens: int, num_logprobs: int) -> None:
|
||||
"""Inference result should be the same between hf and vllm.
|
||||
|
||||
All the image fixtures for the test is under tests/images.
|
||||
For huggingface runner, we provide the PIL images as input.
|
||||
For vllm runner, we provide MultiModalData objects and corresponding
|
||||
vision language config as input.
|
||||
Note, the text input is also adjusted to abide by vllm contract.
|
||||
The text output is sanitized to be able to compare with hf.
|
||||
"""
|
||||
images = [asset.pil_image for asset in image_assets]
|
||||
|
||||
inputs_per_image = [(
|
||||
[prompt for _ in size_factors],
|
||||
[rescale_image_size(image, factor) for factor in size_factors],
|
||||
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
|
||||
|
||||
# max_model_len should be greater than image_feature_size
|
||||
with vllm_runner(model, dtype=dtype, enforce_eager=True) as vllm_model:
|
||||
vllm_outputs_per_image = [
|
||||
vllm_model.generate_greedy_logprobs(prompts,
|
||||
max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
images=images)
|
||||
for prompts, images in inputs_per_image
|
||||
]
|
||||
|
||||
with hf_runner(model, dtype=dtype, is_vision_model=True) as hf_model:
|
||||
hf_outputs_per_image = [
|
||||
hf_model.generate_greedy_logprobs_limit(prompts,
|
||||
max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
images=images)
|
||||
for prompts, images in inputs_per_image
|
||||
]
|
||||
|
||||
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
|
||||
vllm_outputs_per_image):
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=[
|
||||
vllm_to_hf_output(vllm_output, model)
|
||||
for vllm_output in vllm_outputs
|
||||
],
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
@ -77,8 +77,8 @@ def run_test(
|
||||
vllm_model.generate_greedy_logprobs(prompts,
|
||||
max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
images=vllm_images)
|
||||
for prompts, vllm_images in inputs_per_image
|
||||
images=images)
|
||||
for prompts, images in inputs_per_image
|
||||
]
|
||||
|
||||
with hf_runner(model, dtype=dtype) as hf_model:
|
||||
@ -89,9 +89,9 @@ def run_test(
|
||||
hf_model.generate_greedy_logprobs_limit(prompts,
|
||||
max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
images=hf_images,
|
||||
images=images,
|
||||
eos_token_id=eos_token_id)
|
||||
for prompts, hf_images in inputs_per_image
|
||||
for prompts, images in inputs_per_image
|
||||
]
|
||||
|
||||
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
|
||||
|
@ -88,9 +88,9 @@ def run_test(
|
||||
vllm_model.generate_greedy_logprobs(prompts,
|
||||
max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
images=vllm_images,
|
||||
images=images,
|
||||
stop_token_ids=stop_token_ids)
|
||||
for prompts, vllm_images in inputs_per_image
|
||||
for prompts, images in inputs_per_image
|
||||
]
|
||||
|
||||
with hf_runner(model, dtype=dtype) as hf_model, torch.no_grad():
|
||||
@ -114,9 +114,9 @@ def run_test(
|
||||
hf_model.generate_greedy_logprobs_limit(prompts,
|
||||
max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
images=hf_images,
|
||||
images=images,
|
||||
tokenizer=tokenizer)
|
||||
for prompts, hf_images in inputs_per_image
|
||||
for prompts, images in inputs_per_image
|
||||
]
|
||||
|
||||
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
|
||||
|
@ -101,8 +101,8 @@ def run_test(
|
||||
vllm_model.generate_greedy_logprobs(prompts,
|
||||
max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
images=vllm_images)
|
||||
for prompts, vllm_images in inputs_per_image
|
||||
images=images)
|
||||
for prompts, images in inputs_per_image
|
||||
]
|
||||
|
||||
# use eager mode for hf runner, since phi3_v didn't work with flash_attn
|
||||
@ -114,9 +114,9 @@ def run_test(
|
||||
hf_model.generate_greedy_logprobs_limit(prompts,
|
||||
max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
images=hf_images,
|
||||
images=images,
|
||||
eos_token_id=eos_token_id)
|
||||
for prompts, hf_images in inputs_per_image
|
||||
for prompts, images in inputs_per_image
|
||||
]
|
||||
|
||||
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
|
||||
|
@ -16,6 +16,8 @@ _GENERATION_MODELS = {
|
||||
"BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b
|
||||
"BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b
|
||||
"BloomForCausalLM": ("bloom", "BloomForCausalLM"),
|
||||
"Blip2ForConditionalGeneration":
|
||||
("blip2", "Blip2ForConditionalGeneration"),
|
||||
"ChameleonForConditionalGeneration":
|
||||
("chameleon", "ChameleonForConditionalGeneration"),
|
||||
"ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
|
||||
@ -56,8 +58,8 @@ _GENERATION_MODELS = {
|
||||
"OPTForCausalLM": ("opt", "OPTForCausalLM"),
|
||||
"OrionForCausalLM": ("orion", "OrionForCausalLM"),
|
||||
"PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
|
||||
"PaliGemmaForConditionalGeneration":
|
||||
("paligemma", "PaliGemmaForConditionalGeneration"),
|
||||
"PaliGemmaForConditionalGeneration": ("paligemma",
|
||||
"PaliGemmaForConditionalGeneration"),
|
||||
"PhiForCausalLM": ("phi", "PhiForCausalLM"),
|
||||
"Phi3ForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
|
||||
|
269
vllm/model_executor/models/blip.py
Normal file
269
vllm/model_executor/models/blip.py
Normal file
@ -0,0 +1,269 @@
|
||||
"""Minimal implementation of BlipVisionModel intended to be only used
|
||||
within a vision language model."""
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from PIL import Image
|
||||
from transformers import Blip2VisionConfig, BlipVisionConfig
|
||||
from transformers.models.blip.modeling_blip import BlipAttention
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.inputs import LLMInputs
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.multimodal.image import (cached_get_tokenizer,
|
||||
repeat_and_pad_image_tokens)
|
||||
from vllm.sequence import SequenceData
|
||||
|
||||
|
||||
def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
|
||||
assert image_size % patch_size == 0
|
||||
return image_size // patch_size
|
||||
|
||||
|
||||
def get_blip_num_patches(*, image_size: int, patch_size: int) -> int:
|
||||
grid_length = get_blip_patch_grid_length(image_size=image_size,
|
||||
patch_size=patch_size)
|
||||
return grid_length * grid_length
|
||||
|
||||
|
||||
def get_blip_image_feature_size(
|
||||
hf_config: Union[BlipVisionConfig, Blip2VisionConfig], ) -> int:
|
||||
return get_blip_num_patches(image_size=hf_config.image_size,
|
||||
patch_size=hf_config.patch_size)
|
||||
|
||||
|
||||
def get_max_blip_image_tokens(
|
||||
hf_config: Union[BlipVisionConfig, Blip2VisionConfig], ) -> int:
|
||||
return get_blip_image_feature_size(hf_config)
|
||||
|
||||
|
||||
def dummy_seq_data_for_blip(
|
||||
hf_config: Union[BlipVisionConfig, Blip2VisionConfig],
|
||||
seq_len: int,
|
||||
*,
|
||||
image_token_id: int,
|
||||
image_feature_size_override: Optional[int] = None,
|
||||
):
|
||||
if image_feature_size_override is None:
|
||||
image_feature_size = get_blip_image_feature_size(hf_config)
|
||||
else:
|
||||
image_feature_size = image_feature_size_override
|
||||
|
||||
token_ids = [image_token_id] * image_feature_size
|
||||
token_ids += [0] * (seq_len - image_feature_size)
|
||||
return SequenceData(token_ids)
|
||||
|
||||
|
||||
def dummy_image_for_blip(
|
||||
hf_config: Union[BlipVisionConfig, Blip2VisionConfig],
|
||||
*,
|
||||
image_width_override: Optional[int] = None,
|
||||
image_height_override: Optional[int] = None,
|
||||
):
|
||||
width = height = hf_config.image_size
|
||||
if image_width_override is not None:
|
||||
width = image_width_override
|
||||
if image_height_override is not None:
|
||||
height = image_height_override
|
||||
|
||||
image = Image.new("RGB", (width, height), color=0)
|
||||
return {"image": image}
|
||||
|
||||
|
||||
def input_processor_for_blip(
|
||||
model_config: ModelConfig,
|
||||
hf_config: Union[BlipVisionConfig, Blip2VisionConfig],
|
||||
llm_inputs: LLMInputs,
|
||||
*,
|
||||
image_token_id: int,
|
||||
image_feature_size_override: Optional[int] = None,
|
||||
):
|
||||
multi_modal_data = llm_inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||
return llm_inputs
|
||||
|
||||
tokenizer = cached_get_tokenizer(model_config.tokenizer)
|
||||
|
||||
if image_feature_size_override is None:
|
||||
image_feature_size = get_blip_image_feature_size(hf_config)
|
||||
else:
|
||||
image_feature_size = image_feature_size_override
|
||||
|
||||
new_prompt, new_token_ids = repeat_and_pad_image_tokens(
|
||||
tokenizer,
|
||||
llm_inputs.get("prompt"),
|
||||
llm_inputs["prompt_token_ids"],
|
||||
image_token_id=image_token_id,
|
||||
repeat_count=image_feature_size,
|
||||
)
|
||||
|
||||
# NOTE: Create a defensive copy of the original inputs
|
||||
return LLMInputs(prompt_token_ids=new_token_ids,
|
||||
prompt=new_prompt,
|
||||
multi_modal_data=multi_modal_data)
|
||||
|
||||
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/blip/modeling_blip.py#L164 # noqa
|
||||
class BlipVisionEmbeddings(nn.Module):
|
||||
|
||||
def __init__(self, config: BlipVisionConfig):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.image_size = config.image_size
|
||||
self.patch_size = config.patch_size
|
||||
|
||||
self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim))
|
||||
|
||||
self.patch_embedding = nn.Conv2d(
|
||||
in_channels=3,
|
||||
out_channels=self.embed_dim,
|
||||
kernel_size=self.patch_size,
|
||||
stride=self.patch_size,
|
||||
)
|
||||
|
||||
self.num_patches = get_blip_num_patches(image_size=self.image_size,
|
||||
patch_size=self.patch_size)
|
||||
self.num_positions = self.num_patches + 1
|
||||
|
||||
self.position_embedding = nn.Parameter(
|
||||
torch.randn(1, self.num_positions, self.embed_dim))
|
||||
|
||||
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||
batch_size = pixel_values.shape[0]
|
||||
target_dtype = self.patch_embedding.weight.dtype
|
||||
patch_embeds = self.patch_embedding(pixel_values.to(
|
||||
dtype=target_dtype)) # shape = [*, width, grid, grid]
|
||||
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
||||
|
||||
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
||||
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
||||
|
||||
position_embeds = self.position_embedding.to(target_dtype)
|
||||
embeddings = embeddings + position_embeds[:, :embeddings.size(1), :]
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
class BlipMLP(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: BlipVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
|
||||
self.activation_fn = get_act_fn(config.hidden_act)
|
||||
self.fc1 = ColumnParallelLinear(config.hidden_size,
|
||||
config.intermediate_size,
|
||||
bias=True,
|
||||
quant_config=quant_config)
|
||||
self.fc2 = RowParallelLinear(config.intermediate_size,
|
||||
config.hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states, _ = self.fc1(hidden_states)
|
||||
hidden_states = self.activation_fn(hidden_states)
|
||||
hidden_states, _ = self.fc2(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BlipEncoderLayer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: BlipVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
super().__init__()
|
||||
|
||||
self.self_attn = BlipAttention(config)
|
||||
self.layer_norm1 = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
self.mlp = BlipMLP(config, quant_config=quant_config)
|
||||
self.layer_norm2 = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.layer_norm1(hidden_states)
|
||||
hidden_states, _ = self.self_attn(hidden_states=hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.layer_norm2(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BlipEncoder(nn.Module):
|
||||
"""
|
||||
Transformer encoder consisting of `config.num_hidden_layers` self
|
||||
attention layers. Each layer is a [`BlipEncoderLayer`].
|
||||
|
||||
Args:
|
||||
config: BlipConfig
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
config: BlipVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
num_hidden_layers_override: Optional[int] = None):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
|
||||
if num_hidden_layers_override is None:
|
||||
num_hidden_layers = config.num_hidden_layers
|
||||
else:
|
||||
num_hidden_layers = num_hidden_layers_override
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
BlipEncoderLayer(config=config, quant_config=quant_config)
|
||||
for _ in range(num_hidden_layers)
|
||||
])
|
||||
|
||||
def forward(self, inputs_embeds: torch.Tensor):
|
||||
hidden_states = inputs_embeds
|
||||
for encoder_layer in self.layers:
|
||||
hidden_states = encoder_layer(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BlipVisionModel(nn.Module):
|
||||
config_class = BlipVisionConfig
|
||||
main_input_name = "pixel_values"
|
||||
|
||||
def __init__(self,
|
||||
config: BlipVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
num_hidden_layers_override: Optional[int] = None):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
|
||||
self.embeddings = BlipVisionEmbeddings(config)
|
||||
self.encoder = BlipEncoder(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
num_hidden_layers_override=num_hidden_layers_override,
|
||||
)
|
||||
self.post_layernorm = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
|
||||
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.embeddings(pixel_values)
|
||||
hidden_states = self.encoder(inputs_embeds=hidden_states)
|
||||
|
||||
return self.post_layernorm(hidden_states)
|
669
vllm/model_executor/models/blip2.py
Normal file
669
vllm/model_executor/models/blip2.py
Normal file
@ -0,0 +1,669 @@
|
||||
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import (Blip2Config, Blip2QFormerConfig, Blip2VisionConfig,
|
||||
apply_chunking_to_forward)
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, MultiModalConfig
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.opt import OPTModel
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData
|
||||
|
||||
from .blip import (BlipVisionModel, dummy_image_for_blip,
|
||||
get_max_blip_image_tokens)
|
||||
from .interfaces import SupportsVision
|
||||
from .utils import merge_vision_embeddings
|
||||
|
||||
_KEYS_TO_MODIFY_MAPPING = {
|
||||
"language_model.lm_head": "lm_head",
|
||||
"language_model.model": "language_model",
|
||||
}
|
||||
|
||||
|
||||
class Blip2QFormerMultiHeadAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Blip2QFormerConfig,
|
||||
*,
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
cache_config: Optional[CacheConfig],
|
||||
is_cross_attention: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
|
||||
if config.hidden_size % config.num_attention_heads != 0:
|
||||
raise ValueError(
|
||||
f"The hidden size ({config.hidden_size}) is not a multiple of "
|
||||
f"the number of attention heads ({config.num_attention_heads})"
|
||||
)
|
||||
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = (config.hidden_size //
|
||||
config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
self.scaling = self.attention_head_size**-0.5
|
||||
|
||||
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
if is_cross_attention:
|
||||
kv_hidden_size = config.encoder_hidden_size
|
||||
else:
|
||||
kv_hidden_size = config.hidden_size
|
||||
self.key = nn.Linear(kv_hidden_size, self.all_head_size)
|
||||
self.value = nn.Linear(kv_hidden_size, self.all_head_size)
|
||||
|
||||
self.position_embedding_type = getattr(config,
|
||||
"position_embedding_type",
|
||||
"absolute")
|
||||
if self.position_embedding_type != "absolute":
|
||||
raise NotImplementedError("Unsupported position_embedding_type: "
|
||||
f"{self.position_embedding_type}")
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
def transpose_for_scores(self, x):
|
||||
x = x.view(*x.size()[:-1], self.num_attention_heads,
|
||||
self.attention_head_size)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
is_cross_attention = encoder_hidden_states is not None
|
||||
|
||||
if is_cross_attention:
|
||||
key_layer = self.transpose_for_scores(
|
||||
self.key(encoder_hidden_states))
|
||||
value_layer = self.transpose_for_scores(
|
||||
self.value(encoder_hidden_states))
|
||||
else:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
|
||||
attention_scores = torch.matmul(query_layer,
|
||||
key_layer.transpose(-1, -2))
|
||||
attention_probs = torch.softmax(attention_scores * self.scaling,
|
||||
dim=-1)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs_dropped = self.dropout(attention_probs)
|
||||
|
||||
context_layer = torch.matmul(attention_probs_dropped, value_layer)
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
context_layer = context_layer.view(*context_layer.size()[:-2],
|
||||
self.all_head_size)
|
||||
|
||||
return context_layer
|
||||
|
||||
|
||||
class Blip2QFormerSelfOutput(nn.Module):
|
||||
|
||||
def __init__(self, config: Blip2QFormerConfig) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.LayerNorm = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
input_tensor: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Blip2QFormerAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Blip2QFormerConfig,
|
||||
*,
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
cache_config: Optional[CacheConfig],
|
||||
is_cross_attention: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.attention = Blip2QFormerMultiHeadAttention(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
cache_config=cache_config,
|
||||
is_cross_attention=is_cross_attention,
|
||||
)
|
||||
|
||||
self.output = Blip2QFormerSelfOutput(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
) -> Tuple[torch.Tensor]:
|
||||
self_output = self.attention(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
attention_output = self.output(self_output, hidden_states)
|
||||
|
||||
return attention_output
|
||||
|
||||
|
||||
class Blip2QFormerIntermediate(nn.Module):
|
||||
|
||||
def __init__(self, config: Blip2QFormerConfig) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||
self.intermediate_act_fn = get_act_fn(config.hidden_act)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Blip2QFormerOutput(nn.Module):
|
||||
|
||||
def __init__(self, config: Blip2QFormerConfig) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||
self.LayerNorm = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
input_tensor: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Blip2QFormerLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Blip2QFormerConfig,
|
||||
*,
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
cache_config: Optional[CacheConfig],
|
||||
layer_idx: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
self.seq_len_dim = 1
|
||||
self.attention = Blip2QFormerAttention(config,
|
||||
quant_config=quant_config,
|
||||
cache_config=cache_config)
|
||||
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
if layer_idx % config.cross_attention_frequency == 0:
|
||||
self.crossattention = Blip2QFormerAttention(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
cache_config=cache_config,
|
||||
is_cross_attention=True)
|
||||
self.has_cross_attention = True
|
||||
else:
|
||||
self.has_cross_attention = False
|
||||
|
||||
self.intermediate_query = Blip2QFormerIntermediate(config)
|
||||
self.output_query = Blip2QFormerOutput(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor,
|
||||
query_length: int,
|
||||
):
|
||||
attention_output = self.attention(hidden_states)
|
||||
|
||||
if query_length > 0:
|
||||
query_attention_output = attention_output[:, :query_length, :]
|
||||
|
||||
if self.has_cross_attention:
|
||||
query_attention_output = self.crossattention(
|
||||
query_attention_output,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
|
||||
layer_output = apply_chunking_to_forward(
|
||||
self.feed_forward_chunk_query,
|
||||
self.chunk_size_feed_forward,
|
||||
self.seq_len_dim,
|
||||
query_attention_output,
|
||||
)
|
||||
|
||||
if attention_output.shape[1] > query_length:
|
||||
layer_output_text = apply_chunking_to_forward(
|
||||
self.feed_forward_chunk,
|
||||
self.chunk_size_feed_forward,
|
||||
self.seq_len_dim,
|
||||
attention_output[:, query_length:, :],
|
||||
)
|
||||
layer_output = torch.cat([layer_output, layer_output_text],
|
||||
dim=1)
|
||||
else:
|
||||
layer_output = apply_chunking_to_forward(
|
||||
self.feed_forward_chunk,
|
||||
self.chunk_size_feed_forward,
|
||||
self.seq_len_dim,
|
||||
attention_output,
|
||||
)
|
||||
|
||||
return layer_output
|
||||
|
||||
def feed_forward_chunk(self,
|
||||
attention_output: torch.Tensor) -> torch.Tensor:
|
||||
intermediate_output = self.intermediate(attention_output)
|
||||
layer_output = self.output(intermediate_output, attention_output)
|
||||
return layer_output
|
||||
|
||||
def feed_forward_chunk_query(
|
||||
self, attention_output: torch.Tensor) -> torch.Tensor:
|
||||
intermediate_output = self.intermediate_query(attention_output)
|
||||
layer_output = self.output_query(intermediate_output, attention_output)
|
||||
return layer_output
|
||||
|
||||
|
||||
class Blip2QFormerEncoder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Blip2QFormerConfig,
|
||||
*,
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
cache_config: Optional[CacheConfig],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
|
||||
self.layer = nn.ModuleList([
|
||||
Blip2QFormerLayer(config,
|
||||
quant_config=quant_config,
|
||||
cache_config=cache_config,
|
||||
layer_idx=layer_idx)
|
||||
for layer_idx in range(config.num_hidden_layers)
|
||||
])
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor,
|
||||
query_length: int,
|
||||
) -> torch.Tensor:
|
||||
for i in range(self.config.num_hidden_layers):
|
||||
layer_module = self.layer[i]
|
||||
|
||||
hidden_states = layer_module(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
query_length=query_length,
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/blip_2/modeling_blip_2.py#L1025
|
||||
class Blip2QFormerModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Blip2QFormerConfig,
|
||||
*,
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
cache_config: Optional[CacheConfig],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
|
||||
self.layernorm = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
self.encoder = Blip2QFormerEncoder(config,
|
||||
quant_config=quant_config,
|
||||
cache_config=cache_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query_embeds: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor,
|
||||
) -> torch.Tensor:
|
||||
query_length = query_embeds.shape[1]
|
||||
|
||||
embedding_output = self.layernorm(query_embeds)
|
||||
embedding_output = self.dropout(embedding_output)
|
||||
|
||||
sequence_output = self.encoder(
|
||||
embedding_output,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
query_length=query_length,
|
||||
)
|
||||
|
||||
return sequence_output
|
||||
|
||||
|
||||
class Blip2ImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
data: torch.Tensor
|
||||
"""Shape: (batch_size, num_channels, height, width)"""
|
||||
|
||||
|
||||
Blip2ImageInputs = Blip2ImagePixelInputs
|
||||
|
||||
# We use this internally as placeholders since there is no image token
|
||||
# defined on the HuggingFace repo
|
||||
BLIP2_IMAGE_TOKEN = "<image>"
|
||||
BLIP2_IMAGE_TOKEN_ID = 50265
|
||||
|
||||
|
||||
def get_blip2_image_feature_size(hf_config: Blip2Config) -> int:
|
||||
return hf_config.num_query_tokens
|
||||
|
||||
|
||||
def get_max_blip2_image_tokens(ctx: InputContext):
|
||||
hf_config = ctx.get_hf_config(Blip2Config)
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
if isinstance(vision_config, Blip2VisionConfig):
|
||||
return get_max_blip_image_tokens(vision_config)
|
||||
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
def dummy_data_for_blip2(ctx: InputContext, seq_len: int):
|
||||
hf_config = ctx.get_hf_config(Blip2Config)
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
image_feature_size = get_blip2_image_feature_size(hf_config)
|
||||
token_ids = [BLIP2_IMAGE_TOKEN_ID] * image_feature_size
|
||||
token_ids += [0] * (seq_len - image_feature_size)
|
||||
seq_data = SequenceData(token_ids)
|
||||
|
||||
if isinstance(vision_config, Blip2VisionConfig):
|
||||
mm_data = dummy_image_for_blip(vision_config)
|
||||
|
||||
return seq_data, mm_data
|
||||
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
def input_processor_for_blip2(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
multi_modal_data = llm_inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||
return llm_inputs
|
||||
|
||||
hf_config = ctx.get_hf_config(Blip2Config)
|
||||
image_feature_size = get_blip2_image_feature_size(hf_config)
|
||||
|
||||
# The original model places image tokens at the front
|
||||
# https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/blip_2/modeling_blip_2.py#L1514
|
||||
new_token_ids = [BLIP2_IMAGE_TOKEN_ID] * image_feature_size
|
||||
new_token_ids += llm_inputs["prompt_token_ids"]
|
||||
|
||||
new_prompt = llm_inputs.get("prompt")
|
||||
if new_prompt is not None:
|
||||
new_prompt = BLIP2_IMAGE_TOKEN * image_feature_size + new_prompt
|
||||
|
||||
return LLMInputs(prompt_token_ids=new_token_ids,
|
||||
prompt=new_prompt,
|
||||
multi_modal_data=multi_modal_data)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper()
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_blip2_image_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_blip2)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_blip2)
|
||||
class Blip2ForConditionalGeneration(nn.Module, SupportsVision):
|
||||
|
||||
def __init__(self,
|
||||
config: Blip2Config,
|
||||
multimodal_config: MultiModalConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None) -> None:
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
# TODO: Optionally initializes this for supporting embeddings.
|
||||
self.vision_model = BlipVisionModel(config.vision_config)
|
||||
|
||||
self.query_tokens = nn.Parameter(
|
||||
torch.zeros(1, config.num_query_tokens,
|
||||
config.qformer_config.hidden_size))
|
||||
|
||||
self.qformer = Blip2QFormerModel(config.qformer_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
self.language_projection = nn.Linear(
|
||||
config.qformer_config.hidden_size,
|
||||
config.text_config.hidden_size,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
self.quant_config = quant_config
|
||||
|
||||
self.language_model = OPTModel(config.text_config, cache_config,
|
||||
quant_config)
|
||||
|
||||
self.unpadded_vocab_size = config.text_config.vocab_size
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
def get_lm_head(self):
|
||||
return self.language_model.decoder.embed_tokens
|
||||
|
||||
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
|
||||
h = w = self.config.vision_config.image_size
|
||||
expected_dims = (3, h, w)
|
||||
actual_dims = tuple(data.shape[1:])
|
||||
|
||||
if actual_dims != expected_dims:
|
||||
expected_expr = ("batch_size", *map(str, expected_dims))
|
||||
raise ValueError(
|
||||
f"The expected shape of pixel values is {expected_expr}. "
|
||||
f"You supplied {tuple(data.shape)}.")
|
||||
|
||||
return data
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[Blip2ImageInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
|
||||
if pixel_values is None:
|
||||
return None
|
||||
|
||||
if not isinstance(pixel_values, torch.Tensor):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
return Blip2ImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=self._validate_pixel_values(pixel_values),
|
||||
)
|
||||
|
||||
def _image_pixels_to_features(self, vision_model: BlipVisionModel,
|
||||
pixel_values: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
# NOTE: we skip the step to select the vision feature layer since
|
||||
# this is already done inside the vision tower
|
||||
image_features = vision_model(pixel_values)
|
||||
|
||||
return image_features
|
||||
|
||||
def _process_image_pixels(self,
|
||||
inputs: Blip2ImagePixelInputs) -> torch.Tensor:
|
||||
assert self.vision_model is not None
|
||||
|
||||
pixel_values = inputs["data"]
|
||||
|
||||
return self._image_pixels_to_features(self.vision_model, pixel_values)
|
||||
|
||||
def _process_image_input(self,
|
||||
image_input: Blip2ImageInputs) -> torch.Tensor:
|
||||
assert self.vision_model is not None
|
||||
image_features = self._process_image_pixels(image_input)
|
||||
|
||||
query_tokens = self.query_tokens.expand(image_features.shape[0], -1,
|
||||
-1)
|
||||
query_output = self.qformer(
|
||||
query_embeds=query_tokens,
|
||||
encoder_hidden_states=image_features,
|
||||
)
|
||||
|
||||
return self.language_projection(query_output)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
**kwargs: object,
|
||||
) -> SamplerOutput:
|
||||
"""Run forward pass for BLIP-2.
|
||||
|
||||
One key thing to understand is the `input_ids` already accounts for the
|
||||
positions of the to-be-inserted image embeddings.
|
||||
|
||||
Concretely, consider a text prompt:
|
||||
`"Question: What's the content of the image? Answer:"`.
|
||||
|
||||
Tokenizer outputs:
|
||||
`[2, 45641, 35, 653, 18, 5, 1383, 9, 5, 2274, 116, 31652, 35]`.
|
||||
|
||||
To reserve space in KV cache, we have to insert placeholder tokens
|
||||
before they are inputted to the model, so the input processor prepends
|
||||
dummy tokens (denoted as `50265`), resulting in:
|
||||
`[50265, ..., 50265, 2, 45641, 35, ..., 31652, 35]`.
|
||||
|
||||
We insert 32 tokens since it corresponds to the number of query
|
||||
embeddings outputted by the Q-Former and inputted to the language model.
|
||||
|
||||
This way, the `positions` and `attn_metadata` are consistent
|
||||
with the `input_ids`.
|
||||
|
||||
Args:
|
||||
input_ids: Flattened (concatenated) input_ids corresponding to a
|
||||
batch.
|
||||
pixel_values: The pixels in each input image.
|
||||
|
||||
See also:
|
||||
:class:`Blip2ImageInputs`
|
||||
"""
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
||||
if image_input is not None:
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
|
||||
inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds,
|
||||
vision_embeddings,
|
||||
BLIP2_IMAGE_TOKEN_ID)
|
||||
|
||||
input_ids = None
|
||||
else:
|
||||
inputs_embeds = None
|
||||
|
||||
hidden_states = self.language_model(input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
attn_metadata,
|
||||
inputs_embeds=inputs_embeds)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.get_lm_head(), hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
# only doing this for language model part for now.
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
if "lm_head.weight" in name:
|
||||
continue
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
|
||||
if key_to_modify in name:
|
||||
name = name.replace(key_to_modify, new_key)
|
||||
use_default_weight_loading = False
|
||||
if "vision" in name:
|
||||
if self.vision_model is not None:
|
||||
# We only do sharding for language model and
|
||||
# not vision model for now.
|
||||
use_default_weight_loading = True
|
||||
else:
|
||||
for (param_name, weight_name,
|
||||
shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
param = params_dict[name.replace(weight_name, param_name)]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
use_default_weight_loading = True
|
||||
if use_default_weight_loading:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
@ -237,14 +237,19 @@ class OPTDecoder(nn.Module):
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings(input_ids)
|
||||
pos_embeds = self.embed_positions(positions)
|
||||
if self.project_in is not None:
|
||||
inputs_embeds, _ = self.project_in(inputs_embeds)
|
||||
@ -272,14 +277,22 @@ class OPTModel(nn.Module):
|
||||
super().__init__()
|
||||
self.decoder = OPTDecoder(config, cache_config, quant_config)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.decoder.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
return self.decoder(input_ids, positions, kv_caches, attn_metadata)
|
||||
return self.decoder(input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
attn_metadata,
|
||||
inputs_embeds=inputs_embeds)
|
||||
|
||||
|
||||
class OPTForCausalLM(nn.Module):
|
||||
|
@ -1,8 +1,9 @@
|
||||
import sys
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import UserDict, defaultdict
|
||||
from typing import (Any, Callable, Dict, List, Optional, Type, TypedDict,
|
||||
TypeVar, Union, cast)
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Type, TypedDict, TypeVar, Union, cast
|
||||
|
||||
import torch
|
||||
import torch.types
|
||||
@ -15,13 +16,13 @@ from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
NestedTensors = Union[List[torch.Tensor], torch.Tensor]
|
||||
NestedTensors = Union[GenericSequence[torch.Tensor], torch.Tensor]
|
||||
"""
|
||||
Use a list instead of a tensor if the dimensions of each element do not match.
|
||||
Currently only supports up to singly nested list of tensors.
|
||||
"""
|
||||
|
||||
BatchedTensors = Union[List[NestedTensors], NestedTensors]
|
||||
BatchedTensors = Union[GenericSequence[NestedTensors], NestedTensors]
|
||||
"""
|
||||
If each input tensor in the batch has the same size, this is a single batched
|
||||
tensor; otherwise, this is a list of :class:`NestedTensors` with one element
|
||||
@ -53,7 +54,7 @@ class MultiModalInputs(_MultiModalInputsBase):
|
||||
# may be list rather than tensors
|
||||
if isinstance(tensors[0], list):
|
||||
return [[t.to(device=device) for t in tensor[0]]
|
||||
for tensor in tensors]
|
||||
for tensor in cast(List[List[torch.Tensor]], tensors)]
|
||||
|
||||
tensors_ = cast(List[torch.Tensor], tensors)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user