Upgrade transformers to v4.50.3 (#13905)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-03-31 16:59:37 +01:00 committed by GitHub
parent 037bcd942c
commit e5ef4fa99a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 68 additions and 49 deletions

View File

@ -73,7 +73,7 @@ The Transformers fallback explicitly supports the following features:
- <project:#quantization-index> (except GGUF)
- <project:#lora-adapter>
- <project:#distributed-serving> (requires `transformers>=4.49.0`)
- <project:#distributed-serving>
#### Remote code

View File

@ -6,7 +6,7 @@ requests >= 2.26.0
tqdm
blake3
py-cpuinfo
transformers >= 4.48.2 # Required for Bamba model and Transformers backend.
transformers >= 4.50.3
tokenizers >= 0.19.1 # Required for Llama 3.
protobuf # Required by LlamaTokenizer.
fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint.

View File

@ -30,7 +30,7 @@ matplotlib # required for qwen-vl test
mistral_common[opencv] >= 1.5.4 # required for pixtral test
datamodel_code_generator # required for minicpm3 test
lm-eval[api]==0.4.4 # required for model evaluation test
transformers==4.48.2
transformers==4.50.3
# quantization
bitsandbytes>=0.45.3
buildkite-test-collector==0.1.9

View File

@ -643,7 +643,7 @@ tqdm==4.66.6
# transformers
tqdm-multiprocess==0.0.11
# via lm-eval
transformers==4.48.2
transformers==4.50.3
# via
# -r requirements/test.in
# genai-perf

View File

@ -245,7 +245,7 @@ TEST_MODELS = [
# [LANGUAGE GENERATION]
"microsoft/Phi-3.5-MoE-instruct",
"meta-llama/Llama-3.2-1B-Instruct",
# "ArthurZ/Ilama-3.2-1B", NOTE: Uncomment after #13905
"ArthurZ/Ilama-3.2-1B",
"ibm/PowerLM-3b",
# [LANGUAGE EMBEDDING]
"intfloat/e5-mistral-7b-instruct",

View File

@ -8,9 +8,7 @@ from collections import defaultdict
from pathlib import PosixPath
import pytest
from packaging.version import Version
from transformers import AutoModelForImageTextToText, AutoModelForVision2Seq
from transformers import __version__ as TRANSFORMERS_VERSION
from vllm.platforms import current_platform
from vllm.utils import identity
@ -126,25 +124,6 @@ VLM_TEST_SETTINGS = {
dtype="bfloat16",
marks=[pytest.mark.skip(reason="vLLM does not support PrefixLM attention mask")], # noqa: E501
),
# TODO(ywang96): Move Qwen2-VL out of core models in favor of Qwen2.5-VL
# once we upgraded to transformers>=4.49.0.
"qwen2_vl": VLMTestInfo(
models=["Qwen/Qwen2-VL-2B-Instruct"],
test_type=(
VLMTestType.IMAGE,
VLMTestType.MULTI_IMAGE,
VLMTestType.VIDEO
),
prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
img_idx_to_prompt=lambda idx: "<|vision_start|><|image_pad|><|vision_end|>", # noqa: E501
video_idx_to_prompt=lambda idx: "<|vision_start|><|video_pad|><|vision_end|>", # noqa: E501
max_model_len=4096,
max_num_seqs=2,
auto_cls=AutoModelForVision2Seq,
vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output,
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
),
"qwen2_5_vl": VLMTestInfo(
models=["Qwen/Qwen2.5-VL-3B-Instruct"],
test_type=(
@ -218,12 +197,6 @@ VLM_TEST_SETTINGS = {
hf_output_post_proc=model_utils.deepseekvl2_trunc_hf_output,
stop_str=["<end▁of▁sentence>", "<begin▁of▁sentence>"], # noqa: E501
image_size_factors=[(), (1.0, ), (1.0, 1.0, 1.0), (0.1, 0.5, 1.0)],
marks=[
pytest.mark.skipif(
Version(TRANSFORMERS_VERSION) >= Version("4.48"),
reason="HF model is not compatible with transformers>=4.48",
)
],
),
"fuyu": VLMTestInfo(
models=["adept/fuyu-8b"],
@ -336,6 +309,7 @@ VLM_TEST_SETTINGS = {
prompt_formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
num_video_frames=16,
max_model_len=16384,
hf_model_kwargs=model_utils.llava_onevision_hf_model_kwargs("llava-hf/llava-onevision-qwen2-0.5b-ov-hf"), # noqa: E501
auto_cls=AutoModelForVision2Seq,
vllm_output_post_proc=model_utils.llava_onevision_vllm_to_hf_output,
custom_test_opts=[CustomTestOptions(
@ -365,12 +339,6 @@ VLM_TEST_SETTINGS = {
auto_cls=AutoModelForImageTextToText,
vllm_output_post_proc=model_utils.mantis_vllm_to_hf_output,
patch_hf_runner=model_utils.mantis_patch_hf_runner,
marks=[
pytest.mark.skipif(
Version(TRANSFORMERS_VERSION) >= Version("4.48"),
reason="HF model is not compatible with transformers>=4.48",
)
],
),
"minicpmv_25": VLMTestInfo(
models=["openbmb/MiniCPM-Llama3-V-2_5"],
@ -450,6 +418,23 @@ VLM_TEST_SETTINGS = {
vllm_output_post_proc=model_utils.qwen_vllm_to_hf_output,
prompt_path_encoder=model_utils.qwen_prompt_path_encoder,
),
"qwen2_vl": VLMTestInfo(
models=["Qwen/Qwen2-VL-2B-Instruct"],
test_type=(
VLMTestType.IMAGE,
VLMTestType.MULTI_IMAGE,
VLMTestType.VIDEO
),
prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
img_idx_to_prompt=lambda idx: "<|vision_start|><|image_pad|><|vision_end|>", # noqa: E501
video_idx_to_prompt=lambda idx: "<|vision_start|><|video_pad|><|vision_end|>", # noqa: E501
max_model_len=4096,
max_num_seqs=2,
auto_cls=AutoModelForVision2Seq,
vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output,
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
marks=[pytest.mark.cpu_model],
),
"skywork_r1v": VLMTestInfo(
models=["Skywork/Skywork-R1V-38B"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
@ -515,6 +500,7 @@ VLM_TEST_SETTINGS = {
max_model_len=16384,
max_num_seqs=2,
auto_cls=AutoModelForVision2Seq,
hf_model_kwargs=model_utils.llava_onevision_hf_model_kwargs("llava-hf/llava-onevision-qwen2-0.5b-ov-hf"), # noqa: E501
vllm_output_post_proc=model_utils.llava_onevision_vllm_to_hf_output,
custom_test_opts=[CustomTestOptions(
inputs=custom_inputs.multi_image_multi_aspect_ratio_inputs(

View File

@ -104,6 +104,13 @@ def _llava_vllm_to_hf_output(vllm_output: RunnerOutput, model: str,
return hf_output_ids, hf_output_str, out_logprobs
def llava_onevision_hf_model_kwargs(model: str) -> dict:
"""Workaround to fix the sliding window issue in llava_onevision."""
config = AutoConfig.from_pretrained(model)
config.text_config.sliding_window = None
return config.to_dict()
def llava_onevision_vllm_to_hf_output(vllm_output: RunnerOutput,
model: str) -> RunnerOutput:
"""Sanitize vllm output [llava-onevision] to compare with hf output."""

View File

@ -34,6 +34,16 @@ class _HfExamplesInfo:
The minimum version of HF Transformers that is required to run this model.
"""
max_transformers_version: Optional[str] = None
"""
The maximum version of HF Transformers that this model runs on.
"""
transformers_version_reason: Optional[str] = None
"""
The reason for the minimum/maximum version requirement.
"""
is_available_online: bool = True
"""
Set this to ``False`` if the name of this architecture no longer exists on
@ -57,21 +67,28 @@ class _HfExamplesInfo:
If the installed transformers version does not meet the requirements,
perform the given action.
"""
if self.min_transformers_version is None:
if (self.min_transformers_version is None
and self.max_transformers_version is None):
return
current_version = TRANSFORMERS_VERSION
required_version = self.min_transformers_version
if Version(current_version) < Version(required_version):
msg = (
f"You have `transformers=={current_version}` installed, but "
f"`transformers>={required_version}` is required to run this "
"model")
min_version = self.min_transformers_version
max_version = self.max_transformers_version
msg = f"`transformers=={current_version}` installed, but `transformers"
if min_version and Version(current_version) < Version(min_version):
msg += f">={min_version}` is required to run this model."
elif max_version and Version(current_version) > Version(max_version):
msg += f"<={max_version}` is required to run this model."
else:
return
if on_fail == "error":
raise RuntimeError(msg)
else:
pytest.skip(msg)
if self.transformers_version_reason:
msg += f" Reason: {self.transformers_version_reason}"
if on_fail == "error":
raise RuntimeError(msg)
else:
pytest.skip(msg)
def check_available_online(
self,
@ -245,6 +262,9 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"Blip2ForConditionalGeneration": _HfExamplesInfo("Salesforce/blip2-opt-2.7b"), # noqa: E501
"ChameleonForConditionalGeneration": _HfExamplesInfo("facebook/chameleon-7b"), # noqa: E501
"DeepseekVLV2ForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-vl2-tiny", # noqa: E501
extras={"fork": "Isotr0py/deepseek-vl2-tiny"}, # noqa: E501
max_transformers_version="4.48", # noqa: E501
transformers_version_reason="HF model is not compatible.", # noqa: E501
hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501
"FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"),
"Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it",
@ -266,13 +286,19 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"LlavaNextVideoForConditionalGeneration": _HfExamplesInfo("llava-hf/LLaVA-NeXT-Video-7B-hf"), # noqa: E501
"LlavaOnevisionForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-onevision-qwen2-0.5b-ov-hf"), # noqa: E501
"MantisForConditionalGeneration": _HfExamplesInfo("TIGER-Lab/Mantis-8B-siglip-llama3", # noqa: E501
max_transformers_version="4.48", # noqa: E501
transformers_version_reason="HF model is not compatible.", # noqa: E501
hf_overrides={"architectures": ["MantisForConditionalGeneration"]}), # noqa: E501
"MiniCPMO": _HfExamplesInfo("openbmb/MiniCPM-o-2_6",
max_transformers_version="4.48",
transformers_version_reason="Use of deprecated imports which have been removed.", # noqa: E501
trust_remote_code=True),
"MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-Llama3-V-2_5",
extras={"2.6": "openbmb/MiniCPM-V-2_6"}, # noqa: E501
trust_remote_code=True),
"MolmoForCausalLM": _HfExamplesInfo("allenai/Molmo-7B-D-0924",
max_transformers_version="4.48",
transformers_version_reason="Use of private method which no longer exists.", # noqa: E501
extras={"olmo": "allenai/Molmo-7B-O-0924"}, # noqa: E501
trust_remote_code=True),
"NVLM_D": _HfExamplesInfo("nvidia/NVLM-D-72B",