[Model]: add some tests for aria model (#10770)

Signed-off-by: xffxff <1247714429@qq.com>
Signed-off-by: Isotr0py <2037008807@qq.com>
Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
zhou fan 2024-12-02 13:36:36 +08:00 committed by GitHub
parent 995a148575
commit ef31eabc68
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 51 additions and 3 deletions

View File

@ -656,6 +656,7 @@ class VllmRunner:
model_name: str, model_name: str,
task: TaskOption = "auto", task: TaskOption = "auto",
tokenizer_name: Optional[str] = None, tokenizer_name: Optional[str] = None,
tokenizer_mode: str = "auto",
# Use smaller max model length, otherwise bigger model cannot run due # Use smaller max model length, otherwise bigger model cannot run due
# to kv cache size limit. # to kv cache size limit.
max_model_len: int = 1024, max_model_len: int = 1024,
@ -672,6 +673,7 @@ class VllmRunner:
model=model_name, model=model_name,
task=task, task=task,
tokenizer=tokenizer_name, tokenizer=tokenizer_name,
tokenizer_mode=tokenizer_mode,
trust_remote_code=True, trust_remote_code=True,
dtype=dtype, dtype=dtype,
swap_space=swap_space, swap_space=swap_space,
@ -842,6 +844,7 @@ class VllmRunner:
audios: Optional[PromptAudioInput] = None, audios: Optional[PromptAudioInput] = None,
videos: Optional[PromptVideoInput] = None, videos: Optional[PromptVideoInput] = None,
stop_token_ids: Optional[List[int]] = None, stop_token_ids: Optional[List[int]] = None,
stop: Optional[List[str]] = None,
) -> Union[List[TokensTextLogprobs], ) -> Union[List[TokensTextLogprobs],
List[TokensTextLogprobsPromptLogprobs]]: List[TokensTextLogprobsPromptLogprobs]]:
greedy_logprobs_params = SamplingParams( greedy_logprobs_params = SamplingParams(
@ -849,7 +852,8 @@ class VllmRunner:
max_tokens=max_tokens, max_tokens=max_tokens,
logprobs=num_logprobs, logprobs=num_logprobs,
prompt_logprobs=num_prompt_logprobs, prompt_logprobs=num_prompt_logprobs,
stop_token_ids=stop_token_ids) stop_token_ids=stop_token_ids,
stop=stop)
return self.generate_w_logprobs(prompts, return self.generate_w_logprobs(prompts,
greedy_logprobs_params, greedy_logprobs_params,

View File

@ -8,6 +8,7 @@ from typing import Type
import pytest import pytest
import transformers import transformers
from transformers import AutoModelForVision2Seq from transformers import AutoModelForVision2Seq
from transformers.utils import is_flash_attn_2_available
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import cuda_device_count_stateless, identity from vllm.utils import cuda_device_count_stateless, identity
@ -134,6 +135,35 @@ VLM_TEST_SETTINGS = {
marks=[pytest.mark.core_model, pytest.mark.cpu_model], marks=[pytest.mark.core_model, pytest.mark.cpu_model],
), ),
#### Extended model tests #### Extended model tests
"aria": VLMTestInfo(
models=["rhymes-ai/Aria"],
tokenizer_mode="slow",
test_type=(
VLMTestType.IMAGE,
VLMTestType.MULTI_IMAGE,
),
dtype="bfloat16",
prompt_formatter=lambda img_prompt: f"<|im_start|>user\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n ", # noqa: E501
img_idx_to_prompt=lambda idx: "<fim_prefix><|img|><fim_suffix>\n",
max_model_len=4096,
max_num_seqs=2,
single_image_prompts=IMAGE_ASSETS.prompts({
"stop_sign": "<vlm_image>Please describe the image shortly.",
"cherry_blossom": "<vlm_image>Please infer the season with reason.",
}),
multi_image_prompt="<vlm_image><vlm_image>Describe the two images shortly.", # noqa: E501
postprocess_inputs=model_utils.get_key_type_post_processor("pixel_values"),
stop_str=["<|im_end|>"],
image_size_factors=[(0.10, 0.15)],
max_tokens=64,
marks=[
pytest.mark.skipif(
not is_flash_attn_2_available(),
reason="Model needs flash-attn for numeric convergence.",
),
large_gpu_mark(min_gb=64),
],
),
"blip2": VLMTestInfo( "blip2": VLMTestInfo(
models=["Salesforce/blip2-opt-2.7b"], models=["Salesforce/blip2-opt-2.7b"],
test_type=VLMTestType.IMAGE, test_type=VLMTestType.IMAGE,

View File

@ -29,6 +29,8 @@ def run_test(
postprocess_inputs: Callable[[BatchEncoding], BatchEncoding], postprocess_inputs: Callable[[BatchEncoding], BatchEncoding],
comparator: Callable[..., None], comparator: Callable[..., None],
get_stop_token_ids: Optional[Callable[[AutoTokenizer], List[int]]], get_stop_token_ids: Optional[Callable[[AutoTokenizer], List[int]]],
stop_str: Optional[List[str]],
tokenizer_mode: str,
limit_mm_per_prompt: Dict[str, int], limit_mm_per_prompt: Dict[str, int],
model_kwargs: Optional[Dict[str, Any]], model_kwargs: Optional[Dict[str, Any]],
patch_hf_runner: Optional[Callable[[HfRunner], HfRunner]], patch_hf_runner: Optional[Callable[[HfRunner], HfRunner]],
@ -50,11 +52,14 @@ def run_test(
# vLLM needs a fresh new process without cuda initialization. # vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it # if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method). # will hurt multiprocessing backend with fork method (the default method).
vllm_kwargs = {} vllm_kwargs: Dict[str, Any] = {}
if get_stop_token_ids is not None: if get_stop_token_ids is not None:
vllm_kwargs["stop_token_ids"] = get_stop_token_ids(tokenizer) vllm_kwargs["stop_token_ids"] = get_stop_token_ids(tokenizer)
if stop_str:
vllm_kwargs["stop"] = stop_str
with vllm_runner(model, with vllm_runner(model,
tokenizer_mode=tokenizer_mode,
max_model_len=max_model_len, max_model_len=max_model_len,
max_num_seqs=max_num_seqs, max_num_seqs=max_num_seqs,
dtype=dtype, dtype=dtype,
@ -85,6 +90,8 @@ def run_test(
hf_kwargs = {} hf_kwargs = {}
if use_tokenizer_eos: if use_tokenizer_eos:
hf_kwargs["eos_token_id"] = tokenizer.eos_token_id hf_kwargs["eos_token_id"] = tokenizer.eos_token_id
if stop_str:
hf_kwargs["stop_strings"] = stop_str
with hf_model, torch.no_grad(): with hf_model, torch.no_grad():
for prompts, media in inputs: for prompts, media in inputs:

View File

@ -97,6 +97,9 @@ class VLMTestInfo(NamedTuple):
# Optional callable which gets a list of token IDs from the model tokenizer # Optional callable which gets a list of token IDs from the model tokenizer
get_stop_token_ids: Optional[Callable[[AutoTokenizer], List[int]]] = None get_stop_token_ids: Optional[Callable[[AutoTokenizer], List[int]]] = None
# Optional list of strings to stop generation, useful when stop tokens are
# not special tokens in the tokenizer
stop_str: Optional[List[str]] = None
# Exposed options for HF runner # Exposed options for HF runner
model_kwargs: Optional[Dict[str, Any]] = None model_kwargs: Optional[Dict[str, Any]] = None
@ -148,6 +151,8 @@ class VLMTestInfo(NamedTuple):
marks: Optional[List[MarkDecorator]] = None marks: Optional[List[MarkDecorator]] = None
tokenizer_mode: str = "auto"
def get_non_parametrized_runner_kwargs(self): def get_non_parametrized_runner_kwargs(self):
"""Returns a dictionary of expandable kwargs for items that are used """Returns a dictionary of expandable kwargs for items that are used
in all test types, which are NOT used when creating the parametrized in all test types, which are NOT used when creating the parametrized
@ -166,8 +171,10 @@ class VLMTestInfo(NamedTuple):
"postprocess_inputs": self.postprocess_inputs, "postprocess_inputs": self.postprocess_inputs,
"comparator": self.comparator, "comparator": self.comparator,
"get_stop_token_ids": self.get_stop_token_ids, "get_stop_token_ids": self.get_stop_token_ids,
"stop_str": self.stop_str,
"model_kwargs": self.model_kwargs, "model_kwargs": self.model_kwargs,
"patch_hf_runner": self.patch_hf_runner, "patch_hf_runner": self.patch_hf_runner,
"tokenizer_mode": self.tokenizer_mode
} }