[CI/Build] Update models tests & examples (#8874)

Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
Cyrus Leung 2024-09-29 00:54:35 +08:00 committed by GitHub
parent 19d02ff938
commit e1a3f5e831
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 239 additions and 184 deletions

View File

@ -9,6 +9,7 @@
# label(str): the name of the test. emoji allowed.
# fast_check(bool): whether to run this on each commit on fastcheck pipeline.
# fast_check_only(bool): run this test on fastcheck pipeline only
# optional(bool): never run this test by default (i.e. need to unblock manually)
# command(str): the single command to run for tests. incompatible with commands.
# commands(list): the list of commands to run for test. incompatbile with command.
# mirror_hardwares(list): the list of hardwares to run the test on as well. currently only supports [amd]
@ -39,7 +40,7 @@ steps:
# Check API reference (if it fails, you may have missing mock imports)
- grep \"sig sig-object py\" build/html/dev/sampling_params.html
- label: Async Engine, Inputs, Utils, Worker Test # 15min
- label: Async Engine, Inputs, Utils, Worker Test # 24min
fast_check: true
source_file_dependencies:
- vllm/
@ -81,7 +82,7 @@ steps:
commands:
- pytest -v -s core
- label: Entrypoints Test # 20min
- label: Entrypoints Test # 40min
working_dir: "/vllm-workspace/tests"
fast_check: true
mirror_hardwares: [amd]
@ -151,7 +152,7 @@ steps:
# OOM in the CI unless we run this separately
- pytest -v -s tokenization
- label: Examples Test # 12min
- label: Examples Test # 15min
working_dir: "/vllm-workspace/examples"
#mirror_hardwares: [amd]
source_file_dependencies:
@ -169,7 +170,7 @@ steps:
- python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
- python3 offline_inference_encoder_decoder.py
- label: Prefix Caching Test # 7min
- label: Prefix Caching Test # 9min
#mirror_hardwares: [amd]
source_file_dependencies:
- vllm/
@ -177,7 +178,7 @@ steps:
commands:
- pytest -v -s prefix_caching
- label: Samplers Test # 18min
- label: Samplers Test # 36min
source_file_dependencies:
- vllm/model_executor/layers
- vllm/sampling_metadata.py
@ -193,7 +194,7 @@ steps:
- tests/test_logits_processor
command: pytest -v -s test_logits_processor.py
- label: Speculative decoding tests # 22min
- label: Speculative decoding tests # 30min
source_file_dependencies:
- vllm/spec_decode
- tests/spec_decode
@ -203,7 +204,7 @@ steps:
- pytest -v -s spec_decode/e2e/test_multistep_correctness.py
- pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py
- label: LoRA Test %N # 30min each
- label: LoRA Test %N # 15min each
mirror_hardwares: [amd]
source_file_dependencies:
- vllm/lora
@ -211,7 +212,7 @@ steps:
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py
parallelism: 4
- label: "PyTorch Fullgraph Smoke Test"
- label: "PyTorch Fullgraph Smoke Test" # 9min
fast_check: true
source_file_dependencies:
- vllm/
@ -219,14 +220,14 @@ steps:
commands:
- pytest -v -s compile/test_full_graph_smoke.py
- label: "PyTorch Fullgraph Test"
- label: "PyTorch Fullgraph Test" # 18min
source_file_dependencies:
- vllm/
- tests/compile
commands:
- pytest -v -s compile/test_full_graph.py
- label: Kernels Test %N # 30min each
- label: Kernels Test %N # 1h each
mirror_hardwares: [amd]
source_file_dependencies:
- csrc/
@ -256,7 +257,7 @@ steps:
- pip install aiohttp
- bash run-benchmarks.sh
- label: Quantization Test # 15min
- label: Quantization Test # 33min
source_file_dependencies:
- csrc/
- vllm/model_executor/layers/quantization
@ -300,7 +301,7 @@ steps:
- pytest -v -s models/test_oot_registration.py # it needs a clean process
- pytest -v -s models/*.py --ignore=models/test_oot_registration.py
- label: Decoder-only Language Models Test # 1h3min
- label: Decoder-only Language Models Test # 1h36min
#mirror_hardwares: [amd]
source_file_dependencies:
- vllm/
@ -308,7 +309,7 @@ steps:
commands:
- pytest -v -s models/decoder_only/language
- label: Decoder-only Multi-Modal Models Test # 56min
- label: Decoder-only Multi-Modal Models Test # 1h31min
#mirror_hardwares: [amd]
source_file_dependencies:
- vllm/
@ -318,15 +319,25 @@ steps:
- pytest -v -s models/decoder_only/audio_language
- pytest -v -s models/decoder_only/vision_language
- label: Other Models Test # 5min
- label: Other Models Test # 6min
#mirror_hardwares: [amd]
source_file_dependencies:
- vllm/
- tests/models/embedding/language
- tests/models/encoder_decoder/language
- tests/models/encoder_decoder/vision_language
commands:
- pytest -v -s models/embedding/language
- pytest -v -s models/encoder_decoder/language
- pytest -v -s models/encoder_decoder/vision_language
- label: Custom Models Test
#mirror_hardwares: [amd]
optional: true
commands:
# PR authors can temporarily add commands below to test individual models
# e.g. pytest -v -s models/encoder_decoder/vision_language/test_mllama.py
# *To avoid merge conflicts, remember to REMOVE (not just comment out) them before merging the PR*
##### 1 GPU test #####
##### multi gpus test #####
@ -359,7 +370,7 @@ steps:
- # the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up)
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep -q 'Same node test passed'
- label: Distributed Tests (2 GPUs) # 28min
- label: Distributed Tests (2 GPUs) # 40min
#mirror_hardwares: [amd]
working_dir: "/vllm-workspace/tests"
num_gpus: 2
@ -376,14 +387,16 @@ steps:
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep -q 'Same node test passed'
- TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m distributed_2_gpus
# Avoid importing model tests that cause CUDA reinitialization error
- pytest models/encoder_decoder/language/test_bart.py models/decoder_only/vision_language/test_broadcast.py -v -s -m distributed_2_gpus
- pytest models/encoder_decoder/language/test_bart.py -v -s -m distributed_2_gpus
- pytest models/encoder_decoder/vision_language/test_broadcast.py -v -s -m distributed_2_gpus
- pytest models/decoder_only/vision_language/test_broadcast.py -v -s -m distributed_2_gpus
- pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
- pip install -e ./plugins/vllm_add_dummy_model
- pytest -v -s distributed/test_distributed_oot.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py
- label: Multi-step Tests (4 GPUs) # 21min
- label: Multi-step Tests (4 GPUs) # 36min
working_dir: "/vllm-workspace/tests"
num_gpus: 4
source_file_dependencies:
@ -401,7 +414,7 @@ steps:
- pytest -v -s multi_step/test_correctness_async_llm.py
- pytest -v -s multi_step/test_correctness_llm.py
- label: Pipeline Parallelism Test # 23min
- label: Pipeline Parallelism Test # 45min
working_dir: "/vllm-workspace/tests"
num_gpus: 4
source_file_dependencies:
@ -427,7 +440,7 @@ steps:
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- pytest -v -s -x lora/test_long_context.py
- label: Weight Loading Multiple GPU Test
- label: Weight Loading Multiple GPU Test # 33min
working_dir: "/vllm-workspace/tests"
num_gpus: 2
source_file_dependencies:

View File

@ -12,6 +12,10 @@ from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset
from vllm.utils import FlexibleArgumentParser
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
# lower-end GPUs.
# Unless specified, these settings have been tested to work on a single L4.
# LLaVA-1.5
def run_llava(question, modality):
@ -19,7 +23,7 @@ def run_llava(question, modality):
prompt = f"USER: <image>\n{question}\nASSISTANT:"
llm = LLM(model="llava-hf/llava-1.5-7b-hf")
llm = LLM(model="llava-hf/llava-1.5-7b-hf", max_model_len=4096)
stop_token_ids = None
return llm, prompt, stop_token_ids
@ -57,7 +61,7 @@ def run_llava_onevision(question, modality):
<|im_start|>assistant\n"
llm = LLM(model="llava-hf/llava-onevision-qwen2-7b-ov-hf",
max_model_len=32768)
max_model_len=16384)
stop_token_ids = None
return llm, prompt, stop_token_ids
@ -67,7 +71,7 @@ def run_fuyu(question, modality):
assert modality == "image"
prompt = f"{question}\n"
llm = LLM(model="adept/fuyu-8b")
llm = LLM(model="adept/fuyu-8b", max_model_len=2048, max_num_seqs=2)
stop_token_ids = None
return llm, prompt, stop_token_ids
@ -99,7 +103,8 @@ def run_phi3v(question, modality):
llm = LLM(
model="microsoft/Phi-3-vision-128k-instruct",
trust_remote_code=True,
max_num_seqs=5,
max_model_len=4096,
max_num_seqs=2,
mm_processor_kwargs={"num_crops": 16},
)
stop_token_ids = None
@ -122,7 +127,7 @@ def run_chameleon(question, modality):
assert modality == "image"
prompt = f"{question}<image>"
llm = LLM(model="facebook/chameleon-7b")
llm = LLM(model="facebook/chameleon-7b", max_model_len=4096)
stop_token_ids = None
return llm, prompt, stop_token_ids
@ -145,6 +150,8 @@ def run_minicpmv(question, modality):
trust_remote_code=True)
llm = LLM(
model=model_name,
max_model_len=4096,
max_num_seqs=2,
trust_remote_code=True,
)
# NOTE The stop_token_ids are different for various versions of MiniCPM-V
@ -177,7 +184,7 @@ def run_internvl(question, modality):
llm = LLM(
model=model_name,
trust_remote_code=True,
max_num_seqs=5,
max_model_len=4096,
)
tokenizer = AutoTokenizer.from_pretrained(model_name,
@ -215,7 +222,8 @@ def run_qwen_vl(question, modality):
llm = LLM(
model="Qwen/Qwen-VL",
trust_remote_code=True,
max_num_seqs=5,
max_model_len=1024,
max_num_seqs=2,
)
prompt = f"{question}Picture 1: <img></img>\n"
@ -229,8 +237,10 @@ def run_qwen2_vl(question, modality):
model_name = "Qwen/Qwen2-VL-7B-Instruct"
# Tested on L40
llm = LLM(
model=model_name,
max_model_len=8192,
max_num_seqs=5,
)
@ -252,10 +262,10 @@ def run_mllama(question, modality):
# max_model_len (131072) for this model may cause OOM.
# You may lower either to run this example on lower-end GPUs.
# The configuration below has been confirmed to launch on a
# single H100 GPU.
# The configuration below has been confirmed to launch on a single L40 GPU.
llm = LLM(
model=model_name,
max_model_len=4096,
max_num_seqs=16,
enforce_eager=True,
)

View File

@ -28,12 +28,18 @@ class ModelRequestData(NamedTuple):
chat_template: Optional[str]
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
# lower-end GPUs.
# Unless specified, these settings have been tested to work on a single L4.
def load_qwenvl_chat(question: str, image_urls: List[str]) -> ModelRequestData:
model_name = "Qwen/Qwen-VL-Chat"
llm = LLM(
model=model_name,
trust_remote_code=True,
max_num_seqs=5,
max_model_len=1024,
max_num_seqs=2,
limit_mm_per_prompt={"image": len(image_urls)},
)
placeholders = "".join(f"Picture {i}: <img></img>\n"
@ -83,6 +89,7 @@ def load_phi3v(question: str, image_urls: List[str]) -> ModelRequestData:
model="microsoft/Phi-3.5-vision-instruct",
trust_remote_code=True,
max_model_len=4096,
max_num_seqs=2,
limit_mm_per_prompt={"image": len(image_urls)},
mm_processor_kwargs={"num_crops": 4},
)
@ -106,7 +113,6 @@ def load_internvl(question: str, image_urls: List[str]) -> ModelRequestData:
llm = LLM(
model=model_name,
trust_remote_code=True,
max_num_seqs=5,
max_model_len=4096,
limit_mm_per_prompt={"image": len(image_urls)},
)
@ -148,10 +154,11 @@ def load_qwen2_vl(question, image_urls: List[str]) -> ModelRequestData:
model_name = "Qwen/Qwen2-VL-7B-Instruct"
# Tested on L40
llm = LLM(
model=model_name,
max_num_seqs=5,
max_model_len=32768 if process_vision_info is None else 4096,
max_num_seqs=5,
limit_mm_per_prompt={"image": len(image_urls)},
)

View File

@ -246,17 +246,14 @@ _T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature)
class HfRunner:
def wrap_device(self, input: _T) -> _T:
if not is_cpu():
# Check if the input is already on the GPU
if hasattr(input, 'device') and input.device.type == "cuda":
return input # Already on GPU, no need to move
return input.to("cuda")
else:
# Check if the input is already on the CPU
if hasattr(input, 'device') and input.device.type == "cpu":
return input # Already on CPU, no need to move
return input.to("cpu")
def wrap_device(self, input: _T, device: Optional[str] = None) -> _T:
if device is None:
return self.wrap_device(input, "cpu" if is_cpu() else "cuda")
if hasattr(input, "device") and input.device.type == device:
return input
return input.to(device)
def __init__(
self,
@ -333,7 +330,7 @@ class HfRunner:
inputs = self.postprocess_inputs(inputs)
output_ids = self.model.generate(
**self.wrap_device(inputs),
**self.wrap_device(inputs, device=self.model.device.type),
use_cache=True,
**kwargs,
)
@ -406,7 +403,7 @@ class HfRunner:
inputs = self.postprocess_inputs(inputs)
output = self.model.generate(
**self.wrap_device(inputs),
**self.wrap_device(inputs, device=self.model.device.type),
use_cache=True,
do_sample=False,
max_new_tokens=max_tokens,
@ -414,40 +411,39 @@ class HfRunner:
return_dict_in_generate=True,
**kwargs,
)
seq_logprobs: List[torch.Tensor] = []
for hidden_states in output.hidden_states:
last_hidden_states = hidden_states[-1][0]
logits = torch.matmul(
last_hidden_states,
self.model.get_output_embeddings().weight.t(),
)
if self.model.get_output_embeddings().bias is not None:
logits += self.model.get_output_embeddings(
).bias.unsqueeze(0)
logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
seq_logprobs.append(logprobs)
seq_logprobs = self._hidden_states_to_seq_logprobs(
output.hidden_states)
all_logprobs.append(seq_logprobs)
return all_logprobs
def _hidden_states_to_logprobs(
def _hidden_states_to_seq_logprobs(
self,
hidden_states,
num_logprobs,
) -> Tuple[List[Dict[int, float]], int]:
hidden_states: Tuple[Tuple[torch.Tensor, ...], ...],
) -> List[torch.Tensor]:
output_embeddings = self.model.get_output_embeddings()
seq_logprobs: List[torch.Tensor] = []
output_len = len(hidden_states)
for _, hidden_state in enumerate(hidden_states):
last_hidden_states = hidden_state[-1][0]
logits = torch.matmul(
last_hidden_states,
self.model.get_output_embeddings().weight.t(),
last_hidden_states.to(output_embeddings.weight.device),
output_embeddings.weight.t(),
)
if getattr(self.model.get_output_embeddings(), "bias",
None) is not None:
logits += self.model.get_output_embeddings().bias.unsqueeze(0)
if getattr(output_embeddings, "bias", None) is not None:
logits += output_embeddings.bias.unsqueeze(0)
logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
seq_logprobs.append(logprobs)
return seq_logprobs
def _hidden_states_to_logprobs(
self,
hidden_states: Tuple[Tuple[torch.Tensor, ...], ...],
num_logprobs: int,
) -> Tuple[List[Dict[int, float]], int]:
seq_logprobs = self._hidden_states_to_seq_logprobs(hidden_states)
output_len = len(hidden_states)
# convert to dict
seq_logprobs_lst: List[Dict[int, float]] = []
for tok_idx, tok_logprobs in enumerate(seq_logprobs):
@ -500,7 +496,7 @@ class HfRunner:
inputs = self.postprocess_inputs(inputs)
output = self.model.generate(
**self.wrap_device(inputs),
**self.wrap_device(inputs, device=self.model.device.type),
use_cache=True,
do_sample=False,
max_new_tokens=max_tokens,
@ -543,12 +539,20 @@ class HfRunner:
for (encoder_prompt,
decoder_prompt) in to_enc_dec_tuple_list(encoder_decoder_prompts):
encoder_input_ids = self.wrap_device(
self.tokenizer(encoder_prompt, return_tensors="pt").input_ids)
decoder_input_ids = (
None if decoder_prompt is None else self.wrap_device(
self.tokenizer(encoder_prompt, return_tensors="pt").input_ids,
device=self.model.device.type,
)
if decoder_prompt is None:
decoder_input_ids = None
else:
decoder_input_ids = self.wrap_device(
self.tokenizer(decoder_prompt,
return_tensors="pt").input_ids))
return_tensors="pt").input_ids,
device=self.model.device.type,
)
output = self.model.generate(
encoder_input_ids,

View File

@ -16,8 +16,7 @@ from ...utils import check_logprobs_close
# Video test
HF_VIDEO_PROMPTS = VIDEO_ASSETS.prompts({
"sample_demo_1":
"<|im_start|>user <video>\nwhy is this video funny? \
<|im_end|><|im_start|>assistant\n"
"<|im_start|>user\n<video>\nwhy is this video funny?<|im_end|>\n<|im_start|>assistant\n" # noqa: E501
})
models = ["llava-hf/llava-onevision-qwen2-7b-ov-hf"]
@ -165,6 +164,9 @@ def run_video_test(
)
@pytest.mark.skip(
reason=
"Model is too big, test passed on L40 locally but will OOM on CI machine.")
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
"size_factors",
@ -208,6 +210,9 @@ def test_models(hf_runner, vllm_runner, video_assets, model, size_factors,
)
@pytest.mark.skip(
reason=
"Model is too big, test passed on L40 locally but will OOM on CI machine.")
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
"sizes",
@ -254,9 +259,8 @@ def run_image_test(
# max_model_len should be greater than image_feature_size
with vllm_runner(model,
dtype=dtype,
max_num_seqs=1,
max_model_len=16384,
gpu_memory_utilization=0.98,
max_num_seqs=2,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True,
@ -302,8 +306,9 @@ def run_image_test(
)
# FIXME: Swap to a smaller model for this architecture
@pytest.mark.skip(reason="Model OOMing on CI")
@pytest.mark.skip(
reason=
"Model is too big, test passed on L40 locally but will OOM on CI machine.")
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
@ -316,14 +321,10 @@ def test_models_multiple_image_inputs(hf_runner, vllm_runner, image_assets,
inputs = [(
[
"<|im_start|>user <image><image>\nDescribe 2 images. \
<|im_end|><|im_start|>assistant\n",
"<|im_start|>user <image><image>\nDescribe 2 images. \
<|im_end|><|im_start|>assistant\n",
"<|im_start|>user <image><image><image><image>\nDescribe 4 images. \
<|im_end|><|im_start|>assistant\n",
"<|im_start|>user <image>\nWhat is the season? \
<|im_end|><|im_start|>assistant\n",
"<|im_start|>user\n<image><image>\nDescribe 2 images.<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
"<|im_start|>user\n<image><image>\nDescribe 2 images.<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
"<|im_start|>user\n<image><image><image><image>\nDescribe 4 images.<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
"<|im_start|>user\n<image>\nWhat is the season?<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
],
[
[stop_sign, cherry_blossom],

View File

@ -79,7 +79,7 @@ def run_test(
# max_model_len should be greater than image_feature_size
with vllm_runner(model,
max_model_len=4096,
max_num_seqs=1,
max_num_seqs=2,
dtype=dtype,
limit_mm_per_prompt={"image": mm_limit},
tensor_parallel_size=tensor_parallel_size,

View File

@ -90,7 +90,7 @@ def run_test(
# max_model_len should be greater than image_feature_size
with vllm_runner(model,
max_model_len=4096,
max_num_seqs=1,
max_num_seqs=2,
dtype=dtype,
limit_mm_per_prompt={"image": mm_limit},
tensor_parallel_size=tensor_parallel_size,

View File

@ -221,7 +221,7 @@ def run_test(
# Qwen encodes each image into a fixed content size of 256
with vllm_runner(model,
max_model_len=1024,
max_num_seqs=1,
max_num_seqs=2,
dtype=dtype,
limit_mm_per_prompt={"image": mm_limit},
tensor_parallel_size=tensor_parallel_size,

View File

@ -0,0 +1,35 @@
import pytest
from ....utils import multi_gpu_test
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"])
@pytest.mark.parametrize("model", [
"meta-llama/Llama-3.2-11B-Vision-Instruct",
])
def test_models(hf_runner, vllm_runner, image_assets,
distributed_executor_backend, model) -> None:
dtype = "half"
max_tokens = 5
num_logprobs = 5
tensor_parallel_size = 2
if model.startswith("meta-llama/Llama-3.2-11B-Vision-Instruct"):
from .test_mllama import models, run_test
else:
raise NotImplementedError(f"Unsupported model: {model}")
run_test(
hf_runner,
vllm_runner,
image_assets,
model=models[0],
size_factors=[0.25, 0.5, 1.0],
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
)

View File

@ -9,7 +9,6 @@ from vllm.sequence import SampleLogprobs
from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,
_ImageAssets)
from ....utils import multi_gpu_test
from ...utils import check_logprobs_close
_LIMIT_IMAGE_PER_PROMPT = 1
@ -47,14 +46,46 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
if token_id != image_token_id or output_ids[idx - 1] != image_token_id
]
assert output_str[0] == " "
hf_output_str = output_str[1:]
hf_output_str = output_str
if hf_output_ids[-1] == eos_token_id:
hf_output_str = hf_output_str + tokenizer.decode(eos_token_id)
return hf_output_ids, hf_output_str, out_logprobs
def _get_inputs(
image_assets: _ImageAssets,
*,
size_factors: Optional[List[float]] = None,
sizes: Optional[List[Tuple[int, int]]] = None,
) -> List[Tuple[List[str], PromptImageInput]]:
images = [asset.pil_image for asset in image_assets]
if size_factors is not None:
inputs_per_image = [(
[prompt for _ in size_factors],
[rescale_image_size(image, factor) for factor in size_factors],
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
elif sizes is not None:
inputs_per_image = [(
[
prompt if size is not None else text_only_prompts[0]
for size in sizes
],
[
image.resize(size) if size is not None else None
for size in sizes
],
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
if len(sizes) == 0:
inputs_per_image.append(
(text_only_prompts, [None] * len(text_only_prompts)))
else:
raise ValueError("You must provide either `size_factors` or `sizes`")
return inputs_per_image
@overload
def run_test(
hf_runner: Type[HfRunner],
@ -103,39 +134,17 @@ def run_test(
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
images = [asset.pil_image for asset in image_assets]
if size_factors is not None:
inputs_per_image = [(
[prompt for _ in size_factors],
[rescale_image_size(image, factor) for factor in size_factors],
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
elif sizes is not None:
inputs_per_image = [(
[
prompt if size is not None else text_only_prompts[0]
for size in sizes
],
[
image.resize(size) if size is not None else None
for size in sizes
],
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
if len(sizes) == 0:
inputs_per_image.append(
(text_only_prompts, [None] * len(text_only_prompts)))
else:
raise ValueError("You must provide either `size_factors` or `sizes`")
_run_test(hf_runner,
vllm_runner,
inputs_per_image,
model,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend)
_run_test(
hf_runner,
vllm_runner,
_get_inputs(image_assets, size_factors=size_factors, sizes=sizes),
model,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
)
def _run_test(
@ -167,8 +176,8 @@ def _run_test(
# max_model_len should be greater than image_feature_size
with vllm_runner(model,
dtype=dtype,
max_num_seqs=16,
max_model_len=4096,
max_num_seqs=2,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True,
@ -185,7 +194,6 @@ def _run_test(
def process(hf_inputs: BatchEncoding):
return hf_inputs
from transformers import AutoConfig
from transformers.models.mllama import MllamaConfig as MllamaConfigHf
# use transformer's MllamaConfig for hf_runner
@ -193,6 +201,7 @@ def _run_test(
AutoConfig.register("mllama", MllamaConfigHf, exist_ok=True)
with hf_runner(model,
dtype=dtype,
model_kwargs={"device_map": "auto"},
postprocess_inputs=process,
auto_cls=AutoModelForVision2Seq) as hf_model:
hf_outputs_per_image = [
@ -218,26 +227,29 @@ def _run_test(
)
SIZES = [
# Text only
[],
# Single-size
[(512, 512)],
# Single-size, batched
[(512, 512), (512, 512), (512, 512)],
# Multi-size, batched
[(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024),
(1024, 1024), (512, 1536), (512, 2028)],
# Multi-size, batched, including text only
[(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024),
(1024, 1024), (512, 1536), (512, 2028), None],
# mllama has 8 possible aspect ratios, carefully set the sizes
# to cover all of them
]
@pytest.mark.skip(
reason=
"Model is too big, test passed on L40 locally but will OOM on CI machine.")
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
"sizes",
[
# Text only
[],
# Single-size
[(512, 512)],
# Single-size, batched
[(512, 512), (512, 512), (512, 512)],
# Multi-size, batched
[(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024),
(1024, 1024), (512, 1536), (512, 2028)],
# Multi-size, batched, including text only
[(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024),
(1024, 1024), (512, 1536), (512, 2028), None],
# mllama has 8 possible aspect ratios, carefully set the sizes
# to cover all of them
],
)
@pytest.mark.parametrize("sizes", SIZES)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
@ -254,30 +266,3 @@ def test_models(hf_runner, vllm_runner, image_assets, model, sizes, dtype,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
"sizes",
[
[(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024),
(1024, 1024), (512, 1536), (512, 2028), None],
],
)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models_distributed(hf_runner, vllm_runner, image_assets, model, sizes,
dtype, max_tokens, num_logprobs) -> None:
run_test(
hf_runner,
vllm_runner,
image_assets,
model,
sizes=sizes,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=2,
)

View File

@ -1,9 +1,12 @@
import warnings
from typing import Dict, List, Optional, Sequence, Tuple, Union
import torch
from vllm.config import ModelConfig
from vllm.inputs import InputContext
from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs
from vllm.utils import is_cpu
TokensText = Tuple[List[int], str]
@ -247,6 +250,7 @@ def check_logprobs_close(
def build_model_context(model_name: str,
tokenizer_name: Optional[str] = None,
trust_remote_code: bool = False,
dtype: Optional[Union[str, torch.dtype]] = None,
mm_processor_kwargs: Optional[Dict] = None,
limit_mm_per_prompt: Optional[Dict] = None):
"""Creates an InputContext for a given model.
@ -264,12 +268,15 @@ def build_model_context(model_name: str,
"""
if tokenizer_name is None:
tokenizer_name = model_name
if dtype is None:
dtype = "bfloat16" if is_cpu() else "half"
model_config = ModelConfig(
model_name,
tokenizer_name,
tokenizer_mode="auto",
trust_remote_code=trust_remote_code,
dtype="float32",
dtype=dtype,
seed=0,
mm_processor_kwargs=mm_processor_kwargs,
limit_mm_per_prompt=limit_mm_per_prompt,

View File

@ -185,16 +185,8 @@ class InputRegistry:
return wrapper
def _get_dummy_encoder_data_factory(self, model_cls: Type[nn.Module]):
if model_cls in self._dummy_encoder_factories_by_model_type:
dummy_factory = self._dummy_encoder_factories_by_model_type[
model_cls]
else:
logger.warning(
"No dummy encoder data factory registered to %s. "
"Using the dummy data factory for the model instead.",
model_cls)
dummy_factory = self._get_dummy_data_factory(model_cls)
return dummy_factory
return self._dummy_encoder_factories_by_model_type \
.get(model_cls, self._default_dummy_data_factory)
def dummy_data_for_profiling(
self,

View File

@ -159,7 +159,8 @@ def apply_fp8_linear(
# Making sure the dummy tensor is on the same device as the weight
global TORCH_DEVICE_IDENTITY
if TORCH_DEVICE_IDENTITY.device != weight.device:
if (TORCH_DEVICE_IDENTITY is not None
and TORCH_DEVICE_IDENTITY.device != weight.device):
TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device)
# GEMM