[Core] Update dtype detection and defaults (#14858)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-03-19 13:49:33 +08:00 committed by GitHub
parent 8b3e94a357
commit f690372b68
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 175 additions and 227 deletions

View File

@ -60,7 +60,7 @@ class TestSetting:
# embedding model
TestSetting(
model="BAAI/bge-multilingual-gemma2",
model_args=["--task", "embed"],
model_args=["--task", "embed", "--dtype", "bfloat16"],
pp_size=1,
tp_size=1,
attn_backend="FLASH_ATTN",

View File

@ -14,8 +14,8 @@ import torch.nn as nn
import torch.nn.functional as F
from huggingface_hub import snapshot_download
from PIL import Image
from transformers import (AutoModelForCausalLM, AutoTokenizer, BatchEncoding,
BatchFeature)
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
BatchEncoding, BatchFeature)
from transformers.models.auto.auto_factory import _BaseAutoModelClass
from tests.models.utils import (TokensTextLogprobs,
@ -23,7 +23,7 @@ from tests.models.utils import (TokensTextLogprobs,
from vllm import LLM, SamplingParams
from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset
from vllm.config import TaskOption, TokenizerPoolConfig
from vllm.config import TaskOption, TokenizerPoolConfig, _get_and_verify_dtype
from vllm.connections import global_http_connection
from vllm.distributed import (cleanup_dist_env_and_memory,
init_distributed_environment,
@ -34,8 +34,7 @@ from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
identity, is_list_of)
from vllm.utils import cuda_device_count_stateless, is_list_of
logger = init_logger(__name__)
@ -271,14 +270,18 @@ _R = TypeVar("_R")
class HfRunner:
def wrap_device(self, x: _T, device: Optional[str] = None) -> _T:
def get_default_device(self):
from vllm.platforms import current_platform
return ("cpu" if current_platform.is_cpu()
or current_platform.is_openvino() else "cuda")
def wrap_device(self, x: _T, device: Optional[str] = None) -> _T:
if x is None or isinstance(x, (bool, )):
return x
if device is None:
device = "cpu" if current_platform.is_cpu(
) or current_platform.is_openvino() else "cuda"
device = self.device
if isinstance(x, dict):
return {k: self.wrap_device(v, device) for k, v in x.items()}
@ -291,45 +294,59 @@ class HfRunner:
def __init__(
self,
model_name: str,
dtype: str = "half",
dtype: str = "auto",
*,
model_kwargs: Optional[dict[str, Any]] = None,
is_sentence_transformer: bool = False,
is_cross_encoder: bool = False,
skip_tokenizer_init: bool = False,
auto_cls: type[_BaseAutoModelClass] = AutoModelForCausalLM,
postprocess_inputs: Callable[..., BatchEncoding] = identity,
) -> None:
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
self.model_name = model_name
self.config = AutoConfig.from_pretrained(
model_name,
trust_remote_code=True,
)
self.device = self.get_default_device()
self.dtype = torch_dtype = _get_and_verify_dtype(self.config, dtype)
model_kwargs = model_kwargs if model_kwargs is not None else {}
model_kwargs.setdefault("torch_dtype", torch_dtype)
if is_sentence_transformer:
# Lazy init required for AMD CI
from sentence_transformers import SentenceTransformer
self.model = self.wrap_device(
SentenceTransformer(
model_name,
device="cpu",
trust_remote_code=True,
).to(dtype=torch_dtype))
self.model = SentenceTransformer(
model_name,
device=self.device,
model_kwargs=model_kwargs,
trust_remote_code=True,
)
elif is_cross_encoder:
# Lazy init required for AMD CI
from sentence_transformers import CrossEncoder
self.model = CrossEncoder(model_name,
device="cpu",
trust_remote_code=True)
self.model.model = self.wrap_device(self.model.model)\
.to(dtype=torch_dtype)
self.model = CrossEncoder(
model_name,
device=self.device,
automodel_args=model_kwargs,
trust_remote_code=True,
)
else:
model_kwargs = model_kwargs if model_kwargs is not None else {}
self.model = self.wrap_device(
auto_cls.from_pretrained(
model_name,
torch_dtype=torch_dtype,
trust_remote_code=True,
**model_kwargs,
))
model = auto_cls.from_pretrained(
model_name,
trust_remote_code=True,
**model_kwargs,
)
if (getattr(model, "quantization_method", None) != "bitsandbytes"
and len({p.device
for p in model.parameters()}) < 2):
model = model.to(self.device)
self.model = model
if not skip_tokenizer_init:
self.tokenizer = AutoTokenizer.from_pretrained(
@ -349,16 +366,13 @@ class HfRunner:
if skip_tokenizer_init:
self.tokenizer = self.processor.tokenizer
self.dtype = dtype
self.postprocess_inputs = postprocess_inputs
def get_inputs(
self,
prompts: list[str],
images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None,
) -> list[BatchEncoding]:
) -> list[Union[BatchFeature, BatchEncoding]]:
if images is not None:
assert len(prompts) == len(images)
@ -368,7 +382,7 @@ class HfRunner:
if audios is not None:
assert len(prompts) == len(audios)
all_inputs: list[BatchEncoding] = []
all_inputs: list[Union[BatchFeature, BatchEncoding]] = []
for i, prompt in enumerate(prompts):
processor_kwargs: dict[str, Any] = {
"text": prompt,
@ -384,7 +398,8 @@ class HfRunner:
processor_kwargs["sampling_rate"] = sr
inputs = self.processor(**processor_kwargs)
inputs = self.postprocess_inputs(inputs, dtype=self.dtype)
if isinstance(inputs, BatchFeature):
inputs = inputs.to(dtype=self.dtype)
all_inputs.append(inputs)
@ -417,7 +432,7 @@ class HfRunner:
outputs: list[tuple[list[list[int]], list[str]]] = []
for inputs in all_inputs:
output_ids = self.model.generate(
**self.wrap_device(inputs, device=self.model.device.type),
**self.wrap_device(inputs),
use_cache=True,
**kwargs,
)
@ -488,7 +503,7 @@ class HfRunner:
all_logprobs: list[list[torch.Tensor]] = []
for inputs in all_inputs:
output = self.model.generate(
**self.wrap_device(inputs, device=self.model.device.type),
**self.wrap_device(inputs),
use_cache=True,
do_sample=False,
max_new_tokens=max_tokens,
@ -569,7 +584,7 @@ class HfRunner:
for inputs in all_inputs:
output = self.model.generate(
**self.wrap_device(inputs, device=self.model.device.type),
**self.wrap_device(inputs),
use_cache=True,
do_sample=False,
max_new_tokens=max_tokens,
@ -620,19 +635,15 @@ class HfRunner:
if images is not None and images[i] is not None:
processor_kwargs["images"] = images[i]
encoder_inputs = self.wrap_device(
self.processor(**processor_kwargs),
device=self.model.device.type,
)
encoder_inputs = self.processor(**processor_kwargs)
encoder_inputs = self.wrap_device(encoder_inputs)
if decoder_prompt is None:
decoder_input_ids = None
else:
decoder_input_ids = self.wrap_device(
self.tokenizer(decoder_prompt,
return_tensors="pt").input_ids,
device=self.model.device.type,
)
decoder_inputs = self.tokenizer(decoder_prompt,
return_tensors="pt")
decoder_input_ids = self.wrap_device(decoder_inputs.input_ids)
output = self.model.generate(
decoder_input_ids=decoder_input_ids,
@ -684,6 +695,7 @@ class VllmRunner:
"""
The default value of some arguments have been modified from
:class:`~vllm.LLM` as follows:
- `trust_remote_code`: Set to `True` instead of `False` for convenience.
- `seed`: Set to `0` instead of `None` for test reproducibility.
- `max_model_len`: Set to `1024` instead of `None` to reduce memory usage.
@ -701,10 +713,8 @@ class VllmRunner:
tokenizer_mode: str = "auto",
trust_remote_code: bool = True,
seed: Optional[int] = 0,
# Use smaller max model length, otherwise bigger model cannot run due
# to kv cache size limit.
max_model_len: int = 1024,
dtype: str = "half",
dtype: str = "auto",
disable_log_stats: bool = True,
tensor_parallel_size: int = 1,
block_size: int = 16,
@ -1110,4 +1120,4 @@ def pytest_collection_modifyitems(config, items):
skip_optional = pytest.mark.skip(reason="need --optional option to run")
for item in items:
if "optional" in item.keywords:
item.add_marker(skip_optional)
item.add_marker(skip_optional)

View File

@ -64,7 +64,6 @@ def test_multi_chat():
def test_chat_multi_image(image_urls: list[str]):
llm = LLM(
model="microsoft/Phi-3.5-vision-instruct",
dtype="bfloat16",
max_model_len=4096,
max_num_seqs=5,
enforce_eager=True,

View File

@ -18,8 +18,6 @@ TEST_AUDIO_URLS = [
@pytest.fixture(scope="module")
def server():
args = [
"--dtype",
"bfloat16",
"--max-model-len",
"2048",
"--max-num-seqs",

View File

@ -24,8 +24,6 @@ def server():
args = [
"--task",
"generate",
"--dtype",
"bfloat16",
"--max-model-len",
"32768",
"--max-num-seqs",

View File

@ -25,8 +25,6 @@ def server():
args = [
"--task",
"generate",
"--dtype",
"bfloat16",
"--max-model-len",
"2048",
"--max-num-seqs",

View File

@ -28,8 +28,6 @@ def server():
args = [
"--task",
"embed",
"--dtype",
"bfloat16",
"--max-model-len",
"2048",
"--max-num-seqs",

View File

@ -34,7 +34,7 @@ def phi3v_model_config():
tokenizer=PHI3V_MODEL_ID,
tokenizer_mode="auto",
trust_remote_code=True,
dtype="bfloat16",
dtype="auto",
seed=0,
limit_mm_per_prompt={
"image": 2,
@ -58,7 +58,7 @@ def mllama_model_config():
tokenizer=MLLAMA_MODEL_ID,
tokenizer_mode="auto",
trust_remote_code=True,
dtype="bfloat16",
dtype="auto",
seed=0,
limit_mm_per_prompt={
"image": 2,
@ -669,7 +669,7 @@ def test_multimodal_image_parsing_matches_hf(model, image_url):
tokenizer=MLLAMA_MODEL_ID,
tokenizer_mode="auto",
trust_remote_code=True,
dtype="bfloat16",
dtype="auto",
seed=0,
limit_mm_per_prompt={
"image": 2,

View File

@ -5,11 +5,10 @@ from typing import Optional
import numpy as np
import pytest
import pytest_asyncio
from transformers import AutoModel, AutoTokenizer, BatchEncoding
from transformers import AutoModel, AutoTokenizer
from vllm.multimodal.audio import resample_audio
from vllm.sequence import SampleLogprobs
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from ....conftest import HfRunner, VllmRunner
from ....utils import RemoteOpenAIServer
@ -107,8 +106,6 @@ def run_test(
**kwargs,
):
"""Inference result should be the same between hf and vllm."""
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
@ -124,15 +121,7 @@ def run_test(
for vllm_prompt, _, audio in prompts_and_audios
]
def process(hf_inputs: BatchEncoding, **kwargs):
hf_inputs["audio_values"] = hf_inputs["audio_values"] \
.to(torch_dtype) # type: ignore
return hf_inputs
with hf_runner(model,
dtype=dtype,
postprocess_inputs=process,
auto_cls=AutoModel) as hf_model:
with hf_runner(model, dtype=dtype, auto_cls=AutoModel) as hf_model:
hf_outputs_per_audio = [
hf_model.generate_greedy_logprobs_limit(
[hf_prompt],

View File

@ -122,9 +122,6 @@ VLM_TEST_SETTINGS = {
"cherry_blossom": "What is in the picture?",
}),
auto_cls=AutoModelForImageTextToText,
postprocess_inputs=model_utils.cast_dtype_post_processor(
"pixel_values"
),
vllm_output_post_proc=model_utils.paligemma_vllm_to_hf_output,
dtype="bfloat16",
marks=[pytest.mark.skip(reason="vLLM does not support PrefixLM attention mask")], # noqa: E501
@ -179,7 +176,6 @@ VLM_TEST_SETTINGS = {
# "cherry_blossom": "<vlm_image>Please infer the season with reason.", # noqa: E501
# }),
# multi_image_prompt="<vlm_image><vlm_image>Describe the two images shortly.", # noqa: E501
# postprocess_inputs=model_utils.cast_dtype_post_processor("pixel_values"), # noqa: E501
# stop_str=["<|im_end|>"],
# image_size_factors=[(0.10, 0.15)],
# max_tokens=64,
@ -200,9 +196,6 @@ VLM_TEST_SETTINGS = {
max_model_len=4096,
max_num_seqs=2,
auto_cls=AutoModelForImageTextToText,
postprocess_inputs=model_utils.cast_dtype_post_processor(
"pixel_values"
),
# For chameleon, we only compare the sequences
vllm_output_post_proc = lambda vllm_output, model: vllm_output[:2],
hf_output_post_proc = lambda hf_output, model: hf_output[:2],
@ -222,7 +215,6 @@ VLM_TEST_SETTINGS = {
}),
multi_image_prompt="image_1:<image>\nimage_2:<image>\nWhich image can we see the car and the tower?", # noqa: E501
patch_hf_runner=model_utils.deepseekvl2_patch_hf_runner,
postprocess_inputs=model_utils.cast_dtype_post_processor("images"),
hf_output_post_proc=model_utils.deepseekvl2_trunc_hf_output,
stop_str=["<end▁of▁sentence>", "<begin▁of▁sentence>"], # noqa: E501
image_size_factors=[(), (1.0, ), (1.0, 1.0, 1.0), (0.1, 0.5, 1.0)],
@ -258,7 +250,6 @@ VLM_TEST_SETTINGS = {
max_model_len=4096,
max_num_seqs=2,
auto_cls=AutoModelForImageTextToText,
dtype="bfloat16",
vllm_runner_kwargs={"mm_processor_kwargs": {"do_pan_and_scan": True}},
patch_hf_runner=model_utils.gemma3_patch_hf_runner,
),
@ -272,7 +263,6 @@ VLM_TEST_SETTINGS = {
}),
max_model_len=2048,
max_num_seqs=2,
dtype="bfloat16",
get_stop_token_ids=lambda tok: [151329, 151336, 151338],
patch_hf_runner=model_utils.glm4v_patch_hf_runner,
# The image embeddings match with HF but the outputs of the language
@ -295,7 +285,6 @@ VLM_TEST_SETTINGS = {
}),
multi_image_prompt="Image-1: <image>\nImage-2: <image>\nDescribe the two images in short.", # noqa: E501
max_model_len=8192,
dtype="bfloat16",
use_tokenizer_eos=True,
num_logprobs=10,
patch_hf_runner=model_utils.h2ovl_patch_hf_runner,
@ -324,10 +313,6 @@ VLM_TEST_SETTINGS = {
}),
multi_image_prompt="Image-1: <image>\nImage-2: <image>\nDescribe the two images in short.", # noqa: E501
max_model_len=4096,
# NOTE: Mono-InternVL-2B doesn't work with fp16,
# it will result NaN during inference.
# See: https://huggingface.co/OpenGVLab/Mono-InternVL-2B/discussions/9
dtype="bfloat16",
use_tokenizer_eos=True,
patch_hf_runner=model_utils.internvl_patch_hf_runner,
),
@ -351,9 +336,6 @@ VLM_TEST_SETTINGS = {
prompt_formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
num_video_frames=16,
max_model_len=16384,
postprocess_inputs=model_utils.cast_dtype_post_processor(
"pixel_values_videos"
),
auto_cls=AutoModelForVision2Seq,
vllm_output_post_proc=model_utils.llava_onevision_vllm_to_hf_output,
custom_test_opts=[CustomTestOptions(
@ -378,9 +360,6 @@ VLM_TEST_SETTINGS = {
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt: f"<|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501
max_model_len=4096,
postprocess_inputs=model_utils.cast_dtype_post_processor(
"pixel_values"
),
get_stop_token_ids=lambda tok: [128009],
auto_cls=AutoModelForImageTextToText,
vllm_output_post_proc=model_utils.mantis_vllm_to_hf_output,
@ -400,8 +379,8 @@ VLM_TEST_SETTINGS = {
max_model_len=4096,
max_num_seqs=2,
get_stop_token_ids=lambda tok: [tok.eos_id, tok.eot_id],
postprocess_inputs=model_utils.wrap_inputs_post_processor,
hf_output_post_proc=model_utils.minicpmv_trunc_hf_output,
patch_hf_runner=model_utils.minicpmv_25_patch_hf_runner,
),
"minicpmo_26": VLMTestInfo(
models=["openbmb/MiniCPM-o-2_6"],
@ -411,11 +390,8 @@ VLM_TEST_SETTINGS = {
max_model_len=4096,
max_num_seqs=2,
get_stop_token_ids=lambda tok: tok.convert_tokens_to_ids(['<|im_end|>', '<|endoftext|>']), # noqa: E501
postprocess_inputs=model_utils.ignore_inputs_post_processor(
"image_sizes"
),
hf_output_post_proc=model_utils.minicpmv_trunc_hf_output,
patch_hf_runner=model_utils.minicpmo_patch_hf_runner
patch_hf_runner=model_utils.minicpmo_26_patch_hf_runner,
),
"minicpmv_26": VLMTestInfo(
models=["openbmb/MiniCPM-V-2_6"],
@ -425,10 +401,8 @@ VLM_TEST_SETTINGS = {
max_model_len=4096,
max_num_seqs=2,
get_stop_token_ids=lambda tok: tok.convert_tokens_to_ids(['<|im_end|>', '<|endoftext|>']), # noqa: E501
postprocess_inputs=model_utils.ignore_inputs_post_processor(
"image_sizes"
),
hf_output_post_proc=model_utils.minicpmv_trunc_hf_output,
patch_hf_runner=model_utils.minicpmv_26_patch_hf_runner,
),
"molmo": VLMTestInfo(
models=["allenai/Molmo-7B-D-0924"],
@ -437,7 +411,6 @@ VLM_TEST_SETTINGS = {
max_model_len=4096,
max_num_seqs=2,
patch_hf_runner=model_utils.molmo_patch_hf_runner,
postprocess_inputs=model_utils.molmo_post_processor,
),
# Tests for phi3v currently live in another file because of a bug in
# transformers. Once this issue is fixed, we can enable them here instead.
@ -482,9 +455,6 @@ VLM_TEST_SETTINGS = {
prompt_formatter=lambda img_prompt: f"USER: {img_prompt}\nASSISTANT:",
max_model_len=4096,
auto_cls=AutoModelForImageTextToText,
postprocess_inputs=model_utils.cast_dtype_post_processor(
"pixel_values"
),
vllm_output_post_proc = lambda vllm_output, model: vllm_output[:2],
hf_output_post_proc = lambda hf_output, model: hf_output[:2],
comparator=check_outputs_equal,
@ -529,9 +499,6 @@ VLM_TEST_SETTINGS = {
test_type=VLMTestType.CUSTOM_INPUTS,
max_model_len=16384,
max_num_seqs=2,
postprocess_inputs=model_utils.cast_dtype_post_processor(
"pixel_values"
),
auto_cls=AutoModelForVision2Seq,
vllm_output_post_proc=model_utils.llava_onevision_vllm_to_hf_output,
custom_test_opts=[CustomTestOptions(

View File

@ -4,7 +4,6 @@ from typing import Any, Callable, Optional, Union
import torch
from PIL.Image import Image
from transformers import BatchEncoding
from transformers.models.auto.auto_factory import _BaseAutoModelClass
from vllm.config import TaskOption
@ -31,7 +30,6 @@ def run_test(
vllm_output_post_proc: Optional[Callable[[RunnerOutput, str], Any]],
auto_cls: type[_BaseAutoModelClass],
use_tokenizer_eos: bool,
postprocess_inputs: Callable[[BatchEncoding], BatchEncoding],
comparator: Callable[..., None],
get_stop_token_ids: Optional[Callable[[AnyTokenizer], list[int]]],
stop_str: Optional[list[str]],
@ -101,7 +99,6 @@ def run_test(
hf_model = hf_runner(model,
dtype=dtype,
auto_cls=auto_cls,
postprocess_inputs=postprocess_inputs,
model_kwargs=hf_model_kwargs)
# Some models need to patch things like the model processor, e.g., internvl

View File

@ -6,16 +6,15 @@ typically specific to a small subset of models.
import re
import types
from pathlib import PosixPath
from typing import Callable, Optional, Union
from typing import Optional, Union
import torch
from PIL.Image import Image
from transformers import (AutoConfig, AutoTokenizer, BatchEncoding,
from transformers import (AutoConfig, AutoTokenizer, BatchFeature,
GenerationConfig)
from vllm.sequence import SampleLogprobs
from vllm.transformers_utils.tokenizer import patch_padding_side
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from .....conftest import HfRunner, ImageAsset, _ImageAssets
from .types import RunnerOutput
@ -211,40 +210,6 @@ def get_llava_embeddings(image_assets: _ImageAssets):
return [asset.image_embeds for asset in image_assets]
####### postprocessors to run on HF BatchEncoding
def cast_dtype_post_processor(
hf_inp_key: str) -> Callable[[BatchEncoding, str], BatchEncoding]:
"""Gets a handle to a post processor which converts a given key into a
target data type."""
def process(hf_inputs: BatchEncoding, dtype: str):
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
hf_inputs[hf_inp_key] = hf_inputs[hf_inp_key].to(torch_dtype)
return hf_inputs
return process
def ignore_inputs_post_processor(
hf_inp_key: str) -> Callable[[BatchEncoding, str], BatchEncoding]:
"""Gets a handle to a post processor which ignores a given key."""
def process(hf_inputs: BatchEncoding, dtype: str):
del hf_inputs[hf_inp_key]
return hf_inputs
return process
def wrap_inputs_post_processor(hf_inputs: BatchEncoding, dtype: str):
return {"model_inputs": hf_inputs}
def molmo_post_processor(hf_inputs: BatchEncoding, dtype: str):
hf_inputs = cast_dtype_post_processor("images")(hf_inputs, dtype)
return {k: v.unsqueeze(0) for k, v in hf_inputs.items()}
####### Prompt path encoders for models that need models on disk
def qwen_prompt_path_encoder(
tmp_path: PosixPath, prompt: str, assets: Union[list[ImageAsset],
@ -295,8 +260,7 @@ def deepseekvl2_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
for k in inputs.keys() # noqa
if k not in ("seq_lens", "sft_format")
}
inputs = BatchEncoding(data=inputs, tensor_type="pt")
return inputs
return BatchFeature(data=inputs, tensor_type="pt")
hf_model.processor = processor
hf_model.model.get_output_embeddings = lambda: \
@ -529,10 +493,52 @@ def mantis_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
return hf_model
def minicpmo_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
def minicpmv_25_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
orig_generate = hf_model.model.generate
def _generate(self, *args, **kwargs):
def _generate(
self,
*args,
input_ids=None,
pixel_values=None,
image_sizes=None,
image_bound=None,
tgt_sizes=None,
**kwargs,
):
model_inputs = {
"input_ids": input_ids,
"pixel_values": pixel_values,
"image_sizes": image_sizes,
"image_bound": image_bound,
"tgt_sizes": tgt_sizes,
}
for k in list(model_inputs.keys()):
if model_inputs[k] is None:
model_inputs.pop(k)
return orig_generate(model_inputs, *args, decode_text=False, **kwargs)
hf_model.model.generate = types.MethodType(_generate, hf_model.model)
return hf_model
def minicpmo_26_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
orig_generate = hf_model.model.generate
def _generate(self, *args, image_sizes=None, **kwargs):
return orig_generate(*args, decode_text=False, **kwargs)
hf_model.model.generate = types.MethodType(_generate, hf_model.model)
return hf_model
def minicpmv_26_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
orig_generate = hf_model.model.generate
def _generate(self, *args, image_sizes=None, **kwargs):
return orig_generate(*args, decode_text=False, **kwargs)
hf_model.model.generate = types.MethodType(_generate, hf_model.model)
@ -551,10 +557,11 @@ def molmo_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
def _generate(self, max_new_tokens=None, do_sample=None, **kwargs):
batch = {
k: kwargs.pop(k)
k: kwargs.pop(k).unsqueeze(0)
for k in ("input_ids", "images", "image_input_idx", "image_masks")
if k in kwargs
}
batch = BatchFeature(batch).to(dtype=self.dtype)
return self.generate_from_batch(
batch,

View File

@ -8,13 +8,12 @@ from typing import Any, Callable, NamedTuple, Optional, Union
import torch
from PIL.Image import Image
from pytest import MarkDecorator
from transformers import AutoModelForCausalLM, BatchEncoding
from transformers import AutoModelForCausalLM
from transformers.models.auto.auto_factory import _BaseAutoModelClass
from vllm.config import TaskOption
from vllm.sequence import SampleLogprobs
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import identity
from .....conftest import IMAGE_ASSETS, HfRunner, ImageAsset, _ImageAssets
from ....utils import check_logprobs_close
@ -110,11 +109,6 @@ class VLMTestInfo(NamedTuple):
# Indicates we should explicitly pass the EOS from the tokenizer
use_tokenizer_eos: bool = False
auto_cls: type[_BaseAutoModelClass] = AutoModelForCausalLM
# Callable to pass to the HF runner to run on inputs; for now, we also pass
# the data type to input post processing, because almost all of the uses of
# postprocess_inputs are to fix the data types of BatchEncoding values.
postprocess_inputs: Callable[[BatchEncoding, str],
BatchEncoding] = identity
patch_hf_runner: Optional[Callable[[HfRunner], HfRunner]] = None
# Post processors that if defined, will run oun the outputs of the
@ -130,7 +124,7 @@ class VLMTestInfo(NamedTuple):
# is all combinations of .models + all fields below
max_tokens: Union[int, tuple[int]] = 128
num_logprobs: Union[int, tuple[int]] = 5
dtype: Union[str, Iterable[str]] = "half"
dtype: Union[str, Union[list[str], tuple[str, ...]]] = "auto"
distributed_executor_backend: Optional[Union[str, Iterable[str]]] = None
# Only expanded in video tests
num_video_frames: Union[int, tuple[int]] = 16
@ -171,7 +165,6 @@ class VLMTestInfo(NamedTuple):
"vllm_output_post_proc": self.vllm_output_post_proc,
"auto_cls": self.auto_cls,
"use_tokenizer_eos": self.use_tokenizer_eos,
"postprocess_inputs": self.postprocess_inputs,
"comparator": self.comparator,
"get_stop_token_ids": self.get_stop_token_ids,
"hf_model_kwargs": self.hf_model_kwargs,

View File

@ -1,12 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
from functools import partial
from typing import Callable
import pytest
import torch
import torch.nn.functional as F
from PIL import Image
from transformers import BatchEncoding, Qwen2VLForConditionalGeneration
from transformers import Qwen2VLForConditionalGeneration
from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner
from ....utils import large_gpu_test
@ -75,10 +75,6 @@ def apply_chat_template_and_add_eos(
return prompt
def postprocess_inputs(hf_model: HfRunner, inputs: BatchEncoding, **kwargs):
return hf_model.model.prepare_inputs_for_generation(**inputs, **kwargs)
def _run_test(
hf_runner: type[HfRunner],
vllm_runner: type[VllmRunner],
@ -118,14 +114,8 @@ def _run_test(
with hf_runner(model,
dtype=dtype,
auto_cls=Qwen2VLForConditionalGeneration) as hf_model:
hf_model.postprocess_inputs = partial(
postprocess_inputs,
hf_model,
cache_position=torch.arange(
0,
1, # 1 for batch size
requires_grad=False),
use_cache=False)
prompts = []
for text, image, embed_text in zip(input_texts, input_images,
embed_texts):
# dse requires non-standard input processing
@ -133,20 +123,34 @@ def _run_test(
messages = get_messages(image, text, embed_text)
prompt = apply_chat_template_and_add_eos(
messages, hf_model.processor.apply_chat_template)
inputs = hf_model.get_inputs(
prompts=[[prompt]],
images=[[image]],
)
with torch.no_grad():
prompts.append(prompt)
all_inputs = hf_model.get_inputs(
prompts=prompts,
images=input_images,
)
with torch.no_grad():
all_outputs = []
for inputs in all_inputs:
inputs = hf_model.model.prepare_inputs_for_generation(
**inputs,
cache_position=torch.arange(1), # 1 for batch size
use_cache=False,
)
outputs = hf_model.model(
**hf_model.wrap_device(inputs[0],
device=hf_model.model.device.type),
**hf_model.wrap_device(inputs),
return_dict=True,
output_hidden_states=True,
)
pooled_output = torch.nn.functional.normalize(
outputs.hidden_states[-1][0, -1], p=2, dim=-1)
hf_outputs.append(pooled_output.tolist())
pooled_output = F.normalize(outputs.hidden_states[-1][0, -1],
p=2,
dim=-1)
all_outputs.append(pooled_output.tolist())
hf_outputs = all_outputs
check_embeddings_close(
embeddings_0_lst=hf_outputs,

View File

@ -86,8 +86,7 @@ def _run_test(
for inputs in all_inputs:
# Based on: https://huggingface.co/royokong/e5-v
outputs = hf_model.model(
**hf_model.wrap_device(inputs,
device=hf_model.model.device.type),
**hf_model.wrap_device(inputs),
return_dict=True,
output_hidden_states=True,
)

View File

@ -53,8 +53,7 @@ def _run_test(
for inputs in all_inputs:
# Based on: https://github.com/TIGER-AI-Lab/VLM2Vec/blob/db3b951bccabba220c1f53ab46a734e50dd2fc08/src/model.py
outputs = hf_model.model(
**hf_model.wrap_device(inputs,
device=hf_model.model.device.type),
**hf_model.wrap_device(inputs),
return_dict=True,
output_hidden_states=True,
)

View File

@ -4,8 +4,7 @@ from typing import Optional, overload
import pytest
import torch
from transformers import (AutoConfig, AutoModelForImageTextToText,
AutoTokenizer, BatchEncoding)
from transformers import AutoConfig, AutoModelForImageTextToText, AutoTokenizer
from vllm import LLM, SamplingParams
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
@ -227,13 +226,9 @@ def _run_test(
for prompts, images in inputs
]
def process(hf_inputs: BatchEncoding, **kwargs):
return hf_inputs
with hf_runner(model,
dtype=dtype,
model_kwargs={"device_map": "auto"},
postprocess_inputs=process,
auto_cls=AutoModelForImageTextToText) as hf_model:
hf_outputs_per_image = [
hf_model.generate_greedy_logprobs_limit(prompts,

View File

@ -2,7 +2,7 @@
import warnings
from collections.abc import Sequence
from typing import Optional, Union
from typing import Any, Optional, Union
import torch
@ -254,9 +254,9 @@ def check_logprobs_close(
def build_model_context(
model_id: str,
task: TaskOption = "auto",
dtype: Optional[Union[str, torch.dtype]] = None,
mm_processor_kwargs: Optional[dict] = None,
limit_mm_per_prompt: Optional[dict] = None,
dtype: Union[str, torch.dtype] = "auto",
mm_processor_kwargs: Optional[dict[str, Any]] = None,
limit_mm_per_prompt: Optional[dict[str, int]] = None,
disable_mm_preprocessor_cache: bool = True,
):
"""Creates an InputContext for a given model.
@ -274,9 +274,6 @@ def build_model_context(
model_info.check_available_online(on_fail="skip")
model_info.check_transformers_version(on_fail="skip")
if dtype is None:
dtype = "half"
model_config = ModelConfig(
model_id,
task=task,

View File

@ -853,7 +853,7 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid):
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="half",
dtype="auto",
revision=None,
limit_mm_per_prompt=limit_mm_per_prompt,
)
@ -892,7 +892,7 @@ def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid):
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="half",
dtype="auto",
revision=None,
limit_mm_per_prompt=limit_mm_per_prompt,
)
@ -965,7 +965,7 @@ def test_hf_processor_kwargs(model_id, call_kwargs, expected_kwargs):
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="half",
dtype="auto",
revision=None,
)

View File

@ -166,7 +166,7 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
test_prompts = multilora_inference.create_test_prompts(lora_path)
# Serialize model before deserializing and binding LoRA adapters
with vllm_runner(model_ref, ) as vllm_model:
with vllm_runner(model_ref) as vllm_model:
model_path = tmp_path / (model_ref + ".tensors")
vllm_model.apply_model(
@ -208,7 +208,7 @@ def test_load_without_tensorizer_load_format(vllm_runner):
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path):
## Serialize model
with vllm_runner(model_ref, ) as vllm_model:
with vllm_runner(model_ref) as vllm_model:
model_path = tmp_path / (model_ref + ".tensors")
vllm_model.apply_model(

View File

@ -50,7 +50,7 @@ def _get_test_sampling_params(
"""Generate random sampling params for a batch."""
def get_mostly_n_gt1() -> int:
"""Mostly n \in [2,20], ~1/3 n=1"""
r"""Mostly n \in [2,20], ~1/3 n=1"""
x = random.randint(0, 28)
if x < 10:
return 1

View File

@ -347,7 +347,7 @@ class ModelConfig:
self.encoder_config = self._get_encoder_config()
self.hf_image_processor_config = get_hf_image_processor_config(
self.model, revision)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
self.use_async_output_proc = use_async_output_proc
self.mm_processor_kwargs = mm_processor_kwargs
self.disable_mm_preprocessor_cache = disable_mm_preprocessor_cache
@ -2526,6 +2526,14 @@ def _get_and_verify_dtype(
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
# because config.torch_dtype can be None.
config_dtype = getattr(config, "torch_dtype", None)
# Fallbacks for multi-modal models if the root config
# does not define torch_dtype
if config_dtype is None and hasattr(config, "text_config"):
config_dtype = getattr(config.text_config, "torch_dtype", None)
if config_dtype is None and hasattr(config, "vision_config"):
config_dtype = getattr(config.vision_config, "torch_dtype", None)
if config_dtype is None:
config_dtype = torch.float32
@ -2533,16 +2541,8 @@ def _get_and_verify_dtype(
dtype = dtype.lower()
if dtype == "auto":
if config_dtype == torch.float32:
if config.model_type in ("gemma2", "gemma3", "gemma3_text"):
logger.info(
"For Gemma 2 and 3, we downcast float32 to bfloat16 "
"instead of float16 by default. Please specify `dtype` "
"if you want to use float16.")
torch_dtype = torch.bfloat16
else:
# Following the common practice, we use float16 for float32
# models.
torch_dtype = torch.float16
# Following common practice, we use float16 for float32 models
torch_dtype = torch.float16
else:
torch_dtype = config_dtype