vllm/tests/samplers/test_beam_search.py
Alex Brooks 6b40996ae8
[Core][Bugfix] Fix Offline MM Beam Search (#16390)
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
2025-04-15 10:33:02 +08:00

142 lines
5.1 KiB
Python

# SPDX-License-Identifier: Apache-2.0
"""Compare the outputs of HF and vLLM when using beam search.
Run `pytest tests/samplers/test_beam_search.py`.
"""
import pytest
from transformers import AutoModelForSeq2SeqLM
from vllm.assets.audio import AudioAsset
@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
"""We can run both engines for this test."""
pass
# FIXME(zhuohan): The test can not pass if we:
# 1. Increase max_tokens to 256.
# 2. Increase beam_width to 8.
# 3. Use the model "huggyllama/llama-7b".
MAX_TOKENS = [64]
BEAM_WIDTHS = [4]
MM_BEAM_WIDTHS = [2]
MODELS = ["TinyLlama/TinyLlama-1.1B-Chat-v1.0"]
@pytest.mark.skip_v1 # FIXME: This fails on V1 right now.
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", MAX_TOKENS)
@pytest.mark.parametrize("beam_width", BEAM_WIDTHS)
def test_beam_search_single_input(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
beam_width: int,
) -> None:
example_prompts = example_prompts[:1]
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_beam_search(example_prompts, beam_width,
max_tokens)
with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.generate_beam_search(example_prompts,
beam_width, max_tokens)
for i in range(len(example_prompts)):
hf_output_ids, hf_output_texts = hf_outputs[i]
vllm_output_ids, vllm_output_texts = vllm_outputs[i]
for j, (hf_text,
vllm_text) in enumerate(zip(hf_output_texts,
vllm_output_texts)):
print(f">>>{j}-th hf output:")
print(hf_text)
print(f">>>{j}-th vllm output:")
print(vllm_text)
assert len(hf_output_ids) == len(vllm_output_ids)
for j in range(len(hf_output_ids)):
assert hf_output_ids[j] == vllm_output_ids[j], (
f"Test{i} output{j}:\nHF: {hf_output_ids}\n"
f"vLLM: {vllm_output_ids}")
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", MAX_TOKENS)
@pytest.mark.parametrize("beam_width", MM_BEAM_WIDTHS)
def test_beam_search_passes_multimodal_data(
hf_runner,
vllm_runner,
dtype: str,
max_tokens: int,
beam_width: int,
) -> None:
"""Ensure that beam search passes multimodal data through correctly."""
# NOTE - this test is primarily to check that mm data is passed to beams
# correctly. As such, we just need to check one extra modality to make
# sure things pass through properly.
audios = [AudioAsset("mary_had_lamb").audio_and_sample_rate]
model = "Qwen/Qwen2-Audio-7B-Instruct"
audio_seq = "<|audio_bos|><|AUDIO|><|audio_eos|>"
prompts = [
f"<|im_start|>user\n{audio_seq}Can you transcribe this?<|im_end|>\n<|im_start|>assistant\n" #noqa: E501
]
with hf_runner(model, dtype=dtype,
auto_cls=AutoModelForSeq2SeqLM) as hf_model:
audio_token_id = hf_model.config.audio_token_index
eos_token_id = hf_model.tokenizer.eos_token_id # <|im_end|>
hf_outputs = hf_model.generate_beam_search(
prompts,
beam_width=beam_width,
max_tokens=max_tokens,
audios=audios,
)
with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.generate_beam_search(
prompts,
beam_width=beam_width,
max_tokens=max_tokens,
audios=audios,
)
seq_with_no_audio_toks = lambda seq: [
tok for tok in seq if tok != audio_token_id
]
for i in range(len(prompts)):
hf_output_ids, hf_output_texts = hf_outputs[i]
vllm_output_ids, vllm_output_texts = vllm_outputs[i]
for j, (hf_text,
vllm_text) in enumerate(zip(hf_output_texts,
vllm_output_texts)):
print(f">>>{j}-th hf output [NOTE: special tokens are filtered]:")
print(hf_text)
print(f">>>{j}-th vllm output:")
print(vllm_text)
assert len(hf_output_ids) == len(vllm_output_ids)
for j in range(len(hf_output_ids)):
# Compare everything except for the audio tokens; we do this since
# the IDs returned from the transformers helper expands the audio
# token to match features, while the vLLM helper maintains the
# single audio token in the input text
filtered_hf_output_ids = seq_with_no_audio_toks(hf_output_ids[j])
filtered_vllm_output_ids = seq_with_no_audio_toks(
vllm_output_ids[j])
# HF output IDs may contain the end of sequence
if len(filtered_hf_output_ids
) == len(filtered_vllm_output_ids) + 1:
assert filtered_hf_output_ids[-1] == eos_token_id
filtered_hf_output_ids = filtered_hf_output_ids[:-1]
assert filtered_hf_output_ids == filtered_vllm_output_ids