[Model][VLM] Initialize support for Mono-InternVL model (#9528)
This commit is contained in:
parent
9dbcce84a7
commit
bb392ea2d2
@ -376,7 +376,7 @@ Text Generation
|
|||||||
* - :code:`InternVLChatModel`
|
* - :code:`InternVLChatModel`
|
||||||
- InternVL2
|
- InternVL2
|
||||||
- T + I\ :sup:`E+`
|
- T + I\ :sup:`E+`
|
||||||
- :code:`OpenGVLab/InternVL2-4B`, :code:`OpenGVLab/InternVL2-8B`, etc.
|
- :code:`OpenGVLab/Mono-InternVL-2B`, :code:`OpenGVLab/InternVL2-4B`, :code:`OpenGVLab/InternVL2-8B`, etc.
|
||||||
-
|
-
|
||||||
- ✅︎
|
- ✅︎
|
||||||
* - :code:`LlavaForConditionalGeneration`
|
* - :code:`LlavaForConditionalGeneration`
|
||||||
|
@ -7,7 +7,6 @@ from PIL.Image import Image
|
|||||||
from transformers import AutoConfig
|
from transformers import AutoConfig
|
||||||
|
|
||||||
from vllm.multimodal.utils import rescale_image_size
|
from vllm.multimodal.utils import rescale_image_size
|
||||||
from vllm.platforms import current_platform
|
|
||||||
|
|
||||||
from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,
|
from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,
|
||||||
_ImageAssets)
|
_ImageAssets)
|
||||||
@ -19,15 +18,20 @@ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
|
|||||||
"cherry_blossom":
|
"cherry_blossom":
|
||||||
"<|im_start|>User\n<image>\nWhat is the season?<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501
|
"<|im_start|>User\n<image>\nWhat is the season?<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501
|
||||||
})
|
})
|
||||||
HF_MULTIIMAGE_IMAGE_PROMPT = "<|im_start|>User\nImage-1: <image>\nImage-2: <image>\nDescribe the two images in detail.<|im_end|>\n<|im_start|>Assistant\n" # noqa: E501
|
HF_MULTIIMAGE_IMAGE_PROMPT = "<|im_start|>User\nImage-1: <image>\nImage-2: <image>\nDescribe the two images in short.<|im_end|>\n<|im_start|>Assistant\n" # noqa: E501
|
||||||
|
|
||||||
models = [
|
models = [
|
||||||
"OpenGVLab/InternVL2-1B",
|
"OpenGVLab/InternVL2-1B",
|
||||||
"OpenGVLab/InternVL2-2B",
|
"OpenGVLab/InternVL2-2B",
|
||||||
|
# NOTE: Mono-InternVL-2B doesn't work with fp16,
|
||||||
|
# it will result NaN during inference.
|
||||||
|
# See: https://huggingface.co/OpenGVLab/Mono-InternVL-2B/discussions/9
|
||||||
|
"OpenGVLab/Mono-InternVL-2B",
|
||||||
# Broken due to outdated implementation of Phi-3
|
# Broken due to outdated implementation of Phi-3
|
||||||
# See: https://huggingface.co/OpenGVLab/InternVL2-4B/discussions/3
|
# See: https://huggingface.co/OpenGVLab/InternVL2-4B/discussions/3
|
||||||
# "OpenGVLab/InternVL2-4B",
|
# "OpenGVLab/InternVL2-4B",
|
||||||
]
|
]
|
||||||
|
target_dtype = "bfloat16"
|
||||||
|
|
||||||
|
|
||||||
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/modeling_internvl_chat.py
|
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/modeling_internvl_chat.py
|
||||||
@ -52,9 +56,15 @@ def generate(
|
|||||||
|
|
||||||
input_embeds = input_embeds.reshape(B, N, C)
|
input_embeds = input_embeds.reshape(B, N, C)
|
||||||
|
|
||||||
outputs = self.language_model.generate(
|
forward_kwargs = dict(
|
||||||
inputs_embeds=input_embeds,
|
inputs_embeds=input_embeds,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
)
|
||||||
|
if getattr(self, "use_visual_token_mask", False):
|
||||||
|
visual_token_mask = selected.reshape(B, N, 1).to(input_embeds.dtype)
|
||||||
|
forward_kwargs["visual_token_mask"] = visual_token_mask
|
||||||
|
outputs = self.language_model.generate(
|
||||||
|
**forward_kwargs,
|
||||||
**generate_kwargs,
|
**generate_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -243,11 +253,6 @@ def run_awq_test(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
target_dtype = "half"
|
|
||||||
if current_platform.is_cpu():
|
|
||||||
target_dtype = "bfloat16"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", models)
|
@pytest.mark.parametrize("model", models)
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"size_factors",
|
"size_factors",
|
||||||
|
@ -97,6 +97,37 @@ class InternVisionEmbeddings(nn.Module):
|
|||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
class InternVisionPatchModel(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: PretrainedConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.embeddings = InternVisionEmbeddings(config)
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.embeddings
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
pixel_values: Optional[torch.Tensor] = None,
|
||||||
|
pixel_embeds: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
if pixel_values is None and pixel_embeds is None:
|
||||||
|
raise ValueError(
|
||||||
|
'You have to specify pixel_values or pixel_embeds')
|
||||||
|
|
||||||
|
if pixel_embeds is not None:
|
||||||
|
hidden_states = pixel_embeds
|
||||||
|
elif pixel_values is not None:
|
||||||
|
if pixel_values.ndim == 4:
|
||||||
|
hidden_states = self.embeddings(pixel_values)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f'wrong pixel_values size: {pixel_values.shape}')
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class InternParallelAttention(nn.Module):
|
class InternParallelAttention(nn.Module):
|
||||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||||
|
|
||||||
|
166
vllm/model_executor/models/internlm2_ve.py
Normal file
166
vllm/model_executor/models/internlm2_ve.py
Normal file
@ -0,0 +1,166 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
from vllm.attention import AttentionMetadata
|
||||||
|
from vllm.config import CacheConfig
|
||||||
|
from vllm.distributed import get_pp_group
|
||||||
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
|
from vllm.model_executor.models.internlm2 import (InternLM2Attention,
|
||||||
|
InternLM2ForCausalLM,
|
||||||
|
InternLM2MLP, InternLM2Model)
|
||||||
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
|
from .utils import make_layers
|
||||||
|
|
||||||
|
|
||||||
|
class InternLM2VEDecoderLayer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
rope_theta = getattr(config, "rope_theta", 10000)
|
||||||
|
rope_scaling = getattr(config, "rope_scaling", None)
|
||||||
|
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||||
|
8192)
|
||||||
|
self.attention = InternLM2Attention(
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
num_heads=config.num_attention_heads,
|
||||||
|
num_kv_heads=config.num_key_value_heads,
|
||||||
|
rope_theta=rope_theta,
|
||||||
|
rope_scaling=rope_scaling,
|
||||||
|
max_position_embeddings=max_position_embeddings,
|
||||||
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
self.feed_forward = InternLM2MLP(
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
intermediate_size=config.intermediate_size,
|
||||||
|
hidden_act=config.hidden_act,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
self.feed_forward_ve = InternLM2MLP(
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
intermediate_size=config.intermediate_size,
|
||||||
|
hidden_act=config.hidden_act,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
self.attention_norm = RMSNorm(config.hidden_size,
|
||||||
|
eps=config.rms_norm_eps)
|
||||||
|
self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
residual: Optional[torch.Tensor],
|
||||||
|
visual_token_mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
# Self Attention
|
||||||
|
if residual is None:
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.attention_norm(hidden_states)
|
||||||
|
else:
|
||||||
|
hidden_states, residual = self.attention_norm(
|
||||||
|
hidden_states, residual)
|
||||||
|
hidden_states = self.attention(
|
||||||
|
positions=positions,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
attn_metadata=attn_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fully Connected
|
||||||
|
hidden_states, residual = self.ffn_norm(hidden_states, residual)
|
||||||
|
if visual_token_mask is not None and visual_token_mask.any():
|
||||||
|
visual_token_mask = visual_token_mask.repeat(
|
||||||
|
1, self.hidden_size).bool()
|
||||||
|
text_token_mask = ~visual_token_mask
|
||||||
|
hidden_states[visual_token_mask] = self.feed_forward_ve(
|
||||||
|
hidden_states[visual_token_mask].reshape(
|
||||||
|
-1, self.hidden_size)).flatten()
|
||||||
|
if text_token_mask.any():
|
||||||
|
hidden_states[text_token_mask] = self.feed_forward(
|
||||||
|
hidden_states[text_token_mask].reshape(
|
||||||
|
-1, self.hidden_size)).flatten()
|
||||||
|
else:
|
||||||
|
hidden_states = self.feed_forward(hidden_states)
|
||||||
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
|
class InternLM2VEModel(InternLM2Model):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
) -> None:
|
||||||
|
super().__init__(config, cache_config, quant_config)
|
||||||
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||||
|
config.num_hidden_layers,
|
||||||
|
lambda prefix: InternLM2VEDecoderLayer(config, cache_config,
|
||||||
|
quant_config),
|
||||||
|
prefix=f"{prefix}.layers")
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
visual_token_mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
|
if get_pp_group().is_first_rank:
|
||||||
|
if inputs_embeds is not None:
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
else:
|
||||||
|
hidden_states = self.tok_embeddings(input_ids)
|
||||||
|
residual = None
|
||||||
|
else:
|
||||||
|
assert intermediate_tensors is not None
|
||||||
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
|
residual = intermediate_tensors["residual"]
|
||||||
|
for i in range(self.start_layer, self.end_layer):
|
||||||
|
layer = self.layers[i]
|
||||||
|
hidden_states, residual = layer(
|
||||||
|
positions,
|
||||||
|
hidden_states,
|
||||||
|
kv_caches[i - self.start_layer],
|
||||||
|
attn_metadata,
|
||||||
|
residual,
|
||||||
|
visual_token_mask=visual_token_mask,
|
||||||
|
)
|
||||||
|
if not get_pp_group().is_last_rank:
|
||||||
|
return IntermediateTensors({
|
||||||
|
"hidden_states": hidden_states,
|
||||||
|
"residual": residual
|
||||||
|
})
|
||||||
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class InternLM2VEForCausalLM(InternLM2ForCausalLM):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(config, cache_config, quant_config)
|
||||||
|
self.model = InternLM2VEModel(config, cache_config, quant_config)
|
@ -21,7 +21,8 @@ from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
|
|||||||
token_inputs)
|
token_inputs)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||||
from vllm.model_executor.models.intern_vit import InternVisionModel
|
from vllm.model_executor.models.intern_vit import (InternVisionModel,
|
||||||
|
InternVisionPatchModel)
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.base import MultiModalInputs
|
from vllm.multimodal.base import MultiModalInputs
|
||||||
@ -427,13 +428,9 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
self.downsample_ratio = config.downsample_ratio
|
self.downsample_ratio = config.downsample_ratio
|
||||||
self.ps_version = config.ps_version
|
self.ps_version = config.ps_version
|
||||||
|
|
||||||
vision_feature_layer = self.select_layer
|
self.llm_arch_name = config.text_config.architectures[0]
|
||||||
if vision_feature_layer < 0:
|
self.is_mono = self.llm_arch_name == 'InternLM2VEForCausalLM'
|
||||||
num_hidden_layers = config.vision_config.num_hidden_layers \
|
self.vision_model = self._init_vision_model(config, self.is_mono)
|
||||||
+ vision_feature_layer + 1
|
|
||||||
else:
|
|
||||||
num_hidden_layers = vision_feature_layer + 1
|
|
||||||
self.vision_model = self._init_vision_model(config, num_hidden_layers)
|
|
||||||
|
|
||||||
self.language_model = init_vllm_registered_model(
|
self.language_model = init_vllm_registered_model(
|
||||||
config.text_config, cache_config, quant_config)
|
config.text_config, cache_config, quant_config)
|
||||||
@ -451,10 +448,19 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
|
|
||||||
return Sampler()
|
return Sampler()
|
||||||
|
|
||||||
def _init_vision_model(self, config: PretrainedConfig,
|
def _init_vision_model(self, config: PretrainedConfig, is_mono: bool):
|
||||||
num_hidden_layers: int):
|
if not is_mono:
|
||||||
return InternVisionModel(config.vision_config,
|
vision_feature_layer = self.select_layer
|
||||||
num_hidden_layers_override=num_hidden_layers)
|
if vision_feature_layer < 0:
|
||||||
|
num_hidden_layers = config.vision_config.num_hidden_layers \
|
||||||
|
+ vision_feature_layer + 1
|
||||||
|
else:
|
||||||
|
num_hidden_layers = vision_feature_layer + 1
|
||||||
|
return InternVisionModel(
|
||||||
|
config.vision_config,
|
||||||
|
num_hidden_layers_override=num_hidden_layers)
|
||||||
|
else:
|
||||||
|
return InternVisionPatchModel(config.vision_config)
|
||||||
|
|
||||||
def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential:
|
def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential:
|
||||||
vit_hidden_size = config.vision_config.hidden_size
|
vit_hidden_size = config.vision_config.hidden_size
|
||||||
@ -562,6 +568,14 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
|
|
||||||
return image_embeds
|
return image_embeds
|
||||||
|
|
||||||
|
def _get_visual_token_mask(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
if self.is_mono:
|
||||||
|
visual_token_mask = (
|
||||||
|
input_ids == self.img_context_token_id).reshape(-1, 1)
|
||||||
|
else:
|
||||||
|
visual_token_mask = None
|
||||||
|
return visual_token_mask
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
@ -574,6 +588,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
if intermediate_tensors is not None:
|
if intermediate_tensors is not None:
|
||||||
input_ids = None
|
input_ids = None
|
||||||
inputs_embeds = None
|
inputs_embeds = None
|
||||||
|
visual_token_mask = None
|
||||||
else:
|
else:
|
||||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||||
if image_input is not None:
|
if image_input is not None:
|
||||||
@ -583,16 +598,24 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
inputs_embeds = merge_multimodal_embeddings(
|
inputs_embeds = merge_multimodal_embeddings(
|
||||||
input_ids, inputs_embeds, vision_embeddings,
|
input_ids, inputs_embeds, vision_embeddings,
|
||||||
self.img_context_token_id)
|
self.img_context_token_id)
|
||||||
|
visual_token_mask = self._get_visual_token_mask(input_ids)
|
||||||
input_ids = None
|
input_ids = None
|
||||||
else:
|
else:
|
||||||
inputs_embeds = None
|
inputs_embeds = None
|
||||||
|
visual_token_mask = None
|
||||||
|
|
||||||
hidden_states = self.language_model.model(input_ids,
|
forward_kwargs = {
|
||||||
positions,
|
"input_ids": input_ids,
|
||||||
kv_caches,
|
"positions": positions,
|
||||||
attn_metadata,
|
"kv_caches": kv_caches,
|
||||||
intermediate_tensors,
|
"attn_metadata": attn_metadata,
|
||||||
inputs_embeds=inputs_embeds)
|
"intermediate_tensors": intermediate_tensors,
|
||||||
|
"inputs_embeds": inputs_embeds,
|
||||||
|
}
|
||||||
|
if self.is_mono:
|
||||||
|
forward_kwargs.update({"visual_token_mask": visual_token_mask})
|
||||||
|
|
||||||
|
hidden_states = self.language_model.model(**forward_kwargs)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def compute_logits(
|
def compute_logits(
|
||||||
|
@ -47,6 +47,7 @@ _TEXT_GENERATION_MODELS = {
|
|||||||
"GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
|
"GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
|
||||||
"InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
|
"InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||||
"InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
|
"InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
|
||||||
|
"InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
|
||||||
"JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
|
"JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
|
||||||
"JambaForCausalLM": ("jamba", "JambaForCausalLM"),
|
"JambaForCausalLM": ("jamba", "JambaForCausalLM"),
|
||||||
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
|
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user