From f690372b6803c43870aefa631ab207ea3d163b1d Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 19 Mar 2025 13:49:33 +0800 Subject: [PATCH] [Core] Update dtype detection and defaults (#14858) Signed-off-by: DarkLight1337 --- tests/compile/test_basic_correctness.py | 2 +- tests/conftest.py | 116 ++++++++++-------- tests/entrypoints/llm/test_chat.py | 1 - tests/entrypoints/openai/test_audio.py | 2 - tests/entrypoints/openai/test_video.py | 2 - tests/entrypoints/openai/test_vision.py | 2 - .../openai/test_vision_embedding.py | 2 - tests/entrypoints/test_chat_utils.py | 6 +- .../audio_language/test_ultravox.py | 15 +-- .../vision_language/test_models.py | 39 +----- .../vision_language/vlm_utils/core.py | 3 - .../vision_language/vlm_utils/model_utils.py | 91 +++++++------- .../vision_language/vlm_utils/types.py | 11 +- .../vision_language/test_dse_qwen2_vl.py | 52 ++++---- .../vision_language/test_llava_next.py | 3 +- .../embedding/vision_language/test_phi3v.py | 3 +- .../vision_language/test_mllama.py | 7 +- tests/models/utils.py | 11 +- tests/multimodal/test_processing.py | 6 +- tests/tensorizer_loader/test_tensorizer.py | 4 +- tests/v1/engine/test_llm_engine.py | 2 +- vllm/config.py | 22 ++-- 22 files changed, 175 insertions(+), 227 deletions(-) diff --git a/tests/compile/test_basic_correctness.py b/tests/compile/test_basic_correctness.py index b639fd71..0b76779b 100644 --- a/tests/compile/test_basic_correctness.py +++ b/tests/compile/test_basic_correctness.py @@ -60,7 +60,7 @@ class TestSetting: # embedding model TestSetting( model="BAAI/bge-multilingual-gemma2", - model_args=["--task", "embed"], + model_args=["--task", "embed", "--dtype", "bfloat16"], pp_size=1, tp_size=1, attn_backend="FLASH_ATTN", diff --git a/tests/conftest.py b/tests/conftest.py index 30e5ca2e..0c71d981 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,8 +14,8 @@ import torch.nn as nn import torch.nn.functional as F from huggingface_hub import snapshot_download from PIL import Image -from transformers import (AutoModelForCausalLM, AutoTokenizer, BatchEncoding, - BatchFeature) +from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer, + BatchEncoding, BatchFeature) from transformers.models.auto.auto_factory import _BaseAutoModelClass from tests.models.utils import (TokensTextLogprobs, @@ -23,7 +23,7 @@ from tests.models.utils import (TokensTextLogprobs, from vllm import LLM, SamplingParams from vllm.assets.image import ImageAsset from vllm.assets.video import VideoAsset -from vllm.config import TaskOption, TokenizerPoolConfig +from vllm.config import TaskOption, TokenizerPoolConfig, _get_and_verify_dtype from vllm.connections import global_http_connection from vllm.distributed import (cleanup_dist_env_and_memory, init_distributed_environment, @@ -34,8 +34,7 @@ from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt, from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.sampling_params import BeamSearchParams -from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless, - identity, is_list_of) +from vllm.utils import cuda_device_count_stateless, is_list_of logger = init_logger(__name__) @@ -271,14 +270,18 @@ _R = TypeVar("_R") class HfRunner: - def wrap_device(self, x: _T, device: Optional[str] = None) -> _T: + def get_default_device(self): from vllm.platforms import current_platform + + return ("cpu" if current_platform.is_cpu() + or current_platform.is_openvino() else "cuda") + + def wrap_device(self, x: _T, device: Optional[str] = None) -> _T: if x is None or isinstance(x, (bool, )): return x if device is None: - device = "cpu" if current_platform.is_cpu( - ) or current_platform.is_openvino() else "cuda" + device = self.device if isinstance(x, dict): return {k: self.wrap_device(v, device) for k, v in x.items()} @@ -291,45 +294,59 @@ class HfRunner: def __init__( self, model_name: str, - dtype: str = "half", + dtype: str = "auto", *, model_kwargs: Optional[dict[str, Any]] = None, is_sentence_transformer: bool = False, is_cross_encoder: bool = False, skip_tokenizer_init: bool = False, auto_cls: type[_BaseAutoModelClass] = AutoModelForCausalLM, - postprocess_inputs: Callable[..., BatchEncoding] = identity, ) -> None: - torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype] - self.model_name = model_name + self.config = AutoConfig.from_pretrained( + model_name, + trust_remote_code=True, + ) + self.device = self.get_default_device() + self.dtype = torch_dtype = _get_and_verify_dtype(self.config, dtype) + + model_kwargs = model_kwargs if model_kwargs is not None else {} + model_kwargs.setdefault("torch_dtype", torch_dtype) + if is_sentence_transformer: # Lazy init required for AMD CI from sentence_transformers import SentenceTransformer - self.model = self.wrap_device( - SentenceTransformer( - model_name, - device="cpu", - trust_remote_code=True, - ).to(dtype=torch_dtype)) + + self.model = SentenceTransformer( + model_name, + device=self.device, + model_kwargs=model_kwargs, + trust_remote_code=True, + ) elif is_cross_encoder: # Lazy init required for AMD CI from sentence_transformers import CrossEncoder - self.model = CrossEncoder(model_name, - device="cpu", - trust_remote_code=True) - self.model.model = self.wrap_device(self.model.model)\ - .to(dtype=torch_dtype) + + self.model = CrossEncoder( + model_name, + device=self.device, + automodel_args=model_kwargs, + trust_remote_code=True, + ) else: - model_kwargs = model_kwargs if model_kwargs is not None else {} - self.model = self.wrap_device( - auto_cls.from_pretrained( - model_name, - torch_dtype=torch_dtype, - trust_remote_code=True, - **model_kwargs, - )) + model = auto_cls.from_pretrained( + model_name, + trust_remote_code=True, + **model_kwargs, + ) + + if (getattr(model, "quantization_method", None) != "bitsandbytes" + and len({p.device + for p in model.parameters()}) < 2): + model = model.to(self.device) + + self.model = model if not skip_tokenizer_init: self.tokenizer = AutoTokenizer.from_pretrained( @@ -349,16 +366,13 @@ class HfRunner: if skip_tokenizer_init: self.tokenizer = self.processor.tokenizer - self.dtype = dtype - self.postprocess_inputs = postprocess_inputs - def get_inputs( self, prompts: list[str], images: Optional[PromptImageInput] = None, videos: Optional[PromptVideoInput] = None, audios: Optional[PromptAudioInput] = None, - ) -> list[BatchEncoding]: + ) -> list[Union[BatchFeature, BatchEncoding]]: if images is not None: assert len(prompts) == len(images) @@ -368,7 +382,7 @@ class HfRunner: if audios is not None: assert len(prompts) == len(audios) - all_inputs: list[BatchEncoding] = [] + all_inputs: list[Union[BatchFeature, BatchEncoding]] = [] for i, prompt in enumerate(prompts): processor_kwargs: dict[str, Any] = { "text": prompt, @@ -384,7 +398,8 @@ class HfRunner: processor_kwargs["sampling_rate"] = sr inputs = self.processor(**processor_kwargs) - inputs = self.postprocess_inputs(inputs, dtype=self.dtype) + if isinstance(inputs, BatchFeature): + inputs = inputs.to(dtype=self.dtype) all_inputs.append(inputs) @@ -417,7 +432,7 @@ class HfRunner: outputs: list[tuple[list[list[int]], list[str]]] = [] for inputs in all_inputs: output_ids = self.model.generate( - **self.wrap_device(inputs, device=self.model.device.type), + **self.wrap_device(inputs), use_cache=True, **kwargs, ) @@ -488,7 +503,7 @@ class HfRunner: all_logprobs: list[list[torch.Tensor]] = [] for inputs in all_inputs: output = self.model.generate( - **self.wrap_device(inputs, device=self.model.device.type), + **self.wrap_device(inputs), use_cache=True, do_sample=False, max_new_tokens=max_tokens, @@ -569,7 +584,7 @@ class HfRunner: for inputs in all_inputs: output = self.model.generate( - **self.wrap_device(inputs, device=self.model.device.type), + **self.wrap_device(inputs), use_cache=True, do_sample=False, max_new_tokens=max_tokens, @@ -620,19 +635,15 @@ class HfRunner: if images is not None and images[i] is not None: processor_kwargs["images"] = images[i] - encoder_inputs = self.wrap_device( - self.processor(**processor_kwargs), - device=self.model.device.type, - ) + encoder_inputs = self.processor(**processor_kwargs) + encoder_inputs = self.wrap_device(encoder_inputs) if decoder_prompt is None: decoder_input_ids = None else: - decoder_input_ids = self.wrap_device( - self.tokenizer(decoder_prompt, - return_tensors="pt").input_ids, - device=self.model.device.type, - ) + decoder_inputs = self.tokenizer(decoder_prompt, + return_tensors="pt") + decoder_input_ids = self.wrap_device(decoder_inputs.input_ids) output = self.model.generate( decoder_input_ids=decoder_input_ids, @@ -684,6 +695,7 @@ class VllmRunner: """ The default value of some arguments have been modified from :class:`~vllm.LLM` as follows: + - `trust_remote_code`: Set to `True` instead of `False` for convenience. - `seed`: Set to `0` instead of `None` for test reproducibility. - `max_model_len`: Set to `1024` instead of `None` to reduce memory usage. @@ -701,10 +713,8 @@ class VllmRunner: tokenizer_mode: str = "auto", trust_remote_code: bool = True, seed: Optional[int] = 0, - # Use smaller max model length, otherwise bigger model cannot run due - # to kv cache size limit. max_model_len: int = 1024, - dtype: str = "half", + dtype: str = "auto", disable_log_stats: bool = True, tensor_parallel_size: int = 1, block_size: int = 16, @@ -1110,4 +1120,4 @@ def pytest_collection_modifyitems(config, items): skip_optional = pytest.mark.skip(reason="need --optional option to run") for item in items: if "optional" in item.keywords: - item.add_marker(skip_optional) \ No newline at end of file + item.add_marker(skip_optional) diff --git a/tests/entrypoints/llm/test_chat.py b/tests/entrypoints/llm/test_chat.py index 710bad4e..e96081c1 100644 --- a/tests/entrypoints/llm/test_chat.py +++ b/tests/entrypoints/llm/test_chat.py @@ -64,7 +64,6 @@ def test_multi_chat(): def test_chat_multi_image(image_urls: list[str]): llm = LLM( model="microsoft/Phi-3.5-vision-instruct", - dtype="bfloat16", max_model_len=4096, max_num_seqs=5, enforce_eager=True, diff --git a/tests/entrypoints/openai/test_audio.py b/tests/entrypoints/openai/test_audio.py index 56fb2932..3267dcc1 100644 --- a/tests/entrypoints/openai/test_audio.py +++ b/tests/entrypoints/openai/test_audio.py @@ -18,8 +18,6 @@ TEST_AUDIO_URLS = [ @pytest.fixture(scope="module") def server(): args = [ - "--dtype", - "bfloat16", "--max-model-len", "2048", "--max-num-seqs", diff --git a/tests/entrypoints/openai/test_video.py b/tests/entrypoints/openai/test_video.py index 36d62224..8c7564ba 100644 --- a/tests/entrypoints/openai/test_video.py +++ b/tests/entrypoints/openai/test_video.py @@ -24,8 +24,6 @@ def server(): args = [ "--task", "generate", - "--dtype", - "bfloat16", "--max-model-len", "32768", "--max-num-seqs", diff --git a/tests/entrypoints/openai/test_vision.py b/tests/entrypoints/openai/test_vision.py index d605394f..bb100e57 100644 --- a/tests/entrypoints/openai/test_vision.py +++ b/tests/entrypoints/openai/test_vision.py @@ -25,8 +25,6 @@ def server(): args = [ "--task", "generate", - "--dtype", - "bfloat16", "--max-model-len", "2048", "--max-num-seqs", diff --git a/tests/entrypoints/openai/test_vision_embedding.py b/tests/entrypoints/openai/test_vision_embedding.py index 100aca6f..74e5c4cc 100644 --- a/tests/entrypoints/openai/test_vision_embedding.py +++ b/tests/entrypoints/openai/test_vision_embedding.py @@ -28,8 +28,6 @@ def server(): args = [ "--task", "embed", - "--dtype", - "bfloat16", "--max-model-len", "2048", "--max-num-seqs", diff --git a/tests/entrypoints/test_chat_utils.py b/tests/entrypoints/test_chat_utils.py index c52fa905..e3b7b660 100644 --- a/tests/entrypoints/test_chat_utils.py +++ b/tests/entrypoints/test_chat_utils.py @@ -34,7 +34,7 @@ def phi3v_model_config(): tokenizer=PHI3V_MODEL_ID, tokenizer_mode="auto", trust_remote_code=True, - dtype="bfloat16", + dtype="auto", seed=0, limit_mm_per_prompt={ "image": 2, @@ -58,7 +58,7 @@ def mllama_model_config(): tokenizer=MLLAMA_MODEL_ID, tokenizer_mode="auto", trust_remote_code=True, - dtype="bfloat16", + dtype="auto", seed=0, limit_mm_per_prompt={ "image": 2, @@ -669,7 +669,7 @@ def test_multimodal_image_parsing_matches_hf(model, image_url): tokenizer=MLLAMA_MODEL_ID, tokenizer_mode="auto", trust_remote_code=True, - dtype="bfloat16", + dtype="auto", seed=0, limit_mm_per_prompt={ "image": 2, diff --git a/tests/models/decoder_only/audio_language/test_ultravox.py b/tests/models/decoder_only/audio_language/test_ultravox.py index f8770bca..83ece5d2 100644 --- a/tests/models/decoder_only/audio_language/test_ultravox.py +++ b/tests/models/decoder_only/audio_language/test_ultravox.py @@ -5,11 +5,10 @@ from typing import Optional import numpy as np import pytest import pytest_asyncio -from transformers import AutoModel, AutoTokenizer, BatchEncoding +from transformers import AutoModel, AutoTokenizer from vllm.multimodal.audio import resample_audio from vllm.sequence import SampleLogprobs -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE from ....conftest import HfRunner, VllmRunner from ....utils import RemoteOpenAIServer @@ -107,8 +106,6 @@ def run_test( **kwargs, ): """Inference result should be the same between hf and vllm.""" - torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype] - # NOTE: take care of the order. run vLLM first, and then run HF. # vLLM needs a fresh new process without cuda initialization. # if we run HF first, the cuda initialization will be done and it @@ -124,15 +121,7 @@ def run_test( for vllm_prompt, _, audio in prompts_and_audios ] - def process(hf_inputs: BatchEncoding, **kwargs): - hf_inputs["audio_values"] = hf_inputs["audio_values"] \ - .to(torch_dtype) # type: ignore - return hf_inputs - - with hf_runner(model, - dtype=dtype, - postprocess_inputs=process, - auto_cls=AutoModel) as hf_model: + with hf_runner(model, dtype=dtype, auto_cls=AutoModel) as hf_model: hf_outputs_per_audio = [ hf_model.generate_greedy_logprobs_limit( [hf_prompt], diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py index 2f903a33..5690249e 100644 --- a/tests/models/decoder_only/vision_language/test_models.py +++ b/tests/models/decoder_only/vision_language/test_models.py @@ -122,9 +122,6 @@ VLM_TEST_SETTINGS = { "cherry_blossom": "What is in the picture?", }), auto_cls=AutoModelForImageTextToText, - postprocess_inputs=model_utils.cast_dtype_post_processor( - "pixel_values" - ), vllm_output_post_proc=model_utils.paligemma_vllm_to_hf_output, dtype="bfloat16", marks=[pytest.mark.skip(reason="vLLM does not support PrefixLM attention mask")], # noqa: E501 @@ -179,7 +176,6 @@ VLM_TEST_SETTINGS = { # "cherry_blossom": "Please infer the season with reason.", # noqa: E501 # }), # multi_image_prompt="Describe the two images shortly.", # noqa: E501 - # postprocess_inputs=model_utils.cast_dtype_post_processor("pixel_values"), # noqa: E501 # stop_str=["<|im_end|>"], # image_size_factors=[(0.10, 0.15)], # max_tokens=64, @@ -200,9 +196,6 @@ VLM_TEST_SETTINGS = { max_model_len=4096, max_num_seqs=2, auto_cls=AutoModelForImageTextToText, - postprocess_inputs=model_utils.cast_dtype_post_processor( - "pixel_values" - ), # For chameleon, we only compare the sequences vllm_output_post_proc = lambda vllm_output, model: vllm_output[:2], hf_output_post_proc = lambda hf_output, model: hf_output[:2], @@ -222,7 +215,6 @@ VLM_TEST_SETTINGS = { }), multi_image_prompt="image_1:\nimage_2:\nWhich image can we see the car and the tower?", # noqa: E501 patch_hf_runner=model_utils.deepseekvl2_patch_hf_runner, - postprocess_inputs=model_utils.cast_dtype_post_processor("images"), hf_output_post_proc=model_utils.deepseekvl2_trunc_hf_output, stop_str=["<|end▁of▁sentence|>", "<|begin▁of▁sentence|>"], # noqa: E501 image_size_factors=[(), (1.0, ), (1.0, 1.0, 1.0), (0.1, 0.5, 1.0)], @@ -258,7 +250,6 @@ VLM_TEST_SETTINGS = { max_model_len=4096, max_num_seqs=2, auto_cls=AutoModelForImageTextToText, - dtype="bfloat16", vllm_runner_kwargs={"mm_processor_kwargs": {"do_pan_and_scan": True}}, patch_hf_runner=model_utils.gemma3_patch_hf_runner, ), @@ -272,7 +263,6 @@ VLM_TEST_SETTINGS = { }), max_model_len=2048, max_num_seqs=2, - dtype="bfloat16", get_stop_token_ids=lambda tok: [151329, 151336, 151338], patch_hf_runner=model_utils.glm4v_patch_hf_runner, # The image embeddings match with HF but the outputs of the language @@ -295,7 +285,6 @@ VLM_TEST_SETTINGS = { }), multi_image_prompt="Image-1: \nImage-2: \nDescribe the two images in short.", # noqa: E501 max_model_len=8192, - dtype="bfloat16", use_tokenizer_eos=True, num_logprobs=10, patch_hf_runner=model_utils.h2ovl_patch_hf_runner, @@ -324,10 +313,6 @@ VLM_TEST_SETTINGS = { }), multi_image_prompt="Image-1: \nImage-2: \nDescribe the two images in short.", # noqa: E501 max_model_len=4096, - # NOTE: Mono-InternVL-2B doesn't work with fp16, - # it will result NaN during inference. - # See: https://huggingface.co/OpenGVLab/Mono-InternVL-2B/discussions/9 - dtype="bfloat16", use_tokenizer_eos=True, patch_hf_runner=model_utils.internvl_patch_hf_runner, ), @@ -351,9 +336,6 @@ VLM_TEST_SETTINGS = { prompt_formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 num_video_frames=16, max_model_len=16384, - postprocess_inputs=model_utils.cast_dtype_post_processor( - "pixel_values_videos" - ), auto_cls=AutoModelForVision2Seq, vllm_output_post_proc=model_utils.llava_onevision_vllm_to_hf_output, custom_test_opts=[CustomTestOptions( @@ -378,9 +360,6 @@ VLM_TEST_SETTINGS = { test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), prompt_formatter=lambda img_prompt: f"<|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501 max_model_len=4096, - postprocess_inputs=model_utils.cast_dtype_post_processor( - "pixel_values" - ), get_stop_token_ids=lambda tok: [128009], auto_cls=AutoModelForImageTextToText, vllm_output_post_proc=model_utils.mantis_vllm_to_hf_output, @@ -400,8 +379,8 @@ VLM_TEST_SETTINGS = { max_model_len=4096, max_num_seqs=2, get_stop_token_ids=lambda tok: [tok.eos_id, tok.eot_id], - postprocess_inputs=model_utils.wrap_inputs_post_processor, hf_output_post_proc=model_utils.minicpmv_trunc_hf_output, + patch_hf_runner=model_utils.minicpmv_25_patch_hf_runner, ), "minicpmo_26": VLMTestInfo( models=["openbmb/MiniCPM-o-2_6"], @@ -411,11 +390,8 @@ VLM_TEST_SETTINGS = { max_model_len=4096, max_num_seqs=2, get_stop_token_ids=lambda tok: tok.convert_tokens_to_ids(['<|im_end|>', '<|endoftext|>']), # noqa: E501 - postprocess_inputs=model_utils.ignore_inputs_post_processor( - "image_sizes" - ), hf_output_post_proc=model_utils.minicpmv_trunc_hf_output, - patch_hf_runner=model_utils.minicpmo_patch_hf_runner + patch_hf_runner=model_utils.minicpmo_26_patch_hf_runner, ), "minicpmv_26": VLMTestInfo( models=["openbmb/MiniCPM-V-2_6"], @@ -425,10 +401,8 @@ VLM_TEST_SETTINGS = { max_model_len=4096, max_num_seqs=2, get_stop_token_ids=lambda tok: tok.convert_tokens_to_ids(['<|im_end|>', '<|endoftext|>']), # noqa: E501 - postprocess_inputs=model_utils.ignore_inputs_post_processor( - "image_sizes" - ), hf_output_post_proc=model_utils.minicpmv_trunc_hf_output, + patch_hf_runner=model_utils.minicpmv_26_patch_hf_runner, ), "molmo": VLMTestInfo( models=["allenai/Molmo-7B-D-0924"], @@ -437,7 +411,6 @@ VLM_TEST_SETTINGS = { max_model_len=4096, max_num_seqs=2, patch_hf_runner=model_utils.molmo_patch_hf_runner, - postprocess_inputs=model_utils.molmo_post_processor, ), # Tests for phi3v currently live in another file because of a bug in # transformers. Once this issue is fixed, we can enable them here instead. @@ -482,9 +455,6 @@ VLM_TEST_SETTINGS = { prompt_formatter=lambda img_prompt: f"USER: {img_prompt}\nASSISTANT:", max_model_len=4096, auto_cls=AutoModelForImageTextToText, - postprocess_inputs=model_utils.cast_dtype_post_processor( - "pixel_values" - ), vllm_output_post_proc = lambda vllm_output, model: vllm_output[:2], hf_output_post_proc = lambda hf_output, model: hf_output[:2], comparator=check_outputs_equal, @@ -529,9 +499,6 @@ VLM_TEST_SETTINGS = { test_type=VLMTestType.CUSTOM_INPUTS, max_model_len=16384, max_num_seqs=2, - postprocess_inputs=model_utils.cast_dtype_post_processor( - "pixel_values" - ), auto_cls=AutoModelForVision2Seq, vllm_output_post_proc=model_utils.llava_onevision_vllm_to_hf_output, custom_test_opts=[CustomTestOptions( diff --git a/tests/models/decoder_only/vision_language/vlm_utils/core.py b/tests/models/decoder_only/vision_language/vlm_utils/core.py index 31f0209b..2eae643f 100644 --- a/tests/models/decoder_only/vision_language/vlm_utils/core.py +++ b/tests/models/decoder_only/vision_language/vlm_utils/core.py @@ -4,7 +4,6 @@ from typing import Any, Callable, Optional, Union import torch from PIL.Image import Image -from transformers import BatchEncoding from transformers.models.auto.auto_factory import _BaseAutoModelClass from vllm.config import TaskOption @@ -31,7 +30,6 @@ def run_test( vllm_output_post_proc: Optional[Callable[[RunnerOutput, str], Any]], auto_cls: type[_BaseAutoModelClass], use_tokenizer_eos: bool, - postprocess_inputs: Callable[[BatchEncoding], BatchEncoding], comparator: Callable[..., None], get_stop_token_ids: Optional[Callable[[AnyTokenizer], list[int]]], stop_str: Optional[list[str]], @@ -101,7 +99,6 @@ def run_test( hf_model = hf_runner(model, dtype=dtype, auto_cls=auto_cls, - postprocess_inputs=postprocess_inputs, model_kwargs=hf_model_kwargs) # Some models need to patch things like the model processor, e.g., internvl diff --git a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py index 3b4d1237..c84bf6dc 100644 --- a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py +++ b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py @@ -6,16 +6,15 @@ typically specific to a small subset of models. import re import types from pathlib import PosixPath -from typing import Callable, Optional, Union +from typing import Optional, Union import torch from PIL.Image import Image -from transformers import (AutoConfig, AutoTokenizer, BatchEncoding, +from transformers import (AutoConfig, AutoTokenizer, BatchFeature, GenerationConfig) from vllm.sequence import SampleLogprobs from vllm.transformers_utils.tokenizer import patch_padding_side -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE from .....conftest import HfRunner, ImageAsset, _ImageAssets from .types import RunnerOutput @@ -211,40 +210,6 @@ def get_llava_embeddings(image_assets: _ImageAssets): return [asset.image_embeds for asset in image_assets] -####### postprocessors to run on HF BatchEncoding -def cast_dtype_post_processor( - hf_inp_key: str) -> Callable[[BatchEncoding, str], BatchEncoding]: - """Gets a handle to a post processor which converts a given key into a - target data type.""" - - def process(hf_inputs: BatchEncoding, dtype: str): - torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype] - hf_inputs[hf_inp_key] = hf_inputs[hf_inp_key].to(torch_dtype) - return hf_inputs - - return process - - -def ignore_inputs_post_processor( - hf_inp_key: str) -> Callable[[BatchEncoding, str], BatchEncoding]: - """Gets a handle to a post processor which ignores a given key.""" - - def process(hf_inputs: BatchEncoding, dtype: str): - del hf_inputs[hf_inp_key] - return hf_inputs - - return process - - -def wrap_inputs_post_processor(hf_inputs: BatchEncoding, dtype: str): - return {"model_inputs": hf_inputs} - - -def molmo_post_processor(hf_inputs: BatchEncoding, dtype: str): - hf_inputs = cast_dtype_post_processor("images")(hf_inputs, dtype) - return {k: v.unsqueeze(0) for k, v in hf_inputs.items()} - - ####### Prompt path encoders for models that need models on disk def qwen_prompt_path_encoder( tmp_path: PosixPath, prompt: str, assets: Union[list[ImageAsset], @@ -295,8 +260,7 @@ def deepseekvl2_patch_hf_runner(hf_model: HfRunner) -> HfRunner: for k in inputs.keys() # noqa if k not in ("seq_lens", "sft_format") } - inputs = BatchEncoding(data=inputs, tensor_type="pt") - return inputs + return BatchFeature(data=inputs, tensor_type="pt") hf_model.processor = processor hf_model.model.get_output_embeddings = lambda: \ @@ -529,10 +493,52 @@ def mantis_patch_hf_runner(hf_model: HfRunner) -> HfRunner: return hf_model -def minicpmo_patch_hf_runner(hf_model: HfRunner) -> HfRunner: +def minicpmv_25_patch_hf_runner(hf_model: HfRunner) -> HfRunner: orig_generate = hf_model.model.generate - def _generate(self, *args, **kwargs): + def _generate( + self, + *args, + input_ids=None, + pixel_values=None, + image_sizes=None, + image_bound=None, + tgt_sizes=None, + **kwargs, + ): + model_inputs = { + "input_ids": input_ids, + "pixel_values": pixel_values, + "image_sizes": image_sizes, + "image_bound": image_bound, + "tgt_sizes": tgt_sizes, + } + for k in list(model_inputs.keys()): + if model_inputs[k] is None: + model_inputs.pop(k) + + return orig_generate(model_inputs, *args, decode_text=False, **kwargs) + + hf_model.model.generate = types.MethodType(_generate, hf_model.model) + + return hf_model + + +def minicpmo_26_patch_hf_runner(hf_model: HfRunner) -> HfRunner: + orig_generate = hf_model.model.generate + + def _generate(self, *args, image_sizes=None, **kwargs): + return orig_generate(*args, decode_text=False, **kwargs) + + hf_model.model.generate = types.MethodType(_generate, hf_model.model) + + return hf_model + + +def minicpmv_26_patch_hf_runner(hf_model: HfRunner) -> HfRunner: + orig_generate = hf_model.model.generate + + def _generate(self, *args, image_sizes=None, **kwargs): return orig_generate(*args, decode_text=False, **kwargs) hf_model.model.generate = types.MethodType(_generate, hf_model.model) @@ -551,10 +557,11 @@ def molmo_patch_hf_runner(hf_model: HfRunner) -> HfRunner: def _generate(self, max_new_tokens=None, do_sample=None, **kwargs): batch = { - k: kwargs.pop(k) + k: kwargs.pop(k).unsqueeze(0) for k in ("input_ids", "images", "image_input_idx", "image_masks") if k in kwargs } + batch = BatchFeature(batch).to(dtype=self.dtype) return self.generate_from_batch( batch, diff --git a/tests/models/decoder_only/vision_language/vlm_utils/types.py b/tests/models/decoder_only/vision_language/vlm_utils/types.py index bdbdbc7e..1ae61ea4 100644 --- a/tests/models/decoder_only/vision_language/vlm_utils/types.py +++ b/tests/models/decoder_only/vision_language/vlm_utils/types.py @@ -8,13 +8,12 @@ from typing import Any, Callable, NamedTuple, Optional, Union import torch from PIL.Image import Image from pytest import MarkDecorator -from transformers import AutoModelForCausalLM, BatchEncoding +from transformers import AutoModelForCausalLM from transformers.models.auto.auto_factory import _BaseAutoModelClass from vllm.config import TaskOption from vllm.sequence import SampleLogprobs from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.utils import identity from .....conftest import IMAGE_ASSETS, HfRunner, ImageAsset, _ImageAssets from ....utils import check_logprobs_close @@ -110,11 +109,6 @@ class VLMTestInfo(NamedTuple): # Indicates we should explicitly pass the EOS from the tokenizer use_tokenizer_eos: bool = False auto_cls: type[_BaseAutoModelClass] = AutoModelForCausalLM - # Callable to pass to the HF runner to run on inputs; for now, we also pass - # the data type to input post processing, because almost all of the uses of - # postprocess_inputs are to fix the data types of BatchEncoding values. - postprocess_inputs: Callable[[BatchEncoding, str], - BatchEncoding] = identity patch_hf_runner: Optional[Callable[[HfRunner], HfRunner]] = None # Post processors that if defined, will run oun the outputs of the @@ -130,7 +124,7 @@ class VLMTestInfo(NamedTuple): # is all combinations of .models + all fields below max_tokens: Union[int, tuple[int]] = 128 num_logprobs: Union[int, tuple[int]] = 5 - dtype: Union[str, Iterable[str]] = "half" + dtype: Union[str, Union[list[str], tuple[str, ...]]] = "auto" distributed_executor_backend: Optional[Union[str, Iterable[str]]] = None # Only expanded in video tests num_video_frames: Union[int, tuple[int]] = 16 @@ -171,7 +165,6 @@ class VLMTestInfo(NamedTuple): "vllm_output_post_proc": self.vllm_output_post_proc, "auto_cls": self.auto_cls, "use_tokenizer_eos": self.use_tokenizer_eos, - "postprocess_inputs": self.postprocess_inputs, "comparator": self.comparator, "get_stop_token_ids": self.get_stop_token_ids, "hf_model_kwargs": self.hf_model_kwargs, diff --git a/tests/models/embedding/vision_language/test_dse_qwen2_vl.py b/tests/models/embedding/vision_language/test_dse_qwen2_vl.py index 7391df6e..3c15b0b5 100644 --- a/tests/models/embedding/vision_language/test_dse_qwen2_vl.py +++ b/tests/models/embedding/vision_language/test_dse_qwen2_vl.py @@ -1,12 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 -from functools import partial from typing import Callable import pytest import torch +import torch.nn.functional as F from PIL import Image -from transformers import BatchEncoding, Qwen2VLForConditionalGeneration +from transformers import Qwen2VLForConditionalGeneration from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner from ....utils import large_gpu_test @@ -75,10 +75,6 @@ def apply_chat_template_and_add_eos( return prompt -def postprocess_inputs(hf_model: HfRunner, inputs: BatchEncoding, **kwargs): - return hf_model.model.prepare_inputs_for_generation(**inputs, **kwargs) - - def _run_test( hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], @@ -118,14 +114,8 @@ def _run_test( with hf_runner(model, dtype=dtype, auto_cls=Qwen2VLForConditionalGeneration) as hf_model: - hf_model.postprocess_inputs = partial( - postprocess_inputs, - hf_model, - cache_position=torch.arange( - 0, - 1, # 1 for batch size - requires_grad=False), - use_cache=False) + + prompts = [] for text, image, embed_text in zip(input_texts, input_images, embed_texts): # dse requires non-standard input processing @@ -133,20 +123,34 @@ def _run_test( messages = get_messages(image, text, embed_text) prompt = apply_chat_template_and_add_eos( messages, hf_model.processor.apply_chat_template) - inputs = hf_model.get_inputs( - prompts=[[prompt]], - images=[[image]], - ) - with torch.no_grad(): + + prompts.append(prompt) + + all_inputs = hf_model.get_inputs( + prompts=prompts, + images=input_images, + ) + + with torch.no_grad(): + all_outputs = [] + for inputs in all_inputs: + inputs = hf_model.model.prepare_inputs_for_generation( + **inputs, + cache_position=torch.arange(1), # 1 for batch size + use_cache=False, + ) outputs = hf_model.model( - **hf_model.wrap_device(inputs[0], - device=hf_model.model.device.type), + **hf_model.wrap_device(inputs), return_dict=True, output_hidden_states=True, ) - pooled_output = torch.nn.functional.normalize( - outputs.hidden_states[-1][0, -1], p=2, dim=-1) - hf_outputs.append(pooled_output.tolist()) + pooled_output = F.normalize(outputs.hidden_states[-1][0, -1], + p=2, + dim=-1) + + all_outputs.append(pooled_output.tolist()) + + hf_outputs = all_outputs check_embeddings_close( embeddings_0_lst=hf_outputs, diff --git a/tests/models/embedding/vision_language/test_llava_next.py b/tests/models/embedding/vision_language/test_llava_next.py index d5d410f1..4da59ff5 100644 --- a/tests/models/embedding/vision_language/test_llava_next.py +++ b/tests/models/embedding/vision_language/test_llava_next.py @@ -86,8 +86,7 @@ def _run_test( for inputs in all_inputs: # Based on: https://huggingface.co/royokong/e5-v outputs = hf_model.model( - **hf_model.wrap_device(inputs, - device=hf_model.model.device.type), + **hf_model.wrap_device(inputs), return_dict=True, output_hidden_states=True, ) diff --git a/tests/models/embedding/vision_language/test_phi3v.py b/tests/models/embedding/vision_language/test_phi3v.py index 3226138a..9cc767c2 100644 --- a/tests/models/embedding/vision_language/test_phi3v.py +++ b/tests/models/embedding/vision_language/test_phi3v.py @@ -53,8 +53,7 @@ def _run_test( for inputs in all_inputs: # Based on: https://github.com/TIGER-AI-Lab/VLM2Vec/blob/db3b951bccabba220c1f53ab46a734e50dd2fc08/src/model.py outputs = hf_model.model( - **hf_model.wrap_device(inputs, - device=hf_model.model.device.type), + **hf_model.wrap_device(inputs), return_dict=True, output_hidden_states=True, ) diff --git a/tests/models/encoder_decoder/vision_language/test_mllama.py b/tests/models/encoder_decoder/vision_language/test_mllama.py index d2cdcfe4..b6ea31cc 100644 --- a/tests/models/encoder_decoder/vision_language/test_mllama.py +++ b/tests/models/encoder_decoder/vision_language/test_mllama.py @@ -4,8 +4,7 @@ from typing import Optional, overload import pytest import torch -from transformers import (AutoConfig, AutoModelForImageTextToText, - AutoTokenizer, BatchEncoding) +from transformers import AutoConfig, AutoModelForImageTextToText, AutoTokenizer from vllm import LLM, SamplingParams from vllm.attention.backends.flash_attn import FlashAttentionMetadata @@ -227,13 +226,9 @@ def _run_test( for prompts, images in inputs ] - def process(hf_inputs: BatchEncoding, **kwargs): - return hf_inputs - with hf_runner(model, dtype=dtype, model_kwargs={"device_map": "auto"}, - postprocess_inputs=process, auto_cls=AutoModelForImageTextToText) as hf_model: hf_outputs_per_image = [ hf_model.generate_greedy_logprobs_limit(prompts, diff --git a/tests/models/utils.py b/tests/models/utils.py index 2280a6c9..7109169e 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -2,7 +2,7 @@ import warnings from collections.abc import Sequence -from typing import Optional, Union +from typing import Any, Optional, Union import torch @@ -254,9 +254,9 @@ def check_logprobs_close( def build_model_context( model_id: str, task: TaskOption = "auto", - dtype: Optional[Union[str, torch.dtype]] = None, - mm_processor_kwargs: Optional[dict] = None, - limit_mm_per_prompt: Optional[dict] = None, + dtype: Union[str, torch.dtype] = "auto", + mm_processor_kwargs: Optional[dict[str, Any]] = None, + limit_mm_per_prompt: Optional[dict[str, int]] = None, disable_mm_preprocessor_cache: bool = True, ): """Creates an InputContext for a given model. @@ -274,9 +274,6 @@ def build_model_context( model_info.check_available_online(on_fail="skip") model_info.check_transformers_version(on_fail="skip") - if dtype is None: - dtype = "half" - model_config = ModelConfig( model_id, task=task, diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py index a358eee5..fbb7e507 100644 --- a/tests/multimodal/test_processing.py +++ b/tests/multimodal/test_processing.py @@ -853,7 +853,7 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid): tokenizer_mode="auto", trust_remote_code=False, seed=0, - dtype="half", + dtype="auto", revision=None, limit_mm_per_prompt=limit_mm_per_prompt, ) @@ -892,7 +892,7 @@ def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid): tokenizer_mode="auto", trust_remote_code=False, seed=0, - dtype="half", + dtype="auto", revision=None, limit_mm_per_prompt=limit_mm_per_prompt, ) @@ -965,7 +965,7 @@ def test_hf_processor_kwargs(model_id, call_kwargs, expected_kwargs): tokenizer_mode="auto", trust_remote_code=False, seed=0, - dtype="half", + dtype="auto", revision=None, ) diff --git a/tests/tensorizer_loader/test_tensorizer.py b/tests/tensorizer_loader/test_tensorizer.py index b268d4bf..5b9661bf 100644 --- a/tests/tensorizer_loader/test_tensorizer.py +++ b/tests/tensorizer_loader/test_tensorizer.py @@ -166,7 +166,7 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path): test_prompts = multilora_inference.create_test_prompts(lora_path) # Serialize model before deserializing and binding LoRA adapters - with vllm_runner(model_ref, ) as vllm_model: + with vllm_runner(model_ref) as vllm_model: model_path = tmp_path / (model_ref + ".tensors") vllm_model.apply_model( @@ -208,7 +208,7 @@ def test_load_without_tensorizer_load_format(vllm_runner): @pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed") def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path): ## Serialize model - with vllm_runner(model_ref, ) as vllm_model: + with vllm_runner(model_ref) as vllm_model: model_path = tmp_path / (model_ref + ".tensors") vllm_model.apply_model( diff --git a/tests/v1/engine/test_llm_engine.py b/tests/v1/engine/test_llm_engine.py index 5446653c..cefb89eb 100644 --- a/tests/v1/engine/test_llm_engine.py +++ b/tests/v1/engine/test_llm_engine.py @@ -50,7 +50,7 @@ def _get_test_sampling_params( """Generate random sampling params for a batch.""" def get_mostly_n_gt1() -> int: - """Mostly n \in [2,20], ~1/3 n=1""" + r"""Mostly n \in [2,20], ~1/3 n=1""" x = random.randint(0, 28) if x < 10: return 1 diff --git a/vllm/config.py b/vllm/config.py index c510677d..c248122d 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -347,7 +347,7 @@ class ModelConfig: self.encoder_config = self._get_encoder_config() self.hf_image_processor_config = get_hf_image_processor_config( self.model, revision) - self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) + self.dtype = _get_and_verify_dtype(self.hf_config, dtype) self.use_async_output_proc = use_async_output_proc self.mm_processor_kwargs = mm_processor_kwargs self.disable_mm_preprocessor_cache = disable_mm_preprocessor_cache @@ -2526,6 +2526,14 @@ def _get_and_verify_dtype( # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct # because config.torch_dtype can be None. config_dtype = getattr(config, "torch_dtype", None) + + # Fallbacks for multi-modal models if the root config + # does not define torch_dtype + if config_dtype is None and hasattr(config, "text_config"): + config_dtype = getattr(config.text_config, "torch_dtype", None) + if config_dtype is None and hasattr(config, "vision_config"): + config_dtype = getattr(config.vision_config, "torch_dtype", None) + if config_dtype is None: config_dtype = torch.float32 @@ -2533,16 +2541,8 @@ def _get_and_verify_dtype( dtype = dtype.lower() if dtype == "auto": if config_dtype == torch.float32: - if config.model_type in ("gemma2", "gemma3", "gemma3_text"): - logger.info( - "For Gemma 2 and 3, we downcast float32 to bfloat16 " - "instead of float16 by default. Please specify `dtype` " - "if you want to use float16.") - torch_dtype = torch.bfloat16 - else: - # Following the common practice, we use float16 for float32 - # models. - torch_dtype = torch.float16 + # Following common practice, we use float16 for float32 models + torch_dtype = torch.float16 else: torch_dtype = config_dtype