[VLM] Merged multi-modal processor for GLM4V (#12449)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
fe743b798d
commit
86222a3dab
@ -719,7 +719,7 @@ See [this page](#generative-models) for more information on how to use generativ
|
||||
* `THUDM/glm-4v-9b` etc.
|
||||
* ✅︎
|
||||
* ✅︎
|
||||
*
|
||||
* ✅︎
|
||||
- * `H2OVLChatModel`
|
||||
* H2OVL
|
||||
* T + I<sup>E+</sup>
|
||||
|
@ -106,7 +106,9 @@ def run_glm4v(question: str, modality: str):
|
||||
trust_remote_code=True,
|
||||
enforce_eager=True,
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
||||
prompt = question
|
||||
prompt = f"<|user|>\n<|begin_of_image|><|endoftext|><|end_of_image|>\
|
||||
{question}<|assistant|>"
|
||||
|
||||
stop_token_ids = [151329, 151336, 151338]
|
||||
return llm, prompt, stop_token_ids
|
||||
|
||||
|
@ -147,6 +147,7 @@ def _test_processing_correctness(
|
||||
"facebook/chameleon-7b",
|
||||
"deepseek-ai/deepseek-vl2-tiny",
|
||||
"adept/fuyu-8b",
|
||||
"THUDM/glm-4v-9b",
|
||||
"h2oai/h2ovl-mississippi-800m",
|
||||
"OpenGVLab/InternVL2-1B",
|
||||
"HuggingFaceM4/Idefics3-8B-Llama3",
|
||||
|
@ -4,20 +4,21 @@
|
||||
# https://github.com/THUDM/CogAgent
|
||||
"""Inference-only CogAgent model compatible with THUDM weights."""
|
||||
from argparse import Namespace
|
||||
from array import array
|
||||
from typing import (Dict, Iterable, List, Mapping, Optional, Set, Tuple,
|
||||
TypedDict)
|
||||
from typing import (Iterable, List, Mapping, Optional, Sequence, Set, Tuple,
|
||||
TypedDict, Union)
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch import nn
|
||||
from torch.nn import LayerNorm
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms import InterpolationMode
|
||||
from transformers import PreTrainedTokenizer, TensorType
|
||||
from transformers.image_utils import ImageInput
|
||||
from transformers.tokenization_utils_base import TextInput
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
||||
InputContext, token_inputs)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
@ -35,194 +36,233 @@ from vllm.model_executor.models.glm4_vision_encoder import EVA2CLIPModel
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (ModalityData, MultiModalKwargs,
|
||||
NestedTensors)
|
||||
from vllm.multimodal.utils import cached_get_tokenizer
|
||||
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
|
||||
SequenceData)
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
|
||||
from vllm.multimodal.parse import ImageSize, MultiModalDataItems
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, BatchFeature,
|
||||
BoundPromptReplacement,
|
||||
MultiModalFieldConfig,
|
||||
PlaceholderFeaturesInfo,
|
||||
PromptReplacement)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs import ChatGLMConfig
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
IMAGE_TOKEN_ID = 151329
|
||||
|
||||
|
||||
def build_normalization_transform(image_size: int) -> transforms.Compose:
|
||||
"""
|
||||
Build a normalization transform which can be applied to one or
|
||||
more input images from which we want to extract visual features.
|
||||
|
||||
Args:
|
||||
image_size: size of the image to be processed for visual embeddings.
|
||||
|
||||
Returns:
|
||||
Callable transform for normalizing and resizing one RGB image.
|
||||
"""
|
||||
|
||||
return transforms.Compose([
|
||||
transforms.Resize(
|
||||
(image_size, image_size),
|
||||
interpolation=InterpolationMode.BICUBIC,
|
||||
),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
(0.48145466, 0.4578275, 0.40821073),
|
||||
(0.26862954, 0.26130258, 0.27577711),
|
||||
),
|
||||
])
|
||||
|
||||
|
||||
def calculate_image_placeholder(vision_config):
|
||||
return (vision_config["image_size"] // vision_config["patch_size"] // 2)**2
|
||||
|
||||
|
||||
def mm_input_mapper_for_glmv(
|
||||
ctx: InputContext,
|
||||
data: ModalityData[object],
|
||||
) -> Dict:
|
||||
model_config = ctx.model_config
|
||||
tokenizer = cached_get_tokenizer(
|
||||
model_config.tokenizer,
|
||||
trust_remote_code=model_config.trust_remote_code)
|
||||
if tokenizer is None:
|
||||
raise RuntimeError("No HuggingFace processor is available "
|
||||
"to process the image object")
|
||||
try:
|
||||
raw_batch_data = tokenizer.apply_chat_template(
|
||||
conversation=[{
|
||||
"role": "user",
|
||||
"image": data
|
||||
}],
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_tensors="pt",
|
||||
return_dict=True).data
|
||||
except Exception:
|
||||
logger.error("Failed to process image (%s)", data)
|
||||
raise
|
||||
pixel_values = raw_batch_data['images']
|
||||
|
||||
return MultiModalKwargs({'pixel_values': pixel_values})
|
||||
|
||||
|
||||
def merge_glm_vision_embeddings(
|
||||
input_ids: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor,
|
||||
vision_embeddings: torch.Tensor,
|
||||
boi_token_id: int,
|
||||
eoi_token_id: int,
|
||||
) -> torch.Tensor:
|
||||
|
||||
boi_positions = (input_ids == boi_token_id).nonzero(as_tuple=True)[0]
|
||||
eoi_positions = (input_ids == eoi_token_id).nonzero(as_tuple=True)[0]
|
||||
|
||||
mask = torch.zeros_like(input_ids, dtype=torch.bool)
|
||||
|
||||
for boi_pos, eoi_pos in zip(boi_positions, eoi_positions):
|
||||
assert boi_pos < eoi_pos
|
||||
mask[boi_pos:eoi_pos + 1] = True
|
||||
inputs_embeds[mask] = vision_embeddings.view(-1,
|
||||
vision_embeddings.shape[-1])
|
||||
return inputs_embeds
|
||||
|
||||
|
||||
class GLMImagePixelInputs(TypedDict):
|
||||
pixel_values: torch.Tensor
|
||||
"""Shape: `(batch_size, num_channels, height, width)`"""
|
||||
|
||||
|
||||
def get_max_glmv_image_tokens(ctx: InputContext):
|
||||
hf_config = ctx.get_hf_config(ChatGLMConfig)
|
||||
class GLM4VProcessor:
|
||||
"""
|
||||
This model doesn't define its own HF processor,
|
||||
so we implement our own one here.
|
||||
|
||||
vision_config = getattr(hf_config, 'vision_config', None)
|
||||
if vision_config is None:
|
||||
return 1
|
||||
elif isinstance(vision_config, dict):
|
||||
return calculate_image_placeholder(vision_config)
|
||||
"""
|
||||
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
def __init__(
|
||||
self,
|
||||
config: ChatGLMConfig,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
if hasattr(self.config, "vision_config"):
|
||||
self.image_transform = build_normalization_transform(
|
||||
config.vision_config["image_size"])
|
||||
else:
|
||||
self.image_transform = None
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: Optional[Union[TextInput, list[TextInput]]] = None,
|
||||
images: Optional[Union[ImageInput, list[ImageInput]]] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
) -> BatchFeature:
|
||||
if text is None:
|
||||
text = []
|
||||
if not isinstance(text, list):
|
||||
text = [text]
|
||||
if images is None:
|
||||
images = []
|
||||
if not isinstance(images, list):
|
||||
images = [images]
|
||||
text_inputs = self.tokenizer(text)
|
||||
if len(images) == 0:
|
||||
image_inputs = {}
|
||||
else:
|
||||
if self.image_transform is None:
|
||||
raise ValueError("This model does not support image inputs")
|
||||
|
||||
pixel_values = [self.image_transform(image) for image in images]
|
||||
image_inputs = {"pixel_values": torch.stack(pixel_values)}
|
||||
|
||||
return BatchFeature(
|
||||
{
|
||||
**text_inputs,
|
||||
**image_inputs,
|
||||
},
|
||||
tensor_type=return_tensors,
|
||||
)
|
||||
|
||||
|
||||
def dummy_data_for_glmv(ctx: InputContext, seq_len: int,
|
||||
mm_counts: Mapping[str, int]) -> DummyData:
|
||||
hf_config = ctx.get_hf_config(ChatGLMConfig)
|
||||
vision_config = getattr(hf_config, 'vision_config', None)
|
||||
class GLM4VProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
if vision_config is None:
|
||||
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * seq_len)
|
||||
seq_data = SequenceData(token_ids)
|
||||
return DummyData(seq_data, None)
|
||||
elif isinstance(vision_config, dict):
|
||||
image_size = vision_config["image_size"]
|
||||
image_placeholder_length = calculate_image_placeholder(vision_config)
|
||||
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [hf_config.boi_token_id] +
|
||||
[0] * image_placeholder_length +
|
||||
[hf_config.eoi_token_id])
|
||||
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
[0] * (seq_len - image_placeholder_length - 2))
|
||||
seq_data = SequenceData(token_ids)
|
||||
def __init__(self, ctx):
|
||||
super().__init__(ctx)
|
||||
self._pre_calculate()
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": 1}
|
||||
|
||||
def get_mm_max_tokens_per_item(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> Mapping[str, int]:
|
||||
|
||||
return {"image": self.image_token_num + 2}
|
||||
|
||||
def _pre_calculate(self):
|
||||
hf_config = self.get_hf_config()
|
||||
vision_config = hf_config.vision_config
|
||||
self.image_token_num = calculate_image_placeholder(vision_config)
|
||||
self.image_size = vision_config["image_size"]
|
||||
|
||||
def get_num_image_tokens(self) -> int:
|
||||
return self.image_token_num + 2
|
||||
|
||||
def get_image_size(self) -> ImageSize:
|
||||
|
||||
return ImageSize(height=self.image_size, width=self.image_size)
|
||||
|
||||
def get_hf_processor(self) -> GLM4VProcessor:
|
||||
return GLM4VProcessor(
|
||||
self.get_hf_config(),
|
||||
self.get_tokenizer(),
|
||||
)
|
||||
|
||||
|
||||
class GLM4VDummyInputsBuilder(BaseDummyInputsBuilder[GLM4VProcessingInfo]):
|
||||
|
||||
def get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
target_width, target_height = self.info.get_image_size()
|
||||
|
||||
mm_data = {
|
||||
"image": Image.new("RGB", (image_size, image_size), color=0)
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images)
|
||||
}
|
||||
text = "<|begin_of_image|><|endoftext|><|end_of_image|>"
|
||||
return ProcessorInputs(
|
||||
prompt_text=text,
|
||||
mm_data=mm_data,
|
||||
)
|
||||
|
||||
|
||||
class GLM4VMultiModalProcessor(BaseMultiModalProcessor[GLM4VProcessingInfo]):
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(pixel_values=MultiModalFieldConfig.batched("image"))
|
||||
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
|
||||
def get_replacement(item_idx: int):
|
||||
image_tokens = self.info.image_token_num
|
||||
return [IMAGE_TOKEN_ID] * image_tokens
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=[IMAGE_TOKEN_ID],
|
||||
replacement=get_replacement,
|
||||
),
|
||||
]
|
||||
|
||||
def _apply_prompt_replacements(
|
||||
self,
|
||||
token_ids: list[int],
|
||||
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
|
||||
mm_item_counts: Mapping[str, int],
|
||||
) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
|
||||
token_ids, text, placeholders = super()._apply_prompt_replacements(
|
||||
token_ids=token_ids,
|
||||
mm_prompt_repls=mm_prompt_repls,
|
||||
mm_item_counts=mm_item_counts,
|
||||
)
|
||||
hf_config = self.info.get_hf_config()
|
||||
boi_token_id = hf_config.boi_token_id
|
||||
eoi_token_id = hf_config.eoi_token_id
|
||||
placeholders = {
|
||||
modality: [
|
||||
PlaceholderFeaturesInfo(
|
||||
modality=p.modality,
|
||||
item_idx=p.item_idx,
|
||||
start_idx=p.start_idx - 1,
|
||||
tokens=[boi_token_id] + p.tokens + [eoi_token_id],
|
||||
) for p in ps
|
||||
]
|
||||
for modality, ps in placeholders.items()
|
||||
}
|
||||
|
||||
return DummyData(seq_data, mm_data)
|
||||
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
def find_all_positions(input_ids: List[int], target: int) -> List[int]:
|
||||
return [index for index, value in enumerate(input_ids) if value == target]
|
||||
|
||||
|
||||
def input_processor_for_glmv(ctx: InputContext, inputs: DecoderOnlyInputs):
|
||||
multi_modal_data = inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||
return inputs
|
||||
|
||||
hf_config = ctx.get_hf_config(ChatGLMConfig)
|
||||
vision_config = getattr(hf_config, 'vision_config', None)
|
||||
|
||||
if vision_config is None:
|
||||
return inputs
|
||||
elif isinstance(vision_config, dict):
|
||||
image_placeholder_length = calculate_image_placeholder(vision_config)
|
||||
else:
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
input_ids = inputs["prompt_token_ids"]
|
||||
|
||||
tokenizer = cached_get_tokenizer(
|
||||
ctx.model_config.model,
|
||||
trust_remote_code=ctx.model_config.trust_remote_code)
|
||||
|
||||
try:
|
||||
raw_batch_data = tokenizer.apply_chat_template(
|
||||
conversation=[{
|
||||
"role": "user",
|
||||
"image": multi_modal_data["image"],
|
||||
"content": inputs['prompt'],
|
||||
}],
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_tensors="pt",
|
||||
return_dict=True,
|
||||
).data
|
||||
except Exception:
|
||||
logger.error("Failed to process content (%s)", inputs['prompt'])
|
||||
raise
|
||||
input_ids = raw_batch_data['input_ids'][0].tolist()
|
||||
|
||||
boi_token_id = hf_config.boi_token_id
|
||||
eoi_token_id = hf_config.eoi_token_id
|
||||
boi_positions = find_all_positions(input_ids, boi_token_id)
|
||||
eoi_positions = find_all_positions(input_ids, eoi_token_id)
|
||||
|
||||
assert len(boi_positions) == len(eoi_positions)
|
||||
|
||||
new_input_ids = []
|
||||
final_processed_position = 0
|
||||
|
||||
for boi_position, eoi_position in zip(boi_positions, eoi_positions):
|
||||
assert boi_position < eoi_position
|
||||
new_input_ids.extend(input_ids[final_processed_position:boi_position +
|
||||
1])
|
||||
new_input_ids.extend([input_ids[boi_position + 1]] *
|
||||
image_placeholder_length)
|
||||
final_processed_position = eoi_position
|
||||
|
||||
new_input_ids.extend(input_ids[final_processed_position:])
|
||||
|
||||
prompt = inputs.get("prompt")
|
||||
if prompt is None:
|
||||
prompt = tokenizer.decode(new_input_ids)
|
||||
|
||||
return token_inputs(
|
||||
prompt_token_ids=new_input_ids,
|
||||
prompt=prompt,
|
||||
multi_modal_data=multi_modal_data,
|
||||
)
|
||||
return token_ids, text, placeholders
|
||||
|
||||
|
||||
class GLMAttention(nn.Module):
|
||||
@ -572,12 +612,16 @@ class ChatGLMModel(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.embedding(input_ids)
|
||||
if multimodal_embeddings is not None:
|
||||
inputs_embeds = merge_glm_vision_embeddings(
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids=input_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
vision_embeddings=multimodal_embeddings,
|
||||
boi_token_id=self.config.boi_token_id,
|
||||
eoi_token_id=self.config.eoi_token_id)
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
placeholder_token_id=[
|
||||
self.config.boi_token_id,
|
||||
IMAGE_TOKEN_ID,
|
||||
self.config.eoi_token_id,
|
||||
],
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
@ -593,14 +637,12 @@ class ChatGLMModel(nn.Module):
|
||||
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||
# condition is for v0 compatibility.
|
||||
if intermediate_tensors is None and inputs_embeds is None:
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = intermediate_tensors["hidden_states"]
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
input_ids = None
|
||||
else:
|
||||
inputs_embeds = intermediate_tensors["hidden_states"]
|
||||
|
||||
# Run encoder.
|
||||
hidden_states = self.encoder(
|
||||
hidden_states=inputs_embeds,
|
||||
@ -763,11 +805,21 @@ class ChatGLMV(ChatGLMBaseModel, SupportsMultiModal):
|
||||
connector="transformer.vision.linear_proj",
|
||||
tower_model="transformer.vision.transformer")
|
||||
|
||||
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
|
||||
return self.transformer.get_multimodal_embeddings(**kwargs)
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper(mm_input_mapper_for_glmv)
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_glmv_image_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_glmv)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_glmv)
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[NestedTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
return self.transformer.get_input_embeddings(input_ids,
|
||||
multimodal_embeddings)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(GLM4VMultiModalProcessor,
|
||||
info=GLM4VProcessingInfo,
|
||||
dummy_inputs=GLM4VDummyInputsBuilder)
|
||||
class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
|
||||
SupportsMultiModal):
|
||||
# Ensure that the LoRA support check passes when the class is not
|
||||
|
Loading…
x
Reference in New Issue
Block a user