[MODEL] Qwen Multimodal Support (Qwen-VL / Qwen-VL-Chat) (#8029)

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Alex Brooks 2024-09-05 06:48:10 -06:00 committed by GitHub
parent 8685ba1a1e
commit 9da25a88aa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 1111 additions and 209 deletions

View File

@ -242,6 +242,11 @@ Multimodal Language Models
- Image\ :sup:`+` - Image\ :sup:`+`
- :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc. - :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc.
- -
* - :code:`QWenLMHeadModel`
- Qwen
- Image
- :code:`Qwen/Qwen-VL`, :code:`Qwen/Qwen-VL-Chat`, etc.
-
* - :code:`UltravoxModel` * - :code:`UltravoxModel`
- Ultravox - Ultravox
- Audio\ :sup:`E+` - Audio\ :sup:`E+`

View File

@ -159,6 +159,20 @@ def run_blip2(question):
return llm, prompt, stop_token_ids return llm, prompt, stop_token_ids
# Qwen
def run_qwen_vl(question):
llm = LLM(
model="Qwen/Qwen-VL",
trust_remote_code=True,
max_num_seqs=5,
)
prompt = f"{question}Picture 1: <img></img>\n"
stop_token_ids = None
return llm, prompt, stop_token_ids
model_example_map = { model_example_map = {
"llava": run_llava, "llava": run_llava,
"llava-next": run_llava_next, "llava-next": run_llava_next,
@ -169,6 +183,7 @@ model_example_map = {
"minicpmv": run_minicpmv, "minicpmv": run_minicpmv,
"blip-2": run_blip2, "blip-2": run_blip2,
"internvl_chat": run_internvl, "internvl_chat": run_internvl,
"qwen_vl": run_qwen_vl,
} }

View File

@ -1,19 +1,154 @@
from typing import Type import pathlib
from typing import List, Optional, Type
import pytest import pytest
from ..conftest import HfRunner, VllmRunner from vllm.multimodal.utils import rescale_image_size
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
from .utils import check_logprobs_close from .utils import check_logprobs_close
models = ["qwen/qwen-vl"] pytestmark = pytest.mark.vlm
text_only_models = [
"Qwen/Qwen-7B-Chat" # Has no visual component
]
multimodal_models = ["Qwen/Qwen-VL"]
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign":
"Picture 1: <img></img>\nWhat's the content of the image?: ",
"cherry_blossom":
"Picture 1: <img></img>\nWhat is the season?: ",
})
@pytest.mark.parametrize("dtype", ["half"]) ### Tests for multimodal Qwen models
def run_test(
tmp_path: pathlib.PosixPath,
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
image_assets: _ImageAssets,
model: str,
*,
size_factors: List[float],
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
"""Inference result should be the same between hf and vllm.
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding MultiModalConfig as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
images = [asset.pil_image for asset in image_assets]
# Export the images to a tempdir and substitute it into the hf prompt;
# the contents between <img>/</img> will be ignored by VLLM, but the
# transformers implementation for the visual transformer parses this to
# reload it in the forward call; the contents are treated as a URL or a
# local path.
for idx, asset in enumerate(image_assets):
image_tmp_path = tmp_path / f"{asset.name}.jpg"
asset.pil_image.save(image_tmp_path)
HF_IMAGE_PROMPTS[idx] = HF_IMAGE_PROMPTS[idx].replace(
"<img></img>", f"<img>{image_tmp_path}</img>")
inputs_per_image = [(
[prompt for _ in size_factors],
[rescale_image_size(image, factor) for factor in size_factors],
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).
# max_model_len should be greater than image_feature_size
# Qwen encodes images into a fixed content size of 256
with vllm_runner(model,
max_model_len=300,
max_num_seqs=1,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True) as vllm_model:
vllm_outputs_per_image = [
vllm_model.generate_greedy_logprobs(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images)
for prompts, images in inputs_per_image
]
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs_per_image = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images)
for prompts, images in inputs_per_image
]
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
vllm_outputs_per_image):
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)
@pytest.mark.parametrize("model", multimodal_models)
@pytest.mark.parametrize(
"size_factors",
[
# No image
[],
# Single-scale
[1.0],
# Single-scale, batched
[1.0, 1.0, 1.0],
# Multi-scale
[0.25, 0.5, 1.0],
],
)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [8])
@pytest.mark.parametrize("num_logprobs", [5])
def test_multimodal_models(tmp_path, hf_runner, vllm_runner, image_assets,
model, size_factors, dtype, max_tokens,
num_logprobs) -> None:
run_test(
tmp_path,
hf_runner,
vllm_runner,
image_assets,
model,
size_factors=size_factors,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)
# Ensure that a text-only Qwen model can still be loaded and
# used for inference in VLLM without throwing.
@pytest.mark.parametrize("model", text_only_models)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("model", models) def test_text_only_qwen_model_can_be_loaded_and_run(
def test_text_only_qwen_model(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner], vllm_runner: Type[VllmRunner],
example_prompts, example_prompts,
model: str, model: str,
@ -22,27 +157,9 @@ def test_text_only_qwen_model(
max_tokens: int, max_tokens: int,
num_logprobs: int, num_logprobs: int,
): ):
# This test checks language inputs only, since the visual component
# for qwen-vl is still unsupported in VLLM. In the near-future, the
# implementation and this test will be extended to consider
# visual inputs as well.
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts,
max_tokens,
num_logprobs=num_logprobs,
)
with vllm_runner(model, dtype=dtype) as vllm_model: with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.generate_greedy_logprobs( vllm_model.generate_greedy_logprobs(
example_prompts, example_prompts,
max_tokens, max_tokens,
num_logprobs=num_logprobs, num_logprobs=num_logprobs,
) )
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)

View File

@ -150,6 +150,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
if model_type in ("blip-2", "chatglm", "fuyu", "paligemma"): if model_type in ("blip-2", "chatglm", "fuyu", "paligemma"):
# These models do not use image tokens in the prompt # These models do not use image tokens in the prompt
return None return None
if model_type == "qwen":
return f"Picture {current_count}: <img></img>"
if model_type.startswith("llava"): if model_type.startswith("llava"):
return self._cached_token_str(self._tokenizer, return self._cached_token_str(self._tokenizer,
hf_config.image_token_index) hf_config.image_token_index)

View File

@ -0,0 +1,273 @@
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
#
# Copyright 2023 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Shared resampler perceiver network used in multimodal models and
related helpers for sincos positional embeddings.
Example models: Qwen (Qwen-VL), Minicpmv2.0
"""
import math
from functools import partial
from typing import Callable, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.init import trunc_normal_
from vllm.model_executor.layers.linear import ReplicatedLinear
DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)
def get_abs_pos(abs_pos: torch.Tensor, tgt_size: Union[torch.Tensor,
int]) -> torch.Tensor:
# abs_pos: L, C
# tgt_size: (H, W)
# return: M, C
src_size = int(math.sqrt(abs_pos.size(0)))
dtype = abs_pos.dtype
if isinstance(tgt_size, int):
tgt_size = (tgt_size, tgt_size)
if (src_size == tgt_size[0] and src_size == tgt_size[1]):
return abs_pos
return (F.interpolate(
abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
size=(tgt_size[0], tgt_size[1]),
mode="bicubic",
align_corners=False,
).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype))
# sin/cos positional embedding helpers are adapted from:
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
def get_1d_sincos_pos_embed_from_grid(
embed_dim: int, pos: np.ndarray,
version: Tuple[int, int] = (2, 0)) -> torch.Tensor:
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,) / (H, W)
out: (M, D) / (H, W, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float32)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)
if version == (2, 0):
pos = pos.reshape(-1) # (M,)
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
else:
out = np.einsum("hw,d->hwd", pos, omega) # (H, W, D/2), outer product
emb_sin = np.sin(out) # (H, W, D/2)
emb_cos = np.cos(out) # (H, W, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (H, W, D)
return emb
def get_2d_sincos_pos_embed_from_grid(
embed_dim: int, grid: np.ndarray,
version: Tuple[int, int] = (2, 0)) -> torch.Tensor:
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(
embed_dim // 2, grid[0], version) # (H*W, D/2) or (H, W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(
embed_dim // 2, grid[1], version) # (H*W, D/2) or (H, W, D/2)
if version == (2, 0):
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
else:
emb = np.concatenate([emb_h, emb_w], axis=-1) # (H, W, D)
return emb
def get_2d_sincos_pos_embed(
embed_dim: int,
grid_size: Union[int, Tuple[int, int]],
cls_token: bool = False,
version: Tuple[int, int] = (2, 0),
) -> torch.Tensor:
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
if isinstance(grid_size, int):
grid_h_size, grid_w_size = grid_size, grid_size
else:
grid_h_size, grid_w_size = grid_size[0], grid_size[1]
grid_h = np.arange(grid_h_size, dtype=np.float32)
grid_w = np.arange(grid_w_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
assert isinstance(grid, np.ndarray) and \
grid.shape == (2, grid_h_size, grid_w_size)
if version == (2, 0):
grid = grid.reshape([2, 1, grid_h_size, grid_w_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version)
if cls_token:
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed],
axis=0)
else:
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version)
return pos_embed
class BaseResampler(nn.Module):
"""
A 2D perceiver-resampler network with one cross attention layers by
(grid_size**2) learnable queries and 2d sincos pos_emb.
Outputs:
A tensor with the shape of (grid_size**2, embed_dim)
"""
def __init__(
self,
num_queries: int,
embed_dim: int,
num_heads: int,
kv_dim: Optional[int] = None,
norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
do_post_projection: bool = True,
) -> None:
super().__init__()
self.num_queries = num_queries
self.embed_dim = embed_dim
self.num_heads = num_heads
self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
trunc_normal_(self.query, std=0.02)
if kv_dim is not None and kv_dim != embed_dim:
self.kv_proj = ReplicatedLinear(kv_dim, embed_dim, bias=False)
else:
# Maintain the same return value with ReplicatedLinear.forward
self.kv_proj = lambda *args, **kwargs: ( # type: ignore # noqa
nn.Identity()(*args, **kwargs),
None,
)
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
self.ln_q = norm_layer(embed_dim)
self.ln_kv = norm_layer(embed_dim)
self.do_post_projection = do_post_projection
self.ln_post = norm_layer(embed_dim) if do_post_projection else None
self.proj = nn.Parameter(
(embed_dim**-0.5) *
torch.randn(embed_dim, embed_dim)) if do_post_projection else None
def _init_weights(self, m: nn.Module) -> None:
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def _repeat(self, query, N: int):
return query.unsqueeze(1).repeat(1, N, 1)
class Resampler2(BaseResampler):
"""Resampler-perceiver network to be used for a variety of model types,
e.g., Qwen-vl / Minicpmv 2.0. The main difference is the addition of the
do_post_projection arg, which indicates whether or not there should be
a post layer normalization and projector after the attention. This is
present in minicpmv2.0, but not qwen-vl.
"""
def __init__(
self,
grid_size: int,
embed_dim: int,
num_heads: int,
kv_dim: Optional[int] = None,
norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
adaptive: bool = False,
do_post_projection: bool = True,
) -> None:
super().__init__(grid_size**2,
embed_dim,
num_heads,
kv_dim,
norm_layer,
do_post_projection=do_post_projection)
self.adaptive = adaptive
pos_embed_arr = get_2d_sincos_pos_embed(embed_dim,
grid_size,
version=(2, 0))
self.pos_embed = nn.Parameter(
torch.from_numpy(pos_embed_arr).requires_grad_(False))
self.apply(self._init_weights)
def forward(
self,
x: torch.Tensor,
tgt_sizes: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if tgt_sizes is None:
tgt_sizes = int(math.sqrt(x.size(1)))
if self.adaptive:
pos_embed_arr = get_2d_sincos_pos_embed(self.embed_dim,
tgt_sizes,
version=(2, 0))
pos_embed = torch.from_numpy(pos_embed_arr).to(device=x.device,
dtype=x.dtype)
else:
pos_embed = get_abs_pos(self.pos_embed,
tgt_sizes).to(device=x.device,
dtype=x.dtype)
x, _ = self.kv_proj(x)
x = self.ln_kv(x).permute(1, 0, 2)
N = x.shape[1]
q = self.ln_q(self.query)
out = self.attn(
self._repeat(q, N) + self.pos_embed.unsqueeze(1),
x + pos_embed.unsqueeze(1),
x,
attn_mask=attn_mask,
)[0]
x = out.permute(1, 0, 2)
if self.do_post_projection:
x = self.ln_post(x)
x = x @ self.proj
return x

View File

@ -51,7 +51,6 @@ _GENERATION_MODELS = {
"PhiForCausalLM": ("phi", "PhiForCausalLM"), "PhiForCausalLM": ("phi", "PhiForCausalLM"),
"Phi3ForCausalLM": ("llama", "LlamaForCausalLM"), "Phi3ForCausalLM": ("llama", "LlamaForCausalLM"),
"PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"), "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
"Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"), "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
"RWForCausalLM": ("falcon", "FalconForCausalLM"), "RWForCausalLM": ("falcon", "FalconForCausalLM"),
@ -88,6 +87,7 @@ _MULTIMODAL_MODELS = {
"PaliGemmaForConditionalGeneration"), "PaliGemmaForConditionalGeneration"),
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
"UltravoxModel": ("ultravox", "UltravoxModel"), "UltravoxModel": ("ultravox", "UltravoxModel"),
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
} }
_CONDITIONAL_GENERATION_MODELS = { _CONDITIONAL_GENERATION_MODELS = {
"BartModel": ("bart", "BartForConditionalGeneration"), "BartModel": ("bart", "BartForConditionalGeneration"),

View File

@ -26,11 +26,9 @@ import re
from array import array from array import array
from functools import partial from functools import partial
from typing import (Any, Callable, Iterable, List, Mapping, Optional, Tuple, from typing import (Any, Callable, Iterable, List, Mapping, Optional, Tuple,
TypedDict, Union) TypedDict)
import numpy as np
import torch import torch
import torch.nn.functional as F
import torch.types import torch.types
from PIL import Image from PIL import Image
from torch import nn from torch import nn
@ -44,6 +42,8 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.resampler import (Resampler2,
get_2d_sincos_pos_embed)
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.model_loader.utils import set_default_torch_dtype
@ -98,101 +98,6 @@ MiniCPMVImageInputs = MiniCPMVImagePixelInputs
DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6) DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)
def get_abs_pos(abs_pos: torch.Tensor, tgt_size: torch.Tensor):
# abs_pos: L, C
# tgt_size: (H, W)
# return: M, C
src_size = int(math.sqrt(abs_pos.size(0)))
# tgt_size = int(math.sqrt(tgt_size))
dtype = abs_pos.dtype
return (F.interpolate(
abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
size=(tgt_size[0], tgt_size[1]),
mode="bicubic",
align_corners=False,
).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype))
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
def get_2d_sincos_pos_embed(
embed_dim: int,
grid_size: Union[int, Tuple[int, int]],
cls_token: bool = False,
version: Tuple[int, int] = (2, 0),
):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
if isinstance(grid_size, int):
grid_h_size, grid_w_size = grid_size, grid_size
else:
grid_h_size, grid_w_size = grid_size[0], grid_size[1]
grid_h = np.arange(grid_h_size, dtype=np.float32)
grid_w = np.arange(grid_w_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
if version == (2, 0):
grid = grid.reshape([2, 1, grid_h_size, grid_w_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version)
if cls_token:
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed],
axis=0)
else:
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim: int,
grid: np.ndarray,
version: Tuple[int, int] = (2, 0)):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(
embed_dim // 2, grid[0], version) # (H*W, D/2) or (H, W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(
embed_dim // 2, grid[1], version) # (H*W, D/2) or (H, W, D/2)
if version == (2, 0):
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
else:
emb = np.concatenate([emb_h, emb_w], axis=-1) # (H, W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim: int,
pos: np.ndarray,
version: Tuple[int, int] = (2, 0)):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,) / (H, W)
out: (M, D) / (H, W, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float32)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)
if version == (2, 0):
pos = pos.reshape(-1) # (M,)
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
else:
out = np.einsum("hw,d->hwd", pos, omega) # (H, W, D/2), outer product
emb_sin = np.sin(out) # (H, W, D/2)
emb_cos = np.cos(out) # (H, W, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (H, W, D)
return emb
class BaseResampler(nn.Module): class BaseResampler(nn.Module):
""" """
A 2D perceiver-resampler network with one cross attention layers by A 2D perceiver-resampler network with one cross attention layers by
@ -245,62 +150,6 @@ class BaseResampler(nn.Module):
return query.unsqueeze(1).repeat(1, N, 1) return query.unsqueeze(1).repeat(1, N, 1)
class Resampler2(BaseResampler):
def __init__(
self,
grid_size: int,
embed_dim: int,
num_heads: int,
kv_dim: Optional[int] = None,
norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
adaptive: bool = False,
) -> None:
super().__init__(grid_size**2, embed_dim, num_heads, kv_dim,
norm_layer)
self.adaptive = adaptive
pos_embed_arr = get_2d_sincos_pos_embed(embed_dim,
grid_size,
version=(2, 0))
self.pos_embed = nn.Parameter(
torch.from_numpy(pos_embed_arr).float()).requires_grad_(False)
self.apply(self._init_weights)
def forward(
self,
x: torch.Tensor,
tgt_sizes: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
):
if self.adaptive:
pos_embed_arr = get_2d_sincos_pos_embed(self.embed_dim,
tgt_sizes,
version=(2, 0))
pos_embed = torch.from_numpy(pos_embed_arr).to(device=x.device,
dtype=x.dtype)
else:
pos_embed = get_abs_pos(self.pos_embed, tgt_sizes)
x, _ = self.kv_proj(x)
x = self.ln_kv(x).permute(1, 0, 2)
N = x.shape[1]
q = self.ln_q(self.query)
out = self.attn(
self._repeat(q, N) + self.pos_embed.unsqueeze(1),
x + pos_embed.unsqueeze(1),
x,
attn_mask=attn_mask,
)[0]
x = out.permute(1, 0, 2)
x = self.ln_post(x)
x = x @ self.proj
return x
class Resampler2_5(BaseResampler): class Resampler2_5(BaseResampler):
def __init__( def __init__(
@ -782,7 +631,8 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
num_heads=embed_dim // 128, num_heads=embed_dim // 128,
grid_size=int(math.sqrt(self.config.query_num)), grid_size=int(math.sqrt(self.config.query_num)),
kv_dim=vision_dim, kv_dim=vision_dim,
adaptive=True, adaptive=False,
do_post_projection=True,
) )
return resampler return resampler

View File

@ -4,36 +4,402 @@
# Copyright (c) Alibaba Cloud. # Copyright (c) Alibaba Cloud.
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE # LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
"""Inference-only QWen model compatible with HuggingFace weights.""" """Inference-only QWen model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Tuple
import math
import re
from array import array
from functools import partial
from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping,
Optional, Tuple, TypedDict, Union)
import numpy as np
import torch import torch
from PIL import Image
from torch import nn from torch import nn
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.resampler import Resampler2, get_abs_pos
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.utils import print_warning_once from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData)
from .utils import is_pp_missing_parameter, make_layers from .utils import flatten_bn, is_pp_missing_parameter, make_layers
logger = init_logger(__name__)
# NOTE: Qwen models have a few other special tags, e.g., ref, bbox, quad;
# for the time being, these tags are not considered as special at encoding
# time. This may change as VLLMs multimodal API changes in the future.
IMG_START = "<img>"
IMG_END = "</img>"
IMG_PAD = "<imgpad>"
# Image context is fixed at 256 for all images
MAX_QWEN_IMG_TOKENS = 256
# Image normalization params
CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073)
CLIP_STD = (0.26862954, 0.26130258, 0.27577711)
class QwenImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""
Shape: `(batch_size * num_images, 3, image_size, image_size)`
Note that image_size is the value in the vision config to which we resize
the image to in the normalization transform. Currently multi-image support
can only be leveraged by passing image embeddings directly.
"""
class QwenImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor
"""Shape: `(batch_size * num_images, 256, hidden_size)`
`hidden_size` must match the hidden size of the language model backbone
and is stored in the visual config of the model if we have one.
"""
QwenImageInputs = Union[QwenImagePixelInputs, QwenImageEmbeddingInputs]
class VisualAttention(nn.Module):
"""self-attention layer class.
Self-attention layer takes input with size [s, b, h]
and returns output of the same size.
"""
def __init__(
self,
embed_dim: int,
num_heads: int,
bias: bool = True,
kdim: Optional[int] = None,
vdim: Optional[int] = None,
):
super().__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self._qkv_same_embed_dim = self.kdim == embed_dim \
and self.vdim == embed_dim
self.num_heads = num_heads
# Per attention head and per partition values.
assert embed_dim % num_heads == 0
self.hidden_size_per_attention_head = embed_dim // num_heads
self.num_attention_heads_per_partition = num_heads
self.hidden_size_per_partition = embed_dim
# Strided linear layer.
assert self._qkv_same_embed_dim, \
'Visual Attention implementation only supports self-attention'
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
def forward(
self,
x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# query/key/value: [sq, b, h]
sq, b, _ = x.size()
mixed_x_layer = self.in_proj(x)
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
query_layer, key_layer, value_layer = mixed_x_layer.split(
self.hidden_size_per_attention_head, dim=-1)
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer = query_layer.view(
sq, b * self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head).transpose(0, 1)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer = key_layer.view(
sq, b * self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head).transpose(0, 1)
q_scaled = query_layer / self.norm_factor
if attn_mask is not None:
attention_probs = torch.baddbmm(attn_mask, q_scaled,
key_layer.transpose(-2, -1))
else:
attention_probs = torch.bmm(q_scaled, key_layer.transpose(-2, -1))
attention_probs = attention_probs.softmax(dim=-1)
value_layer = value_layer.view(
sq, b * self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head).transpose(0, 1)
# matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer)
# change view [b, np, sq, hn]
context_layer = context_layer.view(
b, self.num_attention_heads_per_partition, sq,
self.hidden_size_per_attention_head)
# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
# [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape = context_layer.size()[:-2] + \
(self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)
output = self.out_proj(context_layer)
return output
class QwenVMLP(nn.Module):
"""MLP for the visual component of the Qwen model."""
def __init__(
self,
hidden_size: int,
intermediate_size: int,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.c_fc = ColumnParallelLinear(hidden_size,
intermediate_size,
bias=True,
quant_config=quant_config)
self.act_fn = get_act_fn("gelu", quant_config, intermediate_size)
self.c_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=True,
quant_config=quant_config,
)
def forward(self, x):
x, _ = self.c_fc(x)
x = self.act_fn(x)
x, _ = self.c_proj(x)
return x
class VisualAttentionBlock(nn.Module):
def __init__(
self,
d_model: int,
n_head: int,
mlp_ratio: float = 4.0,
norm_layer: Callable = nn.LayerNorm,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.ln_1 = norm_layer(d_model)
self.ln_2 = norm_layer(d_model)
mlp_width = int(d_model * mlp_ratio)
self.attn = VisualAttention(d_model, n_head)
self.mlp = QwenVMLP(
hidden_size=d_model,
intermediate_size=mlp_width,
quant_config=quant_config,
)
def attention(
self,
x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
attn_mask = attn_mask.to(x.dtype) if attn_mask is not None else None
return self.attn(x, attn_mask=attn_mask)
def forward(
self,
x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
x = x + self.attention(self.ln_1(x), attn_mask=attn_mask)
x = x + self.mlp(self.ln_2(x))
return x
class TransformerBlock(nn.Module):
def __init__(
self,
width: int,
layers: int,
heads: int,
mlp_ratio: float = 4.0,
norm_layer: Callable = nn.LayerNorm,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.width = width
self.layers = layers
self.resblocks = nn.ModuleList([
VisualAttentionBlock(width,
heads,
mlp_ratio,
norm_layer=norm_layer,
quant_config=quant_config)
for _ in range(layers)
])
def get_cast_dtype(self) -> torch.dtype:
return self.resblocks[0].mlp.c_fc.weight.dtype
def get_cast_device(self) -> torch.device:
return self.resblocks[0].mlp.c_fc.weight.device
def forward(self,
x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
for r in self.resblocks:
x = r(x, attn_mask=attn_mask)
return x
class VisionTransformer(nn.Module):
def __init__(self,
image_size: int,
patch_size: int,
width: int,
layers: int,
heads: int,
mlp_ratio: float,
n_queries: int = 256,
output_dim: int = 512,
image_start_id: int = 151857,
quant_config: Optional[QuantizationConfig] = None,
**kwargs):
super().__init__()
image_height, image_width = self.image_size = (image_size, image_size)
patch_height, patch_width = self.patch_size = (patch_size, patch_size)
self.grid_size = (image_height // patch_height,
image_width // patch_width)
self.output_dim = output_dim
self.conv1 = nn.Conv2d(in_channels=3,
out_channels=width,
kernel_size=patch_size,
stride=patch_size,
bias=False)
# class embeddings and positional embeddings
scale = width**-0.5
self.positional_embedding = nn.Parameter(scale *
torch.randn(256, width))
norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.ln_pre = norm_layer(width)
self.transformer = TransformerBlock(width,
layers,
heads,
mlp_ratio,
norm_layer=norm_layer,
quant_config=quant_config)
self.attn_pool = Resampler2(
grid_size=int(math.sqrt(n_queries)),
embed_dim=output_dim,
num_heads=output_dim // 128,
kv_dim=width,
norm_layer=norm_layer,
adaptive=False,
do_post_projection=False,
).to(
device=self.positional_embedding.device,
dtype=self.positional_embedding.dtype,
)
self.ln_post = norm_layer(output_dim)
self.proj = nn.Parameter(
(output_dim**-0.5) * torch.randn(output_dim, output_dim))
self.image_start_id = image_start_id
self.image_end_id = image_start_id + 1
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.to(
dtype=self.transformer.get_cast_dtype(),
device=self.transformer.get_cast_device(),
)
# to patches
x = self.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1],
-1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
x = x + get_abs_pos(self.positional_embedding, int(math.sqrt(
x.size(1))))
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.attn_pool(x)
x = self.ln_post(x)
x = x @ self.proj
return x
def get_image_positions(self,
input_ids: torch.Tensor) -> Optional[torch.Tensor]:
"""Given the input IDs, extracts start/stop points corresponding to
images.
args:
Returns:
Optional torch tensor corresponding to start/stop pairs of images.
"""
if torch.any(input_ids == self.image_start_id):
bos_pos = torch.where(input_ids == self.image_start_id)
eos_pos = torch.where(input_ids == self.image_end_id)
return torch.stack((bos_pos[0], eos_pos[0]), dim=1)
return None
class QWenMLP(nn.Module): class QWenMLP(nn.Module):
"""MLP for the language component of the Qwen model, which contains a
MergedColumnParallelLinear merging 2 outputs via silu activation."""
def __init__( def __init__(
self, self,
@ -56,7 +422,7 @@ class QWenMLP(nn.Module):
"Only silu is supported for now.") "Only silu is supported for now.")
self.act_fn = SiluAndMul() self.act_fn = SiluAndMul()
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
gate_up, _ = self.gate_up_proj(x) gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up) x = self.act_fn(gate_up)
x, _ = self.c_proj(x) x, _ = self.c_proj(x)
@ -203,6 +569,9 @@ class QWenModel(nn.Module):
lambda prefix: QWenBlock(config, cache_config, quant_config), lambda prefix: QWenBlock(config, cache_config, quant_config),
prefix=f"{prefix}.h") prefix=f"{prefix}.h")
self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.visual = VisionTransformer(**config.visual,
quant_config=quant_config) if hasattr(
config, "visual") else None
def forward( def forward(
self, self,
@ -211,9 +580,33 @@ class QWenModel(nn.Module):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
pixel_values: Optional[QwenImageInputs],
) -> torch.Tensor: ) -> torch.Tensor:
img_pos = None
# If pixel / visual embeddings are provided, this is a visual model
if pixel_values is not None and self.visual is not None:
if pixel_values["type"] != "image_embeds":
image_embeds = self.visual(pixel_values["data"])
else:
image_embeds = pixel_values["data"]
# features should be of shape (# images, 256, hidden_dim)
img_pos = self.visual.get_image_positions(input_ids)
if isinstance(
img_pos,
np.ndarray) and img_pos.shape[0] != image_embeds.shape[0]:
raise ValueError(
f"Number of placeholders: {img_pos.shape[0]} "
f"does not match number of images {image_embeds.shape[0]}."
)
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
hidden_states = self.wte(input_ids) hidden_states = self.wte(input_ids)
# Merge the image embeddings into the hidden states if actually have
# visual features and the corresponding image tokens
if img_pos is not None:
for idx, (img_bos, img_eos) in enumerate(img_pos):
hidden_states[img_bos + 1:img_eos] = image_embeds[idx]
residual = None residual = None
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
@ -237,16 +630,241 @@ class QWenModel(nn.Module):
return hidden_states return hidden_states
class QWenLMHeadModel(nn.Module): def get_image_text(image_num: int, padding: bool) -> str:
"""Retrieves a placeholder text that when tokenized, will be expanded with
image pads.
Args:
image_num: The number of the image that we want a text prompt for.
Images should be indexed starting at 1.
padding: Whether or not padding should be manually added.
Returns:
Text placeholder prompt for the image being considered.
"""
image_start = f"Picture {image_num}: {IMG_START}"
image_end = f"{IMG_END}\n"
if not padding:
return f"{image_start}{image_end}"
return f"{image_start}{MAX_QWEN_IMG_TOKENS * IMG_PAD}{image_end}"
def input_processor_for_qwen(ctx: InputContext,
llm_inputs: LLMInputs) -> LLMInputs:
"""Processes the inputs, which may or may not be multimodal.
Multimodal inputs will only be processed if the model has a "visual"
component in its model config, otherwise they'll be ignored.
Args:
ctx: Context of the loaded model.
llm_inputs: LLM inputs which may have a multi_modal_data attribute.
Returns:
If the model is language only or not multimodal inputs were provided,
returns llm_inputs unmodified. Otherwise, processes the multimodal
images / image embeddings and adds the fixed-length image placeholders.
"""
multi_modal_data = llm_inputs.get("multi_modal_data")
# Only process images if we have multimodal data and a visual config
hf_config = ctx.get_hf_config()
if (multi_modal_data is None or "image" not in multi_modal_data
or not hasattr(hf_config, "visual")):
return llm_inputs
prompt = llm_inputs.get("prompt")
prompt_token_ids = llm_inputs["prompt_token_ids"]
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(model_config.tokenizer,
trust_remote_code=True)
image_data = multi_modal_data["image"]
if isinstance(image_data, torch.Tensor):
num_dims = len(image_data.shape)
if num_dims < 2 or num_dims > 3:
raise ValueError(
f"Expected img embeds to be have 3 dimensions, got {num_dims}")
num_images = 1 if num_dims == 2 else image_data.shape[0]
else:
# TODO - handle multiple image inputs once the API is solidified
num_images = 1
if prompt is None:
prompt = tokenizer.decode(prompt_token_ids)
# Drops anything between <img>/</img> tags; encoding with the tokenizer
# will automatically add the image pads for the context.
new_prompt, num_matched_images = re.subn(
r"(Picture \d*: <img>).*?(<\/img>\n)",
r"\1\2",
prompt,
)
if num_matched_images != num_images:
logger.warning(
"Number of matched image placeholders %s doesn't match the number "
"of expected images %s; check your placeholder formatting.",
num_matched_images, num_images)
new_prompt_token_ids = tokenizer.encode(new_prompt)
return LLMInputs(prompt=new_prompt,
prompt_token_ids=new_prompt_token_ids,
multi_modal_data=multi_modal_data)
def input_mapper_for_qwen(ctx: InputContext, data: object) -> MultiModalInputs:
"""Maps the input data to its MultiModalInputs (if any).
Args:
ctx: Context of the loaded model.
data: data potentially containing image/image embeddings to be mapped
to pixel_values in .forward() for a visual QWenLMHeadModel model.
Returns:
MultiModalInputs containing the stacked normalized images tensor or
image embeddings.
"""
# Early exit if we have provided an image to a language only Qwen model
hf_config = ctx.get_hf_config()
if not hasattr(hf_config, "visual"):
logger.warning(
"Images were provided but this model has no visual config; "
"multimodal inputs will not be forwarded to the model.")
return MultiModalInputs()
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(model_config.tokenizer,
trust_remote_code=True)
image_pair_tok = tokenizer.encode(IMG_START + IMG_END,
add_special_tokens=False,
return_tensors="pt").squeeze()
image_start_id = image_pair_tok[0]
image_end_id = image_pair_tok[-1]
if (image_start_id + 1) != image_end_id:
raise ValueError(
f"Found image end ID {image_end_id}, but expected {IMG_START} + 1")
if len(image_pair_tok) != (MAX_QWEN_IMG_TOKENS + 2):
raise ValueError(
f"Expected image context length of {MAX_QWEN_IMG_TOKENS}, "
f"but got {image_pair_tok - 2}")
hf_config = ctx.get_hf_config()
image_size = hf_config.visual["image_size"]
img_emb_size = hf_config.visual["output_dim"]
if isinstance(data, torch.Tensor):
# It's expected that our values have already been processed
# by the visual transformer; shape is expected to be:
# (# images, 256, hidden_size)
if len(data.shape) == 2:
# Assume only one image embed was provided; unsqueeze the extra dim
data = data.unsqueeze(0)
if len(data.shape) != 3 or data.shape[
1] != MAX_QWEN_IMG_TOKENS or data.shape[2] != img_emb_size:
raise ValueError(
"Expected image embeds to be a tensor of shape"
f"[# images, {MAX_QWEN_IMG_TOKENS}, {img_emb_size}], but "
f"received shape [{data.shape}]")
pixel_values = data
else:
transform = build_normalization_transform(image_size)
# TODO - handle multiple image inputs once the API is solidified
transformed_images = [transform(data)]
pixel_values = torch.stack(transformed_images, dim=0)
return MultiModalInputs({"pixel_values": pixel_values})
def build_normalization_transform(image_size: int) -> transforms.Compose:
"""Builds a normalization transform which can be applied to one or
more input images from which we want to extract visual features.
Args:
image_size: size of the image to be processed for visual embeddings.
Returns:
Callable transform for normalizing and resizing one RGB image.
"""
return transforms.Compose([
transforms.Resize((image_size, image_size),
interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD),
])
def dummy_data_for_qwen(
ctx: InputContext,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Tuple[SequenceData, Optional[Dict]]:
"""Build dummy data for warming up Qwen models; this will only contain text
matching the defaults for VLLM unless the model has a visual config.
Args:
ctx: Context of the loaded model.
seq_len: Number of tokens in the text sequence.
mm_counts: multimodal data counts.
Returns:
Tuple containing sequential and multimodal data.
"""
hf_config = ctx.get_hf_config()
# The presence of a visual config indicates this is a multimodal model.
# If we don't have it, the model is considered an LLM for warmup purposes.
if not hasattr(hf_config, "visual"):
seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * seq_len))
mm_data = None
return seq_data, mm_data
# We have a visual component - use images to warm up
num_images = mm_counts["image"]
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(model_config.tokenizer,
trust_remote_code=True)
# Build the image prompts with no imgpads; the tokenizer will add img pads
image_prompt = ''.join(
[get_image_text(idx, False) for idx in range(1, num_images + 1)])
toks = tokenizer.encode(image_prompt, add_special_tokens=False)
# Make sure we actually get the fixed context size per tok padding
num_pads = toks.count(tokenizer.encode(IMG_PAD)[0])
if num_pads != (num_images * MAX_QWEN_IMG_TOKENS):
raise ValueError(
f"Tokenized dummy data should encode {MAX_QWEN_IMG_TOKENS} pads"
f" per image, but got {num_pads} pads for {num_images} image(s)"
" in total. Are you using a qwen tokenizer?")
# Ensure the number of tokens is at minimum the sequence length provided
if len(toks) < seq_len:
toks += [0] * (seq_len - len(toks))
# Build the input images; width/height doesn't actually matter here since
# the data will get resized and the # of tokens per image is constant
image = Image.new("RGB", (224, 224), color=0)
mm_data = {"image": image if num_images == 1 else [image] * num_images}
return SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, toks)), mm_data
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_qwen)
@MULTIMODAL_REGISTRY.register_max_image_tokens(MAX_QWEN_IMG_TOKENS)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen)
@INPUT_REGISTRY.register_input_processor(input_processor_for_qwen)
class QWenLMHeadModel(nn.Module, SupportsMultiModal):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.multimodal_config = multimodal_config
self.quant_config = quant_config self.quant_config = quant_config
self.transformer = QWenModel(config, cache_config, quant_config) self.transformer = QWenModel(config, cache_config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, self.lm_head = ParallelLMHead(config.vocab_size,
@ -257,16 +875,47 @@ class QWenLMHeadModel(nn.Module):
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
def forward( def _get_image_input_type(
self, self,
input_ids: torch.Tensor, pixel_values: Optional[torch.Tensor]) -> Optional[QwenImageInputs]:
positions: torch.Tensor, """Determines if the provided pixel_values are normalized pixel values
kv_caches: List[torch.Tensor], or image embeddings.
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, Args:
) -> torch.Tensor: pixel_values: Optional data to processed into visual embeddings.
Returns:
None of the QwenImageInputs type used to determine whether or not
the visual transformer needs to process the pixel_values.
"""
if pixel_values is not None and self.transformer.visual is not None:
pixel_values = flatten_bn(pixel_values)
if len(pixel_values.shape) == 3 and pixel_values.shape[
1] == MAX_QWEN_IMG_TOKENS and pixel_values.shape[
2] == self.config.visual["output_dim"]:
return QwenImageEmbeddingInputs(
type="image_embeds",
data=pixel_values,
)
else:
# If we have the wrong shape, assume we still need to process
return QwenImagePixelInputs(
type="pixel_values",
data=pixel_values,
)
return None
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
pixel_values: Optional[torch.Tensor] = None) -> torch.Tensor:
pixel_values = self._get_image_input_type(pixel_values)
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors) attn_metadata, intermediate_tensors,
pixel_values)
return hidden_states return hidden_states
def make_empty_intermediate_tensors( def make_empty_intermediate_tensors(
@ -328,15 +977,6 @@ class QWenLMHeadModel(nn.Module):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
# Skip loading visual weights to support Qwen-VL models
# in cases with text-only inputs
# TODO: add support for Qwen-VL
if (name not in params_dict
and name.startswith("transformer.visual.")):
print_warning_once(
"Only text inputs are allowed. Images won't be handled "
"until Qwen-VL models are fully supported.")
continue
# Skip layers on other devices. # Skip layers on other devices.
if is_pp_missing_parameter(name, self): if is_pp_missing_parameter(name, self):
continue continue