[MODEL] Qwen Multimodal Support (Qwen-VL / Qwen-VL-Chat) (#8029)
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
8685ba1a1e
commit
9da25a88aa
@ -242,6 +242,11 @@ Multimodal Language Models
|
|||||||
- Image\ :sup:`+`
|
- Image\ :sup:`+`
|
||||||
- :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc.
|
- :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc.
|
||||||
-
|
-
|
||||||
|
* - :code:`QWenLMHeadModel`
|
||||||
|
- Qwen
|
||||||
|
- Image
|
||||||
|
- :code:`Qwen/Qwen-VL`, :code:`Qwen/Qwen-VL-Chat`, etc.
|
||||||
|
-
|
||||||
* - :code:`UltravoxModel`
|
* - :code:`UltravoxModel`
|
||||||
- Ultravox
|
- Ultravox
|
||||||
- Audio\ :sup:`E+`
|
- Audio\ :sup:`E+`
|
||||||
|
@ -159,6 +159,20 @@ def run_blip2(question):
|
|||||||
return llm, prompt, stop_token_ids
|
return llm, prompt, stop_token_ids
|
||||||
|
|
||||||
|
|
||||||
|
# Qwen
|
||||||
|
def run_qwen_vl(question):
|
||||||
|
|
||||||
|
llm = LLM(
|
||||||
|
model="Qwen/Qwen-VL",
|
||||||
|
trust_remote_code=True,
|
||||||
|
max_num_seqs=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = f"{question}Picture 1: <img></img>\n"
|
||||||
|
stop_token_ids = None
|
||||||
|
return llm, prompt, stop_token_ids
|
||||||
|
|
||||||
|
|
||||||
model_example_map = {
|
model_example_map = {
|
||||||
"llava": run_llava,
|
"llava": run_llava,
|
||||||
"llava-next": run_llava_next,
|
"llava-next": run_llava_next,
|
||||||
@ -169,6 +183,7 @@ model_example_map = {
|
|||||||
"minicpmv": run_minicpmv,
|
"minicpmv": run_minicpmv,
|
||||||
"blip-2": run_blip2,
|
"blip-2": run_blip2,
|
||||||
"internvl_chat": run_internvl,
|
"internvl_chat": run_internvl,
|
||||||
|
"qwen_vl": run_qwen_vl,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,19 +1,154 @@
|
|||||||
from typing import Type
|
import pathlib
|
||||||
|
from typing import List, Optional, Type
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from ..conftest import HfRunner, VllmRunner
|
from vllm.multimodal.utils import rescale_image_size
|
||||||
|
|
||||||
|
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
|
||||||
from .utils import check_logprobs_close
|
from .utils import check_logprobs_close
|
||||||
|
|
||||||
models = ["qwen/qwen-vl"]
|
pytestmark = pytest.mark.vlm
|
||||||
|
|
||||||
|
text_only_models = [
|
||||||
|
"Qwen/Qwen-7B-Chat" # Has no visual component
|
||||||
|
]
|
||||||
|
|
||||||
|
multimodal_models = ["Qwen/Qwen-VL"]
|
||||||
|
|
||||||
|
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
|
||||||
|
"stop_sign":
|
||||||
|
"Picture 1: <img></img>\nWhat's the content of the image?: ",
|
||||||
|
"cherry_blossom":
|
||||||
|
"Picture 1: <img></img>\nWhat is the season?: ",
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("dtype", ["half"])
|
### Tests for multimodal Qwen models
|
||||||
|
def run_test(
|
||||||
|
tmp_path: pathlib.PosixPath,
|
||||||
|
hf_runner: Type[HfRunner],
|
||||||
|
vllm_runner: Type[VllmRunner],
|
||||||
|
image_assets: _ImageAssets,
|
||||||
|
model: str,
|
||||||
|
*,
|
||||||
|
size_factors: List[float],
|
||||||
|
dtype: str,
|
||||||
|
max_tokens: int,
|
||||||
|
num_logprobs: int,
|
||||||
|
tensor_parallel_size: int,
|
||||||
|
distributed_executor_backend: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""Inference result should be the same between hf and vllm.
|
||||||
|
|
||||||
|
All the image fixtures for the test is under tests/images.
|
||||||
|
For huggingface runner, we provide the PIL images as input.
|
||||||
|
For vllm runner, we provide MultiModalDataDict objects
|
||||||
|
and corresponding MultiModalConfig as input.
|
||||||
|
Note, the text input is also adjusted to abide by vllm contract.
|
||||||
|
The text output is sanitized to be able to compare with hf.
|
||||||
|
"""
|
||||||
|
images = [asset.pil_image for asset in image_assets]
|
||||||
|
|
||||||
|
# Export the images to a tempdir and substitute it into the hf prompt;
|
||||||
|
# the contents between <img>/</img> will be ignored by VLLM, but the
|
||||||
|
# transformers implementation for the visual transformer parses this to
|
||||||
|
# reload it in the forward call; the contents are treated as a URL or a
|
||||||
|
# local path.
|
||||||
|
for idx, asset in enumerate(image_assets):
|
||||||
|
image_tmp_path = tmp_path / f"{asset.name}.jpg"
|
||||||
|
asset.pil_image.save(image_tmp_path)
|
||||||
|
HF_IMAGE_PROMPTS[idx] = HF_IMAGE_PROMPTS[idx].replace(
|
||||||
|
"<img></img>", f"<img>{image_tmp_path}</img>")
|
||||||
|
|
||||||
|
inputs_per_image = [(
|
||||||
|
[prompt for _ in size_factors],
|
||||||
|
[rescale_image_size(image, factor) for factor in size_factors],
|
||||||
|
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
|
||||||
|
|
||||||
|
# NOTE: take care of the order. run vLLM first, and then run HF.
|
||||||
|
# vLLM needs a fresh new process without cuda initialization.
|
||||||
|
# if we run HF first, the cuda initialization will be done and it
|
||||||
|
# will hurt multiprocessing backend with fork method (the default method).
|
||||||
|
|
||||||
|
# max_model_len should be greater than image_feature_size
|
||||||
|
# Qwen encodes images into a fixed content size of 256
|
||||||
|
with vllm_runner(model,
|
||||||
|
max_model_len=300,
|
||||||
|
max_num_seqs=1,
|
||||||
|
dtype=dtype,
|
||||||
|
tensor_parallel_size=tensor_parallel_size,
|
||||||
|
distributed_executor_backend=distributed_executor_backend,
|
||||||
|
enforce_eager=True) as vllm_model:
|
||||||
|
vllm_outputs_per_image = [
|
||||||
|
vllm_model.generate_greedy_logprobs(prompts,
|
||||||
|
max_tokens,
|
||||||
|
num_logprobs=num_logprobs,
|
||||||
|
images=images)
|
||||||
|
for prompts, images in inputs_per_image
|
||||||
|
]
|
||||||
|
|
||||||
|
with hf_runner(model, dtype=dtype) as hf_model:
|
||||||
|
hf_outputs_per_image = [
|
||||||
|
hf_model.generate_greedy_logprobs_limit(prompts,
|
||||||
|
max_tokens,
|
||||||
|
num_logprobs=num_logprobs,
|
||||||
|
images=images)
|
||||||
|
for prompts, images in inputs_per_image
|
||||||
|
]
|
||||||
|
|
||||||
|
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
|
||||||
|
vllm_outputs_per_image):
|
||||||
|
|
||||||
|
check_logprobs_close(
|
||||||
|
outputs_0_lst=hf_outputs,
|
||||||
|
outputs_1_lst=vllm_outputs,
|
||||||
|
name_0="hf",
|
||||||
|
name_1="vllm",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", multimodal_models)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"size_factors",
|
||||||
|
[
|
||||||
|
# No image
|
||||||
|
[],
|
||||||
|
# Single-scale
|
||||||
|
[1.0],
|
||||||
|
# Single-scale, batched
|
||||||
|
[1.0, 1.0, 1.0],
|
||||||
|
# Multi-scale
|
||||||
|
[0.25, 0.5, 1.0],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||||
|
@pytest.mark.parametrize("max_tokens", [8])
|
||||||
|
@pytest.mark.parametrize("num_logprobs", [5])
|
||||||
|
def test_multimodal_models(tmp_path, hf_runner, vllm_runner, image_assets,
|
||||||
|
model, size_factors, dtype, max_tokens,
|
||||||
|
num_logprobs) -> None:
|
||||||
|
run_test(
|
||||||
|
tmp_path,
|
||||||
|
hf_runner,
|
||||||
|
vllm_runner,
|
||||||
|
image_assets,
|
||||||
|
model,
|
||||||
|
size_factors=size_factors,
|
||||||
|
dtype=dtype,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
num_logprobs=num_logprobs,
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Ensure that a text-only Qwen model can still be loaded and
|
||||||
|
# used for inference in VLLM without throwing.
|
||||||
|
@pytest.mark.parametrize("model", text_only_models)
|
||||||
|
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||||
@pytest.mark.parametrize("max_tokens", [32])
|
@pytest.mark.parametrize("max_tokens", [32])
|
||||||
@pytest.mark.parametrize("num_logprobs", [5])
|
@pytest.mark.parametrize("num_logprobs", [5])
|
||||||
@pytest.mark.parametrize("model", models)
|
def test_text_only_qwen_model_can_be_loaded_and_run(
|
||||||
def test_text_only_qwen_model(
|
|
||||||
hf_runner: Type[HfRunner],
|
|
||||||
vllm_runner: Type[VllmRunner],
|
vllm_runner: Type[VllmRunner],
|
||||||
example_prompts,
|
example_prompts,
|
||||||
model: str,
|
model: str,
|
||||||
@ -22,27 +157,9 @@ def test_text_only_qwen_model(
|
|||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
num_logprobs: int,
|
num_logprobs: int,
|
||||||
):
|
):
|
||||||
# This test checks language inputs only, since the visual component
|
|
||||||
# for qwen-vl is still unsupported in VLLM. In the near-future, the
|
|
||||||
# implementation and this test will be extended to consider
|
|
||||||
# visual inputs as well.
|
|
||||||
with hf_runner(model, dtype=dtype) as hf_model:
|
|
||||||
hf_outputs = hf_model.generate_greedy_logprobs_limit(
|
|
||||||
example_prompts,
|
|
||||||
max_tokens,
|
|
||||||
num_logprobs=num_logprobs,
|
|
||||||
)
|
|
||||||
|
|
||||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||||
vllm_outputs = vllm_model.generate_greedy_logprobs(
|
vllm_model.generate_greedy_logprobs(
|
||||||
example_prompts,
|
example_prompts,
|
||||||
max_tokens,
|
max_tokens,
|
||||||
num_logprobs=num_logprobs,
|
num_logprobs=num_logprobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
check_logprobs_close(
|
|
||||||
outputs_0_lst=hf_outputs,
|
|
||||||
outputs_1_lst=vllm_outputs,
|
|
||||||
name_0="hf",
|
|
||||||
name_1="vllm",
|
|
||||||
)
|
|
||||||
|
@ -150,6 +150,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
|||||||
if model_type in ("blip-2", "chatglm", "fuyu", "paligemma"):
|
if model_type in ("blip-2", "chatglm", "fuyu", "paligemma"):
|
||||||
# These models do not use image tokens in the prompt
|
# These models do not use image tokens in the prompt
|
||||||
return None
|
return None
|
||||||
|
if model_type == "qwen":
|
||||||
|
return f"Picture {current_count}: <img></img>"
|
||||||
if model_type.startswith("llava"):
|
if model_type.startswith("llava"):
|
||||||
return self._cached_token_str(self._tokenizer,
|
return self._cached_token_str(self._tokenizer,
|
||||||
hf_config.image_token_index)
|
hf_config.image_token_index)
|
||||||
|
273
vllm/model_executor/layers/resampler.py
Normal file
273
vllm/model_executor/layers/resampler.py
Normal file
@ -0,0 +1,273 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Adapted from
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
|
||||||
|
# https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py
|
||||||
|
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
|
||||||
|
#
|
||||||
|
# Copyright 2023 The Qwen team.
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||||
|
# and OPT implementations in this library. It has been modified from its
|
||||||
|
# original forms to accommodate minor architectural differences compared
|
||||||
|
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Shared resampler perceiver network used in multimodal models and
|
||||||
|
related helpers for sincos positional embeddings.
|
||||||
|
|
||||||
|
Example models: Qwen (Qwen-VL), Minicpmv2.0
|
||||||
|
"""
|
||||||
|
import math
|
||||||
|
from functools import partial
|
||||||
|
from typing import Callable, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn.init import trunc_normal_
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||||
|
|
||||||
|
DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)
|
||||||
|
|
||||||
|
|
||||||
|
def get_abs_pos(abs_pos: torch.Tensor, tgt_size: Union[torch.Tensor,
|
||||||
|
int]) -> torch.Tensor:
|
||||||
|
# abs_pos: L, C
|
||||||
|
# tgt_size: (H, W)
|
||||||
|
# return: M, C
|
||||||
|
src_size = int(math.sqrt(abs_pos.size(0)))
|
||||||
|
dtype = abs_pos.dtype
|
||||||
|
if isinstance(tgt_size, int):
|
||||||
|
tgt_size = (tgt_size, tgt_size)
|
||||||
|
if (src_size == tgt_size[0] and src_size == tgt_size[1]):
|
||||||
|
return abs_pos
|
||||||
|
return (F.interpolate(
|
||||||
|
abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
|
||||||
|
size=(tgt_size[0], tgt_size[1]),
|
||||||
|
mode="bicubic",
|
||||||
|
align_corners=False,
|
||||||
|
).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype))
|
||||||
|
|
||||||
|
|
||||||
|
# sin/cos positional embedding helpers are adapted from:
|
||||||
|
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
|
||||||
|
def get_1d_sincos_pos_embed_from_grid(
|
||||||
|
embed_dim: int, pos: np.ndarray,
|
||||||
|
version: Tuple[int, int] = (2, 0)) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
embed_dim: output dimension for each position
|
||||||
|
pos: a list of positions to be encoded: size (M,) / (H, W)
|
||||||
|
out: (M, D) / (H, W, D)
|
||||||
|
"""
|
||||||
|
assert embed_dim % 2 == 0
|
||||||
|
omega = np.arange(embed_dim // 2, dtype=np.float32)
|
||||||
|
omega /= embed_dim / 2.0
|
||||||
|
omega = 1.0 / 10000**omega # (D/2,)
|
||||||
|
|
||||||
|
if version == (2, 0):
|
||||||
|
pos = pos.reshape(-1) # (M,)
|
||||||
|
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
||||||
|
emb_sin = np.sin(out) # (M, D/2)
|
||||||
|
emb_cos = np.cos(out) # (M, D/2)
|
||||||
|
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
||||||
|
else:
|
||||||
|
out = np.einsum("hw,d->hwd", pos, omega) # (H, W, D/2), outer product
|
||||||
|
emb_sin = np.sin(out) # (H, W, D/2)
|
||||||
|
emb_cos = np.cos(out) # (H, W, D/2)
|
||||||
|
emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (H, W, D)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
def get_2d_sincos_pos_embed_from_grid(
|
||||||
|
embed_dim: int, grid: np.ndarray,
|
||||||
|
version: Tuple[int, int] = (2, 0)) -> torch.Tensor:
|
||||||
|
assert embed_dim % 2 == 0
|
||||||
|
|
||||||
|
# use half of dimensions to encode grid_h
|
||||||
|
emb_h = get_1d_sincos_pos_embed_from_grid(
|
||||||
|
embed_dim // 2, grid[0], version) # (H*W, D/2) or (H, W, D/2)
|
||||||
|
emb_w = get_1d_sincos_pos_embed_from_grid(
|
||||||
|
embed_dim // 2, grid[1], version) # (H*W, D/2) or (H, W, D/2)
|
||||||
|
|
||||||
|
if version == (2, 0):
|
||||||
|
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
||||||
|
else:
|
||||||
|
emb = np.concatenate([emb_h, emb_w], axis=-1) # (H, W, D)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
def get_2d_sincos_pos_embed(
|
||||||
|
embed_dim: int,
|
||||||
|
grid_size: Union[int, Tuple[int, int]],
|
||||||
|
cls_token: bool = False,
|
||||||
|
version: Tuple[int, int] = (2, 0),
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
grid_size: int of the grid height and width
|
||||||
|
return:
|
||||||
|
pos_embed: [grid_size*grid_size, embed_dim] or
|
||||||
|
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
||||||
|
"""
|
||||||
|
if isinstance(grid_size, int):
|
||||||
|
grid_h_size, grid_w_size = grid_size, grid_size
|
||||||
|
else:
|
||||||
|
grid_h_size, grid_w_size = grid_size[0], grid_size[1]
|
||||||
|
|
||||||
|
grid_h = np.arange(grid_h_size, dtype=np.float32)
|
||||||
|
grid_w = np.arange(grid_w_size, dtype=np.float32)
|
||||||
|
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
||||||
|
grid = np.stack(grid, axis=0)
|
||||||
|
assert isinstance(grid, np.ndarray) and \
|
||||||
|
grid.shape == (2, grid_h_size, grid_w_size)
|
||||||
|
|
||||||
|
if version == (2, 0):
|
||||||
|
grid = grid.reshape([2, 1, grid_h_size, grid_w_size])
|
||||||
|
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version)
|
||||||
|
if cls_token:
|
||||||
|
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed],
|
||||||
|
axis=0)
|
||||||
|
else:
|
||||||
|
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version)
|
||||||
|
return pos_embed
|
||||||
|
|
||||||
|
|
||||||
|
class BaseResampler(nn.Module):
|
||||||
|
"""
|
||||||
|
A 2D perceiver-resampler network with one cross attention layers by
|
||||||
|
(grid_size**2) learnable queries and 2d sincos pos_emb.
|
||||||
|
Outputs:
|
||||||
|
A tensor with the shape of (grid_size**2, embed_dim)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_queries: int,
|
||||||
|
embed_dim: int,
|
||||||
|
num_heads: int,
|
||||||
|
kv_dim: Optional[int] = None,
|
||||||
|
norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
|
||||||
|
do_post_projection: bool = True,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.num_queries = num_queries
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
|
||||||
|
self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
|
||||||
|
trunc_normal_(self.query, std=0.02)
|
||||||
|
if kv_dim is not None and kv_dim != embed_dim:
|
||||||
|
self.kv_proj = ReplicatedLinear(kv_dim, embed_dim, bias=False)
|
||||||
|
else:
|
||||||
|
# Maintain the same return value with ReplicatedLinear.forward
|
||||||
|
self.kv_proj = lambda *args, **kwargs: ( # type: ignore # noqa
|
||||||
|
nn.Identity()(*args, **kwargs),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
|
||||||
|
self.ln_q = norm_layer(embed_dim)
|
||||||
|
self.ln_kv = norm_layer(embed_dim)
|
||||||
|
self.do_post_projection = do_post_projection
|
||||||
|
self.ln_post = norm_layer(embed_dim) if do_post_projection else None
|
||||||
|
self.proj = nn.Parameter(
|
||||||
|
(embed_dim**-0.5) *
|
||||||
|
torch.randn(embed_dim, embed_dim)) if do_post_projection else None
|
||||||
|
|
||||||
|
def _init_weights(self, m: nn.Module) -> None:
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
trunc_normal_(m.weight, std=0.02)
|
||||||
|
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
elif isinstance(m, nn.LayerNorm):
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
nn.init.constant_(m.weight, 1.0)
|
||||||
|
|
||||||
|
def _repeat(self, query, N: int):
|
||||||
|
return query.unsqueeze(1).repeat(1, N, 1)
|
||||||
|
|
||||||
|
|
||||||
|
class Resampler2(BaseResampler):
|
||||||
|
"""Resampler-perceiver network to be used for a variety of model types,
|
||||||
|
e.g., Qwen-vl / Minicpmv 2.0. The main difference is the addition of the
|
||||||
|
do_post_projection arg, which indicates whether or not there should be
|
||||||
|
a post layer normalization and projector after the attention. This is
|
||||||
|
present in minicpmv2.0, but not qwen-vl.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
grid_size: int,
|
||||||
|
embed_dim: int,
|
||||||
|
num_heads: int,
|
||||||
|
kv_dim: Optional[int] = None,
|
||||||
|
norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
|
||||||
|
adaptive: bool = False,
|
||||||
|
do_post_projection: bool = True,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(grid_size**2,
|
||||||
|
embed_dim,
|
||||||
|
num_heads,
|
||||||
|
kv_dim,
|
||||||
|
norm_layer,
|
||||||
|
do_post_projection=do_post_projection)
|
||||||
|
|
||||||
|
self.adaptive = adaptive
|
||||||
|
pos_embed_arr = get_2d_sincos_pos_embed(embed_dim,
|
||||||
|
grid_size,
|
||||||
|
version=(2, 0))
|
||||||
|
|
||||||
|
self.pos_embed = nn.Parameter(
|
||||||
|
torch.from_numpy(pos_embed_arr).requires_grad_(False))
|
||||||
|
|
||||||
|
self.apply(self._init_weights)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
tgt_sizes: Optional[torch.Tensor] = None,
|
||||||
|
attn_mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if tgt_sizes is None:
|
||||||
|
tgt_sizes = int(math.sqrt(x.size(1)))
|
||||||
|
if self.adaptive:
|
||||||
|
pos_embed_arr = get_2d_sincos_pos_embed(self.embed_dim,
|
||||||
|
tgt_sizes,
|
||||||
|
version=(2, 0))
|
||||||
|
pos_embed = torch.from_numpy(pos_embed_arr).to(device=x.device,
|
||||||
|
dtype=x.dtype)
|
||||||
|
else:
|
||||||
|
pos_embed = get_abs_pos(self.pos_embed,
|
||||||
|
tgt_sizes).to(device=x.device,
|
||||||
|
dtype=x.dtype)
|
||||||
|
|
||||||
|
x, _ = self.kv_proj(x)
|
||||||
|
x = self.ln_kv(x).permute(1, 0, 2)
|
||||||
|
|
||||||
|
N = x.shape[1]
|
||||||
|
q = self.ln_q(self.query)
|
||||||
|
out = self.attn(
|
||||||
|
self._repeat(q, N) + self.pos_embed.unsqueeze(1),
|
||||||
|
x + pos_embed.unsqueeze(1),
|
||||||
|
x,
|
||||||
|
attn_mask=attn_mask,
|
||||||
|
)[0]
|
||||||
|
x = out.permute(1, 0, 2)
|
||||||
|
if self.do_post_projection:
|
||||||
|
x = self.ln_post(x)
|
||||||
|
x = x @ self.proj
|
||||||
|
return x
|
@ -51,7 +51,6 @@ _GENERATION_MODELS = {
|
|||||||
"PhiForCausalLM": ("phi", "PhiForCausalLM"),
|
"PhiForCausalLM": ("phi", "PhiForCausalLM"),
|
||||||
"Phi3ForCausalLM": ("llama", "LlamaForCausalLM"),
|
"Phi3ForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||||
"PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
|
"PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
|
||||||
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
|
|
||||||
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
|
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
|
||||||
"Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
|
"Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
|
||||||
"RWForCausalLM": ("falcon", "FalconForCausalLM"),
|
"RWForCausalLM": ("falcon", "FalconForCausalLM"),
|
||||||
@ -88,6 +87,7 @@ _MULTIMODAL_MODELS = {
|
|||||||
"PaliGemmaForConditionalGeneration"),
|
"PaliGemmaForConditionalGeneration"),
|
||||||
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
|
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
|
||||||
"UltravoxModel": ("ultravox", "UltravoxModel"),
|
"UltravoxModel": ("ultravox", "UltravoxModel"),
|
||||||
|
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
|
||||||
}
|
}
|
||||||
_CONDITIONAL_GENERATION_MODELS = {
|
_CONDITIONAL_GENERATION_MODELS = {
|
||||||
"BartModel": ("bart", "BartForConditionalGeneration"),
|
"BartModel": ("bart", "BartForConditionalGeneration"),
|
||||||
|
@ -26,11 +26,9 @@ import re
|
|||||||
from array import array
|
from array import array
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import (Any, Callable, Iterable, List, Mapping, Optional, Tuple,
|
from typing import (Any, Callable, Iterable, List, Mapping, Optional, Tuple,
|
||||||
TypedDict, Union)
|
TypedDict)
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
|
||||||
import torch.types
|
import torch.types
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -44,6 +42,8 @@ from vllm.logger import init_logger
|
|||||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
|
from vllm.model_executor.layers.resampler import (Resampler2,
|
||||||
|
get_2d_sincos_pos_embed)
|
||||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||||
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
||||||
@ -98,101 +98,6 @@ MiniCPMVImageInputs = MiniCPMVImagePixelInputs
|
|||||||
DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)
|
DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)
|
||||||
|
|
||||||
|
|
||||||
def get_abs_pos(abs_pos: torch.Tensor, tgt_size: torch.Tensor):
|
|
||||||
# abs_pos: L, C
|
|
||||||
# tgt_size: (H, W)
|
|
||||||
# return: M, C
|
|
||||||
src_size = int(math.sqrt(abs_pos.size(0)))
|
|
||||||
# tgt_size = int(math.sqrt(tgt_size))
|
|
||||||
dtype = abs_pos.dtype
|
|
||||||
|
|
||||||
return (F.interpolate(
|
|
||||||
abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
|
|
||||||
size=(tgt_size[0], tgt_size[1]),
|
|
||||||
mode="bicubic",
|
|
||||||
align_corners=False,
|
|
||||||
).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype))
|
|
||||||
|
|
||||||
|
|
||||||
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
|
|
||||||
def get_2d_sincos_pos_embed(
|
|
||||||
embed_dim: int,
|
|
||||||
grid_size: Union[int, Tuple[int, int]],
|
|
||||||
cls_token: bool = False,
|
|
||||||
version: Tuple[int, int] = (2, 0),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
grid_size: int of the grid height and width
|
|
||||||
return:
|
|
||||||
pos_embed: [grid_size*grid_size, embed_dim] or
|
|
||||||
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
|
||||||
"""
|
|
||||||
if isinstance(grid_size, int):
|
|
||||||
grid_h_size, grid_w_size = grid_size, grid_size
|
|
||||||
else:
|
|
||||||
grid_h_size, grid_w_size = grid_size[0], grid_size[1]
|
|
||||||
|
|
||||||
grid_h = np.arange(grid_h_size, dtype=np.float32)
|
|
||||||
grid_w = np.arange(grid_w_size, dtype=np.float32)
|
|
||||||
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
|
||||||
grid = np.stack(grid, axis=0)
|
|
||||||
|
|
||||||
if version == (2, 0):
|
|
||||||
grid = grid.reshape([2, 1, grid_h_size, grid_w_size])
|
|
||||||
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version)
|
|
||||||
if cls_token:
|
|
||||||
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed],
|
|
||||||
axis=0)
|
|
||||||
else:
|
|
||||||
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version)
|
|
||||||
return pos_embed
|
|
||||||
|
|
||||||
|
|
||||||
def get_2d_sincos_pos_embed_from_grid(embed_dim: int,
|
|
||||||
grid: np.ndarray,
|
|
||||||
version: Tuple[int, int] = (2, 0)):
|
|
||||||
assert embed_dim % 2 == 0
|
|
||||||
|
|
||||||
# use half of dimensions to encode grid_h
|
|
||||||
emb_h = get_1d_sincos_pos_embed_from_grid(
|
|
||||||
embed_dim // 2, grid[0], version) # (H*W, D/2) or (H, W, D/2)
|
|
||||||
emb_w = get_1d_sincos_pos_embed_from_grid(
|
|
||||||
embed_dim // 2, grid[1], version) # (H*W, D/2) or (H, W, D/2)
|
|
||||||
|
|
||||||
if version == (2, 0):
|
|
||||||
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
|
||||||
else:
|
|
||||||
emb = np.concatenate([emb_h, emb_w], axis=-1) # (H, W, D)
|
|
||||||
return emb
|
|
||||||
|
|
||||||
|
|
||||||
def get_1d_sincos_pos_embed_from_grid(embed_dim: int,
|
|
||||||
pos: np.ndarray,
|
|
||||||
version: Tuple[int, int] = (2, 0)):
|
|
||||||
"""
|
|
||||||
embed_dim: output dimension for each position
|
|
||||||
pos: a list of positions to be encoded: size (M,) / (H, W)
|
|
||||||
out: (M, D) / (H, W, D)
|
|
||||||
"""
|
|
||||||
assert embed_dim % 2 == 0
|
|
||||||
omega = np.arange(embed_dim // 2, dtype=np.float32)
|
|
||||||
omega /= embed_dim / 2.0
|
|
||||||
omega = 1.0 / 10000**omega # (D/2,)
|
|
||||||
|
|
||||||
if version == (2, 0):
|
|
||||||
pos = pos.reshape(-1) # (M,)
|
|
||||||
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
|
||||||
emb_sin = np.sin(out) # (M, D/2)
|
|
||||||
emb_cos = np.cos(out) # (M, D/2)
|
|
||||||
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
|
||||||
else:
|
|
||||||
out = np.einsum("hw,d->hwd", pos, omega) # (H, W, D/2), outer product
|
|
||||||
emb_sin = np.sin(out) # (H, W, D/2)
|
|
||||||
emb_cos = np.cos(out) # (H, W, D/2)
|
|
||||||
emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (H, W, D)
|
|
||||||
return emb
|
|
||||||
|
|
||||||
|
|
||||||
class BaseResampler(nn.Module):
|
class BaseResampler(nn.Module):
|
||||||
"""
|
"""
|
||||||
A 2D perceiver-resampler network with one cross attention layers by
|
A 2D perceiver-resampler network with one cross attention layers by
|
||||||
@ -245,62 +150,6 @@ class BaseResampler(nn.Module):
|
|||||||
return query.unsqueeze(1).repeat(1, N, 1)
|
return query.unsqueeze(1).repeat(1, N, 1)
|
||||||
|
|
||||||
|
|
||||||
class Resampler2(BaseResampler):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
grid_size: int,
|
|
||||||
embed_dim: int,
|
|
||||||
num_heads: int,
|
|
||||||
kv_dim: Optional[int] = None,
|
|
||||||
norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
|
|
||||||
adaptive: bool = False,
|
|
||||||
) -> None:
|
|
||||||
super().__init__(grid_size**2, embed_dim, num_heads, kv_dim,
|
|
||||||
norm_layer)
|
|
||||||
|
|
||||||
self.adaptive = adaptive
|
|
||||||
pos_embed_arr = get_2d_sincos_pos_embed(embed_dim,
|
|
||||||
grid_size,
|
|
||||||
version=(2, 0))
|
|
||||||
self.pos_embed = nn.Parameter(
|
|
||||||
torch.from_numpy(pos_embed_arr).float()).requires_grad_(False)
|
|
||||||
|
|
||||||
self.apply(self._init_weights)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
x: torch.Tensor,
|
|
||||||
tgt_sizes: torch.Tensor,
|
|
||||||
attn_mask: Optional[torch.Tensor] = None,
|
|
||||||
):
|
|
||||||
if self.adaptive:
|
|
||||||
pos_embed_arr = get_2d_sincos_pos_embed(self.embed_dim,
|
|
||||||
tgt_sizes,
|
|
||||||
version=(2, 0))
|
|
||||||
pos_embed = torch.from_numpy(pos_embed_arr).to(device=x.device,
|
|
||||||
dtype=x.dtype)
|
|
||||||
else:
|
|
||||||
pos_embed = get_abs_pos(self.pos_embed, tgt_sizes)
|
|
||||||
|
|
||||||
x, _ = self.kv_proj(x)
|
|
||||||
x = self.ln_kv(x).permute(1, 0, 2)
|
|
||||||
|
|
||||||
N = x.shape[1]
|
|
||||||
q = self.ln_q(self.query)
|
|
||||||
out = self.attn(
|
|
||||||
self._repeat(q, N) + self.pos_embed.unsqueeze(1),
|
|
||||||
x + pos_embed.unsqueeze(1),
|
|
||||||
x,
|
|
||||||
attn_mask=attn_mask,
|
|
||||||
)[0]
|
|
||||||
x = out.permute(1, 0, 2)
|
|
||||||
|
|
||||||
x = self.ln_post(x)
|
|
||||||
x = x @ self.proj
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class Resampler2_5(BaseResampler):
|
class Resampler2_5(BaseResampler):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -782,7 +631,8 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
|
|||||||
num_heads=embed_dim // 128,
|
num_heads=embed_dim // 128,
|
||||||
grid_size=int(math.sqrt(self.config.query_num)),
|
grid_size=int(math.sqrt(self.config.query_num)),
|
||||||
kv_dim=vision_dim,
|
kv_dim=vision_dim,
|
||||||
adaptive=True,
|
adaptive=False,
|
||||||
|
do_post_projection=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
return resampler
|
return resampler
|
||||||
|
@ -4,36 +4,402 @@
|
|||||||
# Copyright (c) Alibaba Cloud.
|
# Copyright (c) Alibaba Cloud.
|
||||||
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
|
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
|
||||||
"""Inference-only QWen model compatible with HuggingFace weights."""
|
"""Inference-only QWen model compatible with HuggingFace weights."""
|
||||||
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
|
||||||
|
|
||||||
|
import math
|
||||||
|
import re
|
||||||
|
from array import array
|
||||||
|
from functools import partial
|
||||||
|
from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping,
|
||||||
|
Optional, Tuple, TypedDict, Union)
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from PIL import Image
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from torchvision import transforms
|
||||||
|
from torchvision.transforms import InterpolationMode
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
from vllm.config import CacheConfig
|
from vllm.config import CacheConfig, MultiModalConfig
|
||||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
|
MergedColumnParallelLinear,
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
|
from vllm.model_executor.layers.resampler import Resampler2, get_abs_pos
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
ParallelLMHead, VocabParallelEmbedding)
|
ParallelLMHead, VocabParallelEmbedding)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
from vllm.model_executor.models.interfaces import SupportsMultiModal
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.utils import print_warning_once
|
from vllm.multimodal.base import MultiModalInputs
|
||||||
|
from vllm.multimodal.utils import cached_get_tokenizer
|
||||||
|
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
|
||||||
|
SequenceData)
|
||||||
|
|
||||||
from .utils import is_pp_missing_parameter, make_layers
|
from .utils import flatten_bn, is_pp_missing_parameter, make_layers
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
# NOTE: Qwen models have a few other special tags, e.g., ref, bbox, quad;
|
||||||
|
# for the time being, these tags are not considered as special at encoding
|
||||||
|
# time. This may change as VLLMs multimodal API changes in the future.
|
||||||
|
IMG_START = "<img>"
|
||||||
|
IMG_END = "</img>"
|
||||||
|
IMG_PAD = "<imgpad>"
|
||||||
|
# Image context is fixed at 256 for all images
|
||||||
|
MAX_QWEN_IMG_TOKENS = 256
|
||||||
|
# Image normalization params
|
||||||
|
CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073)
|
||||||
|
CLIP_STD = (0.26862954, 0.26130258, 0.27577711)
|
||||||
|
|
||||||
|
|
||||||
|
class QwenImagePixelInputs(TypedDict):
|
||||||
|
type: Literal["pixel_values"]
|
||||||
|
data: torch.Tensor
|
||||||
|
"""
|
||||||
|
Shape: `(batch_size * num_images, 3, image_size, image_size)`
|
||||||
|
|
||||||
|
Note that image_size is the value in the vision config to which we resize
|
||||||
|
the image to in the normalization transform. Currently multi-image support
|
||||||
|
can only be leveraged by passing image embeddings directly.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class QwenImageEmbeddingInputs(TypedDict):
|
||||||
|
type: Literal["image_embeds"]
|
||||||
|
data: torch.Tensor
|
||||||
|
"""Shape: `(batch_size * num_images, 256, hidden_size)`
|
||||||
|
|
||||||
|
`hidden_size` must match the hidden size of the language model backbone
|
||||||
|
and is stored in the visual config of the model if we have one.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
QwenImageInputs = Union[QwenImagePixelInputs, QwenImageEmbeddingInputs]
|
||||||
|
|
||||||
|
|
||||||
|
class VisualAttention(nn.Module):
|
||||||
|
"""self-attention layer class.
|
||||||
|
Self-attention layer takes input with size [s, b, h]
|
||||||
|
and returns output of the same size.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embed_dim: int,
|
||||||
|
num_heads: int,
|
||||||
|
bias: bool = True,
|
||||||
|
kdim: Optional[int] = None,
|
||||||
|
vdim: Optional[int] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.kdim = kdim if kdim is not None else embed_dim
|
||||||
|
self.vdim = vdim if vdim is not None else embed_dim
|
||||||
|
self._qkv_same_embed_dim = self.kdim == embed_dim \
|
||||||
|
and self.vdim == embed_dim
|
||||||
|
|
||||||
|
self.num_heads = num_heads
|
||||||
|
|
||||||
|
# Per attention head and per partition values.
|
||||||
|
assert embed_dim % num_heads == 0
|
||||||
|
self.hidden_size_per_attention_head = embed_dim // num_heads
|
||||||
|
self.num_attention_heads_per_partition = num_heads
|
||||||
|
self.hidden_size_per_partition = embed_dim
|
||||||
|
|
||||||
|
# Strided linear layer.
|
||||||
|
assert self._qkv_same_embed_dim, \
|
||||||
|
'Visual Attention implementation only supports self-attention'
|
||||||
|
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim)
|
||||||
|
self.out_proj = nn.Linear(embed_dim, embed_dim)
|
||||||
|
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
attn_mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# query/key/value: [sq, b, h]
|
||||||
|
sq, b, _ = x.size()
|
||||||
|
mixed_x_layer = self.in_proj(x)
|
||||||
|
|
||||||
|
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
|
||||||
|
new_tensor_shape = mixed_x_layer.size()[:-1] + \
|
||||||
|
(self.num_attention_heads_per_partition,
|
||||||
|
3 * self.hidden_size_per_attention_head)
|
||||||
|
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
|
||||||
|
|
||||||
|
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
|
||||||
|
query_layer, key_layer, value_layer = mixed_x_layer.split(
|
||||||
|
self.hidden_size_per_attention_head, dim=-1)
|
||||||
|
|
||||||
|
# [sq, b, np, hn] -> [sq, b * np, hn]
|
||||||
|
query_layer = query_layer.view(
|
||||||
|
sq, b * self.num_attention_heads_per_partition,
|
||||||
|
self.hidden_size_per_attention_head).transpose(0, 1)
|
||||||
|
# [sk, b, np, hn] -> [sk, b * np, hn]
|
||||||
|
key_layer = key_layer.view(
|
||||||
|
sq, b * self.num_attention_heads_per_partition,
|
||||||
|
self.hidden_size_per_attention_head).transpose(0, 1)
|
||||||
|
|
||||||
|
q_scaled = query_layer / self.norm_factor
|
||||||
|
if attn_mask is not None:
|
||||||
|
attention_probs = torch.baddbmm(attn_mask, q_scaled,
|
||||||
|
key_layer.transpose(-2, -1))
|
||||||
|
else:
|
||||||
|
attention_probs = torch.bmm(q_scaled, key_layer.transpose(-2, -1))
|
||||||
|
attention_probs = attention_probs.softmax(dim=-1)
|
||||||
|
|
||||||
|
value_layer = value_layer.view(
|
||||||
|
sq, b * self.num_attention_heads_per_partition,
|
||||||
|
self.hidden_size_per_attention_head).transpose(0, 1)
|
||||||
|
|
||||||
|
# matmul: [b * np, sq, hn]
|
||||||
|
context_layer = torch.bmm(attention_probs, value_layer)
|
||||||
|
|
||||||
|
# change view [b, np, sq, hn]
|
||||||
|
context_layer = context_layer.view(
|
||||||
|
b, self.num_attention_heads_per_partition, sq,
|
||||||
|
self.hidden_size_per_attention_head)
|
||||||
|
|
||||||
|
# [b, np, sq, hn] --> [sq, b, np, hn]
|
||||||
|
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
|
||||||
|
|
||||||
|
# [sq, b, np, hn] --> [sq, b, hp]
|
||||||
|
new_context_layer_shape = context_layer.size()[:-2] + \
|
||||||
|
(self.hidden_size_per_partition,)
|
||||||
|
context_layer = context_layer.view(*new_context_layer_shape)
|
||||||
|
|
||||||
|
output = self.out_proj(context_layer)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class QwenVMLP(nn.Module):
|
||||||
|
"""MLP for the visual component of the Qwen model."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.c_fc = ColumnParallelLinear(hidden_size,
|
||||||
|
intermediate_size,
|
||||||
|
bias=True,
|
||||||
|
quant_config=quant_config)
|
||||||
|
self.act_fn = get_act_fn("gelu", quant_config, intermediate_size)
|
||||||
|
self.c_proj = RowParallelLinear(
|
||||||
|
intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
bias=True,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x, _ = self.c_fc(x)
|
||||||
|
x = self.act_fn(x)
|
||||||
|
x, _ = self.c_proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class VisualAttentionBlock(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
d_model: int,
|
||||||
|
n_head: int,
|
||||||
|
mlp_ratio: float = 4.0,
|
||||||
|
norm_layer: Callable = nn.LayerNorm,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.ln_1 = norm_layer(d_model)
|
||||||
|
self.ln_2 = norm_layer(d_model)
|
||||||
|
mlp_width = int(d_model * mlp_ratio)
|
||||||
|
self.attn = VisualAttention(d_model, n_head)
|
||||||
|
self.mlp = QwenVMLP(
|
||||||
|
hidden_size=d_model,
|
||||||
|
intermediate_size=mlp_width,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
def attention(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
attn_mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
attn_mask = attn_mask.to(x.dtype) if attn_mask is not None else None
|
||||||
|
return self.attn(x, attn_mask=attn_mask)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
attn_mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
x = x + self.attention(self.ln_1(x), attn_mask=attn_mask)
|
||||||
|
x = x + self.mlp(self.ln_2(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerBlock(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
width: int,
|
||||||
|
layers: int,
|
||||||
|
heads: int,
|
||||||
|
mlp_ratio: float = 4.0,
|
||||||
|
norm_layer: Callable = nn.LayerNorm,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.width = width
|
||||||
|
self.layers = layers
|
||||||
|
|
||||||
|
self.resblocks = nn.ModuleList([
|
||||||
|
VisualAttentionBlock(width,
|
||||||
|
heads,
|
||||||
|
mlp_ratio,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
quant_config=quant_config)
|
||||||
|
for _ in range(layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
def get_cast_dtype(self) -> torch.dtype:
|
||||||
|
return self.resblocks[0].mlp.c_fc.weight.dtype
|
||||||
|
|
||||||
|
def get_cast_device(self) -> torch.device:
|
||||||
|
return self.resblocks[0].mlp.c_fc.weight.device
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
for r in self.resblocks:
|
||||||
|
x = r(x, attn_mask=attn_mask)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class VisionTransformer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
image_size: int,
|
||||||
|
patch_size: int,
|
||||||
|
width: int,
|
||||||
|
layers: int,
|
||||||
|
heads: int,
|
||||||
|
mlp_ratio: float,
|
||||||
|
n_queries: int = 256,
|
||||||
|
output_dim: int = 512,
|
||||||
|
image_start_id: int = 151857,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
**kwargs):
|
||||||
|
super().__init__()
|
||||||
|
image_height, image_width = self.image_size = (image_size, image_size)
|
||||||
|
patch_height, patch_width = self.patch_size = (patch_size, patch_size)
|
||||||
|
self.grid_size = (image_height // patch_height,
|
||||||
|
image_width // patch_width)
|
||||||
|
self.output_dim = output_dim
|
||||||
|
self.conv1 = nn.Conv2d(in_channels=3,
|
||||||
|
out_channels=width,
|
||||||
|
kernel_size=patch_size,
|
||||||
|
stride=patch_size,
|
||||||
|
bias=False)
|
||||||
|
|
||||||
|
# class embeddings and positional embeddings
|
||||||
|
scale = width**-0.5
|
||||||
|
self.positional_embedding = nn.Parameter(scale *
|
||||||
|
torch.randn(256, width))
|
||||||
|
|
||||||
|
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
||||||
|
|
||||||
|
self.ln_pre = norm_layer(width)
|
||||||
|
self.transformer = TransformerBlock(width,
|
||||||
|
layers,
|
||||||
|
heads,
|
||||||
|
mlp_ratio,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
quant_config=quant_config)
|
||||||
|
|
||||||
|
self.attn_pool = Resampler2(
|
||||||
|
grid_size=int(math.sqrt(n_queries)),
|
||||||
|
embed_dim=output_dim,
|
||||||
|
num_heads=output_dim // 128,
|
||||||
|
kv_dim=width,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
adaptive=False,
|
||||||
|
do_post_projection=False,
|
||||||
|
).to(
|
||||||
|
device=self.positional_embedding.device,
|
||||||
|
dtype=self.positional_embedding.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.ln_post = norm_layer(output_dim)
|
||||||
|
self.proj = nn.Parameter(
|
||||||
|
(output_dim**-0.5) * torch.randn(output_dim, output_dim))
|
||||||
|
self.image_start_id = image_start_id
|
||||||
|
self.image_end_id = image_start_id + 1
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = x.to(
|
||||||
|
dtype=self.transformer.get_cast_dtype(),
|
||||||
|
device=self.transformer.get_cast_device(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# to patches
|
||||||
|
x = self.conv1(x) # shape = [*, width, grid, grid]
|
||||||
|
x = x.reshape(x.shape[0], x.shape[1],
|
||||||
|
-1) # shape = [*, width, grid ** 2]
|
||||||
|
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
||||||
|
|
||||||
|
x = x + get_abs_pos(self.positional_embedding, int(math.sqrt(
|
||||||
|
x.size(1))))
|
||||||
|
|
||||||
|
x = self.ln_pre(x)
|
||||||
|
|
||||||
|
x = x.permute(1, 0, 2) # NLD -> LND
|
||||||
|
x = self.transformer(x)
|
||||||
|
x = x.permute(1, 0, 2) # LND -> NLD
|
||||||
|
|
||||||
|
x = self.attn_pool(x)
|
||||||
|
x = self.ln_post(x)
|
||||||
|
x = x @ self.proj
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def get_image_positions(self,
|
||||||
|
input_ids: torch.Tensor) -> Optional[torch.Tensor]:
|
||||||
|
"""Given the input IDs, extracts start/stop points corresponding to
|
||||||
|
images.
|
||||||
|
|
||||||
|
args:
|
||||||
|
Returns:
|
||||||
|
Optional torch tensor corresponding to start/stop pairs of images.
|
||||||
|
"""
|
||||||
|
if torch.any(input_ids == self.image_start_id):
|
||||||
|
bos_pos = torch.where(input_ids == self.image_start_id)
|
||||||
|
eos_pos = torch.where(input_ids == self.image_end_id)
|
||||||
|
return torch.stack((bos_pos[0], eos_pos[0]), dim=1)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class QWenMLP(nn.Module):
|
class QWenMLP(nn.Module):
|
||||||
|
"""MLP for the language component of the Qwen model, which contains a
|
||||||
|
MergedColumnParallelLinear merging 2 outputs via silu activation."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -56,7 +422,7 @@ class QWenMLP(nn.Module):
|
|||||||
"Only silu is supported for now.")
|
"Only silu is supported for now.")
|
||||||
self.act_fn = SiluAndMul()
|
self.act_fn = SiluAndMul()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
gate_up, _ = self.gate_up_proj(x)
|
gate_up, _ = self.gate_up_proj(x)
|
||||||
x = self.act_fn(gate_up)
|
x = self.act_fn(gate_up)
|
||||||
x, _ = self.c_proj(x)
|
x, _ = self.c_proj(x)
|
||||||
@ -203,6 +569,9 @@ class QWenModel(nn.Module):
|
|||||||
lambda prefix: QWenBlock(config, cache_config, quant_config),
|
lambda prefix: QWenBlock(config, cache_config, quant_config),
|
||||||
prefix=f"{prefix}.h")
|
prefix=f"{prefix}.h")
|
||||||
self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||||
|
self.visual = VisionTransformer(**config.visual,
|
||||||
|
quant_config=quant_config) if hasattr(
|
||||||
|
config, "visual") else None
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -211,9 +580,33 @@ class QWenModel(nn.Module):
|
|||||||
kv_caches: List[torch.Tensor],
|
kv_caches: List[torch.Tensor],
|
||||||
attn_metadata: AttentionMetadata,
|
attn_metadata: AttentionMetadata,
|
||||||
intermediate_tensors: Optional[IntermediateTensors],
|
intermediate_tensors: Optional[IntermediateTensors],
|
||||||
|
pixel_values: Optional[QwenImageInputs],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
img_pos = None
|
||||||
|
# If pixel / visual embeddings are provided, this is a visual model
|
||||||
|
if pixel_values is not None and self.visual is not None:
|
||||||
|
if pixel_values["type"] != "image_embeds":
|
||||||
|
image_embeds = self.visual(pixel_values["data"])
|
||||||
|
else:
|
||||||
|
image_embeds = pixel_values["data"]
|
||||||
|
|
||||||
|
# features should be of shape (# images, 256, hidden_dim)
|
||||||
|
img_pos = self.visual.get_image_positions(input_ids)
|
||||||
|
if isinstance(
|
||||||
|
img_pos,
|
||||||
|
np.ndarray) and img_pos.shape[0] != image_embeds.shape[0]:
|
||||||
|
raise ValueError(
|
||||||
|
f"Number of placeholders: {img_pos.shape[0]} "
|
||||||
|
f"does not match number of images {image_embeds.shape[0]}."
|
||||||
|
)
|
||||||
|
|
||||||
if get_pp_group().is_first_rank:
|
if get_pp_group().is_first_rank:
|
||||||
hidden_states = self.wte(input_ids)
|
hidden_states = self.wte(input_ids)
|
||||||
|
# Merge the image embeddings into the hidden states if actually have
|
||||||
|
# visual features and the corresponding image tokens
|
||||||
|
if img_pos is not None:
|
||||||
|
for idx, (img_bos, img_eos) in enumerate(img_pos):
|
||||||
|
hidden_states[img_bos + 1:img_eos] = image_embeds[idx]
|
||||||
residual = None
|
residual = None
|
||||||
else:
|
else:
|
||||||
assert intermediate_tensors is not None
|
assert intermediate_tensors is not None
|
||||||
@ -237,16 +630,241 @@ class QWenModel(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class QWenLMHeadModel(nn.Module):
|
def get_image_text(image_num: int, padding: bool) -> str:
|
||||||
|
"""Retrieves a placeholder text that when tokenized, will be expanded with
|
||||||
|
image pads.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_num: The number of the image that we want a text prompt for.
|
||||||
|
Images should be indexed starting at 1.
|
||||||
|
padding: Whether or not padding should be manually added.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Text placeholder prompt for the image being considered.
|
||||||
|
"""
|
||||||
|
image_start = f"Picture {image_num}: {IMG_START}"
|
||||||
|
image_end = f"{IMG_END}\n"
|
||||||
|
if not padding:
|
||||||
|
return f"{image_start}{image_end}"
|
||||||
|
return f"{image_start}{MAX_QWEN_IMG_TOKENS * IMG_PAD}{image_end}"
|
||||||
|
|
||||||
|
|
||||||
|
def input_processor_for_qwen(ctx: InputContext,
|
||||||
|
llm_inputs: LLMInputs) -> LLMInputs:
|
||||||
|
"""Processes the inputs, which may or may not be multimodal.
|
||||||
|
Multimodal inputs will only be processed if the model has a "visual"
|
||||||
|
component in its model config, otherwise they'll be ignored.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctx: Context of the loaded model.
|
||||||
|
llm_inputs: LLM inputs which may have a multi_modal_data attribute.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
If the model is language only or not multimodal inputs were provided,
|
||||||
|
returns llm_inputs unmodified. Otherwise, processes the multimodal
|
||||||
|
images / image embeddings and adds the fixed-length image placeholders.
|
||||||
|
"""
|
||||||
|
multi_modal_data = llm_inputs.get("multi_modal_data")
|
||||||
|
|
||||||
|
# Only process images if we have multimodal data and a visual config
|
||||||
|
hf_config = ctx.get_hf_config()
|
||||||
|
if (multi_modal_data is None or "image" not in multi_modal_data
|
||||||
|
or not hasattr(hf_config, "visual")):
|
||||||
|
return llm_inputs
|
||||||
|
|
||||||
|
prompt = llm_inputs.get("prompt")
|
||||||
|
prompt_token_ids = llm_inputs["prompt_token_ids"]
|
||||||
|
model_config = ctx.model_config
|
||||||
|
tokenizer = cached_get_tokenizer(model_config.tokenizer,
|
||||||
|
trust_remote_code=True)
|
||||||
|
image_data = multi_modal_data["image"]
|
||||||
|
if isinstance(image_data, torch.Tensor):
|
||||||
|
num_dims = len(image_data.shape)
|
||||||
|
if num_dims < 2 or num_dims > 3:
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected img embeds to be have 3 dimensions, got {num_dims}")
|
||||||
|
num_images = 1 if num_dims == 2 else image_data.shape[0]
|
||||||
|
else:
|
||||||
|
# TODO - handle multiple image inputs once the API is solidified
|
||||||
|
num_images = 1
|
||||||
|
|
||||||
|
if prompt is None:
|
||||||
|
prompt = tokenizer.decode(prompt_token_ids)
|
||||||
|
|
||||||
|
# Drops anything between <img>/</img> tags; encoding with the tokenizer
|
||||||
|
# will automatically add the image pads for the context.
|
||||||
|
new_prompt, num_matched_images = re.subn(
|
||||||
|
r"(Picture \d*: <img>).*?(<\/img>\n)",
|
||||||
|
r"\1\2",
|
||||||
|
prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
if num_matched_images != num_images:
|
||||||
|
logger.warning(
|
||||||
|
"Number of matched image placeholders %s doesn't match the number "
|
||||||
|
"of expected images %s; check your placeholder formatting.",
|
||||||
|
num_matched_images, num_images)
|
||||||
|
|
||||||
|
new_prompt_token_ids = tokenizer.encode(new_prompt)
|
||||||
|
|
||||||
|
return LLMInputs(prompt=new_prompt,
|
||||||
|
prompt_token_ids=new_prompt_token_ids,
|
||||||
|
multi_modal_data=multi_modal_data)
|
||||||
|
|
||||||
|
|
||||||
|
def input_mapper_for_qwen(ctx: InputContext, data: object) -> MultiModalInputs:
|
||||||
|
"""Maps the input data to its MultiModalInputs (if any).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctx: Context of the loaded model.
|
||||||
|
data: data potentially containing image/image embeddings to be mapped
|
||||||
|
to pixel_values in .forward() for a visual QWenLMHeadModel model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MultiModalInputs containing the stacked normalized images tensor or
|
||||||
|
image embeddings.
|
||||||
|
"""
|
||||||
|
# Early exit if we have provided an image to a language only Qwen model
|
||||||
|
hf_config = ctx.get_hf_config()
|
||||||
|
if not hasattr(hf_config, "visual"):
|
||||||
|
logger.warning(
|
||||||
|
"Images were provided but this model has no visual config; "
|
||||||
|
"multimodal inputs will not be forwarded to the model.")
|
||||||
|
return MultiModalInputs()
|
||||||
|
|
||||||
|
model_config = ctx.model_config
|
||||||
|
tokenizer = cached_get_tokenizer(model_config.tokenizer,
|
||||||
|
trust_remote_code=True)
|
||||||
|
|
||||||
|
image_pair_tok = tokenizer.encode(IMG_START + IMG_END,
|
||||||
|
add_special_tokens=False,
|
||||||
|
return_tensors="pt").squeeze()
|
||||||
|
image_start_id = image_pair_tok[0]
|
||||||
|
image_end_id = image_pair_tok[-1]
|
||||||
|
if (image_start_id + 1) != image_end_id:
|
||||||
|
raise ValueError(
|
||||||
|
f"Found image end ID {image_end_id}, but expected {IMG_START} + 1")
|
||||||
|
if len(image_pair_tok) != (MAX_QWEN_IMG_TOKENS + 2):
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected image context length of {MAX_QWEN_IMG_TOKENS}, "
|
||||||
|
f"but got {image_pair_tok - 2}")
|
||||||
|
|
||||||
|
hf_config = ctx.get_hf_config()
|
||||||
|
image_size = hf_config.visual["image_size"]
|
||||||
|
img_emb_size = hf_config.visual["output_dim"]
|
||||||
|
|
||||||
|
if isinstance(data, torch.Tensor):
|
||||||
|
# It's expected that our values have already been processed
|
||||||
|
# by the visual transformer; shape is expected to be:
|
||||||
|
# (# images, 256, hidden_size)
|
||||||
|
if len(data.shape) == 2:
|
||||||
|
# Assume only one image embed was provided; unsqueeze the extra dim
|
||||||
|
data = data.unsqueeze(0)
|
||||||
|
if len(data.shape) != 3 or data.shape[
|
||||||
|
1] != MAX_QWEN_IMG_TOKENS or data.shape[2] != img_emb_size:
|
||||||
|
raise ValueError(
|
||||||
|
"Expected image embeds to be a tensor of shape"
|
||||||
|
f"[# images, {MAX_QWEN_IMG_TOKENS}, {img_emb_size}], but "
|
||||||
|
f"received shape [{data.shape}]")
|
||||||
|
pixel_values = data
|
||||||
|
|
||||||
|
else:
|
||||||
|
transform = build_normalization_transform(image_size)
|
||||||
|
# TODO - handle multiple image inputs once the API is solidified
|
||||||
|
transformed_images = [transform(data)]
|
||||||
|
pixel_values = torch.stack(transformed_images, dim=0)
|
||||||
|
return MultiModalInputs({"pixel_values": pixel_values})
|
||||||
|
|
||||||
|
|
||||||
|
def build_normalization_transform(image_size: int) -> transforms.Compose:
|
||||||
|
"""Builds a normalization transform which can be applied to one or
|
||||||
|
more input images from which we want to extract visual features.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_size: size of the image to be processed for visual embeddings.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Callable transform for normalizing and resizing one RGB image.
|
||||||
|
"""
|
||||||
|
return transforms.Compose([
|
||||||
|
transforms.Resize((image_size, image_size),
|
||||||
|
interpolation=InterpolationMode.BICUBIC),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD),
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
def dummy_data_for_qwen(
|
||||||
|
ctx: InputContext,
|
||||||
|
seq_len: int,
|
||||||
|
mm_counts: Mapping[str, int],
|
||||||
|
) -> Tuple[SequenceData, Optional[Dict]]:
|
||||||
|
"""Build dummy data for warming up Qwen models; this will only contain text
|
||||||
|
matching the defaults for VLLM unless the model has a visual config.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctx: Context of the loaded model.
|
||||||
|
seq_len: Number of tokens in the text sequence.
|
||||||
|
mm_counts: multimodal data counts.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple containing sequential and multimodal data.
|
||||||
|
"""
|
||||||
|
hf_config = ctx.get_hf_config()
|
||||||
|
|
||||||
|
# The presence of a visual config indicates this is a multimodal model.
|
||||||
|
# If we don't have it, the model is considered an LLM for warmup purposes.
|
||||||
|
if not hasattr(hf_config, "visual"):
|
||||||
|
seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * seq_len))
|
||||||
|
mm_data = None
|
||||||
|
return seq_data, mm_data
|
||||||
|
|
||||||
|
# We have a visual component - use images to warm up
|
||||||
|
num_images = mm_counts["image"]
|
||||||
|
model_config = ctx.model_config
|
||||||
|
tokenizer = cached_get_tokenizer(model_config.tokenizer,
|
||||||
|
trust_remote_code=True)
|
||||||
|
|
||||||
|
# Build the image prompts with no imgpads; the tokenizer will add img pads
|
||||||
|
image_prompt = ''.join(
|
||||||
|
[get_image_text(idx, False) for idx in range(1, num_images + 1)])
|
||||||
|
toks = tokenizer.encode(image_prompt, add_special_tokens=False)
|
||||||
|
|
||||||
|
# Make sure we actually get the fixed context size per tok padding
|
||||||
|
num_pads = toks.count(tokenizer.encode(IMG_PAD)[0])
|
||||||
|
if num_pads != (num_images * MAX_QWEN_IMG_TOKENS):
|
||||||
|
raise ValueError(
|
||||||
|
f"Tokenized dummy data should encode {MAX_QWEN_IMG_TOKENS} pads"
|
||||||
|
f" per image, but got {num_pads} pads for {num_images} image(s)"
|
||||||
|
" in total. Are you using a qwen tokenizer?")
|
||||||
|
|
||||||
|
# Ensure the number of tokens is at minimum the sequence length provided
|
||||||
|
if len(toks) < seq_len:
|
||||||
|
toks += [0] * (seq_len - len(toks))
|
||||||
|
|
||||||
|
# Build the input images; width/height doesn't actually matter here since
|
||||||
|
# the data will get resized and the # of tokens per image is constant
|
||||||
|
image = Image.new("RGB", (224, 224), color=0)
|
||||||
|
mm_data = {"image": image if num_images == 1 else [image] * num_images}
|
||||||
|
return SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, toks)), mm_data
|
||||||
|
|
||||||
|
|
||||||
|
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_qwen)
|
||||||
|
@MULTIMODAL_REGISTRY.register_max_image_tokens(MAX_QWEN_IMG_TOKENS)
|
||||||
|
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen)
|
||||||
|
@INPUT_REGISTRY.register_input_processor(input_processor_for_qwen)
|
||||||
|
class QWenLMHeadModel(nn.Module, SupportsMultiModal):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
|
multimodal_config: MultiModalConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.multimodal_config = multimodal_config
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.transformer = QWenModel(config, cache_config, quant_config)
|
self.transformer = QWenModel(config, cache_config, quant_config)
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||||
@ -257,16 +875,47 @@ class QWenLMHeadModel(nn.Module):
|
|||||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
self.sampler = Sampler()
|
self.sampler = Sampler()
|
||||||
|
|
||||||
def forward(
|
def _get_image_input_type(
|
||||||
self,
|
self,
|
||||||
|
pixel_values: Optional[torch.Tensor]) -> Optional[QwenImageInputs]:
|
||||||
|
"""Determines if the provided pixel_values are normalized pixel values
|
||||||
|
or image embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pixel_values: Optional data to processed into visual embeddings.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None of the QwenImageInputs type used to determine whether or not
|
||||||
|
the visual transformer needs to process the pixel_values.
|
||||||
|
"""
|
||||||
|
if pixel_values is not None and self.transformer.visual is not None:
|
||||||
|
pixel_values = flatten_bn(pixel_values)
|
||||||
|
if len(pixel_values.shape) == 3 and pixel_values.shape[
|
||||||
|
1] == MAX_QWEN_IMG_TOKENS and pixel_values.shape[
|
||||||
|
2] == self.config.visual["output_dim"]:
|
||||||
|
return QwenImageEmbeddingInputs(
|
||||||
|
type="image_embeds",
|
||||||
|
data=pixel_values,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# If we have the wrong shape, assume we still need to process
|
||||||
|
return QwenImagePixelInputs(
|
||||||
|
type="pixel_values",
|
||||||
|
data=pixel_values,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
kv_caches: List[torch.Tensor],
|
||||||
attn_metadata: AttentionMetadata,
|
attn_metadata: AttentionMetadata,
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
) -> torch.Tensor:
|
pixel_values: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
pixel_values = self._get_image_input_type(pixel_values)
|
||||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||||
attn_metadata, intermediate_tensors)
|
attn_metadata, intermediate_tensors,
|
||||||
|
pixel_values)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def make_empty_intermediate_tensors(
|
def make_empty_intermediate_tensors(
|
||||||
@ -328,15 +977,6 @@ class QWenLMHeadModel(nn.Module):
|
|||||||
# Skip loading extra bias for GPTQ models.
|
# Skip loading extra bias for GPTQ models.
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
# Skip loading visual weights to support Qwen-VL models
|
|
||||||
# in cases with text-only inputs
|
|
||||||
# TODO: add support for Qwen-VL
|
|
||||||
if (name not in params_dict
|
|
||||||
and name.startswith("transformer.visual.")):
|
|
||||||
print_warning_once(
|
|
||||||
"Only text inputs are allowed. Images won't be handled "
|
|
||||||
"until Qwen-VL models are fully supported.")
|
|
||||||
continue
|
|
||||||
# Skip layers on other devices.
|
# Skip layers on other devices.
|
||||||
if is_pp_missing_parameter(name, self):
|
if is_pp_missing_parameter(name, self):
|
||||||
continue
|
continue
|
||||||
|
Loading…
x
Reference in New Issue
Block a user