[VLM] Clean up Phi-4-MM ViT implementation (#14812)
Signed-off-by: Isotr0py <2037008807@qq.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
parent
3453b964a3
commit
def232e122
@ -8,6 +8,7 @@ pytest-shard
|
|||||||
|
|
||||||
# testing utils
|
# testing utils
|
||||||
awscli
|
awscli
|
||||||
|
backoff # required for phi4mm test
|
||||||
decord # required for video tests
|
decord # required for video tests
|
||||||
einops # required for MPT, qwen-vl and Mamba
|
einops # required for MPT, qwen-vl and Mamba
|
||||||
httpx
|
httpx
|
||||||
|
@ -33,6 +33,8 @@ audioread==3.0.1
|
|||||||
# via librosa
|
# via librosa
|
||||||
awscli==1.35.23
|
awscli==1.35.23
|
||||||
# via -r requirements/test.in
|
# via -r requirements/test.in
|
||||||
|
backoff==2.2.1
|
||||||
|
# via -r requirements/test.in
|
||||||
bitsandbytes==0.45.3
|
bitsandbytes==0.45.3
|
||||||
# via -r requirements/test.in
|
# via -r requirements/test.in
|
||||||
black==24.10.0
|
black==24.10.0
|
||||||
|
229
tests/models/decoder_only/vision_language/test_phi4mm.py
Normal file
229
tests/models/decoder_only/vision_language/test_phi4mm.py
Normal file
@ -0,0 +1,229 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
|
from vllm.multimodal.image import rescale_image_size
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.sequence import SampleLogprobs
|
||||||
|
|
||||||
|
from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner
|
||||||
|
from ....utils import large_gpu_test
|
||||||
|
from ...utils import check_logprobs_close
|
||||||
|
|
||||||
|
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
|
||||||
|
"stop_sign":
|
||||||
|
"<|user|>\n<|image_1|>\nWhat's the content of the image?<|end|>\n<|assistant|>\n", # noqa: E501
|
||||||
|
"cherry_blossom":
|
||||||
|
"<|user|>\n<|image_1|>\nPlease infer the season with reason in details.<|end|>\n<|assistant|>\n", # noqa: E501
|
||||||
|
})
|
||||||
|
HF_MULTIIMAGE_IMAGE_PROMPT = "<|user|>\n<|image_1|>\n<|image_2|>\nDescribe these images.<|end|>\n<|assistant|>\n" # noqa: E501
|
||||||
|
|
||||||
|
model_path = snapshot_download("microsoft/Phi-4-multimodal-instruct")
|
||||||
|
# Since the vision-lora and speech-lora co-exist with the base model,
|
||||||
|
# we have to manually specify the path of the lora weights.
|
||||||
|
vision_lora_path = os.path.join(model_path, "vision-lora")
|
||||||
|
models = [model_path]
|
||||||
|
|
||||||
|
|
||||||
|
def vllm_to_hf_output(vllm_output: tuple[list[int], str,
|
||||||
|
Optional[SampleLogprobs]],
|
||||||
|
model: str):
|
||||||
|
"""Sanitize vllm output to be comparable with hf output."""
|
||||||
|
_, output_str, out_logprobs = vllm_output
|
||||||
|
|
||||||
|
output_str_without_image = re.sub(r"(<\|image_\d+\|>)+", "", output_str)
|
||||||
|
assert output_str_without_image[0] == " "
|
||||||
|
output_str_without_image = output_str_without_image[1:]
|
||||||
|
|
||||||
|
hf_output_str = output_str_without_image + "<|end|><|endoftext|>"
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model)
|
||||||
|
hf_output_ids = tokenizer.encode(output_str_without_image)
|
||||||
|
assert hf_output_ids[0] == 1
|
||||||
|
hf_output_ids = hf_output_ids[1:]
|
||||||
|
|
||||||
|
return hf_output_ids, hf_output_str, out_logprobs
|
||||||
|
|
||||||
|
|
||||||
|
target_dtype = "half"
|
||||||
|
|
||||||
|
# ROCm Triton FA can run into shared memory issues with these models,
|
||||||
|
# use other backends in the meantime
|
||||||
|
# FIXME (mattwong, gshtrasb, hongxiayan)
|
||||||
|
if current_platform.is_rocm():
|
||||||
|
os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0"
|
||||||
|
|
||||||
|
|
||||||
|
def run_test(
|
||||||
|
hf_runner: type[HfRunner],
|
||||||
|
vllm_runner: type[VllmRunner],
|
||||||
|
inputs: list[tuple[list[str], PromptImageInput]],
|
||||||
|
model: str,
|
||||||
|
*,
|
||||||
|
max_model_len: int,
|
||||||
|
dtype: str,
|
||||||
|
max_tokens: int,
|
||||||
|
num_logprobs: int,
|
||||||
|
mm_limit: int,
|
||||||
|
tensor_parallel_size: int,
|
||||||
|
distributed_executor_backend: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""Inference result should be the same between hf and vllm.
|
||||||
|
|
||||||
|
All the image fixtures for the test are from IMAGE_ASSETS.
|
||||||
|
For huggingface runner, we provide the PIL images as input.
|
||||||
|
For vllm runner, we provide MultiModalDataDict objects
|
||||||
|
and corresponding MultiModalConfig as input.
|
||||||
|
Note, the text input is also adjusted to abide by vllm contract.
|
||||||
|
The text output is sanitized to be able to compare with hf.
|
||||||
|
"""
|
||||||
|
# NOTE: take care of the order. run vLLM first, and then run HF.
|
||||||
|
# vLLM needs a fresh new process without cuda initialization.
|
||||||
|
# if we run HF first, the cuda initialization will be done and it
|
||||||
|
# will hurt multiprocessing backend with fork method (the default method).
|
||||||
|
# max_model_len should be greater than image_feature_size
|
||||||
|
with vllm_runner(
|
||||||
|
model,
|
||||||
|
task="generate",
|
||||||
|
max_model_len=max_model_len,
|
||||||
|
max_num_seqs=2,
|
||||||
|
dtype=dtype,
|
||||||
|
limit_mm_per_prompt={"image": mm_limit},
|
||||||
|
tensor_parallel_size=tensor_parallel_size,
|
||||||
|
distributed_executor_backend=distributed_executor_backend,
|
||||||
|
enable_lora=True,
|
||||||
|
max_lora_rank=320,
|
||||||
|
lora_extra_vocab_size=0,
|
||||||
|
gpu_memory_utilization=0.8, # set to 0.8 to avoid OOM in CI
|
||||||
|
enforce_eager=True,
|
||||||
|
) as vllm_model:
|
||||||
|
lora_request = LoRARequest("vision", 1, vision_lora_path)
|
||||||
|
vllm_model.model.llm_engine.add_lora(lora_request=lora_request)
|
||||||
|
vllm_outputs_per_case = [
|
||||||
|
vllm_model.generate_greedy_logprobs(prompts,
|
||||||
|
max_tokens,
|
||||||
|
num_logprobs=num_logprobs,
|
||||||
|
images=images)
|
||||||
|
for prompts, images in inputs
|
||||||
|
]
|
||||||
|
|
||||||
|
# use eager mode for hf runner, since phi3_v didn't work with flash_attn
|
||||||
|
hf_model_kwargs = {"_attn_implementation": "eager"}
|
||||||
|
with hf_runner(model, dtype=dtype,
|
||||||
|
model_kwargs=hf_model_kwargs) as hf_model:
|
||||||
|
eos_token_id = hf_model.processor.tokenizer.eos_token_id
|
||||||
|
hf_outputs_per_case = [
|
||||||
|
hf_model.generate_greedy_logprobs_limit(prompts,
|
||||||
|
max_tokens,
|
||||||
|
num_logprobs=num_logprobs,
|
||||||
|
images=images,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
num_logits_to_keep=0)
|
||||||
|
for prompts, images in inputs
|
||||||
|
]
|
||||||
|
|
||||||
|
for hf_outputs, vllm_outputs in zip(hf_outputs_per_case,
|
||||||
|
vllm_outputs_per_case):
|
||||||
|
check_logprobs_close(
|
||||||
|
outputs_0_lst=hf_outputs,
|
||||||
|
outputs_1_lst=vllm_outputs,
|
||||||
|
name_0="hf",
|
||||||
|
name_1="vllm",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Since we use _attn_implementation="eager" for hf_runner, there is more
|
||||||
|
# significant numerical difference. The basic `logprobs=5` fails to pass.
|
||||||
|
@pytest.mark.parametrize("model", models)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"size_factors",
|
||||||
|
[
|
||||||
|
# No image
|
||||||
|
[],
|
||||||
|
# Single-scale
|
||||||
|
[1.0],
|
||||||
|
# Single-scale, batched
|
||||||
|
[1.0, 1.0, 1.0],
|
||||||
|
# Multi-scale
|
||||||
|
[0.7, 0.75, 1.0],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("dtype", [target_dtype])
|
||||||
|
@pytest.mark.parametrize("max_model_len", [4096])
|
||||||
|
@pytest.mark.parametrize("max_tokens", [128])
|
||||||
|
@pytest.mark.parametrize("num_logprobs", [10])
|
||||||
|
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
|
||||||
|
dtype: str, max_model_len: int, max_tokens: int,
|
||||||
|
num_logprobs: int) -> None:
|
||||||
|
images = [asset.pil_image for asset in image_assets]
|
||||||
|
|
||||||
|
inputs_per_image = [(
|
||||||
|
[prompt for _ in size_factors],
|
||||||
|
[rescale_image_size(image, factor) for factor in size_factors],
|
||||||
|
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
|
||||||
|
|
||||||
|
run_test(
|
||||||
|
hf_runner,
|
||||||
|
vllm_runner,
|
||||||
|
inputs_per_image,
|
||||||
|
model,
|
||||||
|
dtype=dtype,
|
||||||
|
max_model_len=max_model_len,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
num_logprobs=num_logprobs,
|
||||||
|
mm_limit=1,
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@large_gpu_test(min_gb=48)
|
||||||
|
@pytest.mark.parametrize("model", models)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"size_factors",
|
||||||
|
[
|
||||||
|
# No image
|
||||||
|
# [],
|
||||||
|
# Single-scale
|
||||||
|
[1.0],
|
||||||
|
# Single-scale, batched
|
||||||
|
[1.0, 1.0, 1.0],
|
||||||
|
# Multi-scale
|
||||||
|
[0.25, 0.5, 1.0],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("dtype", [target_dtype])
|
||||||
|
@pytest.mark.parametrize("max_model_len", [10000])
|
||||||
|
@pytest.mark.parametrize("max_tokens", [128])
|
||||||
|
@pytest.mark.parametrize("num_logprobs", [10])
|
||||||
|
@pytest.mark.xfail(
|
||||||
|
reason="Phi-4-MM multi-image inference is divergent with hf model.")
|
||||||
|
def test_multi_images_models(hf_runner, vllm_runner, image_assets, model,
|
||||||
|
size_factors, dtype: str, max_model_len: int,
|
||||||
|
max_tokens: int, num_logprobs: int) -> None:
|
||||||
|
images = [asset.pil_image for asset in image_assets]
|
||||||
|
|
||||||
|
inputs_per_case = [
|
||||||
|
([HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors],
|
||||||
|
[[rescale_image_size(image, factor) for image in images]
|
||||||
|
for factor in size_factors])
|
||||||
|
]
|
||||||
|
|
||||||
|
run_test(
|
||||||
|
hf_runner,
|
||||||
|
vllm_runner,
|
||||||
|
inputs_per_case,
|
||||||
|
model,
|
||||||
|
dtype=dtype,
|
||||||
|
max_model_len=max_model_len,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
num_logprobs=num_logprobs,
|
||||||
|
mm_limit=2,
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
)
|
@ -60,7 +60,7 @@ class AriaVisionTransformer(Idefics3VisionTransformer, SupportsQuant):
|
|||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(config, quant_config, prefix)
|
super().__init__(config, quant_config=quant_config, prefix=prefix)
|
||||||
# Unlike Idefics3VisionTransformer which uses LayerNorm after the
|
# Unlike Idefics3VisionTransformer which uses LayerNorm after the
|
||||||
# final layer, Aria omits this normalization, so we replace it with an
|
# final layer, Aria omits this normalization, so we replace it with an
|
||||||
# Identity layer
|
# Identity layer
|
||||||
@ -512,7 +512,7 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.vision_tower = AriaVisionTransformer(
|
self.vision_tower = AriaVisionTransformer(
|
||||||
config.vision_config,
|
config.vision_config,
|
||||||
quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.vision_tower",
|
prefix=f"{prefix}.vision_tower",
|
||||||
)
|
)
|
||||||
self.multi_modal_projector = AriaProjector(config)
|
self.multi_modal_projector = AriaProjector(config)
|
||||||
|
@ -113,7 +113,7 @@ class Idefics2VisionAttention(nn.Module):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: Idefics2Config,
|
config: Idefics2VisionConfig,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -164,7 +164,7 @@ class Idefics2VisionMLP(nn.Module):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: Idefics2Config,
|
config: Idefics2VisionConfig,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -249,16 +249,24 @@ class Idefics2Encoder(nn.Module):
|
|||||||
self,
|
self,
|
||||||
config: Idefics2Config,
|
config: Idefics2Config,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
*,
|
||||||
|
num_hidden_layers_override: Optional[int] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
|
if num_hidden_layers_override is None:
|
||||||
|
num_hidden_layers = config.num_hidden_layers
|
||||||
|
else:
|
||||||
|
num_hidden_layers = num_hidden_layers_override
|
||||||
|
|
||||||
self.layers = nn.ModuleList([
|
self.layers = nn.ModuleList([
|
||||||
Idefics2EncoderLayer(config,
|
Idefics2EncoderLayer(config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.layers.{layer_idx}")
|
prefix=f"{prefix}.layers.{layer_idx}")
|
||||||
for layer_idx in range(config.num_hidden_layers)
|
for layer_idx in range(num_hidden_layers)
|
||||||
])
|
])
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -287,6 +295,9 @@ class Idefics2VisionTransformer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
config: Idefics2VisionConfig,
|
config: Idefics2VisionConfig,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
*,
|
||||||
|
num_hidden_layers_override: Optional[int] = None,
|
||||||
|
require_post_norm: bool = True,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -294,11 +305,24 @@ class Idefics2VisionTransformer(nn.Module):
|
|||||||
embed_dim = config.hidden_size
|
embed_dim = config.hidden_size
|
||||||
self.config = config
|
self.config = config
|
||||||
self.embeddings = Idefics2VisionEmbeddings(config)
|
self.embeddings = Idefics2VisionEmbeddings(config)
|
||||||
self.encoder = Idefics2Encoder(config,
|
self.encoder = Idefics2Encoder(
|
||||||
|
config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
num_hidden_layers_override=num_hidden_layers_override,
|
||||||
prefix=f"{prefix}.encoder")
|
prefix=f"{prefix}.encoder")
|
||||||
self.post_layernorm = nn.LayerNorm(embed_dim,
|
|
||||||
eps=config.layer_norm_eps)
|
num_hidden_layers = config.num_hidden_layers
|
||||||
|
if len(self.encoder.layers) > config.num_hidden_layers:
|
||||||
|
raise ValueError(
|
||||||
|
f"The original encoder only has {num_hidden_layers} "
|
||||||
|
f"layers, but you requested {len(self.encoder.layers)} layers."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.require_post_norm = require_post_norm
|
||||||
|
self.post_layernorm = nn.LayerNorm(
|
||||||
|
embed_dim,
|
||||||
|
eps=config.layer_norm_eps,
|
||||||
|
) if require_post_norm else nn.Identity()
|
||||||
|
|
||||||
def get_input_embeddings(self):
|
def get_input_embeddings(self):
|
||||||
return self.embeddings
|
return self.embeddings
|
||||||
@ -328,7 +352,24 @@ class Idefics2VisionTransformer(nn.Module):
|
|||||||
]
|
]
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
loaded_params: Set[str] = set()
|
loaded_params: Set[str] = set()
|
||||||
|
layer_count = len(self.encoder.layers)
|
||||||
|
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
|
# skip pooling header
|
||||||
|
if name.startswith("head."):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# post_layernorm is optional
|
||||||
|
if (name.startswith("post_layernorm.")
|
||||||
|
and not self.require_post_norm):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# omit layers when num_hidden_layers_override is set
|
||||||
|
if name.startswith("encoder.layers."):
|
||||||
|
layer_idx = int(name.split(".")[2])
|
||||||
|
if layer_idx >= layer_count:
|
||||||
|
continue
|
||||||
|
|
||||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
continue
|
continue
|
||||||
|
@ -11,7 +11,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torchvision.transforms as T
|
import torchvision.transforms as T
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig, SiglipVisionConfig
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
@ -32,10 +32,10 @@ from vllm.multimodal.inputs import MultiModalInputs, NestedTensors
|
|||||||
from vllm.sequence import IntermediateTensors, SequenceData
|
from vllm.sequence import IntermediateTensors, SequenceData
|
||||||
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
||||||
|
|
||||||
|
from .idefics2_vision_model import Idefics2VisionTransformer
|
||||||
from .interfaces import SupportsLoRA, SupportsMultiModal
|
from .interfaces import SupportsLoRA, SupportsMultiModal
|
||||||
from .phi4mm_audio import AudioEmbedding
|
from .phi4mm_audio import AudioEmbedding
|
||||||
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
|
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
|
||||||
from .vision_siglip_navit import get_siglip_vision_model
|
|
||||||
|
|
||||||
# <|endoftext10|> (see vocab.json in hf model)
|
# <|endoftext10|> (see vocab.json in hf model)
|
||||||
_IMAGE_PLACEHOLDER_TOKEN_ID = 200010
|
_IMAGE_PLACEHOLDER_TOKEN_ID = 200010
|
||||||
@ -339,6 +339,33 @@ def preprocess(images, dynamic_hd_size, vit_resolution, vit_patch_size):
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def get_navit_vision_model(layer_idx: int = -1, **kwargs):
|
||||||
|
vision_config = {
|
||||||
|
"hidden_size": 1152,
|
||||||
|
"image_size": 448,
|
||||||
|
"intermediate_size": 4304,
|
||||||
|
"model_type": "siglip_vision_model",
|
||||||
|
"num_attention_heads": 16,
|
||||||
|
"num_hidden_layers": 27,
|
||||||
|
"patch_size": 14,
|
||||||
|
}
|
||||||
|
|
||||||
|
model_config = SiglipVisionConfig(**vision_config, **kwargs)
|
||||||
|
if layer_idx < 0:
|
||||||
|
num_hidden_layers = model_config.num_hidden_layers \
|
||||||
|
+ layer_idx + 1
|
||||||
|
else:
|
||||||
|
num_hidden_layers = layer_idx + 1
|
||||||
|
|
||||||
|
vision_model = Idefics2VisionTransformer(
|
||||||
|
config=model_config,
|
||||||
|
require_post_norm=False,
|
||||||
|
num_hidden_layers_override=num_hidden_layers,
|
||||||
|
)
|
||||||
|
|
||||||
|
return vision_model
|
||||||
|
|
||||||
|
|
||||||
class Phi4MMImageEncoder(nn.Module):
|
class Phi4MMImageEncoder(nn.Module):
|
||||||
"""Image embedding."""
|
"""Image embedding."""
|
||||||
|
|
||||||
@ -362,8 +389,7 @@ class Phi4MMImageEncoder(nn.Module):
|
|||||||
self.layer_idx = -2
|
self.layer_idx = -2
|
||||||
self.type_feature = 'patch'
|
self.type_feature = 'patch'
|
||||||
|
|
||||||
self.img_processor = get_siglip_vision_model(
|
self.img_processor = get_navit_vision_model(layer_idx=self.layer_idx)
|
||||||
_flash_attn_2_enabled=True)
|
|
||||||
|
|
||||||
pe_weight = self.img_processor.embeddings.position_embedding.weight
|
pe_weight = self.img_processor.embeddings.position_embedding.weight
|
||||||
L, D = pe_weight.size()
|
L, D = pe_weight.size()
|
||||||
@ -430,16 +456,11 @@ class Phi4MMImageEncoder(nn.Module):
|
|||||||
def get_img_features(self,
|
def get_img_features(self,
|
||||||
img_embeds: torch.FloatTensor,
|
img_embeds: torch.FloatTensor,
|
||||||
attention_mask=None) -> torch.FloatTensor:
|
attention_mask=None) -> torch.FloatTensor:
|
||||||
LAYER_IDX = self.layer_idx
|
|
||||||
TYPE_FEATURE = self.type_feature
|
|
||||||
|
|
||||||
img_processor_output = self.img_processor(
|
img_feature = self.img_processor(img_embeds,
|
||||||
img_embeds,
|
|
||||||
output_hidden_states=True,
|
|
||||||
patch_attention_mask=attention_mask)
|
patch_attention_mask=attention_mask)
|
||||||
img_feature = img_processor_output.hidden_states[LAYER_IDX]
|
|
||||||
|
|
||||||
if TYPE_FEATURE == "patch":
|
if self.type_feature == "patch":
|
||||||
patch_feature = img_feature
|
patch_feature = img_feature
|
||||||
|
|
||||||
use_token_compression = self.image_token_compression is not None
|
use_token_compression = self.image_token_compression is not None
|
||||||
|
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user