[Model] Add support for Gemma 3 (#14660)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Signed-off-by: Roger Wang <ywang@roblox.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: Roger Wang <ywang@roblox.com> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
45f3f3f59e
commit
c0c25e25fa
@ -267,6 +267,11 @@ See [this page](#generative-models) for more information on how to use generativ
|
|||||||
* `google/gemma-2-9b`, `google/gemma-2-27b`, etc.
|
* `google/gemma-2-9b`, `google/gemma-2-27b`, etc.
|
||||||
* ✅︎
|
* ✅︎
|
||||||
* ✅︎
|
* ✅︎
|
||||||
|
- * `Gemma3ForCausalLM`
|
||||||
|
* Gemma 3
|
||||||
|
* `google/gemma-3-1b-it`, etc.
|
||||||
|
* ✅︎
|
||||||
|
* ✅︎
|
||||||
- * `GlmForCausalLM`
|
- * `GlmForCausalLM`
|
||||||
* GLM-4
|
* GLM-4
|
||||||
* `THUDM/glm-4-9b-chat-hf`, etc.
|
* `THUDM/glm-4-9b-chat-hf`, etc.
|
||||||
@ -752,6 +757,13 @@ See [this page](#generative-models) for more information on how to use generativ
|
|||||||
*
|
*
|
||||||
* ✅︎
|
* ✅︎
|
||||||
* ✅︎
|
* ✅︎
|
||||||
|
- * `Gemma3ForConditionalGeneration`
|
||||||
|
* Gemma 3
|
||||||
|
* T + I<sup>+</sup>
|
||||||
|
* `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc.
|
||||||
|
* ✅︎
|
||||||
|
* ✅︎
|
||||||
|
* ✅︎\*
|
||||||
- * `GLM4VForCausalLM`<sup>^</sup>
|
- * `GLM4VForCausalLM`<sup>^</sup>
|
||||||
* GLM-4V
|
* GLM-4V
|
||||||
* T + I
|
* T + I
|
||||||
@ -937,6 +949,31 @@ For more details, please see: <gh-pr:4087#issuecomment-2250397630>
|
|||||||
To use Qwen2.5-VL series models, you have to install Hugging Face Transformers library from source via `pip install git+https://github.com/huggingface/transformers`.
|
To use Qwen2.5-VL series models, you have to install Hugging Face Transformers library from source via `pip install git+https://github.com/huggingface/transformers`.
|
||||||
:::
|
:::
|
||||||
|
|
||||||
|
:::{note}
|
||||||
|
To use Gemma3 series models, you have to install Hugging Face Transformers library from source via
|
||||||
|
`pip install git+https://github.com/huggingface/transformers`.
|
||||||
|
The earliest commit that supports this is [`50d3530aa04e7a7d003e6b255a98f79fd0447357`](https://github.com/huggingface/transformers/commit/50d3530aa04e7a7d003e6b255a98f79fd0447357).
|
||||||
|
|
||||||
|
Both V0 and V1 support `Gemma3ForConditionalGeneration` for text-only inputs.
|
||||||
|
However, there are differences in how they handle text + image inputs:
|
||||||
|
|
||||||
|
V0 correctly implements the model's attention pattern:
|
||||||
|
- Uses bidirectional attention between the image tokens corresponding to the same image
|
||||||
|
- Uses causal attention for other tokens
|
||||||
|
- Implemented via (naive) PyTorch SDPA with masking tensors
|
||||||
|
- Note: May use significant memory for long prompts with image
|
||||||
|
|
||||||
|
V1 currently uses a simplified attention pattern:
|
||||||
|
- Uses causal attention for all tokens, including image tokens
|
||||||
|
- Generates reasonable outputs but does not match the original model's attention for text + image inputs
|
||||||
|
- Will be updated in the future to support the correct behavior
|
||||||
|
|
||||||
|
This limitation exists because the model's mixed attention pattern (bidirectional for images, causal otherwise) is not yet supported by vLLM's attention backends.
|
||||||
|
|
||||||
|
Additionally, vLLM's current Gemma 3 implementation does not support the pan-and-scan image pre-processing algorithm, which helps handle images with skewed aspect ratios by intelligently cropping them into multiple views.
|
||||||
|
Without this feature, model performance may degrade when processing images that deviate significantly from square dimensions.
|
||||||
|
:::
|
||||||
|
|
||||||
### Pooling Models
|
### Pooling Models
|
||||||
|
|
||||||
See [this page](pooling-models) for more information on how to use pooling models.
|
See [this page](pooling-models) for more information on how to use pooling models.
|
||||||
|
@ -118,6 +118,23 @@ def run_fuyu(questions: list[str], modality: str):
|
|||||||
return llm, prompts, stop_token_ids
|
return llm, prompts, stop_token_ids
|
||||||
|
|
||||||
|
|
||||||
|
# Gemma 3
|
||||||
|
def run_gemma3(questions: list[str], modality: str):
|
||||||
|
assert modality == "image"
|
||||||
|
model_name = "google/gemma-3-4b-it"
|
||||||
|
|
||||||
|
llm = LLM(model=model_name,
|
||||||
|
max_model_len=2048,
|
||||||
|
max_num_seqs=2,
|
||||||
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
||||||
|
|
||||||
|
prompts = [("<bos><start_of_turn>user\n"
|
||||||
|
f"<start_of_image>{question}<end_of_turn>\n"
|
||||||
|
"<start_of_turn>model\n") for question in questions]
|
||||||
|
stop_token_ids = None
|
||||||
|
return llm, prompts, stop_token_ids
|
||||||
|
|
||||||
|
|
||||||
# GLM-4v
|
# GLM-4v
|
||||||
def run_glm4v(questions: list[str], modality: str):
|
def run_glm4v(questions: list[str], modality: str):
|
||||||
assert modality == "image"
|
assert modality == "image"
|
||||||
@ -405,7 +422,7 @@ def run_mllama(questions: list[str], modality: str):
|
|||||||
"type": "image"
|
"type": "image"
|
||||||
}, {
|
}, {
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": f"{question}"
|
"text": question
|
||||||
}]
|
}]
|
||||||
}] for question in questions]
|
}] for question in questions]
|
||||||
prompts = tokenizer.apply_chat_template(messages,
|
prompts = tokenizer.apply_chat_template(messages,
|
||||||
@ -664,6 +681,7 @@ model_example_map = {
|
|||||||
"deepseek_vl_v2": run_deepseek_vl2,
|
"deepseek_vl_v2": run_deepseek_vl2,
|
||||||
"florence2": run_florence2,
|
"florence2": run_florence2,
|
||||||
"fuyu": run_fuyu,
|
"fuyu": run_fuyu,
|
||||||
|
"gemma3": run_gemma3,
|
||||||
"glm4v": run_glm4v,
|
"glm4v": run_glm4v,
|
||||||
"h2ovl_chat": run_h2ovl,
|
"h2ovl_chat": run_h2ovl,
|
||||||
"idefics3": run_idefics3,
|
"idefics3": run_idefics3,
|
||||||
|
@ -80,6 +80,42 @@ def load_deepseek_vl2(question: str, image_urls: list[str]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_gemma3(question, image_urls: list[str]) -> ModelRequestData:
|
||||||
|
model_name = "google/gemma-3-4b-it"
|
||||||
|
|
||||||
|
llm = LLM(model=model_name,
|
||||||
|
max_model_len=8192,
|
||||||
|
max_num_seqs=2,
|
||||||
|
limit_mm_per_prompt={"image": len(image_urls)})
|
||||||
|
|
||||||
|
placeholders = [{"type": "image", "image": url} for url in image_urls]
|
||||||
|
messages = [{
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content": [
|
||||||
|
*placeholders,
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": question
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}]
|
||||||
|
|
||||||
|
processor = AutoProcessor.from_pretrained(model_name)
|
||||||
|
|
||||||
|
prompt = processor.apply_chat_template(messages,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True)
|
||||||
|
|
||||||
|
return ModelRequestData(
|
||||||
|
llm=llm,
|
||||||
|
prompt=prompt,
|
||||||
|
stop_token_ids=None,
|
||||||
|
image_data=[fetch_image(url) for url in image_urls],
|
||||||
|
chat_template=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_h2ovl(question: str, image_urls: list[str]) -> ModelRequestData:
|
def load_h2ovl(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||||
model_name = "h2oai/h2ovl-mississippi-800m"
|
model_name = "h2oai/h2ovl-mississippi-800m"
|
||||||
|
|
||||||
@ -496,6 +532,7 @@ def load_qwen2_5_vl(question, image_urls: list[str]) -> ModelRequestData:
|
|||||||
model_example_map = {
|
model_example_map = {
|
||||||
"aria": load_aria,
|
"aria": load_aria,
|
||||||
"deepseek_vl_v2": load_deepseek_vl2,
|
"deepseek_vl_v2": load_deepseek_vl2,
|
||||||
|
"gemma3": load_gemma3,
|
||||||
"h2ovl_chat": load_h2ovl,
|
"h2ovl_chat": load_h2ovl,
|
||||||
"idefics3": load_idefics3,
|
"idefics3": load_idefics3,
|
||||||
"internvl_chat": load_internvl,
|
"internvl_chat": load_internvl,
|
||||||
|
@ -162,6 +162,7 @@ def _test_processing_correctness(
|
|||||||
"deepseek-ai/deepseek-vl2-tiny",
|
"deepseek-ai/deepseek-vl2-tiny",
|
||||||
"microsoft/Florence-2-base",
|
"microsoft/Florence-2-base",
|
||||||
"adept/fuyu-8b",
|
"adept/fuyu-8b",
|
||||||
|
"google/gemma-3-4b-it",
|
||||||
"THUDM/glm-4v-9b",
|
"THUDM/glm-4v-9b",
|
||||||
"h2oai/h2ovl-mississippi-800m",
|
"h2oai/h2ovl-mississippi-800m",
|
||||||
"OpenGVLab/InternVL2-1B",
|
"OpenGVLab/InternVL2-1B",
|
||||||
|
@ -124,6 +124,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
|||||||
"FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"),
|
"FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"),
|
||||||
"GemmaForCausalLM": _HfExamplesInfo("google/gemma-2b"),
|
"GemmaForCausalLM": _HfExamplesInfo("google/gemma-2b"),
|
||||||
"Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"),
|
"Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"),
|
||||||
|
"Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it",
|
||||||
|
min_transformers_version="4.50"),
|
||||||
"GlmForCausalLM": _HfExamplesInfo("THUDM/glm-4-9b-chat-hf"),
|
"GlmForCausalLM": _HfExamplesInfo("THUDM/glm-4-9b-chat-hf"),
|
||||||
"GPT2LMHeadModel": _HfExamplesInfo("gpt2"),
|
"GPT2LMHeadModel": _HfExamplesInfo("gpt2"),
|
||||||
"GPTBigCodeForCausalLM": _HfExamplesInfo("bigcode/starcoder"),
|
"GPTBigCodeForCausalLM": _HfExamplesInfo("bigcode/starcoder"),
|
||||||
@ -241,6 +243,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
|||||||
"DeepseekVLV2ForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-vl2-tiny", # noqa: E501
|
"DeepseekVLV2ForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-vl2-tiny", # noqa: E501
|
||||||
hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501
|
hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501
|
||||||
"FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"),
|
"FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"),
|
||||||
|
"Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it",
|
||||||
|
min_transformers_version="4.50"),
|
||||||
"GLM4VForCausalLM": _HfExamplesInfo("THUDM/glm-4v-9b",
|
"GLM4VForCausalLM": _HfExamplesInfo("THUDM/glm-4v-9b",
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501
|
hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501
|
||||||
|
@ -350,10 +350,11 @@ class ModelConfig:
|
|||||||
if self.enforce_eager is None:
|
if self.enforce_eager is None:
|
||||||
self.enforce_eager = False
|
self.enforce_eager = False
|
||||||
|
|
||||||
|
interleaved_attn_models = ["gemma2", "gemma3_text", "cohere2"]
|
||||||
sliding_window = getattr(self.hf_text_config, "sliding_window", None)
|
sliding_window = getattr(self.hf_text_config, "sliding_window", None)
|
||||||
has_interleaved_attention = (sliding_window is not None) and (
|
has_interleaved_attention = (sliding_window is not None) and (
|
||||||
isinstance(sliding_window, list) or
|
isinstance(sliding_window, list) or
|
||||||
(self.hf_text_config.model_type in ["gemma2", "cohere2"]))
|
(self.hf_text_config.model_type in interleaved_attn_models))
|
||||||
|
|
||||||
if (not self.disable_sliding_window and has_interleaved_attention):
|
if (not self.disable_sliding_window and has_interleaved_attention):
|
||||||
if (backend :=
|
if (backend :=
|
||||||
@ -2501,11 +2502,11 @@ def _get_and_verify_dtype(
|
|||||||
dtype = dtype.lower()
|
dtype = dtype.lower()
|
||||||
if dtype == "auto":
|
if dtype == "auto":
|
||||||
if config_dtype == torch.float32:
|
if config_dtype == torch.float32:
|
||||||
if config.model_type == "gemma2":
|
if config.model_type in ("gemma2", "gemma3", "gemma3_text"):
|
||||||
logger.info(
|
logger.info(
|
||||||
"For Gemma 2, we downcast float32 to bfloat16 instead "
|
"For Gemma 2 and 3, we downcast float32 to bfloat16 "
|
||||||
"of float16 by default. Please specify `dtype` if you "
|
"instead of float16 by default. Please specify `dtype` "
|
||||||
"want to use float16.")
|
"if you want to use float16.")
|
||||||
torch_dtype = torch.bfloat16
|
torch_dtype = torch.bfloat16
|
||||||
else:
|
else:
|
||||||
# Following the common practice, we use float16 for float32
|
# Following the common practice, we use float16 for float32
|
||||||
@ -2637,7 +2638,9 @@ def _get_and_verify_max_len(
|
|||||||
derived_max_model_len = default_max_len
|
derived_max_model_len = default_max_len
|
||||||
|
|
||||||
rope_scaling = getattr(hf_config, "rope_scaling", None)
|
rope_scaling = getattr(hf_config, "rope_scaling", None)
|
||||||
if rope_scaling is not None:
|
# NOTE(woosuk): Gemma3's max_model_len (128K) is already scaled by RoPE
|
||||||
|
# scaling, so we skip applying the scaling factor again.
|
||||||
|
if rope_scaling is not None and "gemma3" not in hf_config.model_type:
|
||||||
# No need to consider "type" key because of patch_rope_scaling when
|
# No need to consider "type" key because of patch_rope_scaling when
|
||||||
# loading HF config
|
# loading HF config
|
||||||
rope_type = rope_scaling["rope_type"]
|
rope_type = rope_scaling["rope_type"]
|
||||||
|
@ -433,6 +433,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
|||||||
return "<image>"
|
return "<image>"
|
||||||
if model_type == "aria":
|
if model_type == "aria":
|
||||||
return "<|fim_prefix|><|img|><|fim_suffix|>"
|
return "<|fim_prefix|><|img|><|fim_suffix|>"
|
||||||
|
if model_type == "gemma3":
|
||||||
|
return "<start_of_image>"
|
||||||
|
|
||||||
raise TypeError(f"Unknown {modality} model type: {model_type}")
|
raise TypeError(f"Unknown {modality} model type: {model_type}")
|
||||||
elif modality == "audio":
|
elif modality == "audio":
|
||||||
|
533
vllm/model_executor/models/gemma3.py
Normal file
533
vllm/model_executor/models/gemma3.py
Normal file
@ -0,0 +1,533 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# Copyright 2025 The vLLM team.
|
||||||
|
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from typing import Iterable, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import nn
|
||||||
|
from transformers import Gemma3TextConfig
|
||||||
|
|
||||||
|
from vllm.attention import Attention
|
||||||
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.activation import GeluAndMul
|
||||||
|
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
|
||||||
|
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||||
|
QKVParallelLinear,
|
||||||
|
RowParallelLinear)
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
|
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||||
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
|
VocabParallelEmbedding)
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import (
|
||||||
|
default_weight_loader, maybe_remap_kv_scale_name)
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
|
from .interfaces import SupportsLoRA, SupportsPP
|
||||||
|
from .utils import (AutoWeightsLoader, extract_layer_index,
|
||||||
|
is_pp_missing_parameter,
|
||||||
|
make_empty_intermediate_tensors_factory, make_layers,
|
||||||
|
maybe_prefix)
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class Gemma3MLP(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
hidden_activation: str,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.gate_up_proj = MergedColumnParallelLinear(
|
||||||
|
hidden_size, [intermediate_size] * 2,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config)
|
||||||
|
self.down_proj = RowParallelLinear(intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config)
|
||||||
|
if hidden_activation != "gelu_pytorch_tanh":
|
||||||
|
raise ValueError(
|
||||||
|
"Gemma3 uses `gelu_pytorch_tanh` as the hidden activation "
|
||||||
|
"function. Please set `hidden_act` and `hidden_activation` to "
|
||||||
|
"`gelu_pytorch_tanh`.")
|
||||||
|
self.act_fn = GeluAndMul(approximate="tanh")
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
gate_up, _ = self.gate_up_proj(x)
|
||||||
|
x = self.act_fn(gate_up)
|
||||||
|
x, _ = self.down_proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Gemma3Attention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
config: Gemma3TextConfig,
|
||||||
|
hidden_size: int,
|
||||||
|
num_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
head_dim: int,
|
||||||
|
max_position_embeddings: int,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
attn_logits_soft_cap: Optional[float] = None,
|
||||||
|
prefix: str = "") -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
self.total_num_heads = num_heads
|
||||||
|
assert self.total_num_heads % tp_size == 0
|
||||||
|
self.num_heads = self.total_num_heads // tp_size
|
||||||
|
self.total_num_kv_heads = num_kv_heads
|
||||||
|
if self.total_num_kv_heads >= tp_size:
|
||||||
|
# Number of KV heads is greater than TP size, so we partition
|
||||||
|
# the KV heads across multiple tensor parallel GPUs.
|
||||||
|
assert self.total_num_kv_heads % tp_size == 0
|
||||||
|
else:
|
||||||
|
# Number of KV heads is less than TP size, so we replicate
|
||||||
|
# the KV heads across multiple tensor parallel GPUs.
|
||||||
|
assert tp_size % self.total_num_kv_heads == 0
|
||||||
|
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.q_size = self.num_heads * self.head_dim
|
||||||
|
self.kv_size = self.num_kv_heads * self.head_dim
|
||||||
|
self.scaling = config.query_pre_attn_scalar**-0.5
|
||||||
|
|
||||||
|
self.qkv_proj = QKVParallelLinear(
|
||||||
|
hidden_size,
|
||||||
|
self.head_dim,
|
||||||
|
self.total_num_heads,
|
||||||
|
self.total_num_kv_heads,
|
||||||
|
bias=config.attention_bias,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
self.o_proj = RowParallelLinear(
|
||||||
|
self.total_num_heads * self.head_dim,
|
||||||
|
hidden_size,
|
||||||
|
bias=config.attention_bias,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.q_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
||||||
|
self.k_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
# TODO(woosuk): Add reference to the original HF implementation.
|
||||||
|
layer_idx = extract_layer_index(prefix)
|
||||||
|
self.is_sliding = bool((layer_idx + 1) % config.sliding_window_pattern)
|
||||||
|
# Initialize the rotary embedding.
|
||||||
|
if self.is_sliding:
|
||||||
|
# Local attention. Override the values in config.json.
|
||||||
|
self.rope_theta = config.rope_local_base_freq
|
||||||
|
self.rope_scaling = {"rope_type": "default"}
|
||||||
|
self.sliding_window = config.interleaved_sliding_window
|
||||||
|
else:
|
||||||
|
# Global attention. Use the values in config.json.
|
||||||
|
self.rope_theta = config.rope_theta
|
||||||
|
self.rope_scaling = config.rope_scaling
|
||||||
|
self.sliding_window = None
|
||||||
|
self.rotary_emb = get_rope(
|
||||||
|
self.head_dim,
|
||||||
|
rotary_dim=self.head_dim,
|
||||||
|
max_position=max_position_embeddings,
|
||||||
|
base=self.rope_theta,
|
||||||
|
is_neox_style=True,
|
||||||
|
rope_scaling=self.rope_scaling,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize the attention.
|
||||||
|
self.attn = Attention(self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
self.scaling,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
logits_soft_cap=attn_logits_soft_cap,
|
||||||
|
per_layer_sliding_window=self.sliding_window,
|
||||||
|
prefix=f"{prefix}.attn")
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
|
|
||||||
|
q = q.unflatten(-1, (self.num_heads, self.head_dim))
|
||||||
|
q = self.q_norm(q)
|
||||||
|
q = q.flatten(-2, -1)
|
||||||
|
k = k.unflatten(-1, (self.num_kv_heads, self.head_dim))
|
||||||
|
k = self.k_norm(k)
|
||||||
|
k = k.flatten(-2, -1)
|
||||||
|
|
||||||
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
|
attn_output = self.attn(q, k, v)
|
||||||
|
|
||||||
|
if not kwargs.get("has_images", False):
|
||||||
|
# Fast path for text-only inputs. The performance for the text-only
|
||||||
|
# inputs are not affected by the naive attention below.
|
||||||
|
output, _ = self.o_proj(attn_output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
# NOTE(woosuk): Gemma3 uses bidirectional attention between image tokens
|
||||||
|
# that correspond to the same image while using causal attention
|
||||||
|
# otherwise. Current attention backends cannot handle this pattern, so
|
||||||
|
# we temporarily use a naive attention implementation with mask tensors.
|
||||||
|
|
||||||
|
# We intentionally keep the attention backend as-is and only override
|
||||||
|
# `attn_output` with the naive implementation's output. This minimizes
|
||||||
|
# changes to existing model runners and attention backends. The call to
|
||||||
|
# `self.attn(q, k, v)` is only used to populate the KV cache - its
|
||||||
|
# output is discarded and overwritten below. While this duplicates
|
||||||
|
# computation, it maintains compatibility.
|
||||||
|
# TODO(woosuk): Optimize by implementing custom attention kernels.
|
||||||
|
attn_output = self.naive_attn_with_masks(q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
out=attn_output,
|
||||||
|
**kwargs)
|
||||||
|
output, _ = self.o_proj(attn_output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def naive_attn_with_masks(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
out: torch.Tensor,
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# NOTE(woosuk): As described in the comment above, this code is not
|
||||||
|
# meant to be performant. It is only meant to be correct.
|
||||||
|
q = q.view(-1, self.num_heads, self.head_dim)
|
||||||
|
# Expand the key and value to handle GQA.
|
||||||
|
num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
|
k = k.view(-1, self.num_kv_heads, self.head_dim)
|
||||||
|
k = k.repeat_interleave(num_queries_per_kv, dim=-2)
|
||||||
|
v = v.view(-1, self.num_kv_heads, self.head_dim)
|
||||||
|
v = v.repeat_interleave(num_queries_per_kv, dim=-2)
|
||||||
|
|
||||||
|
if self.is_sliding:
|
||||||
|
attn_masks = kwargs["local_attn_masks"]
|
||||||
|
else:
|
||||||
|
attn_masks = kwargs["global_attn_masks"]
|
||||||
|
|
||||||
|
seq_lens = kwargs["seq_lens"]
|
||||||
|
start_idx = 0
|
||||||
|
for seq_len, attn_mask in zip(seq_lens, attn_masks):
|
||||||
|
end_idx = start_idx + seq_len
|
||||||
|
query = q[start_idx:end_idx].unsqueeze(0)
|
||||||
|
key = k[start_idx:end_idx].unsqueeze(0)
|
||||||
|
value = v[start_idx:end_idx].unsqueeze(0)
|
||||||
|
|
||||||
|
# Transpose.
|
||||||
|
query = query.transpose(1, 2)
|
||||||
|
key = key.transpose(1, 2)
|
||||||
|
value = value.transpose(1, 2)
|
||||||
|
|
||||||
|
output = F.scaled_dot_product_attention(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
attn_mask,
|
||||||
|
self.scaling,
|
||||||
|
)
|
||||||
|
output = output.transpose(1, 2).flatten(-2, -1)
|
||||||
|
out[start_idx:end_idx] = output
|
||||||
|
start_idx = end_idx
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Gemma3DecoderLayer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: Gemma3TextConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.self_attn = Gemma3Attention(
|
||||||
|
config=config,
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
num_heads=config.num_attention_heads,
|
||||||
|
num_kv_heads=config.num_key_value_heads,
|
||||||
|
head_dim=config.head_dim,
|
||||||
|
max_position_embeddings=config.max_position_embeddings,
|
||||||
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
attn_logits_soft_cap=None,
|
||||||
|
prefix=f"{prefix}.self_attn",
|
||||||
|
)
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.mlp = Gemma3MLP(
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
intermediate_size=config.intermediate_size,
|
||||||
|
hidden_activation=config.hidden_activation,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
self.input_layernorm = GemmaRMSNorm(config.hidden_size,
|
||||||
|
eps=config.rms_norm_eps)
|
||||||
|
self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size,
|
||||||
|
eps=config.rms_norm_eps)
|
||||||
|
self.pre_feedforward_layernorm = GemmaRMSNorm(config.hidden_size,
|
||||||
|
eps=config.rms_norm_eps)
|
||||||
|
self.post_feedforward_layernorm = GemmaRMSNorm(config.hidden_size,
|
||||||
|
eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
residual: Optional[torch.Tensor],
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
if residual is None:
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
else:
|
||||||
|
hidden_states, residual = self.input_layernorm(
|
||||||
|
hidden_states, residual)
|
||||||
|
hidden_states = self.self_attn(
|
||||||
|
positions=positions,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
|
|
||||||
|
hidden_states, residual = self.pre_feedforward_layernorm(
|
||||||
|
hidden_states, residual)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = self.post_feedforward_layernorm(hidden_states)
|
||||||
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
|
@support_torch_compile
|
||||||
|
class Gemma3Model(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
super().__init__()
|
||||||
|
config = vllm_config.model_config.hf_config
|
||||||
|
cache_config = vllm_config.cache_config
|
||||||
|
quant_config = vllm_config.quant_config
|
||||||
|
self.config = config
|
||||||
|
self.quant_config = quant_config
|
||||||
|
|
||||||
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
|
config.vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
)
|
||||||
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||||
|
config.num_hidden_layers,
|
||||||
|
lambda prefix: Gemma3DecoderLayer(
|
||||||
|
config, cache_config, quant_config, prefix=prefix),
|
||||||
|
prefix=f"{prefix}.layers")
|
||||||
|
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
# Normalize the embedding by sqrt(hidden_size)
|
||||||
|
# The normalizer's data type should be downcasted to the model's
|
||||||
|
# data type such as bfloat16, not float32.
|
||||||
|
# See https://github.com/huggingface/transformers/pull/29402
|
||||||
|
normalizer = self.config.hidden_size**0.5
|
||||||
|
self.register_buffer("normalizer", torch.tensor(normalizer))
|
||||||
|
self.make_empty_intermediate_tensors = (
|
||||||
|
make_empty_intermediate_tensors_factory(
|
||||||
|
["hidden_states", "residual"], config.hidden_size))
|
||||||
|
|
||||||
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
# NOTE(woosuk): Only apply the normalizer to the output of
|
||||||
|
# vocab embedding. Don't apply it to the vision embedding.
|
||||||
|
return self.embed_tokens(input_ids) * self.normalizer
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.Tensor],
|
||||||
|
positions: torch.Tensor,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors],
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
|
if get_pp_group().is_first_rank:
|
||||||
|
if inputs_embeds is not None:
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
else:
|
||||||
|
hidden_states = self.get_input_embeddings(input_ids)
|
||||||
|
residual = None
|
||||||
|
else:
|
||||||
|
assert intermediate_tensors is not None
|
||||||
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
|
residual = intermediate_tensors["residual"]
|
||||||
|
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||||
|
hidden_states, residual = layer(
|
||||||
|
positions,
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
if not get_pp_group().is_last_rank:
|
||||||
|
return IntermediateTensors({
|
||||||
|
"hidden_states": hidden_states,
|
||||||
|
"residual": residual
|
||||||
|
})
|
||||||
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
|
torch.Tensor]]) -> Set[str]:
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
("qkv_proj", "q_proj", "q"),
|
||||||
|
("qkv_proj", "k_proj", "k"),
|
||||||
|
("qkv_proj", "v_proj", "v"),
|
||||||
|
("gate_up_proj", "gate_proj", 0),
|
||||||
|
("gate_up_proj", "up_proj", 1),
|
||||||
|
]
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
loaded_params: Set[str] = set()
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
if (self.quant_config is not None and
|
||||||
|
(scale_name := self.quant_config.get_cache_scale(name))):
|
||||||
|
# Loading kv cache scales for compressed-tensors quantization
|
||||||
|
param = params_dict[scale_name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
loaded_weight = loaded_weight[0]
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(scale_name)
|
||||||
|
continue
|
||||||
|
for (param_name, shard_name, shard_id) in stacked_params_mapping:
|
||||||
|
if shard_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(shard_name, param_name)
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
# Remapping the name of FP8 kv-scale.
|
||||||
|
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||||
|
if name is None:
|
||||||
|
continue
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(name)
|
||||||
|
|
||||||
|
unloaded_params = params_dict.keys() - loaded_params
|
||||||
|
if unloaded_params:
|
||||||
|
logger.warning(
|
||||||
|
"Some weights are not initialized from checkpoints: %s",
|
||||||
|
unloaded_params)
|
||||||
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
|
class Gemma3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||||
|
packed_modules_mapping = {
|
||||||
|
"qkv_proj": [
|
||||||
|
"q_proj",
|
||||||
|
"k_proj",
|
||||||
|
"v_proj",
|
||||||
|
],
|
||||||
|
"gate_up_proj": [
|
||||||
|
"gate_proj",
|
||||||
|
"up_proj",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
config = vllm_config.model_config.hf_config
|
||||||
|
quant_config = vllm_config.quant_config
|
||||||
|
lora_config = vllm_config.lora_config
|
||||||
|
del lora_config # Unused.
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
# currently all existing Gemma models have `tie_word_embeddings` enabled
|
||||||
|
assert config.tie_word_embeddings
|
||||||
|
self.quant_config = quant_config
|
||||||
|
self.model = Gemma3Model(vllm_config=vllm_config,
|
||||||
|
prefix=maybe_prefix(prefix, "model"))
|
||||||
|
self.logits_processor = LogitsProcessor(
|
||||||
|
config.vocab_size, soft_cap=config.final_logit_softcapping)
|
||||||
|
self.sampler = get_sampler()
|
||||||
|
self.make_empty_intermediate_tensors = (
|
||||||
|
self.model.make_empty_intermediate_tensors)
|
||||||
|
|
||||||
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.model.get_input_embeddings(input_ids)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
|
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||||
|
inputs_embeds, **kwargs)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def compute_logits(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> Optional[torch.Tensor]:
|
||||||
|
logits = self.logits_processor(self.model.embed_tokens, hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
logits: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> Optional[SamplerOutput]:
|
||||||
|
next_tokens = self.sampler(logits, sampling_metadata)
|
||||||
|
return next_tokens
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
|
torch.Tensor]]) -> Set[str]:
|
||||||
|
loader = AutoWeightsLoader(
|
||||||
|
self,
|
||||||
|
skip_prefixes=(["lm_head."]
|
||||||
|
if self.config.tie_word_embeddings else None),
|
||||||
|
)
|
||||||
|
return loader.load_weights(weights)
|
425
vllm/model_executor/models/gemma3_mm.py
Normal file
425
vllm/model_executor/models/gemma3_mm.py
Normal file
@ -0,0 +1,425 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
from typing import (Any, Iterable, Literal, Mapping, Optional, Sequence, Set,
|
||||||
|
Tuple, TypedDict, Union)
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from transformers import BatchFeature, Gemma3Config, ProcessorMixin
|
||||||
|
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
|
||||||
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
|
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
|
||||||
|
NestedTensors)
|
||||||
|
from vllm.multimodal.parse import ImageSize, MultiModalDataItems
|
||||||
|
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||||
|
BaseProcessingInfo, PromptReplacement,
|
||||||
|
PromptUpdate, PromptUpdateDetails)
|
||||||
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||||
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
|
from .interfaces import SupportsMultiModal, SupportsPP
|
||||||
|
from .siglip import SiglipVisionModel
|
||||||
|
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||||
|
maybe_prefix, merge_multimodal_embeddings)
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class Gemma3ImagePixelInputs(TypedDict):
|
||||||
|
type: Literal["pixel_values"]
|
||||||
|
data: torch.Tensor
|
||||||
|
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
|
||||||
|
|
||||||
|
|
||||||
|
Gemma3ImageInputs = Gemma3ImagePixelInputs
|
||||||
|
|
||||||
|
|
||||||
|
class Gemma3ProcessingInfo(BaseProcessingInfo):
|
||||||
|
|
||||||
|
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||||
|
return {"image": None}
|
||||||
|
|
||||||
|
def get_mm_max_tokens_per_item(
|
||||||
|
self,
|
||||||
|
seq_len: int,
|
||||||
|
mm_counts: Mapping[str, int],
|
||||||
|
) -> Mapping[str, int]:
|
||||||
|
hf_config = self.ctx.get_hf_config()
|
||||||
|
return {"image": hf_config.mm_tokens_per_image}
|
||||||
|
|
||||||
|
def get_num_image_tokens(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
image_width: int,
|
||||||
|
image_height: int,
|
||||||
|
processor: Optional[ProcessorMixin],
|
||||||
|
) -> int:
|
||||||
|
hf_config = self.ctx.get_hf_config()
|
||||||
|
return hf_config.mm_tokens_per_image
|
||||||
|
|
||||||
|
def get_image_size_with_most_features(self) -> ImageSize:
|
||||||
|
# Result in the max possible feature size (h:w = 16:1)
|
||||||
|
return ImageSize(height=8000, width=50)
|
||||||
|
|
||||||
|
|
||||||
|
class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]):
|
||||||
|
|
||||||
|
def get_dummy_processor_inputs(
|
||||||
|
self,
|
||||||
|
seq_len: int,
|
||||||
|
mm_counts: Mapping[str, int],
|
||||||
|
) -> ProcessorInputs:
|
||||||
|
tokenizer = self.info.get_tokenizer()
|
||||||
|
boi_token = tokenizer.boi_token
|
||||||
|
|
||||||
|
num_images = mm_counts.get("image", 0)
|
||||||
|
target_width, target_height = \
|
||||||
|
self.info.get_image_size_with_most_features()
|
||||||
|
|
||||||
|
mm_data = {
|
||||||
|
"image":
|
||||||
|
self._get_dummy_images(width=target_width,
|
||||||
|
height=target_height,
|
||||||
|
num_images=num_images)
|
||||||
|
}
|
||||||
|
return ProcessorInputs(
|
||||||
|
prompt_text=" ".join([boi_token] * num_images),
|
||||||
|
mm_data=mm_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
|
||||||
|
|
||||||
|
def _call_hf_processor(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
mm_data: Mapping[str, object],
|
||||||
|
mm_kwargs: Mapping[str, object],
|
||||||
|
) -> BatchFeature:
|
||||||
|
# TODO(woosuk): Support pan-and-scan.
|
||||||
|
img_kwargs = mm_kwargs.get("images_kwargs", {})
|
||||||
|
img_kwargs["do_pan_and_scan"] = False
|
||||||
|
mm_kwargs["images_kwargs"] = img_kwargs
|
||||||
|
return super()._call_hf_processor(
|
||||||
|
prompt=prompt,
|
||||||
|
mm_data=mm_data,
|
||||||
|
mm_kwargs=mm_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_mm_fields_config(
|
||||||
|
self,
|
||||||
|
hf_inputs: BatchFeature,
|
||||||
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
|
) -> Mapping[str, MultiModalFieldConfig]:
|
||||||
|
return dict(pixel_values=MultiModalFieldConfig.batched("image"))
|
||||||
|
|
||||||
|
def _get_prompt_updates(
|
||||||
|
self,
|
||||||
|
mm_items: MultiModalDataItems,
|
||||||
|
hf_processor_mm_kwargs: Mapping[str, Any],
|
||||||
|
out_mm_kwargs: MultiModalKwargs,
|
||||||
|
) -> Sequence[PromptUpdate]:
|
||||||
|
tokenizer = self.info.get_tokenizer()
|
||||||
|
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||||
|
hf_config = self.info.get_hf_config()
|
||||||
|
|
||||||
|
boi_token = tokenizer.boi_token
|
||||||
|
image_token = tokenizer.image_token
|
||||||
|
mm_tokens_per_image = hf_config.mm_tokens_per_image
|
||||||
|
image_tokens_expanded = "".join([image_token] * mm_tokens_per_image)
|
||||||
|
|
||||||
|
def get_replacement_gemma3(item_idx: int):
|
||||||
|
return PromptUpdateDetails(
|
||||||
|
full=hf_processor.full_image_sequence,
|
||||||
|
features=image_tokens_expanded,
|
||||||
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
PromptReplacement(
|
||||||
|
modality="image",
|
||||||
|
target=boi_token,
|
||||||
|
replacement=get_replacement_gemma3,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class Gemma3MultiModalProjector(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: Gemma3Config):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.mm_input_projection_weight = nn.Parameter(
|
||||||
|
torch.zeros(config.vision_config.hidden_size,
|
||||||
|
config.text_config.hidden_size))
|
||||||
|
|
||||||
|
self.mm_soft_emb_norm = GemmaRMSNorm(
|
||||||
|
config.vision_config.hidden_size,
|
||||||
|
eps=config.vision_config.layer_norm_eps)
|
||||||
|
|
||||||
|
self.patches_per_image = int(config.vision_config.image_size //
|
||||||
|
config.vision_config.patch_size)
|
||||||
|
self.tokens_per_side = int(config.mm_tokens_per_image**0.5)
|
||||||
|
self.kernel_size = self.patches_per_image // self.tokens_per_side
|
||||||
|
self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size,
|
||||||
|
stride=self.kernel_size)
|
||||||
|
|
||||||
|
def forward(self, vision_outputs: torch.Tensor):
|
||||||
|
batch_size, _, seq_length = vision_outputs.shape
|
||||||
|
|
||||||
|
reshaped_vision_outputs = vision_outputs.transpose(1, 2)
|
||||||
|
reshaped_vision_outputs = reshaped_vision_outputs.reshape(
|
||||||
|
batch_size, seq_length, self.patches_per_image,
|
||||||
|
self.patches_per_image)
|
||||||
|
reshaped_vision_outputs = reshaped_vision_outputs.contiguous()
|
||||||
|
|
||||||
|
pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs)
|
||||||
|
pooled_vision_outputs = pooled_vision_outputs.flatten(2)
|
||||||
|
pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2)
|
||||||
|
|
||||||
|
normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs)
|
||||||
|
|
||||||
|
projected_vision_outputs = torch.matmul(
|
||||||
|
normed_vision_outputs, self.mm_input_projection_weight)
|
||||||
|
return projected_vision_outputs.type_as(vision_outputs)
|
||||||
|
|
||||||
|
|
||||||
|
@MULTIMODAL_REGISTRY.register_processor(Gemma3MultiModalProcessor,
|
||||||
|
info=Gemma3ProcessingInfo,
|
||||||
|
dummy_inputs=Gemma3DummyInputsBuilder)
|
||||||
|
class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||||
|
SupportsPP):
|
||||||
|
packed_modules_mapping = {
|
||||||
|
"qkv_proj": [
|
||||||
|
"q_proj",
|
||||||
|
"k_proj",
|
||||||
|
"v_proj",
|
||||||
|
],
|
||||||
|
"gate_up_proj": [
|
||||||
|
"gate_proj",
|
||||||
|
"up_proj",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
super().__init__()
|
||||||
|
config = vllm_config.model_config.hf_config
|
||||||
|
quant_config = vllm_config.quant_config
|
||||||
|
multimodal_config = vllm_config.model_config.multimodal_config
|
||||||
|
self.config = config
|
||||||
|
self.quant_config = quant_config
|
||||||
|
self.multimodal_config = multimodal_config
|
||||||
|
self.sliding_window = config.text_config.interleaved_sliding_window
|
||||||
|
|
||||||
|
self.vision_tower = SiglipVisionModel(config.vision_config,
|
||||||
|
quant_config,
|
||||||
|
prefix=maybe_prefix(
|
||||||
|
prefix, "vision_tower"))
|
||||||
|
self.multi_modal_projector = Gemma3MultiModalProjector(config)
|
||||||
|
|
||||||
|
self.language_model = init_vllm_registered_model(
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
hf_config=config.text_config,
|
||||||
|
prefix=maybe_prefix(prefix, "language_model"),
|
||||||
|
architectures=["Gemma3ForCausalLM"],
|
||||||
|
)
|
||||||
|
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||||
|
self.language_model.logits_processor.scale *= logit_scale
|
||||||
|
|
||||||
|
self.make_empty_intermediate_tensors = (
|
||||||
|
self.language_model.make_empty_intermediate_tensors)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sampler(self):
|
||||||
|
return self.language_model.sampler
|
||||||
|
|
||||||
|
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
|
||||||
|
h = w = self.config.vision_config.image_size
|
||||||
|
expected_dims = (3, h, w)
|
||||||
|
|
||||||
|
def _validate_shape(d: torch.Tensor):
|
||||||
|
if d.shape != expected_dims:
|
||||||
|
raise ValueError(
|
||||||
|
"The expected shape of pixel values per image per batch "
|
||||||
|
f"is {expected_dims}. You supplied {tuple(d.shape)}.")
|
||||||
|
|
||||||
|
for d in data:
|
||||||
|
_validate_shape(d)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _parse_and_validate_image_input(
|
||||||
|
self, **kwargs: object) -> Optional[Gemma3ImageInputs]:
|
||||||
|
pixel_values = kwargs.pop("pixel_values", None)
|
||||||
|
image_embeds = kwargs.pop("image_embeds", None)
|
||||||
|
assert image_embeds is None, "Gemma3 does not support image_embeds."
|
||||||
|
if pixel_values is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not isinstance(pixel_values, (torch.Tensor, list[torch.Tensor])):
|
||||||
|
raise ValueError("Incorrect type of pixel values. "
|
||||||
|
f"Got type: {type(pixel_values)}")
|
||||||
|
|
||||||
|
pixel_values = flatten_bn(pixel_values, concat=True)
|
||||||
|
return Gemma3ImagePixelInputs(
|
||||||
|
type="pixel_values",
|
||||||
|
data=self._validate_pixel_values(pixel_values),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _image_pixels_to_features(
|
||||||
|
self,
|
||||||
|
vision_tower: SiglipVisionModel,
|
||||||
|
pixel_values: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
target_dtype = vision_tower.get_input_embeddings().weight.dtype
|
||||||
|
image_features = vision_tower(pixel_values.to(dtype=target_dtype))
|
||||||
|
return image_features
|
||||||
|
|
||||||
|
def _process_image_input(
|
||||||
|
self,
|
||||||
|
image_input: Gemma3ImageInputs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
assert self.vision_tower is not None
|
||||||
|
pixel_values = image_input["data"]
|
||||||
|
vision_outputs = self._image_pixels_to_features(
|
||||||
|
self.vision_tower,
|
||||||
|
pixel_values,
|
||||||
|
)
|
||||||
|
return self.multi_modal_projector(vision_outputs)
|
||||||
|
|
||||||
|
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
|
||||||
|
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||||
|
if image_input is None:
|
||||||
|
return None
|
||||||
|
vision_embeddings = self._process_image_input(image_input)
|
||||||
|
return vision_embeddings
|
||||||
|
|
||||||
|
def get_input_embeddings(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
multimodal_embeddings: Optional[NestedTensors] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if multimodal_embeddings is None:
|
||||||
|
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||||
|
else:
|
||||||
|
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||||
|
inputs_embeds = merge_multimodal_embeddings(
|
||||||
|
input_ids, inputs_embeds, multimodal_embeddings,
|
||||||
|
self.config.image_token_index)
|
||||||
|
return inputs_embeds
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs: object) -> Union[SamplerOutput, IntermediateTensors]:
|
||||||
|
if intermediate_tensors is not None:
|
||||||
|
inputs_embeds = None
|
||||||
|
|
||||||
|
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||||
|
# condition is for v0 compatibility.
|
||||||
|
elif inputs_embeds is None:
|
||||||
|
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||||
|
|
||||||
|
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||||
|
vision_embeddings)
|
||||||
|
if vision_embeddings is not None:
|
||||||
|
kwargs = self.prepare_attn_masks(
|
||||||
|
input_ids,
|
||||||
|
positions,
|
||||||
|
mask_dtype=vision_embeddings.dtype,
|
||||||
|
**kwargs)
|
||||||
|
input_ids = None
|
||||||
|
|
||||||
|
hidden_states = self.language_model.model(input_ids,
|
||||||
|
positions,
|
||||||
|
intermediate_tensors,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def prepare_attn_masks(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
mask_dtype: torch.dtype,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
kwargs["has_images"] = True
|
||||||
|
# NOTE(woosuk): Here, we distinguish the sequences by the position id 0.
|
||||||
|
# This is a HACK. Fix this.
|
||||||
|
start_idices = (positions == 0).cpu().nonzero()
|
||||||
|
num_seqs = len(start_idices)
|
||||||
|
seq_lens = []
|
||||||
|
for i in range(num_seqs):
|
||||||
|
start_idx = start_idices[i].item()
|
||||||
|
if i < num_seqs - 1:
|
||||||
|
end_idx = start_idices[i + 1].item()
|
||||||
|
else:
|
||||||
|
end_idx = len(input_ids)
|
||||||
|
seq_lens.append(end_idx - start_idx)
|
||||||
|
kwargs["seq_lens"] = seq_lens
|
||||||
|
|
||||||
|
global_attn_masks = []
|
||||||
|
local_attn_masks = []
|
||||||
|
start_idx = 0
|
||||||
|
for seq_len in seq_lens:
|
||||||
|
end_idx = start_idx + seq_len
|
||||||
|
input_token_ids = input_ids[start_idx:end_idx]
|
||||||
|
start_idx = end_idx
|
||||||
|
# Create a global causal mask.
|
||||||
|
global_attn_mask = torch.empty(
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
seq_len,
|
||||||
|
seq_len,
|
||||||
|
dtype=mask_dtype,
|
||||||
|
device=input_ids.device,
|
||||||
|
)
|
||||||
|
global_attn_mask.fill_(float("-inf"))
|
||||||
|
# Fill the lower triangle with 0.
|
||||||
|
global_attn_mask = global_attn_mask.triu(diagonal=1)
|
||||||
|
|
||||||
|
# Consider the bidirectional attention between image tokens.
|
||||||
|
img_mask = torch.zeros_like(global_attn_mask)
|
||||||
|
img_pos = (input_token_ids == self.config.image_token_index)
|
||||||
|
img_mask[:, :, :, img_pos] += 1
|
||||||
|
img_mask[:, :, img_pos, :] += 1
|
||||||
|
global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask)
|
||||||
|
global_attn_masks.append(global_attn_mask)
|
||||||
|
|
||||||
|
# Create a local causal mask with sliding window (1024).
|
||||||
|
local_attn_mask = torch.ones_like(global_attn_mask)
|
||||||
|
local_attn_mask = torch.tril(local_attn_mask,
|
||||||
|
diagonal=-self.sliding_window)
|
||||||
|
local_attn_mask = torch.where(local_attn_mask == 0,
|
||||||
|
global_attn_mask, float("-inf"))
|
||||||
|
local_attn_masks.append(local_attn_mask)
|
||||||
|
kwargs["global_attn_masks"] = global_attn_masks
|
||||||
|
kwargs["local_attn_masks"] = local_attn_masks
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
def compute_logits(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> Optional[torch.Tensor]:
|
||||||
|
return self.language_model.compute_logits(hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
logits: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> Optional[SamplerOutput]:
|
||||||
|
return self.language_model.sample(logits, sampling_metadata)
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
|
torch.Tensor]]) -> Set[str]:
|
||||||
|
loader = AutoWeightsLoader(self)
|
||||||
|
return loader.load_weights(weights)
|
@ -53,6 +53,7 @@ _TEXT_GENERATION_MODELS = {
|
|||||||
"Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
|
"Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
|
||||||
"GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
|
"GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
|
||||||
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
|
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
|
||||||
|
"Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
|
||||||
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
|
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
|
||||||
"GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
|
"GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
|
||||||
"GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
|
"GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
|
||||||
@ -161,6 +162,7 @@ _MULTIMODAL_MODELS = {
|
|||||||
"ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"), # noqa: E501
|
"ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"), # noqa: E501
|
||||||
"DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
|
"DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
|
||||||
"FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
|
"FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
|
||||||
|
"Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"), # noqa: E501
|
||||||
"GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
|
"GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
|
||||||
"H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
|
"H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
|
||||||
"InternVLChatModel": ("internvl", "InternVLChatModel"),
|
"InternVLChatModel": ("internvl", "InternVLChatModel"),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user