[Core][Bugfix] Fix Offline MM Beam Search (#16390)
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
parent
d2020acac7
commit
6b40996ae8
@ -29,12 +29,11 @@ from vllm.distributed import (cleanup_dist_env_and_memory,
|
||||
init_distributed_environment,
|
||||
initialize_model_parallel)
|
||||
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
|
||||
TokensPrompt, to_enc_dec_tuple_list,
|
||||
zip_enc_dec_prompts)
|
||||
to_enc_dec_tuple_list, zip_enc_dec_prompts)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import BeamSearchParams
|
||||
from vllm.utils import cuda_device_count_stateless, is_list_of
|
||||
from vllm.utils import cuda_device_count_stateless
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -469,12 +468,19 @@ class HfRunner:
|
||||
prompts: list[str],
|
||||
beam_width: int,
|
||||
max_tokens: int,
|
||||
images: Optional[PromptImageInput] = None,
|
||||
videos: Optional[PromptVideoInput] = None,
|
||||
audios: Optional[PromptAudioInput] = None,
|
||||
) -> list[tuple[list[list[int]], list[str]]]:
|
||||
outputs = self.generate(prompts,
|
||||
do_sample=False,
|
||||
max_new_tokens=max_tokens,
|
||||
num_beams=beam_width,
|
||||
num_return_sequences=beam_width)
|
||||
num_return_sequences=beam_width,
|
||||
images=images,
|
||||
videos=videos,
|
||||
audios=audios)
|
||||
|
||||
for i in range(len(outputs)):
|
||||
output_ids, output_str = outputs[i]
|
||||
for j in range(len(output_ids)):
|
||||
@ -936,18 +942,20 @@ class VllmRunner:
|
||||
|
||||
def generate_beam_search(
|
||||
self,
|
||||
prompts: Union[list[str], list[list[int]]],
|
||||
prompts: list[str],
|
||||
beam_width: int,
|
||||
max_tokens: int,
|
||||
images: Optional[PromptImageInput] = None,
|
||||
videos: Optional[PromptVideoInput] = None,
|
||||
audios: Optional[PromptAudioInput] = None,
|
||||
) -> list[tuple[list[list[int]], list[str]]]:
|
||||
if is_list_of(prompts, str, check="all"):
|
||||
prompts = [TextPrompt(prompt=prompt) for prompt in prompts]
|
||||
else:
|
||||
prompts = [
|
||||
TokensPrompt(prompt_token_ids=tokens) for tokens in prompts
|
||||
]
|
||||
inputs = self.get_inputs(prompts,
|
||||
images=images,
|
||||
videos=videos,
|
||||
audios=audios)
|
||||
|
||||
outputs = self.model.beam_search(
|
||||
prompts,
|
||||
inputs,
|
||||
BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens))
|
||||
returned_outputs = []
|
||||
for output in outputs:
|
||||
|
@ -5,6 +5,9 @@ Run `pytest tests/samplers/test_beam_search.py`.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from transformers import AutoModelForSeq2SeqLM
|
||||
|
||||
from vllm.assets.audio import AudioAsset
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
@ -19,6 +22,7 @@ def v1(run_with_both_engines):
|
||||
# 3. Use the model "huggyllama/llama-7b".
|
||||
MAX_TOKENS = [64]
|
||||
BEAM_WIDTHS = [4]
|
||||
MM_BEAM_WIDTHS = [2]
|
||||
MODELS = ["TinyLlama/TinyLlama-1.1B-Chat-v1.0"]
|
||||
|
||||
|
||||
@ -48,15 +52,90 @@ def test_beam_search_single_input(
|
||||
for i in range(len(example_prompts)):
|
||||
hf_output_ids, hf_output_texts = hf_outputs[i]
|
||||
vllm_output_ids, vllm_output_texts = vllm_outputs[i]
|
||||
for i, (hf_text,
|
||||
for j, (hf_text,
|
||||
vllm_text) in enumerate(zip(hf_output_texts,
|
||||
vllm_output_texts)):
|
||||
print(f">>>{i}-th hf output:")
|
||||
print(f">>>{j}-th hf output:")
|
||||
print(hf_text)
|
||||
print(f">>>{i}-th vllm output:")
|
||||
print(f">>>{j}-th vllm output:")
|
||||
print(vllm_text)
|
||||
assert len(hf_output_ids) == len(vllm_output_ids)
|
||||
for j in range(len(hf_output_ids)):
|
||||
assert hf_output_ids[j] == vllm_output_ids[j], (
|
||||
f"Test{i} output{j}:\nHF: {hf_output_ids}\n"
|
||||
f"vLLM: {vllm_output_ids}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", MAX_TOKENS)
|
||||
@pytest.mark.parametrize("beam_width", MM_BEAM_WIDTHS)
|
||||
def test_beam_search_passes_multimodal_data(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
beam_width: int,
|
||||
) -> None:
|
||||
"""Ensure that beam search passes multimodal data through correctly."""
|
||||
# NOTE - this test is primarily to check that mm data is passed to beams
|
||||
# correctly. As such, we just need to check one extra modality to make
|
||||
# sure things pass through properly.
|
||||
audios = [AudioAsset("mary_had_lamb").audio_and_sample_rate]
|
||||
model = "Qwen/Qwen2-Audio-7B-Instruct"
|
||||
audio_seq = "<|audio_bos|><|AUDIO|><|audio_eos|>"
|
||||
prompts = [
|
||||
f"<|im_start|>user\n{audio_seq}Can you transcribe this?<|im_end|>\n<|im_start|>assistant\n" #noqa: E501
|
||||
]
|
||||
|
||||
with hf_runner(model, dtype=dtype,
|
||||
auto_cls=AutoModelForSeq2SeqLM) as hf_model:
|
||||
audio_token_id = hf_model.config.audio_token_index
|
||||
eos_token_id = hf_model.tokenizer.eos_token_id # <|im_end|>
|
||||
hf_outputs = hf_model.generate_beam_search(
|
||||
prompts,
|
||||
beam_width=beam_width,
|
||||
max_tokens=max_tokens,
|
||||
audios=audios,
|
||||
)
|
||||
|
||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_beam_search(
|
||||
prompts,
|
||||
beam_width=beam_width,
|
||||
max_tokens=max_tokens,
|
||||
audios=audios,
|
||||
)
|
||||
|
||||
seq_with_no_audio_toks = lambda seq: [
|
||||
tok for tok in seq if tok != audio_token_id
|
||||
]
|
||||
|
||||
for i in range(len(prompts)):
|
||||
hf_output_ids, hf_output_texts = hf_outputs[i]
|
||||
vllm_output_ids, vllm_output_texts = vllm_outputs[i]
|
||||
|
||||
for j, (hf_text,
|
||||
vllm_text) in enumerate(zip(hf_output_texts,
|
||||
vllm_output_texts)):
|
||||
print(f">>>{j}-th hf output [NOTE: special tokens are filtered]:")
|
||||
print(hf_text)
|
||||
print(f">>>{j}-th vllm output:")
|
||||
print(vllm_text)
|
||||
assert len(hf_output_ids) == len(vllm_output_ids)
|
||||
|
||||
for j in range(len(hf_output_ids)):
|
||||
# Compare everything except for the audio tokens; we do this since
|
||||
# the IDs returned from the transformers helper expands the audio
|
||||
# token to match features, while the vLLM helper maintains the
|
||||
# single audio token in the input text
|
||||
filtered_hf_output_ids = seq_with_no_audio_toks(hf_output_ids[j])
|
||||
filtered_vllm_output_ids = seq_with_no_audio_toks(
|
||||
vllm_output_ids[j])
|
||||
|
||||
# HF output IDs may contain the end of sequence
|
||||
if len(filtered_hf_output_ids
|
||||
) == len(filtered_vllm_output_ids) + 1:
|
||||
assert filtered_hf_output_ids[-1] == eos_token_id
|
||||
filtered_hf_output_ids = filtered_hf_output_ids[:-1]
|
||||
|
||||
assert filtered_hf_output_ids == filtered_vllm_output_ids
|
||||
|
@ -38,9 +38,18 @@ class BeamSearchOutput:
|
||||
|
||||
class BeamSearchInstance:
|
||||
|
||||
def __init__(self, prompt_tokens: list[int]):
|
||||
def __init__(
|
||||
self,
|
||||
prompt_tokens: list[int],
|
||||
logprobs: Optional[list[dict[int, Logprob]]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.beams: list[BeamSearchSequence] = [
|
||||
BeamSearchSequence(tokens=prompt_tokens, logprobs=[])
|
||||
BeamSearchSequence(
|
||||
tokens=prompt_tokens,
|
||||
logprobs=[] if logprobs is None else list(logprobs),
|
||||
**kwargs,
|
||||
)
|
||||
]
|
||||
self.completed: list[BeamSearchSequence] = []
|
||||
|
||||
|
@ -536,15 +536,18 @@ class LLM:
|
||||
tokenizer.eos_token_id,
|
||||
length_penalty)
|
||||
|
||||
# TODO - fix handling of multimodal data for beam search; we pass it
|
||||
# through in the async version on the abstract EngineClient, but not
|
||||
# here.
|
||||
if any("multi_modal_data" in prompt
|
||||
and prompt["multi_modal_data"] is not None
|
||||
for prompt in prompts):
|
||||
logger.warning(
|
||||
"Multimodal data appears to have been provided, but is not"
|
||||
" currently being passed through in LLM.beam_search()!")
|
||||
def create_tokens_prompt_from_beam(
|
||||
beam: BeamSearchSequence) -> TokensPrompt:
|
||||
token_prompt_kwargs: TokensPrompt = {
|
||||
"prompt_token_ids": beam.tokens
|
||||
}
|
||||
if beam.multi_modal_data is not None:
|
||||
token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data
|
||||
|
||||
if beam.mm_processor_kwargs is not None:
|
||||
token_prompt_kwargs[
|
||||
"mm_processor_kwargs"] = beam.mm_processor_kwargs
|
||||
return TokensPrompt(**token_prompt_kwargs)
|
||||
|
||||
tokenizer = self.get_tokenizer()
|
||||
# generate 2 * beam_width candidates at each step
|
||||
@ -556,11 +559,20 @@ class LLM:
|
||||
instances: list[BeamSearchInstance] = []
|
||||
|
||||
for prompt in prompts:
|
||||
# Add multimodal processor kwargs & data
|
||||
mm_kwargs = {}
|
||||
if "multi_modal_data" in prompt:
|
||||
mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"]
|
||||
if "mm_processor_kwargs" in prompt:
|
||||
mm_kwargs["mm_processor_kwargs"] = prompt[
|
||||
"mm_processor_kwargs"]
|
||||
|
||||
if is_token_prompt(prompt):
|
||||
prompt_tokens = prompt["prompt_token_ids"]
|
||||
else:
|
||||
prompt_tokens = tokenizer.encode(prompt["prompt"])
|
||||
instances.append(BeamSearchInstance(prompt_tokens))
|
||||
instances.append(
|
||||
BeamSearchInstance(prompt_tokens, logprobs=None, **mm_kwargs))
|
||||
|
||||
for _ in range(max_tokens):
|
||||
all_beams: list[BeamSearchSequence] = list(
|
||||
@ -575,8 +587,7 @@ class LLM:
|
||||
break
|
||||
|
||||
prompts_batch = [
|
||||
TokensPrompt(prompt_token_ids=beam.tokens)
|
||||
for beam in all_beams
|
||||
create_tokens_prompt_from_beam(beam) for beam in all_beams
|
||||
]
|
||||
|
||||
# only runs for one step
|
||||
@ -602,7 +613,10 @@ class LLM:
|
||||
tokens=current_beam.tokens + [token_id],
|
||||
logprobs=current_beam.logprobs + [logprobs],
|
||||
cum_logprob=current_beam.cum_logprob +
|
||||
logprob_obj.logprob)
|
||||
logprob_obj.logprob,
|
||||
multi_modal_data=current_beam.multi_modal_data,
|
||||
mm_processor_kwargs=current_beam.
|
||||
mm_processor_kwargs)
|
||||
|
||||
if token_id == tokenizer.eos_token_id and \
|
||||
not ignore_eos:
|
||||
|
Loading…
x
Reference in New Issue
Block a user