[CI/Build] Fix VLM test failures when using transformers v4.46 (#9666)
This commit is contained in:
parent
d27cfbf791
commit
c866e0079d
@ -232,20 +232,22 @@ def video_assets() -> _VideoAssets:
|
||||
return VIDEO_ASSETS
|
||||
|
||||
|
||||
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature)
|
||||
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature, dict)
|
||||
|
||||
|
||||
class HfRunner:
|
||||
|
||||
def wrap_device(self, input: _T, device: Optional[str] = None) -> _T:
|
||||
def wrap_device(self, x: _T, device: Optional[str] = None) -> _T:
|
||||
if device is None:
|
||||
return self.wrap_device(
|
||||
input, "cpu" if current_platform.is_cpu() else "cuda")
|
||||
device = "cpu" if current_platform.is_cpu() else "cuda"
|
||||
|
||||
if hasattr(input, "device") and input.device.type == device:
|
||||
return input
|
||||
if isinstance(x, dict):
|
||||
return {k: self.wrap_device(v, device) for k, v in x.items()}
|
||||
|
||||
return input.to(device)
|
||||
if hasattr(x, "device") and x.device.type == device:
|
||||
return x
|
||||
|
||||
return x.to(device)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -1,6 +1,7 @@
|
||||
from typing import List, Optional, Type
|
||||
|
||||
import pytest
|
||||
import transformers
|
||||
from transformers import AutoModelForVision2Seq, BatchEncoding
|
||||
|
||||
from vllm.multimodal.utils import rescale_image_size
|
||||
@ -93,6 +94,10 @@ def run_test(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
transformers.__version__.startswith("4.46.0"),
|
||||
reason="Model broken in HF, see huggingface/transformers#34379",
|
||||
)
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize(
|
||||
"size_factors",
|
||||
|
@ -32,8 +32,8 @@ HF_MULTIIMAGE_IMAGE_PROMPT = \
|
||||
models = ["openbmb/MiniCPM-Llama3-V-2_5"]
|
||||
|
||||
|
||||
def _wrap_inputs(hf_inputs: BatchEncoding) -> BatchEncoding:
|
||||
return BatchEncoding({"model_inputs": hf_inputs})
|
||||
def _wrap_inputs(hf_inputs: BatchEncoding):
|
||||
return {"model_inputs": hf_inputs}
|
||||
|
||||
|
||||
def trunc_hf_output(hf_output: Tuple[List[int], str,
|
||||
|
@ -2,11 +2,12 @@ import os
|
||||
from typing import List, Optional, Tuple, Type
|
||||
|
||||
import pytest
|
||||
from transformers import AutoConfig, AutoModelForVision2Seq, AutoTokenizer
|
||||
from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer,
|
||||
BatchEncoding)
|
||||
|
||||
from vllm.multimodal.utils import rescale_image_size
|
||||
from vllm.sequence import SampleLogprobs
|
||||
from vllm.utils import is_hip
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, is_hip
|
||||
|
||||
from ....conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
|
||||
from ...utils import check_logprobs_close
|
||||
@ -74,6 +75,7 @@ def run_test(
|
||||
Note, the text input is also adjusted to abide by vllm contract.
|
||||
The text output is sanitized to be able to compare with hf.
|
||||
"""
|
||||
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
|
||||
images = [asset.pil_image for asset in image_assets]
|
||||
|
||||
inputs_per_image = [(
|
||||
@ -100,7 +102,14 @@ def run_test(
|
||||
for prompts, images in inputs_per_image
|
||||
]
|
||||
|
||||
with hf_runner(model, dtype=dtype,
|
||||
def process(hf_inputs: BatchEncoding):
|
||||
hf_inputs["pixel_values"] = hf_inputs["pixel_values"] \
|
||||
.to(torch_dtype) # type: ignore
|
||||
return hf_inputs
|
||||
|
||||
with hf_runner(model,
|
||||
dtype=dtype,
|
||||
postprocess_inputs=process,
|
||||
auto_cls=AutoModelForVision2Seq) as hf_model:
|
||||
hf_outputs_per_image = [
|
||||
hf_model.generate_greedy_logprobs_limit(prompts,
|
||||
|
Loading…
x
Reference in New Issue
Block a user