From 6b40996ae8a5b065de5b3b650b5b1324f67f6334 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Mon, 14 Apr 2025 20:33:02 -0600 Subject: [PATCH] [Core][Bugfix] Fix Offline MM Beam Search (#16390) Signed-off-by: Alex-Brooks Co-authored-by: Cyrus Leung --- tests/conftest.py | 32 ++++++----- tests/samplers/test_beam_search.py | 85 ++++++++++++++++++++++++++++-- vllm/beam_search.py | 13 ++++- vllm/entrypoints/llm.py | 40 +++++++++----- 4 files changed, 140 insertions(+), 30 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 69447d3c..d272f448 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -29,12 +29,11 @@ from vllm.distributed import (cleanup_dist_env_and_memory, init_distributed_environment, initialize_model_parallel) from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt, - TokensPrompt, to_enc_dec_tuple_list, - zip_enc_dec_prompts) + to_enc_dec_tuple_list, zip_enc_dec_prompts) from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.sampling_params import BeamSearchParams -from vllm.utils import cuda_device_count_stateless, is_list_of +from vllm.utils import cuda_device_count_stateless logger = init_logger(__name__) @@ -469,12 +468,19 @@ class HfRunner: prompts: list[str], beam_width: int, max_tokens: int, + images: Optional[PromptImageInput] = None, + videos: Optional[PromptVideoInput] = None, + audios: Optional[PromptAudioInput] = None, ) -> list[tuple[list[list[int]], list[str]]]: outputs = self.generate(prompts, do_sample=False, max_new_tokens=max_tokens, num_beams=beam_width, - num_return_sequences=beam_width) + num_return_sequences=beam_width, + images=images, + videos=videos, + audios=audios) + for i in range(len(outputs)): output_ids, output_str = outputs[i] for j in range(len(output_ids)): @@ -936,18 +942,20 @@ class VllmRunner: def generate_beam_search( self, - prompts: Union[list[str], list[list[int]]], + prompts: list[str], beam_width: int, max_tokens: int, + images: Optional[PromptImageInput] = None, + videos: Optional[PromptVideoInput] = None, + audios: Optional[PromptAudioInput] = None, ) -> list[tuple[list[list[int]], list[str]]]: - if is_list_of(prompts, str, check="all"): - prompts = [TextPrompt(prompt=prompt) for prompt in prompts] - else: - prompts = [ - TokensPrompt(prompt_token_ids=tokens) for tokens in prompts - ] + inputs = self.get_inputs(prompts, + images=images, + videos=videos, + audios=audios) + outputs = self.model.beam_search( - prompts, + inputs, BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens)) returned_outputs = [] for output in outputs: diff --git a/tests/samplers/test_beam_search.py b/tests/samplers/test_beam_search.py index a1a81b38..5de1137e 100644 --- a/tests/samplers/test_beam_search.py +++ b/tests/samplers/test_beam_search.py @@ -5,6 +5,9 @@ Run `pytest tests/samplers/test_beam_search.py`. """ import pytest +from transformers import AutoModelForSeq2SeqLM + +from vllm.assets.audio import AudioAsset @pytest.fixture(autouse=True) @@ -19,6 +22,7 @@ def v1(run_with_both_engines): # 3. Use the model "huggyllama/llama-7b". MAX_TOKENS = [64] BEAM_WIDTHS = [4] +MM_BEAM_WIDTHS = [2] MODELS = ["TinyLlama/TinyLlama-1.1B-Chat-v1.0"] @@ -48,15 +52,90 @@ def test_beam_search_single_input( for i in range(len(example_prompts)): hf_output_ids, hf_output_texts = hf_outputs[i] vllm_output_ids, vllm_output_texts = vllm_outputs[i] - for i, (hf_text, + for j, (hf_text, vllm_text) in enumerate(zip(hf_output_texts, vllm_output_texts)): - print(f">>>{i}-th hf output:") + print(f">>>{j}-th hf output:") print(hf_text) - print(f">>>{i}-th vllm output:") + print(f">>>{j}-th vllm output:") print(vllm_text) assert len(hf_output_ids) == len(vllm_output_ids) for j in range(len(hf_output_ids)): assert hf_output_ids[j] == vllm_output_ids[j], ( f"Test{i} output{j}:\nHF: {hf_output_ids}\n" f"vLLM: {vllm_output_ids}") + + +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", MAX_TOKENS) +@pytest.mark.parametrize("beam_width", MM_BEAM_WIDTHS) +def test_beam_search_passes_multimodal_data( + hf_runner, + vllm_runner, + dtype: str, + max_tokens: int, + beam_width: int, +) -> None: + """Ensure that beam search passes multimodal data through correctly.""" + # NOTE - this test is primarily to check that mm data is passed to beams + # correctly. As such, we just need to check one extra modality to make + # sure things pass through properly. + audios = [AudioAsset("mary_had_lamb").audio_and_sample_rate] + model = "Qwen/Qwen2-Audio-7B-Instruct" + audio_seq = "<|audio_bos|><|AUDIO|><|audio_eos|>" + prompts = [ + f"<|im_start|>user\n{audio_seq}Can you transcribe this?<|im_end|>\n<|im_start|>assistant\n" #noqa: E501 + ] + + with hf_runner(model, dtype=dtype, + auto_cls=AutoModelForSeq2SeqLM) as hf_model: + audio_token_id = hf_model.config.audio_token_index + eos_token_id = hf_model.tokenizer.eos_token_id # <|im_end|> + hf_outputs = hf_model.generate_beam_search( + prompts, + beam_width=beam_width, + max_tokens=max_tokens, + audios=audios, + ) + + with vllm_runner(model, dtype=dtype) as vllm_model: + vllm_outputs = vllm_model.generate_beam_search( + prompts, + beam_width=beam_width, + max_tokens=max_tokens, + audios=audios, + ) + + seq_with_no_audio_toks = lambda seq: [ + tok for tok in seq if tok != audio_token_id + ] + + for i in range(len(prompts)): + hf_output_ids, hf_output_texts = hf_outputs[i] + vllm_output_ids, vllm_output_texts = vllm_outputs[i] + + for j, (hf_text, + vllm_text) in enumerate(zip(hf_output_texts, + vllm_output_texts)): + print(f">>>{j}-th hf output [NOTE: special tokens are filtered]:") + print(hf_text) + print(f">>>{j}-th vllm output:") + print(vllm_text) + assert len(hf_output_ids) == len(vllm_output_ids) + + for j in range(len(hf_output_ids)): + # Compare everything except for the audio tokens; we do this since + # the IDs returned from the transformers helper expands the audio + # token to match features, while the vLLM helper maintains the + # single audio token in the input text + filtered_hf_output_ids = seq_with_no_audio_toks(hf_output_ids[j]) + filtered_vllm_output_ids = seq_with_no_audio_toks( + vllm_output_ids[j]) + + # HF output IDs may contain the end of sequence + if len(filtered_hf_output_ids + ) == len(filtered_vllm_output_ids) + 1: + assert filtered_hf_output_ids[-1] == eos_token_id + filtered_hf_output_ids = filtered_hf_output_ids[:-1] + + assert filtered_hf_output_ids == filtered_vllm_output_ids diff --git a/vllm/beam_search.py b/vllm/beam_search.py index 5d4ebdb7..967510ab 100644 --- a/vllm/beam_search.py +++ b/vllm/beam_search.py @@ -38,9 +38,18 @@ class BeamSearchOutput: class BeamSearchInstance: - def __init__(self, prompt_tokens: list[int]): + def __init__( + self, + prompt_tokens: list[int], + logprobs: Optional[list[dict[int, Logprob]]] = None, + **kwargs, + ): self.beams: list[BeamSearchSequence] = [ - BeamSearchSequence(tokens=prompt_tokens, logprobs=[]) + BeamSearchSequence( + tokens=prompt_tokens, + logprobs=[] if logprobs is None else list(logprobs), + **kwargs, + ) ] self.completed: list[BeamSearchSequence] = [] diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index a707087a..57c7ab73 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -536,15 +536,18 @@ class LLM: tokenizer.eos_token_id, length_penalty) - # TODO - fix handling of multimodal data for beam search; we pass it - # through in the async version on the abstract EngineClient, but not - # here. - if any("multi_modal_data" in prompt - and prompt["multi_modal_data"] is not None - for prompt in prompts): - logger.warning( - "Multimodal data appears to have been provided, but is not" - " currently being passed through in LLM.beam_search()!") + def create_tokens_prompt_from_beam( + beam: BeamSearchSequence) -> TokensPrompt: + token_prompt_kwargs: TokensPrompt = { + "prompt_token_ids": beam.tokens + } + if beam.multi_modal_data is not None: + token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data + + if beam.mm_processor_kwargs is not None: + token_prompt_kwargs[ + "mm_processor_kwargs"] = beam.mm_processor_kwargs + return TokensPrompt(**token_prompt_kwargs) tokenizer = self.get_tokenizer() # generate 2 * beam_width candidates at each step @@ -556,11 +559,20 @@ class LLM: instances: list[BeamSearchInstance] = [] for prompt in prompts: + # Add multimodal processor kwargs & data + mm_kwargs = {} + if "multi_modal_data" in prompt: + mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"] + if "mm_processor_kwargs" in prompt: + mm_kwargs["mm_processor_kwargs"] = prompt[ + "mm_processor_kwargs"] + if is_token_prompt(prompt): prompt_tokens = prompt["prompt_token_ids"] else: prompt_tokens = tokenizer.encode(prompt["prompt"]) - instances.append(BeamSearchInstance(prompt_tokens)) + instances.append( + BeamSearchInstance(prompt_tokens, logprobs=None, **mm_kwargs)) for _ in range(max_tokens): all_beams: list[BeamSearchSequence] = list( @@ -575,8 +587,7 @@ class LLM: break prompts_batch = [ - TokensPrompt(prompt_token_ids=beam.tokens) - for beam in all_beams + create_tokens_prompt_from_beam(beam) for beam in all_beams ] # only runs for one step @@ -602,7 +613,10 @@ class LLM: tokens=current_beam.tokens + [token_id], logprobs=current_beam.logprobs + [logprobs], cum_logprob=current_beam.cum_logprob + - logprob_obj.logprob) + logprob_obj.logprob, + multi_modal_data=current_beam.multi_modal_data, + mm_processor_kwargs=current_beam. + mm_processor_kwargs) if token_id == tokenizer.eos_token_id and \ not ignore_eos: