[Core][Bugfix] Fix Offline MM Beam Search (#16390)

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
Alex Brooks 2025-04-14 20:33:02 -06:00 committed by GitHub
parent d2020acac7
commit 6b40996ae8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 140 additions and 30 deletions

View File

@ -29,12 +29,11 @@ from vllm.distributed import (cleanup_dist_env_and_memory,
init_distributed_environment, init_distributed_environment,
initialize_model_parallel) initialize_model_parallel)
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt, from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
TokensPrompt, to_enc_dec_tuple_list, to_enc_dec_tuple_list, zip_enc_dec_prompts)
zip_enc_dec_prompts)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams from vllm.sampling_params import BeamSearchParams
from vllm.utils import cuda_device_count_stateless, is_list_of from vllm.utils import cuda_device_count_stateless
logger = init_logger(__name__) logger = init_logger(__name__)
@ -469,12 +468,19 @@ class HfRunner:
prompts: list[str], prompts: list[str],
beam_width: int, beam_width: int,
max_tokens: int, max_tokens: int,
images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None,
) -> list[tuple[list[list[int]], list[str]]]: ) -> list[tuple[list[list[int]], list[str]]]:
outputs = self.generate(prompts, outputs = self.generate(prompts,
do_sample=False, do_sample=False,
max_new_tokens=max_tokens, max_new_tokens=max_tokens,
num_beams=beam_width, num_beams=beam_width,
num_return_sequences=beam_width) num_return_sequences=beam_width,
images=images,
videos=videos,
audios=audios)
for i in range(len(outputs)): for i in range(len(outputs)):
output_ids, output_str = outputs[i] output_ids, output_str = outputs[i]
for j in range(len(output_ids)): for j in range(len(output_ids)):
@ -936,18 +942,20 @@ class VllmRunner:
def generate_beam_search( def generate_beam_search(
self, self,
prompts: Union[list[str], list[list[int]]], prompts: list[str],
beam_width: int, beam_width: int,
max_tokens: int, max_tokens: int,
images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None,
) -> list[tuple[list[list[int]], list[str]]]: ) -> list[tuple[list[list[int]], list[str]]]:
if is_list_of(prompts, str, check="all"): inputs = self.get_inputs(prompts,
prompts = [TextPrompt(prompt=prompt) for prompt in prompts] images=images,
else: videos=videos,
prompts = [ audios=audios)
TokensPrompt(prompt_token_ids=tokens) for tokens in prompts
]
outputs = self.model.beam_search( outputs = self.model.beam_search(
prompts, inputs,
BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens)) BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens))
returned_outputs = [] returned_outputs = []
for output in outputs: for output in outputs:

View File

@ -5,6 +5,9 @@ Run `pytest tests/samplers/test_beam_search.py`.
""" """
import pytest import pytest
from transformers import AutoModelForSeq2SeqLM
from vllm.assets.audio import AudioAsset
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
@ -19,6 +22,7 @@ def v1(run_with_both_engines):
# 3. Use the model "huggyllama/llama-7b". # 3. Use the model "huggyllama/llama-7b".
MAX_TOKENS = [64] MAX_TOKENS = [64]
BEAM_WIDTHS = [4] BEAM_WIDTHS = [4]
MM_BEAM_WIDTHS = [2]
MODELS = ["TinyLlama/TinyLlama-1.1B-Chat-v1.0"] MODELS = ["TinyLlama/TinyLlama-1.1B-Chat-v1.0"]
@ -48,15 +52,90 @@ def test_beam_search_single_input(
for i in range(len(example_prompts)): for i in range(len(example_prompts)):
hf_output_ids, hf_output_texts = hf_outputs[i] hf_output_ids, hf_output_texts = hf_outputs[i]
vllm_output_ids, vllm_output_texts = vllm_outputs[i] vllm_output_ids, vllm_output_texts = vllm_outputs[i]
for i, (hf_text, for j, (hf_text,
vllm_text) in enumerate(zip(hf_output_texts, vllm_text) in enumerate(zip(hf_output_texts,
vllm_output_texts)): vllm_output_texts)):
print(f">>>{i}-th hf output:") print(f">>>{j}-th hf output:")
print(hf_text) print(hf_text)
print(f">>>{i}-th vllm output:") print(f">>>{j}-th vllm output:")
print(vllm_text) print(vllm_text)
assert len(hf_output_ids) == len(vllm_output_ids) assert len(hf_output_ids) == len(vllm_output_ids)
for j in range(len(hf_output_ids)): for j in range(len(hf_output_ids)):
assert hf_output_ids[j] == vllm_output_ids[j], ( assert hf_output_ids[j] == vllm_output_ids[j], (
f"Test{i} output{j}:\nHF: {hf_output_ids}\n" f"Test{i} output{j}:\nHF: {hf_output_ids}\n"
f"vLLM: {vllm_output_ids}") f"vLLM: {vllm_output_ids}")
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", MAX_TOKENS)
@pytest.mark.parametrize("beam_width", MM_BEAM_WIDTHS)
def test_beam_search_passes_multimodal_data(
hf_runner,
vllm_runner,
dtype: str,
max_tokens: int,
beam_width: int,
) -> None:
"""Ensure that beam search passes multimodal data through correctly."""
# NOTE - this test is primarily to check that mm data is passed to beams
# correctly. As such, we just need to check one extra modality to make
# sure things pass through properly.
audios = [AudioAsset("mary_had_lamb").audio_and_sample_rate]
model = "Qwen/Qwen2-Audio-7B-Instruct"
audio_seq = "<|audio_bos|><|AUDIO|><|audio_eos|>"
prompts = [
f"<|im_start|>user\n{audio_seq}Can you transcribe this?<|im_end|>\n<|im_start|>assistant\n" #noqa: E501
]
with hf_runner(model, dtype=dtype,
auto_cls=AutoModelForSeq2SeqLM) as hf_model:
audio_token_id = hf_model.config.audio_token_index
eos_token_id = hf_model.tokenizer.eos_token_id # <|im_end|>
hf_outputs = hf_model.generate_beam_search(
prompts,
beam_width=beam_width,
max_tokens=max_tokens,
audios=audios,
)
with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.generate_beam_search(
prompts,
beam_width=beam_width,
max_tokens=max_tokens,
audios=audios,
)
seq_with_no_audio_toks = lambda seq: [
tok for tok in seq if tok != audio_token_id
]
for i in range(len(prompts)):
hf_output_ids, hf_output_texts = hf_outputs[i]
vllm_output_ids, vllm_output_texts = vllm_outputs[i]
for j, (hf_text,
vllm_text) in enumerate(zip(hf_output_texts,
vllm_output_texts)):
print(f">>>{j}-th hf output [NOTE: special tokens are filtered]:")
print(hf_text)
print(f">>>{j}-th vllm output:")
print(vllm_text)
assert len(hf_output_ids) == len(vllm_output_ids)
for j in range(len(hf_output_ids)):
# Compare everything except for the audio tokens; we do this since
# the IDs returned from the transformers helper expands the audio
# token to match features, while the vLLM helper maintains the
# single audio token in the input text
filtered_hf_output_ids = seq_with_no_audio_toks(hf_output_ids[j])
filtered_vllm_output_ids = seq_with_no_audio_toks(
vllm_output_ids[j])
# HF output IDs may contain the end of sequence
if len(filtered_hf_output_ids
) == len(filtered_vllm_output_ids) + 1:
assert filtered_hf_output_ids[-1] == eos_token_id
filtered_hf_output_ids = filtered_hf_output_ids[:-1]
assert filtered_hf_output_ids == filtered_vllm_output_ids

View File

@ -38,9 +38,18 @@ class BeamSearchOutput:
class BeamSearchInstance: class BeamSearchInstance:
def __init__(self, prompt_tokens: list[int]): def __init__(
self,
prompt_tokens: list[int],
logprobs: Optional[list[dict[int, Logprob]]] = None,
**kwargs,
):
self.beams: list[BeamSearchSequence] = [ self.beams: list[BeamSearchSequence] = [
BeamSearchSequence(tokens=prompt_tokens, logprobs=[]) BeamSearchSequence(
tokens=prompt_tokens,
logprobs=[] if logprobs is None else list(logprobs),
**kwargs,
)
] ]
self.completed: list[BeamSearchSequence] = [] self.completed: list[BeamSearchSequence] = []

View File

@ -536,15 +536,18 @@ class LLM:
tokenizer.eos_token_id, tokenizer.eos_token_id,
length_penalty) length_penalty)
# TODO - fix handling of multimodal data for beam search; we pass it def create_tokens_prompt_from_beam(
# through in the async version on the abstract EngineClient, but not beam: BeamSearchSequence) -> TokensPrompt:
# here. token_prompt_kwargs: TokensPrompt = {
if any("multi_modal_data" in prompt "prompt_token_ids": beam.tokens
and prompt["multi_modal_data"] is not None }
for prompt in prompts): if beam.multi_modal_data is not None:
logger.warning( token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data
"Multimodal data appears to have been provided, but is not"
" currently being passed through in LLM.beam_search()!") if beam.mm_processor_kwargs is not None:
token_prompt_kwargs[
"mm_processor_kwargs"] = beam.mm_processor_kwargs
return TokensPrompt(**token_prompt_kwargs)
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
# generate 2 * beam_width candidates at each step # generate 2 * beam_width candidates at each step
@ -556,11 +559,20 @@ class LLM:
instances: list[BeamSearchInstance] = [] instances: list[BeamSearchInstance] = []
for prompt in prompts: for prompt in prompts:
# Add multimodal processor kwargs & data
mm_kwargs = {}
if "multi_modal_data" in prompt:
mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"]
if "mm_processor_kwargs" in prompt:
mm_kwargs["mm_processor_kwargs"] = prompt[
"mm_processor_kwargs"]
if is_token_prompt(prompt): if is_token_prompt(prompt):
prompt_tokens = prompt["prompt_token_ids"] prompt_tokens = prompt["prompt_token_ids"]
else: else:
prompt_tokens = tokenizer.encode(prompt["prompt"]) prompt_tokens = tokenizer.encode(prompt["prompt"])
instances.append(BeamSearchInstance(prompt_tokens)) instances.append(
BeamSearchInstance(prompt_tokens, logprobs=None, **mm_kwargs))
for _ in range(max_tokens): for _ in range(max_tokens):
all_beams: list[BeamSearchSequence] = list( all_beams: list[BeamSearchSequence] = list(
@ -575,8 +587,7 @@ class LLM:
break break
prompts_batch = [ prompts_batch = [
TokensPrompt(prompt_token_ids=beam.tokens) create_tokens_prompt_from_beam(beam) for beam in all_beams
for beam in all_beams
] ]
# only runs for one step # only runs for one step
@ -602,7 +613,10 @@ class LLM:
tokens=current_beam.tokens + [token_id], tokens=current_beam.tokens + [token_id],
logprobs=current_beam.logprobs + [logprobs], logprobs=current_beam.logprobs + [logprobs],
cum_logprob=current_beam.cum_logprob + cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob) logprob_obj.logprob,
multi_modal_data=current_beam.multi_modal_data,
mm_processor_kwargs=current_beam.
mm_processor_kwargs)
if token_id == tokenizer.eos_token_id and \ if token_id == tokenizer.eos_token_id and \
not ignore_eos: not ignore_eos: