[Model] Upgrade Aria to transformers 4.48 (#12203)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-01-20 17:58:48 +08:00 committed by GitHub
parent 3127e975fb
commit b37d82791e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 178 additions and 379 deletions

View File

@ -26,11 +26,8 @@ def run_aria(question: str, modality: str):
# NOTE: Need L40 (or equivalent) to avoid OOM
llm = LLM(model=model_name,
tokenizer_mode="slow",
dtype="bfloat16",
max_model_len=4096,
max_num_seqs=2,
trust_remote_code=True,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
prompt = (f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>\n{question}"

View File

@ -10,7 +10,6 @@ from typing import Type
import pytest
from transformers import AutoModelForVision2Seq
from transformers import __version__ as TRANSFORMERS_VERSION
from transformers.utils import is_flash_attn_2_available
from vllm.platforms import current_platform
from vllm.utils import identity
@ -140,9 +139,7 @@ VLM_TEST_SETTINGS = {
#### Extended model tests
"aria": VLMTestInfo(
models=["rhymes-ai/Aria"],
tokenizer_mode="slow",
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
dtype="bfloat16",
prompt_formatter=lambda img_prompt: f"<|im_start|>user\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n ", # noqa: E501
img_idx_to_prompt=lambda idx: "<fim_prefix><|img|><fim_suffix>\n",
max_model_len=4096,
@ -158,8 +155,8 @@ VLM_TEST_SETTINGS = {
max_tokens=64,
marks=[
pytest.mark.skipif(
not is_flash_attn_2_available(),
reason="Model needs flash-attn for numeric convergence.",
TRANSFORMERS_VERSION < "4.48.0",
reason="HF model requires transformers>=4.48.0",
),
large_gpu_mark(min_gb=64),
],

View File

@ -11,6 +11,7 @@ from vllm.multimodal.processing import ProcessingCache
from vllm.multimodal.utils import cached_get_tokenizer
from ....multimodal.utils import random_audio, random_image, random_video
from ...registry import HF_EXAMPLE_MODELS
def _test_processing_correctness(
@ -20,12 +21,9 @@ def _test_processing_correctness(
num_batches: int,
simplify_rate: float,
):
if model_id == "TIGER-Lab/Mantis-8B-siglip-llama3":
hf_overrides = {"architectures": ["MantisForConditionalGeneration"]}
elif model_id == "deepseek-ai/deepseek-vl2-tiny":
hf_overrides = {"architectures": ["DeepseekVLV2ForCausalLM"]}
else:
hf_overrides = {}
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
model_info.check_available_online(on_fail="skip")
model_info.check_transformers_version(on_fail="skip")
limit_mm_per_prompt = {
modality: 3 if supports_multi else 1
@ -41,7 +39,7 @@ def _test_processing_correctness(
seed=0,
dtype="float16",
revision=None,
hf_overrides=hf_overrides,
hf_overrides=model_info.hf_overrides,
limit_mm_per_prompt=limit_mm_per_prompt,
)

View File

@ -1,5 +1,9 @@
from dataclasses import dataclass, field
from typing import AbstractSet, Mapping, Optional
from typing import AbstractSet, Any, Literal, Mapping, Optional
import pytest
from packaging.version import Version
from transformers import __version__ as TRANSFORMERS_VERSION
@dataclass(frozen=True)
@ -38,6 +42,50 @@ class _HfExamplesInfo:
trust_remote_code: bool = False
"""The ``trust_remote_code`` level required to load the model."""
hf_overrides: dict[str, Any] = field(default_factory=dict)
"""The ``hf_overrides`` required to load the model."""
def check_transformers_version(
self,
*,
on_fail: Literal["error", "skip"],
) -> None:
"""
If the installed transformers version does not meet the requirements,
perform the given action.
"""
if self.min_transformers_version is None:
return
current_version = TRANSFORMERS_VERSION
required_version = self.min_transformers_version
if Version(current_version) < Version(required_version):
msg = (
f"You have `transformers=={current_version}` installed, but "
f"`transformers>={required_version}` is required to run this "
"model")
if on_fail == "error":
raise RuntimeError(msg)
else:
pytest.skip(msg)
def check_available_online(
self,
*,
on_fail: Literal["error", "skip"],
) -> None:
"""
If the model is not available online, perform the given action.
"""
if not self.is_available_online:
msg = "Model is not available online"
if on_fail == "error":
raise RuntimeError(msg)
else:
pytest.skip(msg)
# yapf: disable
_TEXT_GENERATION_EXAMPLE_MODELS = {
@ -48,8 +96,6 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
trust_remote_code=True),
"ArcticForCausalLM": _HfExamplesInfo("Snowflake/snowflake-arctic-instruct",
trust_remote_code=True),
"AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria",
trust_remote_code=True),
"BaiChuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan-7B",
trust_remote_code=True),
"BaichuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan2-7B-chat",
@ -176,6 +222,8 @@ _CROSS_ENCODER_EXAMPLE_MODELS = {
_MULTIMODAL_EXAMPLE_MODELS = {
# [Decoder-only]
"AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria",
min_transformers_version="4.48"),
"Blip2ForConditionalGeneration": _HfExamplesInfo("Salesforce/blip2-opt-2.7b"), # noqa: E501
"ChameleonForConditionalGeneration": _HfExamplesInfo("facebook/chameleon-7b"), # noqa: E501
"ChatGLMModel": _HfExamplesInfo("THUDM/glm-4v-9b",
@ -183,7 +231,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
trust_remote_code=True),
"ChatGLMForConditionalGeneration": _HfExamplesInfo("chatglm2-6b",
is_available_online=False),
"DeepseekVLV2ForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-vl2-tiny"), # noqa: E501
"DeepseekVLV2ForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-vl2-tiny", # noqa: E501
hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501
"FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"),
"H2OVLChatModel": _HfExamplesInfo("h2oai/h2ovl-mississippi-800m"),
"InternVLChatModel": _HfExamplesInfo("OpenGVLab/InternVL2-1B",
@ -194,7 +243,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"LlavaNextForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-v1.6-mistral-7b-hf"), # noqa: E501
"LlavaNextVideoForConditionalGeneration": _HfExamplesInfo("llava-hf/LLaVA-NeXT-Video-7B-hf"), # noqa: E501
"LlavaOnevisionForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-onevision-qwen2-0.5b-ov-hf"), # noqa: E501
"MantisForConditionalGeneration": _HfExamplesInfo("TIGER-Lab/Mantis-8B-siglip-llama3"), # noqa: E501
"MantisForConditionalGeneration": _HfExamplesInfo("TIGER-Lab/Mantis-8B-siglip-llama3", # noqa: E501
hf_overrides={"architectures": ["MantisForConditionalGeneration"]}), # noqa: E501
"MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-Llama3-V-2_5",
trust_remote_code=True),
"MolmoForCausalLM": _HfExamplesInfo("allenai/Molmo-7B-D-0924",
@ -247,5 +297,12 @@ class HfExampleModels:
def get_hf_info(self, model_arch: str) -> _HfExamplesInfo:
return self.hf_models[model_arch]
def find_hf_info(self, model_id: str) -> _HfExamplesInfo:
for info in self.hf_models.values():
if info.default == model_id:
return info
raise ValueError(f"No example model defined for {model_id}")
HF_EXAMPLE_MODELS = HfExampleModels(_EXAMPLE_MODELS)

View File

@ -1,9 +1,7 @@
from unittest.mock import patch
import pytest
from packaging.version import Version
from transformers import PretrainedConfig
from transformers import __version__ as TRANSFORMERS_VERSION
from vllm import LLM
@ -13,16 +11,8 @@ from .registry import HF_EXAMPLE_MODELS
@pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs())
def test_can_initialize(model_arch):
model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch)
if not model_info.is_available_online:
pytest.skip("Model is not available online")
if model_info.min_transformers_version is not None:
current_version = TRANSFORMERS_VERSION
required_version = model_info.min_transformers_version
if Version(current_version) < Version(required_version):
pytest.skip(
f"You have `transformers=={current_version}` installed, but "
f"`transformers>={required_version}` is required to run this "
"model")
model_info.check_available_online(on_fail="skip")
model_info.check_transformers_version(on_fail="skip")
# Avoid OOM
def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig:

View File

@ -21,6 +21,9 @@ from .registry import HF_EXAMPLE_MODELS
@pytest.mark.parametrize("model_arch", ModelRegistry.get_supported_archs())
def test_registry_imports(model_arch):
model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch)
model_info.check_transformers_version(on_fail="skip")
# Ensure all model classes can be imported successfully
model_cls, _ = ModelRegistry.resolve_model_cls(model_arch)

View File

@ -1,9 +1,11 @@
from typing import (Callable, Iterable, List, Mapping, Optional, Set, Tuple,
TypedDict, Union)
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
Union)
import torch
import torch.nn as nn
from transformers import BatchFeature, PretrainedConfig
from transformers import AriaConfig, AriaTextConfig, BatchFeature
from transformers.models.aria.modeling_aria import AriaCrossAttention
from transformers.models.aria.processing_aria import AriaProcessor
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, QuantizationConfig, VllmConfig
@ -26,10 +28,11 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.aria import (AriaMoELMConfig,
AriaVisionConfig)
from .idefics2_vision_model import Idefics2VisionTransformer
# yapf: disable
from .idefics2_vision_model import (
Idefics2VisionTransformer as Idefics3VisionTransformer)
# yapf: enable
from .interfaces import SupportsMultiModal
from .llama import LlamaDecoderLayer, LlamaMLP, LlamaModel
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
@ -47,87 +50,22 @@ class AriaImagePixelInputs(TypedDict):
"""
class AriaVisionTransformer(Idefics2VisionTransformer):
"""
AriaVisionTransformer is a modified version of Idefics2VisionTransformer
that replaces the post-layernorm with an identity layer.
"""
class AriaProjectorMLP(nn.Module):
def __init__(
self,
config: AriaVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__(config, quant_config, prefix)
self.post_layernorm = nn.Identity()
class AriaVisionModel(nn.Module):
config_class = AriaVisionConfig
def __init__(
self,
config: AriaVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
prefix: str = "",
in_features: int,
hidden_features: int,
output_dim: int,
) -> None:
super().__init__()
self.vision_model = AriaVisionTransformer(
config,
quant_config,
prefix=f"{prefix}.vision_model",
)
def forward(
self,
pixel_values: torch.Tensor,
pixel_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
patch_attention_mask = self._create_patch_attention_mask(pixel_mask)
vit_oup = self.vision_model(
pixel_values=pixel_values,
patch_attention_mask=patch_attention_mask,
)
image_atts = self._create_image_attention_mask(patch_attention_mask)
return vit_oup, image_atts
def _create_patch_attention_mask(
self, pixel_mask: Optional[torch.Tensor]) -> torch.Tensor:
if pixel_mask is None:
return None
patches_subgrid = pixel_mask.unfold(
dimension=1,
size=self.vision_model.config.patch_size,
step=self.vision_model.config.patch_size,
).unfold(
dimension=2,
size=self.vision_model.config.patch_size,
step=self.vision_model.config.patch_size,
)
return (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
def _create_image_attention_mask(
self, patch_attention_mask: torch.Tensor) -> torch.Tensor:
if patch_attention_mask is None:
return None
flattened_mask = patch_attention_mask.flatten(1)
return torch.logical_not(flattened_mask)
class FFN(nn.Module):
def __init__(self, embed_dim: int, ff_dim: int, output_dim: int) -> None:
super().__init__()
self.linear_in = ColumnParallelLinear(embed_dim, ff_dim, bias=False)
self.linear_out = RowParallelLinear(ff_dim, output_dim, bias=False)
self.linear_in = ColumnParallelLinear(in_features,
hidden_features,
bias=False)
self.linear_out = RowParallelLinear(hidden_features,
output_dim,
bias=False)
self.act = get_act_fn("gelu_new")
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@ -137,46 +75,6 @@ class FFN(nn.Module):
return hidden_states
class CrossAttention(nn.Module):
def __init__(self, kv_dim: int, embed_dim: int, num_heads: int) -> None:
super().__init__()
self.num_heads = num_heads
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.k_proj = nn.Linear(kv_dim, embed_dim, bias=False)
self.v_proj = nn.Linear(kv_dim, embed_dim, bias=False)
self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
self.linear = nn.Linear(embed_dim, embed_dim)
self.layer_norm = nn.LayerNorm(embed_dim)
self.ln_kv = nn.LayerNorm(kv_dim)
def forward(
self,
x: torch.Tensor,
hidden_states: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
normed_hidden_states = self.layer_norm(hidden_states)
query = self.q_proj(normed_hidden_states).permute(1, 0, 2)
x = self.ln_kv(x)
key = self.k_proj(x).permute(1, 0, 2)
value = self.v_proj(x).permute(1, 0, 2)
attn_output, _ = self.multihead_attn(query,
key,
value,
attn_mask=attn_mask)
attn_output = attn_output.permute(1, 0, 2)
attn_output = self.linear(attn_output)
return attn_output
class AriaProjector(nn.Module):
"""
A projection module with one cross attention layer and one FFN layer, which
@ -198,42 +96,42 @@ class AriaProjector(nn.Module):
A tensor with the shape of (batch_size, query_number, output_dim)
"""
def __init__(
self,
patch_to_query_dict: dict[int, int],
embed_dim: int,
num_heads: int,
kv_dim: int,
ff_dim: int,
output_dim: int,
norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
) -> None:
def __init__(self, config: AriaConfig) -> None:
super().__init__()
self.patch_to_query_dict = patch_to_query_dict
self.embed_dim = embed_dim
self.num_heads = num_heads
self.patch_to_query_dict = config.projector_patch_to_query_dict
self.in_features = config.vision_config.hidden_size
self.num_heads = config.vision_config.num_attention_heads
self.kv_dim = config.vision_config.hidden_size
self.hidden_features = config.text_config.hidden_size
self.output_dim = config.text_config.hidden_size
self.query = nn.Parameter(
torch.empty(max(patch_to_query_dict.values()), self.embed_dim))
torch.empty(config.max_value_projector_patch_to_query_dict,
self.in_features))
self.cross_attn = CrossAttention(kv_dim, embed_dim, num_heads)
self.cross_attn = AriaCrossAttention(config)
self.ln_ffn = norm_layer(embed_dim)
self.ffn = FFN(embed_dim, ff_dim, output_dim)
self.layer_norm = nn.LayerNorm(self.in_features)
self.feed_forward = AriaProjectorMLP(self.in_features,
self.hidden_features,
self.output_dim)
def forward(
self,
x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
bs = x.shape[0]
queries = self.query.unsqueeze(0).repeat(bs, 1, 1)
batch_size, num_patches = x.shape[0], x.shape[1]
query_num = self.patch_to_query_dict.get(x.shape[1], None)
assert (query_num is not None
), f"Query number for {x.shape[1]} patches is not provided"
if num_patches not in self.patch_to_query_dict:
raise KeyError(f"Number of patches {num_patches} not found in "
"patch_to_query_dict amongst possible values "
f"{self.patch_to_query_dict.keys()}.")
queries = queries[:, :query_num, :]
query_num = self.patch_to_query_dict[num_patches]
queries = self.query[:query_num].unsqueeze(0).repeat(batch_size, 1, 1)
if attn_mask is not None:
attn_mask = attn_mask.repeat_interleave(self.num_heads, 0)
@ -241,7 +139,7 @@ class AriaProjector(nn.Module):
attention_out = self.cross_attn(x, queries, attn_mask=attn_mask)
out = self.ffn(self.ln_ffn(attention_out))
out = self.feed_forward(self.layer_norm(attention_out))
return out
@ -278,7 +176,7 @@ class AriaFusedMoE(FusedMoE):
param.data.copy_(loaded_weight.transpose(1, 2))
class MoELayer(nn.Module):
class AriaTextMoELayer(nn.Module):
"""
Mixture of Experts (MoE) Layer for the AriaMoE model.
@ -289,7 +187,7 @@ class MoELayer(nn.Module):
def __init__(
self,
config: AriaMoELMConfig,
config: AriaTextConfig,
quant_config: Optional[QuantizationConfig],
) -> None:
super().__init__()
@ -303,15 +201,16 @@ class MoELayer(nn.Module):
num_experts=config.moe_num_experts,
top_k=config.moe_topk,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
intermediate_size=config.intermediate_size,
quant_config=quant_config,
reduce_results=True,
)
self.shared_experts = LlamaMLP(
config.hidden_size,
config.moe_intermediate_size * config.moe_num_shared_experts,
config.intermediate_size * config.moe_num_shared_experts,
"silu",
quant_config=quant_config,
bias=config.mlp_bias,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@ -329,13 +228,13 @@ class MoELayer(nn.Module):
router_output = torch.nn.functional.linear(hidden_states,
self.router_weight)
shared_expert_output = self.shared_experts(hidden_states)
sparse_expert_output = self.experts(hidden_states, router_output)
shared_expert_output = self.shared_experts(hidden_states)
return sparse_expert_output + shared_expert_output
class MoEDecoderLayer(LlamaDecoderLayer):
class AriaTextDecoderLayer(LlamaDecoderLayer):
"""
Custom Decoder Layer for the AriaMoE model which modifies the standard
`LlamaDecoderLayer` by replacing the traditional MLP with a Mixture of
@ -344,16 +243,16 @@ class MoEDecoderLayer(LlamaDecoderLayer):
def __init__(
self,
config: AriaMoELMConfig,
config: AriaTextConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__(config, cache_config, quant_config, prefix)
self.mlp = MoELayer(config, quant_config=quant_config)
self.mlp = AriaTextMoELayer(config, quant_config=quant_config)
class AriaMoELMModel(LlamaModel):
class AriaTextModel(LlamaModel):
"""
Custom LlamaModel for the AriaMoE model which modifies the standard
LlamaModel by replacing the `LlamaDecoderLayer` with `MoEDecoderLayer`.
@ -362,7 +261,7 @@ class AriaMoELMModel(LlamaModel):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config,
prefix=prefix,
layer_type=MoEDecoderLayer)
layer_type=AriaTextDecoderLayer)
# Adapted from LlamaModel.load_weights with the modification of adding
# the expert weights mapping to `stacked_params_mapping`
@ -434,25 +333,23 @@ class AriaMoELMModel(LlamaModel):
return loaded_params
def build_mm_projector(config: PretrainedConfig):
return AriaProjector(
patch_to_query_dict=config.projector_patch_to_query_dict,
embed_dim=config.vision_config.hidden_size,
num_heads=config.vision_config.num_attention_heads,
kv_dim=config.vision_config.hidden_size,
ff_dim=config.text_config.hidden_size,
output_dim=config.text_config.hidden_size,
)
class AriaProcessingInfo(BaseProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config()
return self.ctx.get_hf_config(AriaConfig)
def get_vision_config(self) -> AriaVisionConfig:
def get_vision_config(self):
return self.get_hf_config().vision_config
def get_hf_processor(self):
processor = self.ctx.get_hf_processor(AriaProcessor)
# Patch for https://github.com/huggingface/transformers/issues/35768
processor.tokenizer.image_token = "<|img|>"
processor.image_token = "<|img|>"
return processor
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
@ -554,10 +451,14 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
quant_config = vllm_config.quant_config
self.config = config
self.vision_tower = AriaVisionModel(config.vision_config)
self.multi_modal_projector = build_mm_projector(config)
self.vision_tower = Idefics3VisionTransformer(
config.vision_config,
quant_config,
prefix=f"{prefix}.vision_tower",
)
self.multi_modal_projector = AriaProjector(config)
self.vocab_size = config.text_config.vocab_size
self.language_model = AriaMoELMModel(
self.language_model = AriaTextModel(
vllm_config=vllm_config.with_hf_config(config.text_config),
prefix=maybe_prefix(prefix, "language_model.model"),
)
@ -608,6 +509,22 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
pixel_mask=pixel_mask,
)
def _create_patch_attention_mask(
self, pixel_mask: Optional[torch.Tensor]) -> torch.Tensor:
if pixel_mask is None:
return None
patches_subgrid = pixel_mask.unfold(
dimension=1,
size=self.vision_tower.config.patch_size,
step=self.vision_tower.config.patch_size,
).unfold(
dimension=2,
size=self.vision_tower.config.patch_size,
step=self.vision_tower.config.patch_size,
)
return (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
def _process_image_input(
self, image_input: AriaImagePixelInputs
) -> Tuple[torch.Tensor, torch.Tensor]:
@ -616,9 +533,18 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
pixel_values = image_input['pixel_values']
pixel_mask = image_input['pixel_mask']
image_feature, image_attn_mask = self.vision_tower(
pixel_values, pixel_mask=pixel_mask)
return self.multi_modal_projector(image_feature, image_attn_mask)
patch_attention_mask = self._create_patch_attention_mask(pixel_mask)
image_outputs = self.vision_tower(
pixel_values=pixel_values,
patch_attention_mask=patch_attention_mask,
)
image_attn_mask = None
if patch_attention_mask is not None:
flattened_mask = patch_attention_mask.flatten(1)
image_attn_mask = torch.logical_not(flattened_mask)
return self.multi_modal_projector(image_outputs, image_attn_mask)
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
image_input = self._parse_and_validate_image_input(**kwargs)
@ -683,6 +609,5 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)
loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

View File

@ -22,10 +22,10 @@ from vllm.envs import VLLM_USE_MODELSCOPE
from vllm.logger import init_logger
# yapf conflicts with isort for this block
# yapf: disable
from vllm.transformers_utils.configs import (AriaConfig, ChatGLMConfig,
Cohere2Config, DbrxConfig,
DeepseekVLV2Config, EAGLEConfig,
ExaoneConfig, H2OVLChatConfig,
from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config,
DbrxConfig, DeepseekVLV2Config,
EAGLEConfig, ExaoneConfig,
H2OVLChatConfig,
InternVLChatConfig, JAISConfig,
MedusaConfig, MllamaConfig,
MLPSpeculatorConfig, MPTConfig,
@ -52,7 +52,6 @@ _CONFIG_REGISTRY_OVERRIDE_HF: Dict[str, Type[PretrainedConfig]] = {
}
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
"aria": AriaConfig,
"chatglm": ChatGLMConfig,
"cohere2": Cohere2Config,
"dbrx": DbrxConfig,

View File

@ -1,4 +1,3 @@
from vllm.transformers_utils.configs.aria import AriaConfig
from vllm.transformers_utils.configs.chatglm import ChatGLMConfig
from vllm.transformers_utils.configs.cohere2 import Cohere2Config
from vllm.transformers_utils.configs.dbrx import DbrxConfig
@ -24,7 +23,6 @@ from vllm.transformers_utils.configs.telechat2 import Telechat2Config
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
__all__ = [
"AriaConfig",
"ChatGLMConfig",
"Cohere2Config",
"DbrxConfig",

View File

@ -1,165 +0,0 @@
# Copyright 2024 Rhymes AI. All rights reserved.
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Mapping
from transformers import PretrainedConfig
from transformers.models.idefics2.configuration_idefics2 import (
Idefics2VisionConfig)
from transformers.models.llama.configuration_llama import LlamaConfig
from vllm.logger import init_logger
logger = init_logger(__name__)
class AriaVisionConfig(Idefics2VisionConfig):
model_type = "aria_vision_model"
class AriaMoELMConfig(LlamaConfig):
"""
Configuration class for AriaMoE language model.
This class extends the LlamaConfig to include additional parameters specific
to the Mixture of Experts (MoE) architecture.
"""
model_type = "aria_moe_lm"
def __init__(
self,
moe_intermediate_size: int = 4096,
moe_num_experts: int = 8,
moe_topk: int = 2,
moe_num_shared_experts: int = 2,
**kwargs,
):
"""
Initialize the AriaMoELMConfig.
Args:
moe_intermediate_size (int): The intermediate size for MoE layers.
Default is 4096.
moe_num_experts (int): The number of experts in the MoE layer.
Default is 8.
moe_topk (int): The number of top experts to route to for each
token. Default is 2.
moe_num_shared_experts (int): The number of shared experts. Default
is 2.
**kwargs: Additional keyword arguments to be passed to the parent
LlamaConfig.
"""
super().__init__(**kwargs)
self.moe_intermediate_size = moe_intermediate_size
self.moe_num_experts = moe_num_experts
self.moe_topk = moe_topk
self.moe_num_shared_experts = moe_num_shared_experts
class AriaConfig(PretrainedConfig):
"""
Configuration class for Aria model.
This class handles the configuration for both vision and text components of
the Aria model,
as well as additional parameters for image token handling and projector
mapping.
Args:
vision_config (AriaVisionConfig or dict): Configuration for the vision
component.
text_config (AriaMoELMConfig or dict): Configuration for the text
component.
projector_patch_to_query_dict (dict): Mapping of patch sizes to query
dimensions.
ignore_index (int): Index to ignore in loss calculation.
image_token_index (int): Index used to represent image tokens.
**kwargs: Additional keyword arguments passed to the parent class.
Attributes:
model_type (str): Type of the model, set to "aria".
is_composition (bool): Whether the model is a composition of multiple
components.
ignore_index (int): Index to ignore in loss calculation.
image_token_index (int): Index used to represent image tokens.
projector_patch_to_query_dict (dict): Mapping of patch sizes to query
dimensions.
vision_config (AriaVisionConfig): Configuration for the vision
component.
text_config (AriaMoELMConfig): Configuration for the text component.
"""
model_type = "aria"
is_composition = False
def __init__(
self,
vision_config: AriaVisionConfig = AriaVisionConfig(), # noqa: B008
text_config: AriaMoELMConfig = AriaMoELMConfig(), # noqa: B008
projector_patch_to_query_dict: Mapping[int, int] = {
1225: 128,
4900: 256,
},
ignore_index=-100,
image_token_index=32000,
tie_word_embeddings=False,
**kwargs,
):
super().__init__(**kwargs)
self.ignore_index = ignore_index
self.image_token_index = image_token_index
self.tie_word_embeddings = tie_word_embeddings
attn_implementation = kwargs.pop("attn_implementation", None)
# Set the default attention implementation to flash_attention_2 if not
# specified
self._attn_implementation = ("flash_attention_2"
if attn_implementation is None else
attn_implementation)
# Convert the keys and values of projector_patch_to_query_dict to
# integers
# This ensures consistency even if they were provided as strings
self.projector_patch_to_query_dict = {
int(k): int(v)
for k, v in projector_patch_to_query_dict.items()
}
if isinstance(vision_config, dict) and "model_type" in vision_config:
vision_config = AriaVisionConfig(**vision_config)
if attn_implementation is None:
vision_attn_implementation = "flash_attention_2"
elif attn_implementation == "sdpa":
logger.warning("SDPA is not supported for vit, using "
"flash_attention_2 instead")
vision_attn_implementation = "flash_attention_2"
else:
vision_attn_implementation = attn_implementation
vision_config._attn_implementation = vision_attn_implementation
self.vision_config = vision_config
if isinstance(text_config, dict) and "model_type" in text_config:
text_attn_implementation = ("sdpa" if attn_implementation is None
else attn_implementation)
text_config = AriaMoELMConfig(**text_config)
text_config._attn_implementation = text_attn_implementation
self.text_config = text_config
# This is needed for the static kv cache
self.num_hidden_layers = self.text_config.num_hidden_layers