[Model]: add some tests for aria model (#10770)
Signed-off-by: xffxff <1247714429@qq.com> Signed-off-by: Isotr0py <2037008807@qq.com> Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
995a148575
commit
ef31eabc68
@ -656,6 +656,7 @@ class VllmRunner:
|
||||
model_name: str,
|
||||
task: TaskOption = "auto",
|
||||
tokenizer_name: Optional[str] = None,
|
||||
tokenizer_mode: str = "auto",
|
||||
# Use smaller max model length, otherwise bigger model cannot run due
|
||||
# to kv cache size limit.
|
||||
max_model_len: int = 1024,
|
||||
@ -672,6 +673,7 @@ class VllmRunner:
|
||||
model=model_name,
|
||||
task=task,
|
||||
tokenizer=tokenizer_name,
|
||||
tokenizer_mode=tokenizer_mode,
|
||||
trust_remote_code=True,
|
||||
dtype=dtype,
|
||||
swap_space=swap_space,
|
||||
@ -842,6 +844,7 @@ class VllmRunner:
|
||||
audios: Optional[PromptAudioInput] = None,
|
||||
videos: Optional[PromptVideoInput] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
) -> Union[List[TokensTextLogprobs],
|
||||
List[TokensTextLogprobsPromptLogprobs]]:
|
||||
greedy_logprobs_params = SamplingParams(
|
||||
@ -849,7 +852,8 @@ class VllmRunner:
|
||||
max_tokens=max_tokens,
|
||||
logprobs=num_logprobs,
|
||||
prompt_logprobs=num_prompt_logprobs,
|
||||
stop_token_ids=stop_token_ids)
|
||||
stop_token_ids=stop_token_ids,
|
||||
stop=stop)
|
||||
|
||||
return self.generate_w_logprobs(prompts,
|
||||
greedy_logprobs_params,
|
||||
|
@ -8,6 +8,7 @@ from typing import Type
|
||||
import pytest
|
||||
import transformers
|
||||
from transformers import AutoModelForVision2Seq
|
||||
from transformers.utils import is_flash_attn_2_available
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cuda_device_count_stateless, identity
|
||||
@ -134,6 +135,35 @@ VLM_TEST_SETTINGS = {
|
||||
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
|
||||
),
|
||||
#### Extended model tests
|
||||
"aria": VLMTestInfo(
|
||||
models=["rhymes-ai/Aria"],
|
||||
tokenizer_mode="slow",
|
||||
test_type=(
|
||||
VLMTestType.IMAGE,
|
||||
VLMTestType.MULTI_IMAGE,
|
||||
),
|
||||
dtype="bfloat16",
|
||||
prompt_formatter=lambda img_prompt: f"<|im_start|>user\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n ", # noqa: E501
|
||||
img_idx_to_prompt=lambda idx: "<fim_prefix><|img|><fim_suffix>\n",
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
single_image_prompts=IMAGE_ASSETS.prompts({
|
||||
"stop_sign": "<vlm_image>Please describe the image shortly.",
|
||||
"cherry_blossom": "<vlm_image>Please infer the season with reason.",
|
||||
}),
|
||||
multi_image_prompt="<vlm_image><vlm_image>Describe the two images shortly.", # noqa: E501
|
||||
postprocess_inputs=model_utils.get_key_type_post_processor("pixel_values"),
|
||||
stop_str=["<|im_end|>"],
|
||||
image_size_factors=[(0.10, 0.15)],
|
||||
max_tokens=64,
|
||||
marks=[
|
||||
pytest.mark.skipif(
|
||||
not is_flash_attn_2_available(),
|
||||
reason="Model needs flash-attn for numeric convergence.",
|
||||
),
|
||||
large_gpu_mark(min_gb=64),
|
||||
],
|
||||
),
|
||||
"blip2": VLMTestInfo(
|
||||
models=["Salesforce/blip2-opt-2.7b"],
|
||||
test_type=VLMTestType.IMAGE,
|
||||
|
@ -29,6 +29,8 @@ def run_test(
|
||||
postprocess_inputs: Callable[[BatchEncoding], BatchEncoding],
|
||||
comparator: Callable[..., None],
|
||||
get_stop_token_ids: Optional[Callable[[AutoTokenizer], List[int]]],
|
||||
stop_str: Optional[List[str]],
|
||||
tokenizer_mode: str,
|
||||
limit_mm_per_prompt: Dict[str, int],
|
||||
model_kwargs: Optional[Dict[str, Any]],
|
||||
patch_hf_runner: Optional[Callable[[HfRunner], HfRunner]],
|
||||
@ -50,11 +52,14 @@ def run_test(
|
||||
# vLLM needs a fresh new process without cuda initialization.
|
||||
# if we run HF first, the cuda initialization will be done and it
|
||||
# will hurt multiprocessing backend with fork method (the default method).
|
||||
vllm_kwargs = {}
|
||||
vllm_kwargs: Dict[str, Any] = {}
|
||||
if get_stop_token_ids is not None:
|
||||
vllm_kwargs["stop_token_ids"] = get_stop_token_ids(tokenizer)
|
||||
if stop_str:
|
||||
vllm_kwargs["stop"] = stop_str
|
||||
|
||||
with vllm_runner(model,
|
||||
tokenizer_mode=tokenizer_mode,
|
||||
max_model_len=max_model_len,
|
||||
max_num_seqs=max_num_seqs,
|
||||
dtype=dtype,
|
||||
@ -85,6 +90,8 @@ def run_test(
|
||||
hf_kwargs = {}
|
||||
if use_tokenizer_eos:
|
||||
hf_kwargs["eos_token_id"] = tokenizer.eos_token_id
|
||||
if stop_str:
|
||||
hf_kwargs["stop_strings"] = stop_str
|
||||
|
||||
with hf_model, torch.no_grad():
|
||||
for prompts, media in inputs:
|
||||
@ -138,4 +145,4 @@ def process_runner_outputs(
|
||||
def process_outputs(output_processor, model, outputs_per_image):
|
||||
"""Applies a model specific post-processor function to a runner's output"""
|
||||
return [[output_processor(res, model) for res in outputs]
|
||||
for outputs in outputs_per_image]
|
||||
for outputs in outputs_per_image]
|
@ -97,6 +97,9 @@ class VLMTestInfo(NamedTuple):
|
||||
|
||||
# Optional callable which gets a list of token IDs from the model tokenizer
|
||||
get_stop_token_ids: Optional[Callable[[AutoTokenizer], List[int]]] = None
|
||||
# Optional list of strings to stop generation, useful when stop tokens are
|
||||
# not special tokens in the tokenizer
|
||||
stop_str: Optional[List[str]] = None
|
||||
|
||||
# Exposed options for HF runner
|
||||
model_kwargs: Optional[Dict[str, Any]] = None
|
||||
@ -148,6 +151,8 @@ class VLMTestInfo(NamedTuple):
|
||||
|
||||
marks: Optional[List[MarkDecorator]] = None
|
||||
|
||||
tokenizer_mode: str = "auto"
|
||||
|
||||
def get_non_parametrized_runner_kwargs(self):
|
||||
"""Returns a dictionary of expandable kwargs for items that are used
|
||||
in all test types, which are NOT used when creating the parametrized
|
||||
@ -166,8 +171,10 @@ class VLMTestInfo(NamedTuple):
|
||||
"postprocess_inputs": self.postprocess_inputs,
|
||||
"comparator": self.comparator,
|
||||
"get_stop_token_ids": self.get_stop_token_ids,
|
||||
"stop_str": self.stop_str,
|
||||
"model_kwargs": self.model_kwargs,
|
||||
"patch_hf_runner": self.patch_hf_runner,
|
||||
"tokenizer_mode": self.tokenizer_mode
|
||||
}
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user