[VLM][Model] Add test for InternViT vision encoder (#7409)
This commit is contained in:
parent
398521ad19
commit
aae6927be0
@ -24,7 +24,9 @@ from vllm.assets.image import ImageAsset
|
|||||||
from vllm.config import TokenizerPoolConfig
|
from vllm.config import TokenizerPoolConfig
|
||||||
from vllm.connections import global_http_connection
|
from vllm.connections import global_http_connection
|
||||||
from vllm.distributed import (destroy_distributed_environment,
|
from vllm.distributed import (destroy_distributed_environment,
|
||||||
destroy_model_parallel)
|
destroy_model_parallel,
|
||||||
|
init_distributed_environment,
|
||||||
|
initialize_model_parallel)
|
||||||
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
|
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
|
||||||
to_enc_dec_tuple_list, zip_enc_dec_prompts)
|
to_enc_dec_tuple_list, zip_enc_dec_prompts)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -90,6 +92,21 @@ def init_test_http_connection():
|
|||||||
global_http_connection.reuse_client = False
|
global_http_connection.reuse_client = False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def dist_init():
|
||||||
|
temp_file = tempfile.mkstemp()[1]
|
||||||
|
init_distributed_environment(
|
||||||
|
world_size=1,
|
||||||
|
rank=0,
|
||||||
|
distributed_init_method=f"file://{temp_file}",
|
||||||
|
local_rank=0,
|
||||||
|
backend="nccl",
|
||||||
|
)
|
||||||
|
initialize_model_parallel(1, 1)
|
||||||
|
yield
|
||||||
|
cleanup()
|
||||||
|
|
||||||
|
|
||||||
def cleanup():
|
def cleanup():
|
||||||
destroy_model_parallel()
|
destroy_model_parallel()
|
||||||
destroy_distributed_environment()
|
destroy_distributed_environment()
|
||||||
|
80
tests/models/test_intern_vit.py
Normal file
80
tests/models/test_intern_vit.py
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
from transformers import AutoConfig, AutoModel, CLIPImageProcessor
|
||||||
|
|
||||||
|
from vllm.model_executor.models.intern_vit import InternVisionModel
|
||||||
|
|
||||||
|
from ..conftest import _ImageAssets, cleanup
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.vlm
|
||||||
|
|
||||||
|
# we use snapshot_download to prevent conflicts between
|
||||||
|
# dynamic_module and trust_remote_code for hf_runner
|
||||||
|
DOWNLOAD_PATTERN = ["*.json", "*.py", "*.safetensors", "*.txt", "*.model"]
|
||||||
|
models = [
|
||||||
|
snapshot_download("OpenGVLab/InternViT-300M-448px",
|
||||||
|
allow_patterns=DOWNLOAD_PATTERN),
|
||||||
|
snapshot_download("OpenGVLab/InternViT-6B-448px-V1-5",
|
||||||
|
allow_patterns=DOWNLOAD_PATTERN),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def run_intern_vit_test(
|
||||||
|
image_assets: _ImageAssets,
|
||||||
|
model: str,
|
||||||
|
*,
|
||||||
|
dtype: str,
|
||||||
|
distributed_executor_backend: Optional[str] = None,
|
||||||
|
):
|
||||||
|
img_processor = CLIPImageProcessor.from_pretrained(model)
|
||||||
|
images = [asset.pil_image for asset in image_assets]
|
||||||
|
pixel_values = [
|
||||||
|
img_processor(images, return_tensors='pt').pixel_values.to(dtype)
|
||||||
|
for images in images
|
||||||
|
]
|
||||||
|
|
||||||
|
config = AutoConfig.from_pretrained(model, trust_remote_code=True)
|
||||||
|
if not getattr(config, "norm_type", None):
|
||||||
|
config.norm_type = "rms_norm"
|
||||||
|
|
||||||
|
hf_model = AutoModel.from_pretrained(model,
|
||||||
|
torch_dtype=dtype,
|
||||||
|
trust_remote_code=True).to("cuda")
|
||||||
|
hf_outputs_per_image = [
|
||||||
|
hf_model(pixel_value.to("cuda")).last_hidden_state
|
||||||
|
for pixel_value in pixel_values
|
||||||
|
]
|
||||||
|
|
||||||
|
vllm_model = InternVisionModel(config)
|
||||||
|
vllm_model.load_weights(hf_model.state_dict().items())
|
||||||
|
|
||||||
|
del hf_model
|
||||||
|
cleanup()
|
||||||
|
|
||||||
|
vllm_model = vllm_model.to("cuda", dtype)
|
||||||
|
vllm_outputs_per_image = [
|
||||||
|
vllm_model(pixel_values=pixel_value.to("cuda"))
|
||||||
|
for pixel_value in pixel_values
|
||||||
|
]
|
||||||
|
del vllm_model
|
||||||
|
cleanup()
|
||||||
|
|
||||||
|
cos_similar = nn.CosineSimilarity(dim=-1)
|
||||||
|
for vllm_output, hf_output in zip(vllm_outputs_per_image,
|
||||||
|
hf_outputs_per_image):
|
||||||
|
assert cos_similar(vllm_output, hf_output).mean() > 0.99
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", models)
|
||||||
|
@pytest.mark.parametrize("dtype", [torch.half])
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_models(dist_init, image_assets, model, dtype: str) -> None:
|
||||||
|
run_intern_vit_test(
|
||||||
|
image_assets,
|
||||||
|
model,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
Loading…
x
Reference in New Issue
Block a user