[CI/Build] Update models tests & examples (#8874)
Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
parent
19d02ff938
commit
e1a3f5e831
@ -9,6 +9,7 @@
|
||||
# label(str): the name of the test. emoji allowed.
|
||||
# fast_check(bool): whether to run this on each commit on fastcheck pipeline.
|
||||
# fast_check_only(bool): run this test on fastcheck pipeline only
|
||||
# optional(bool): never run this test by default (i.e. need to unblock manually)
|
||||
# command(str): the single command to run for tests. incompatible with commands.
|
||||
# commands(list): the list of commands to run for test. incompatbile with command.
|
||||
# mirror_hardwares(list): the list of hardwares to run the test on as well. currently only supports [amd]
|
||||
@ -39,7 +40,7 @@ steps:
|
||||
# Check API reference (if it fails, you may have missing mock imports)
|
||||
- grep \"sig sig-object py\" build/html/dev/sampling_params.html
|
||||
|
||||
- label: Async Engine, Inputs, Utils, Worker Test # 15min
|
||||
- label: Async Engine, Inputs, Utils, Worker Test # 24min
|
||||
fast_check: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
@ -81,7 +82,7 @@ steps:
|
||||
commands:
|
||||
- pytest -v -s core
|
||||
|
||||
- label: Entrypoints Test # 20min
|
||||
- label: Entrypoints Test # 40min
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
fast_check: true
|
||||
mirror_hardwares: [amd]
|
||||
@ -151,7 +152,7 @@ steps:
|
||||
# OOM in the CI unless we run this separately
|
||||
- pytest -v -s tokenization
|
||||
|
||||
- label: Examples Test # 12min
|
||||
- label: Examples Test # 15min
|
||||
working_dir: "/vllm-workspace/examples"
|
||||
#mirror_hardwares: [amd]
|
||||
source_file_dependencies:
|
||||
@ -169,7 +170,7 @@ steps:
|
||||
- python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
|
||||
- python3 offline_inference_encoder_decoder.py
|
||||
|
||||
- label: Prefix Caching Test # 7min
|
||||
- label: Prefix Caching Test # 9min
|
||||
#mirror_hardwares: [amd]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
@ -177,7 +178,7 @@ steps:
|
||||
commands:
|
||||
- pytest -v -s prefix_caching
|
||||
|
||||
- label: Samplers Test # 18min
|
||||
- label: Samplers Test # 36min
|
||||
source_file_dependencies:
|
||||
- vllm/model_executor/layers
|
||||
- vllm/sampling_metadata.py
|
||||
@ -193,7 +194,7 @@ steps:
|
||||
- tests/test_logits_processor
|
||||
command: pytest -v -s test_logits_processor.py
|
||||
|
||||
- label: Speculative decoding tests # 22min
|
||||
- label: Speculative decoding tests # 30min
|
||||
source_file_dependencies:
|
||||
- vllm/spec_decode
|
||||
- tests/spec_decode
|
||||
@ -203,7 +204,7 @@ steps:
|
||||
- pytest -v -s spec_decode/e2e/test_multistep_correctness.py
|
||||
- pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py
|
||||
|
||||
- label: LoRA Test %N # 30min each
|
||||
- label: LoRA Test %N # 15min each
|
||||
mirror_hardwares: [amd]
|
||||
source_file_dependencies:
|
||||
- vllm/lora
|
||||
@ -211,7 +212,7 @@ steps:
|
||||
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py
|
||||
parallelism: 4
|
||||
|
||||
- label: "PyTorch Fullgraph Smoke Test"
|
||||
- label: "PyTorch Fullgraph Smoke Test" # 9min
|
||||
fast_check: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
@ -219,14 +220,14 @@ steps:
|
||||
commands:
|
||||
- pytest -v -s compile/test_full_graph_smoke.py
|
||||
|
||||
- label: "PyTorch Fullgraph Test"
|
||||
- label: "PyTorch Fullgraph Test" # 18min
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/compile
|
||||
commands:
|
||||
- pytest -v -s compile/test_full_graph.py
|
||||
|
||||
- label: Kernels Test %N # 30min each
|
||||
- label: Kernels Test %N # 1h each
|
||||
mirror_hardwares: [amd]
|
||||
source_file_dependencies:
|
||||
- csrc/
|
||||
@ -256,7 +257,7 @@ steps:
|
||||
- pip install aiohttp
|
||||
- bash run-benchmarks.sh
|
||||
|
||||
- label: Quantization Test # 15min
|
||||
- label: Quantization Test # 33min
|
||||
source_file_dependencies:
|
||||
- csrc/
|
||||
- vllm/model_executor/layers/quantization
|
||||
@ -300,7 +301,7 @@ steps:
|
||||
- pytest -v -s models/test_oot_registration.py # it needs a clean process
|
||||
- pytest -v -s models/*.py --ignore=models/test_oot_registration.py
|
||||
|
||||
- label: Decoder-only Language Models Test # 1h3min
|
||||
- label: Decoder-only Language Models Test # 1h36min
|
||||
#mirror_hardwares: [amd]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
@ -308,7 +309,7 @@ steps:
|
||||
commands:
|
||||
- pytest -v -s models/decoder_only/language
|
||||
|
||||
- label: Decoder-only Multi-Modal Models Test # 56min
|
||||
- label: Decoder-only Multi-Modal Models Test # 1h31min
|
||||
#mirror_hardwares: [amd]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
@ -318,15 +319,25 @@ steps:
|
||||
- pytest -v -s models/decoder_only/audio_language
|
||||
- pytest -v -s models/decoder_only/vision_language
|
||||
|
||||
- label: Other Models Test # 5min
|
||||
- label: Other Models Test # 6min
|
||||
#mirror_hardwares: [amd]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/models/embedding/language
|
||||
- tests/models/encoder_decoder/language
|
||||
- tests/models/encoder_decoder/vision_language
|
||||
commands:
|
||||
- pytest -v -s models/embedding/language
|
||||
- pytest -v -s models/encoder_decoder/language
|
||||
- pytest -v -s models/encoder_decoder/vision_language
|
||||
|
||||
- label: Custom Models Test
|
||||
#mirror_hardwares: [amd]
|
||||
optional: true
|
||||
commands:
|
||||
# PR authors can temporarily add commands below to test individual models
|
||||
# e.g. pytest -v -s models/encoder_decoder/vision_language/test_mllama.py
|
||||
# *To avoid merge conflicts, remember to REMOVE (not just comment out) them before merging the PR*
|
||||
|
||||
##### 1 GPU test #####
|
||||
##### multi gpus test #####
|
||||
@ -359,7 +370,7 @@ steps:
|
||||
- # the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up)
|
||||
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep -q 'Same node test passed'
|
||||
|
||||
- label: Distributed Tests (2 GPUs) # 28min
|
||||
- label: Distributed Tests (2 GPUs) # 40min
|
||||
#mirror_hardwares: [amd]
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 2
|
||||
@ -376,14 +387,16 @@ steps:
|
||||
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep -q 'Same node test passed'
|
||||
- TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m distributed_2_gpus
|
||||
# Avoid importing model tests that cause CUDA reinitialization error
|
||||
- pytest models/encoder_decoder/language/test_bart.py models/decoder_only/vision_language/test_broadcast.py -v -s -m distributed_2_gpus
|
||||
- pytest models/encoder_decoder/language/test_bart.py -v -s -m distributed_2_gpus
|
||||
- pytest models/encoder_decoder/vision_language/test_broadcast.py -v -s -m distributed_2_gpus
|
||||
- pytest models/decoder_only/vision_language/test_broadcast.py -v -s -m distributed_2_gpus
|
||||
- pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
|
||||
- pip install -e ./plugins/vllm_add_dummy_model
|
||||
- pytest -v -s distributed/test_distributed_oot.py
|
||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
|
||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py
|
||||
|
||||
- label: Multi-step Tests (4 GPUs) # 21min
|
||||
- label: Multi-step Tests (4 GPUs) # 36min
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 4
|
||||
source_file_dependencies:
|
||||
@ -401,7 +414,7 @@ steps:
|
||||
- pytest -v -s multi_step/test_correctness_async_llm.py
|
||||
- pytest -v -s multi_step/test_correctness_llm.py
|
||||
|
||||
- label: Pipeline Parallelism Test # 23min
|
||||
- label: Pipeline Parallelism Test # 45min
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 4
|
||||
source_file_dependencies:
|
||||
@ -427,7 +440,7 @@ steps:
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- pytest -v -s -x lora/test_long_context.py
|
||||
|
||||
- label: Weight Loading Multiple GPU Test
|
||||
- label: Weight Loading Multiple GPU Test # 33min
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 2
|
||||
source_file_dependencies:
|
||||
|
@ -12,6 +12,10 @@ from vllm.assets.image import ImageAsset
|
||||
from vllm.assets.video import VideoAsset
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
|
||||
# lower-end GPUs.
|
||||
# Unless specified, these settings have been tested to work on a single L4.
|
||||
|
||||
|
||||
# LLaVA-1.5
|
||||
def run_llava(question, modality):
|
||||
@ -19,7 +23,7 @@ def run_llava(question, modality):
|
||||
|
||||
prompt = f"USER: <image>\n{question}\nASSISTANT:"
|
||||
|
||||
llm = LLM(model="llava-hf/llava-1.5-7b-hf")
|
||||
llm = LLM(model="llava-hf/llava-1.5-7b-hf", max_model_len=4096)
|
||||
stop_token_ids = None
|
||||
return llm, prompt, stop_token_ids
|
||||
|
||||
@ -57,7 +61,7 @@ def run_llava_onevision(question, modality):
|
||||
<|im_start|>assistant\n"
|
||||
|
||||
llm = LLM(model="llava-hf/llava-onevision-qwen2-7b-ov-hf",
|
||||
max_model_len=32768)
|
||||
max_model_len=16384)
|
||||
stop_token_ids = None
|
||||
return llm, prompt, stop_token_ids
|
||||
|
||||
@ -67,7 +71,7 @@ def run_fuyu(question, modality):
|
||||
assert modality == "image"
|
||||
|
||||
prompt = f"{question}\n"
|
||||
llm = LLM(model="adept/fuyu-8b")
|
||||
llm = LLM(model="adept/fuyu-8b", max_model_len=2048, max_num_seqs=2)
|
||||
stop_token_ids = None
|
||||
return llm, prompt, stop_token_ids
|
||||
|
||||
@ -99,7 +103,8 @@ def run_phi3v(question, modality):
|
||||
llm = LLM(
|
||||
model="microsoft/Phi-3-vision-128k-instruct",
|
||||
trust_remote_code=True,
|
||||
max_num_seqs=5,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
mm_processor_kwargs={"num_crops": 16},
|
||||
)
|
||||
stop_token_ids = None
|
||||
@ -122,7 +127,7 @@ def run_chameleon(question, modality):
|
||||
assert modality == "image"
|
||||
|
||||
prompt = f"{question}<image>"
|
||||
llm = LLM(model="facebook/chameleon-7b")
|
||||
llm = LLM(model="facebook/chameleon-7b", max_model_len=4096)
|
||||
stop_token_ids = None
|
||||
return llm, prompt, stop_token_ids
|
||||
|
||||
@ -145,6 +150,8 @@ def run_minicpmv(question, modality):
|
||||
trust_remote_code=True)
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
# NOTE The stop_token_ids are different for various versions of MiniCPM-V
|
||||
@ -177,7 +184,7 @@ def run_internvl(question, modality):
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
trust_remote_code=True,
|
||||
max_num_seqs=5,
|
||||
max_model_len=4096,
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
||||
@ -215,7 +222,8 @@ def run_qwen_vl(question, modality):
|
||||
llm = LLM(
|
||||
model="Qwen/Qwen-VL",
|
||||
trust_remote_code=True,
|
||||
max_num_seqs=5,
|
||||
max_model_len=1024,
|
||||
max_num_seqs=2,
|
||||
)
|
||||
|
||||
prompt = f"{question}Picture 1: <img></img>\n"
|
||||
@ -229,8 +237,10 @@ def run_qwen2_vl(question, modality):
|
||||
|
||||
model_name = "Qwen/Qwen2-VL-7B-Instruct"
|
||||
|
||||
# Tested on L40
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
max_model_len=8192,
|
||||
max_num_seqs=5,
|
||||
)
|
||||
|
||||
@ -252,10 +262,10 @@ def run_mllama(question, modality):
|
||||
# max_model_len (131072) for this model may cause OOM.
|
||||
# You may lower either to run this example on lower-end GPUs.
|
||||
|
||||
# The configuration below has been confirmed to launch on a
|
||||
# single H100 GPU.
|
||||
# The configuration below has been confirmed to launch on a single L40 GPU.
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=16,
|
||||
enforce_eager=True,
|
||||
)
|
||||
|
@ -28,12 +28,18 @@ class ModelRequestData(NamedTuple):
|
||||
chat_template: Optional[str]
|
||||
|
||||
|
||||
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
|
||||
# lower-end GPUs.
|
||||
# Unless specified, these settings have been tested to work on a single L4.
|
||||
|
||||
|
||||
def load_qwenvl_chat(question: str, image_urls: List[str]) -> ModelRequestData:
|
||||
model_name = "Qwen/Qwen-VL-Chat"
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
trust_remote_code=True,
|
||||
max_num_seqs=5,
|
||||
max_model_len=1024,
|
||||
max_num_seqs=2,
|
||||
limit_mm_per_prompt={"image": len(image_urls)},
|
||||
)
|
||||
placeholders = "".join(f"Picture {i}: <img></img>\n"
|
||||
@ -83,6 +89,7 @@ def load_phi3v(question: str, image_urls: List[str]) -> ModelRequestData:
|
||||
model="microsoft/Phi-3.5-vision-instruct",
|
||||
trust_remote_code=True,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
limit_mm_per_prompt={"image": len(image_urls)},
|
||||
mm_processor_kwargs={"num_crops": 4},
|
||||
)
|
||||
@ -106,7 +113,6 @@ def load_internvl(question: str, image_urls: List[str]) -> ModelRequestData:
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
trust_remote_code=True,
|
||||
max_num_seqs=5,
|
||||
max_model_len=4096,
|
||||
limit_mm_per_prompt={"image": len(image_urls)},
|
||||
)
|
||||
@ -148,10 +154,11 @@ def load_qwen2_vl(question, image_urls: List[str]) -> ModelRequestData:
|
||||
|
||||
model_name = "Qwen/Qwen2-VL-7B-Instruct"
|
||||
|
||||
# Tested on L40
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
max_num_seqs=5,
|
||||
max_model_len=32768 if process_vision_info is None else 4096,
|
||||
max_num_seqs=5,
|
||||
limit_mm_per_prompt={"image": len(image_urls)},
|
||||
)
|
||||
|
||||
|
@ -246,17 +246,14 @@ _T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature)
|
||||
|
||||
class HfRunner:
|
||||
|
||||
def wrap_device(self, input: _T) -> _T:
|
||||
if not is_cpu():
|
||||
# Check if the input is already on the GPU
|
||||
if hasattr(input, 'device') and input.device.type == "cuda":
|
||||
return input # Already on GPU, no need to move
|
||||
return input.to("cuda")
|
||||
else:
|
||||
# Check if the input is already on the CPU
|
||||
if hasattr(input, 'device') and input.device.type == "cpu":
|
||||
return input # Already on CPU, no need to move
|
||||
return input.to("cpu")
|
||||
def wrap_device(self, input: _T, device: Optional[str] = None) -> _T:
|
||||
if device is None:
|
||||
return self.wrap_device(input, "cpu" if is_cpu() else "cuda")
|
||||
|
||||
if hasattr(input, "device") and input.device.type == device:
|
||||
return input
|
||||
|
||||
return input.to(device)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -333,7 +330,7 @@ class HfRunner:
|
||||
inputs = self.postprocess_inputs(inputs)
|
||||
|
||||
output_ids = self.model.generate(
|
||||
**self.wrap_device(inputs),
|
||||
**self.wrap_device(inputs, device=self.model.device.type),
|
||||
use_cache=True,
|
||||
**kwargs,
|
||||
)
|
||||
@ -406,7 +403,7 @@ class HfRunner:
|
||||
inputs = self.postprocess_inputs(inputs)
|
||||
|
||||
output = self.model.generate(
|
||||
**self.wrap_device(inputs),
|
||||
**self.wrap_device(inputs, device=self.model.device.type),
|
||||
use_cache=True,
|
||||
do_sample=False,
|
||||
max_new_tokens=max_tokens,
|
||||
@ -414,40 +411,39 @@ class HfRunner:
|
||||
return_dict_in_generate=True,
|
||||
**kwargs,
|
||||
)
|
||||
seq_logprobs: List[torch.Tensor] = []
|
||||
for hidden_states in output.hidden_states:
|
||||
last_hidden_states = hidden_states[-1][0]
|
||||
logits = torch.matmul(
|
||||
last_hidden_states,
|
||||
self.model.get_output_embeddings().weight.t(),
|
||||
)
|
||||
if self.model.get_output_embeddings().bias is not None:
|
||||
logits += self.model.get_output_embeddings(
|
||||
).bias.unsqueeze(0)
|
||||
logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
|
||||
seq_logprobs.append(logprobs)
|
||||
seq_logprobs = self._hidden_states_to_seq_logprobs(
|
||||
output.hidden_states)
|
||||
all_logprobs.append(seq_logprobs)
|
||||
return all_logprobs
|
||||
|
||||
def _hidden_states_to_logprobs(
|
||||
def _hidden_states_to_seq_logprobs(
|
||||
self,
|
||||
hidden_states,
|
||||
num_logprobs,
|
||||
) -> Tuple[List[Dict[int, float]], int]:
|
||||
hidden_states: Tuple[Tuple[torch.Tensor, ...], ...],
|
||||
) -> List[torch.Tensor]:
|
||||
output_embeddings = self.model.get_output_embeddings()
|
||||
|
||||
seq_logprobs: List[torch.Tensor] = []
|
||||
output_len = len(hidden_states)
|
||||
for _, hidden_state in enumerate(hidden_states):
|
||||
last_hidden_states = hidden_state[-1][0]
|
||||
logits = torch.matmul(
|
||||
last_hidden_states,
|
||||
self.model.get_output_embeddings().weight.t(),
|
||||
last_hidden_states.to(output_embeddings.weight.device),
|
||||
output_embeddings.weight.t(),
|
||||
)
|
||||
if getattr(self.model.get_output_embeddings(), "bias",
|
||||
None) is not None:
|
||||
logits += self.model.get_output_embeddings().bias.unsqueeze(0)
|
||||
if getattr(output_embeddings, "bias", None) is not None:
|
||||
logits += output_embeddings.bias.unsqueeze(0)
|
||||
logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
|
||||
seq_logprobs.append(logprobs)
|
||||
|
||||
return seq_logprobs
|
||||
|
||||
def _hidden_states_to_logprobs(
|
||||
self,
|
||||
hidden_states: Tuple[Tuple[torch.Tensor, ...], ...],
|
||||
num_logprobs: int,
|
||||
) -> Tuple[List[Dict[int, float]], int]:
|
||||
seq_logprobs = self._hidden_states_to_seq_logprobs(hidden_states)
|
||||
output_len = len(hidden_states)
|
||||
|
||||
# convert to dict
|
||||
seq_logprobs_lst: List[Dict[int, float]] = []
|
||||
for tok_idx, tok_logprobs in enumerate(seq_logprobs):
|
||||
@ -500,7 +496,7 @@ class HfRunner:
|
||||
inputs = self.postprocess_inputs(inputs)
|
||||
|
||||
output = self.model.generate(
|
||||
**self.wrap_device(inputs),
|
||||
**self.wrap_device(inputs, device=self.model.device.type),
|
||||
use_cache=True,
|
||||
do_sample=False,
|
||||
max_new_tokens=max_tokens,
|
||||
@ -543,12 +539,20 @@ class HfRunner:
|
||||
|
||||
for (encoder_prompt,
|
||||
decoder_prompt) in to_enc_dec_tuple_list(encoder_decoder_prompts):
|
||||
|
||||
encoder_input_ids = self.wrap_device(
|
||||
self.tokenizer(encoder_prompt, return_tensors="pt").input_ids)
|
||||
decoder_input_ids = (
|
||||
None if decoder_prompt is None else self.wrap_device(
|
||||
self.tokenizer(encoder_prompt, return_tensors="pt").input_ids,
|
||||
device=self.model.device.type,
|
||||
)
|
||||
|
||||
if decoder_prompt is None:
|
||||
decoder_input_ids = None
|
||||
else:
|
||||
decoder_input_ids = self.wrap_device(
|
||||
self.tokenizer(decoder_prompt,
|
||||
return_tensors="pt").input_ids))
|
||||
return_tensors="pt").input_ids,
|
||||
device=self.model.device.type,
|
||||
)
|
||||
|
||||
output = self.model.generate(
|
||||
encoder_input_ids,
|
||||
|
@ -16,8 +16,7 @@ from ...utils import check_logprobs_close
|
||||
# Video test
|
||||
HF_VIDEO_PROMPTS = VIDEO_ASSETS.prompts({
|
||||
"sample_demo_1":
|
||||
"<|im_start|>user <video>\nwhy is this video funny? \
|
||||
<|im_end|><|im_start|>assistant\n"
|
||||
"<|im_start|>user\n<video>\nwhy is this video funny?<|im_end|>\n<|im_start|>assistant\n" # noqa: E501
|
||||
})
|
||||
|
||||
models = ["llava-hf/llava-onevision-qwen2-7b-ov-hf"]
|
||||
@ -165,6 +164,9 @@ def run_video_test(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason=
|
||||
"Model is too big, test passed on L40 locally but will OOM on CI machine.")
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize(
|
||||
"size_factors",
|
||||
@ -208,6 +210,9 @@ def test_models(hf_runner, vllm_runner, video_assets, model, size_factors,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason=
|
||||
"Model is too big, test passed on L40 locally but will OOM on CI machine.")
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize(
|
||||
"sizes",
|
||||
@ -254,9 +259,8 @@ def run_image_test(
|
||||
# max_model_len should be greater than image_feature_size
|
||||
with vllm_runner(model,
|
||||
dtype=dtype,
|
||||
max_num_seqs=1,
|
||||
max_model_len=16384,
|
||||
gpu_memory_utilization=0.98,
|
||||
max_num_seqs=2,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
enforce_eager=True,
|
||||
@ -302,8 +306,9 @@ def run_image_test(
|
||||
)
|
||||
|
||||
|
||||
# FIXME: Swap to a smaller model for this architecture
|
||||
@pytest.mark.skip(reason="Model OOMing on CI")
|
||||
@pytest.mark.skip(
|
||||
reason=
|
||||
"Model is too big, test passed on L40 locally but will OOM on CI machine.")
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
@ -316,14 +321,10 @@ def test_models_multiple_image_inputs(hf_runner, vllm_runner, image_assets,
|
||||
|
||||
inputs = [(
|
||||
[
|
||||
"<|im_start|>user <image><image>\nDescribe 2 images. \
|
||||
<|im_end|><|im_start|>assistant\n",
|
||||
"<|im_start|>user <image><image>\nDescribe 2 images. \
|
||||
<|im_end|><|im_start|>assistant\n",
|
||||
"<|im_start|>user <image><image><image><image>\nDescribe 4 images. \
|
||||
<|im_end|><|im_start|>assistant\n",
|
||||
"<|im_start|>user <image>\nWhat is the season? \
|
||||
<|im_end|><|im_start|>assistant\n",
|
||||
"<|im_start|>user\n<image><image>\nDescribe 2 images.<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
|
||||
"<|im_start|>user\n<image><image>\nDescribe 2 images.<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
|
||||
"<|im_start|>user\n<image><image><image><image>\nDescribe 4 images.<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
|
||||
"<|im_start|>user\n<image>\nWhat is the season?<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
|
||||
],
|
||||
[
|
||||
[stop_sign, cherry_blossom],
|
||||
|
@ -79,7 +79,7 @@ def run_test(
|
||||
# max_model_len should be greater than image_feature_size
|
||||
with vllm_runner(model,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=1,
|
||||
max_num_seqs=2,
|
||||
dtype=dtype,
|
||||
limit_mm_per_prompt={"image": mm_limit},
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
|
@ -90,7 +90,7 @@ def run_test(
|
||||
# max_model_len should be greater than image_feature_size
|
||||
with vllm_runner(model,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=1,
|
||||
max_num_seqs=2,
|
||||
dtype=dtype,
|
||||
limit_mm_per_prompt={"image": mm_limit},
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
|
@ -221,7 +221,7 @@ def run_test(
|
||||
# Qwen encodes each image into a fixed content size of 256
|
||||
with vllm_runner(model,
|
||||
max_model_len=1024,
|
||||
max_num_seqs=1,
|
||||
max_num_seqs=2,
|
||||
dtype=dtype,
|
||||
limit_mm_per_prompt={"image": mm_limit},
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
|
@ -0,0 +1,35 @@
|
||||
import pytest
|
||||
|
||||
from ....utils import multi_gpu_test
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"])
|
||||
@pytest.mark.parametrize("model", [
|
||||
"meta-llama/Llama-3.2-11B-Vision-Instruct",
|
||||
])
|
||||
def test_models(hf_runner, vllm_runner, image_assets,
|
||||
distributed_executor_backend, model) -> None:
|
||||
|
||||
dtype = "half"
|
||||
max_tokens = 5
|
||||
num_logprobs = 5
|
||||
tensor_parallel_size = 2
|
||||
|
||||
if model.startswith("meta-llama/Llama-3.2-11B-Vision-Instruct"):
|
||||
from .test_mllama import models, run_test
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported model: {model}")
|
||||
|
||||
run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
image_assets,
|
||||
model=models[0],
|
||||
size_factors=[0.25, 0.5, 1.0],
|
||||
dtype=dtype,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
)
|
@ -9,7 +9,6 @@ from vllm.sequence import SampleLogprobs
|
||||
|
||||
from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,
|
||||
_ImageAssets)
|
||||
from ....utils import multi_gpu_test
|
||||
from ...utils import check_logprobs_close
|
||||
|
||||
_LIMIT_IMAGE_PER_PROMPT = 1
|
||||
@ -47,14 +46,46 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
|
||||
if token_id != image_token_id or output_ids[idx - 1] != image_token_id
|
||||
]
|
||||
|
||||
assert output_str[0] == " "
|
||||
hf_output_str = output_str[1:]
|
||||
hf_output_str = output_str
|
||||
if hf_output_ids[-1] == eos_token_id:
|
||||
hf_output_str = hf_output_str + tokenizer.decode(eos_token_id)
|
||||
|
||||
return hf_output_ids, hf_output_str, out_logprobs
|
||||
|
||||
|
||||
def _get_inputs(
|
||||
image_assets: _ImageAssets,
|
||||
*,
|
||||
size_factors: Optional[List[float]] = None,
|
||||
sizes: Optional[List[Tuple[int, int]]] = None,
|
||||
) -> List[Tuple[List[str], PromptImageInput]]:
|
||||
images = [asset.pil_image for asset in image_assets]
|
||||
|
||||
if size_factors is not None:
|
||||
inputs_per_image = [(
|
||||
[prompt for _ in size_factors],
|
||||
[rescale_image_size(image, factor) for factor in size_factors],
|
||||
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
|
||||
elif sizes is not None:
|
||||
inputs_per_image = [(
|
||||
[
|
||||
prompt if size is not None else text_only_prompts[0]
|
||||
for size in sizes
|
||||
],
|
||||
[
|
||||
image.resize(size) if size is not None else None
|
||||
for size in sizes
|
||||
],
|
||||
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
|
||||
if len(sizes) == 0:
|
||||
inputs_per_image.append(
|
||||
(text_only_prompts, [None] * len(text_only_prompts)))
|
||||
else:
|
||||
raise ValueError("You must provide either `size_factors` or `sizes`")
|
||||
|
||||
return inputs_per_image
|
||||
|
||||
|
||||
@overload
|
||||
def run_test(
|
||||
hf_runner: Type[HfRunner],
|
||||
@ -103,39 +134,17 @@ def run_test(
|
||||
tensor_parallel_size: int,
|
||||
distributed_executor_backend: Optional[str] = None,
|
||||
):
|
||||
images = [asset.pil_image for asset in image_assets]
|
||||
|
||||
if size_factors is not None:
|
||||
inputs_per_image = [(
|
||||
[prompt for _ in size_factors],
|
||||
[rescale_image_size(image, factor) for factor in size_factors],
|
||||
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
|
||||
elif sizes is not None:
|
||||
inputs_per_image = [(
|
||||
[
|
||||
prompt if size is not None else text_only_prompts[0]
|
||||
for size in sizes
|
||||
],
|
||||
[
|
||||
image.resize(size) if size is not None else None
|
||||
for size in sizes
|
||||
],
|
||||
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
|
||||
if len(sizes) == 0:
|
||||
inputs_per_image.append(
|
||||
(text_only_prompts, [None] * len(text_only_prompts)))
|
||||
else:
|
||||
raise ValueError("You must provide either `size_factors` or `sizes`")
|
||||
|
||||
_run_test(hf_runner,
|
||||
vllm_runner,
|
||||
inputs_per_image,
|
||||
model,
|
||||
dtype=dtype,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend=distributed_executor_backend)
|
||||
_run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
_get_inputs(image_assets, size_factors=size_factors, sizes=sizes),
|
||||
model,
|
||||
dtype=dtype,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
)
|
||||
|
||||
|
||||
def _run_test(
|
||||
@ -167,8 +176,8 @@ def _run_test(
|
||||
# max_model_len should be greater than image_feature_size
|
||||
with vllm_runner(model,
|
||||
dtype=dtype,
|
||||
max_num_seqs=16,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
enforce_eager=True,
|
||||
@ -185,7 +194,6 @@ def _run_test(
|
||||
def process(hf_inputs: BatchEncoding):
|
||||
return hf_inputs
|
||||
|
||||
from transformers import AutoConfig
|
||||
from transformers.models.mllama import MllamaConfig as MllamaConfigHf
|
||||
|
||||
# use transformer's MllamaConfig for hf_runner
|
||||
@ -193,6 +201,7 @@ def _run_test(
|
||||
AutoConfig.register("mllama", MllamaConfigHf, exist_ok=True)
|
||||
with hf_runner(model,
|
||||
dtype=dtype,
|
||||
model_kwargs={"device_map": "auto"},
|
||||
postprocess_inputs=process,
|
||||
auto_cls=AutoModelForVision2Seq) as hf_model:
|
||||
hf_outputs_per_image = [
|
||||
@ -218,26 +227,29 @@ def _run_test(
|
||||
)
|
||||
|
||||
|
||||
SIZES = [
|
||||
# Text only
|
||||
[],
|
||||
# Single-size
|
||||
[(512, 512)],
|
||||
# Single-size, batched
|
||||
[(512, 512), (512, 512), (512, 512)],
|
||||
# Multi-size, batched
|
||||
[(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024),
|
||||
(1024, 1024), (512, 1536), (512, 2028)],
|
||||
# Multi-size, batched, including text only
|
||||
[(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024),
|
||||
(1024, 1024), (512, 1536), (512, 2028), None],
|
||||
# mllama has 8 possible aspect ratios, carefully set the sizes
|
||||
# to cover all of them
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason=
|
||||
"Model is too big, test passed on L40 locally but will OOM on CI machine.")
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize(
|
||||
"sizes",
|
||||
[
|
||||
# Text only
|
||||
[],
|
||||
# Single-size
|
||||
[(512, 512)],
|
||||
# Single-size, batched
|
||||
[(512, 512), (512, 512), (512, 512)],
|
||||
# Multi-size, batched
|
||||
[(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024),
|
||||
(1024, 1024), (512, 1536), (512, 2028)],
|
||||
# Multi-size, batched, including text only
|
||||
[(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024),
|
||||
(1024, 1024), (512, 1536), (512, 2028), None],
|
||||
# mllama has 8 possible aspect ratios, carefully set the sizes
|
||||
# to cover all of them
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("sizes", SIZES)
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
@ -254,30 +266,3 @@ def test_models(hf_runner, vllm_runner, image_assets, model, sizes, dtype,
|
||||
num_logprobs=num_logprobs,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize(
|
||||
"sizes",
|
||||
[
|
||||
[(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024),
|
||||
(1024, 1024), (512, 1536), (512, 2028), None],
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
def test_models_distributed(hf_runner, vllm_runner, image_assets, model, sizes,
|
||||
dtype, max_tokens, num_logprobs) -> None:
|
||||
run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
image_assets,
|
||||
model,
|
||||
sizes=sizes,
|
||||
dtype=dtype,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
tensor_parallel_size=2,
|
||||
)
|
||||
|
@ -1,9 +1,12 @@
|
||||
import warnings
|
||||
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.inputs import InputContext
|
||||
from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs
|
||||
from vllm.utils import is_cpu
|
||||
|
||||
TokensText = Tuple[List[int], str]
|
||||
|
||||
@ -247,6 +250,7 @@ def check_logprobs_close(
|
||||
def build_model_context(model_name: str,
|
||||
tokenizer_name: Optional[str] = None,
|
||||
trust_remote_code: bool = False,
|
||||
dtype: Optional[Union[str, torch.dtype]] = None,
|
||||
mm_processor_kwargs: Optional[Dict] = None,
|
||||
limit_mm_per_prompt: Optional[Dict] = None):
|
||||
"""Creates an InputContext for a given model.
|
||||
@ -264,12 +268,15 @@ def build_model_context(model_name: str,
|
||||
"""
|
||||
if tokenizer_name is None:
|
||||
tokenizer_name = model_name
|
||||
if dtype is None:
|
||||
dtype = "bfloat16" if is_cpu() else "half"
|
||||
|
||||
model_config = ModelConfig(
|
||||
model_name,
|
||||
tokenizer_name,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=trust_remote_code,
|
||||
dtype="float32",
|
||||
dtype=dtype,
|
||||
seed=0,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||
|
@ -185,16 +185,8 @@ class InputRegistry:
|
||||
return wrapper
|
||||
|
||||
def _get_dummy_encoder_data_factory(self, model_cls: Type[nn.Module]):
|
||||
if model_cls in self._dummy_encoder_factories_by_model_type:
|
||||
dummy_factory = self._dummy_encoder_factories_by_model_type[
|
||||
model_cls]
|
||||
else:
|
||||
logger.warning(
|
||||
"No dummy encoder data factory registered to %s. "
|
||||
"Using the dummy data factory for the model instead.",
|
||||
model_cls)
|
||||
dummy_factory = self._get_dummy_data_factory(model_cls)
|
||||
return dummy_factory
|
||||
return self._dummy_encoder_factories_by_model_type \
|
||||
.get(model_cls, self._default_dummy_data_factory)
|
||||
|
||||
def dummy_data_for_profiling(
|
||||
self,
|
||||
|
@ -159,7 +159,8 @@ def apply_fp8_linear(
|
||||
|
||||
# Making sure the dummy tensor is on the same device as the weight
|
||||
global TORCH_DEVICE_IDENTITY
|
||||
if TORCH_DEVICE_IDENTITY.device != weight.device:
|
||||
if (TORCH_DEVICE_IDENTITY is not None
|
||||
and TORCH_DEVICE_IDENTITY.device != weight.device):
|
||||
TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device)
|
||||
|
||||
# GEMM
|
||||
|
Loading…
x
Reference in New Issue
Block a user