[Core][Bugfix] Fix Offline MM Beam Search (#16390)
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
parent
d2020acac7
commit
6b40996ae8
@ -29,12 +29,11 @@ from vllm.distributed import (cleanup_dist_env_and_memory,
|
|||||||
init_distributed_environment,
|
init_distributed_environment,
|
||||||
initialize_model_parallel)
|
initialize_model_parallel)
|
||||||
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
|
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
|
||||||
TokensPrompt, to_enc_dec_tuple_list,
|
to_enc_dec_tuple_list, zip_enc_dec_prompts)
|
||||||
zip_enc_dec_prompts)
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
from vllm.sampling_params import BeamSearchParams
|
from vllm.sampling_params import BeamSearchParams
|
||||||
from vllm.utils import cuda_device_count_stateless, is_list_of
|
from vllm.utils import cuda_device_count_stateless
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -469,12 +468,19 @@ class HfRunner:
|
|||||||
prompts: list[str],
|
prompts: list[str],
|
||||||
beam_width: int,
|
beam_width: int,
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
|
images: Optional[PromptImageInput] = None,
|
||||||
|
videos: Optional[PromptVideoInput] = None,
|
||||||
|
audios: Optional[PromptAudioInput] = None,
|
||||||
) -> list[tuple[list[list[int]], list[str]]]:
|
) -> list[tuple[list[list[int]], list[str]]]:
|
||||||
outputs = self.generate(prompts,
|
outputs = self.generate(prompts,
|
||||||
do_sample=False,
|
do_sample=False,
|
||||||
max_new_tokens=max_tokens,
|
max_new_tokens=max_tokens,
|
||||||
num_beams=beam_width,
|
num_beams=beam_width,
|
||||||
num_return_sequences=beam_width)
|
num_return_sequences=beam_width,
|
||||||
|
images=images,
|
||||||
|
videos=videos,
|
||||||
|
audios=audios)
|
||||||
|
|
||||||
for i in range(len(outputs)):
|
for i in range(len(outputs)):
|
||||||
output_ids, output_str = outputs[i]
|
output_ids, output_str = outputs[i]
|
||||||
for j in range(len(output_ids)):
|
for j in range(len(output_ids)):
|
||||||
@ -936,18 +942,20 @@ class VllmRunner:
|
|||||||
|
|
||||||
def generate_beam_search(
|
def generate_beam_search(
|
||||||
self,
|
self,
|
||||||
prompts: Union[list[str], list[list[int]]],
|
prompts: list[str],
|
||||||
beam_width: int,
|
beam_width: int,
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
|
images: Optional[PromptImageInput] = None,
|
||||||
|
videos: Optional[PromptVideoInput] = None,
|
||||||
|
audios: Optional[PromptAudioInput] = None,
|
||||||
) -> list[tuple[list[list[int]], list[str]]]:
|
) -> list[tuple[list[list[int]], list[str]]]:
|
||||||
if is_list_of(prompts, str, check="all"):
|
inputs = self.get_inputs(prompts,
|
||||||
prompts = [TextPrompt(prompt=prompt) for prompt in prompts]
|
images=images,
|
||||||
else:
|
videos=videos,
|
||||||
prompts = [
|
audios=audios)
|
||||||
TokensPrompt(prompt_token_ids=tokens) for tokens in prompts
|
|
||||||
]
|
|
||||||
outputs = self.model.beam_search(
|
outputs = self.model.beam_search(
|
||||||
prompts,
|
inputs,
|
||||||
BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens))
|
BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens))
|
||||||
returned_outputs = []
|
returned_outputs = []
|
||||||
for output in outputs:
|
for output in outputs:
|
||||||
|
@ -5,6 +5,9 @@ Run `pytest tests/samplers/test_beam_search.py`.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from transformers import AutoModelForSeq2SeqLM
|
||||||
|
|
||||||
|
from vllm.assets.audio import AudioAsset
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
@ -19,6 +22,7 @@ def v1(run_with_both_engines):
|
|||||||
# 3. Use the model "huggyllama/llama-7b".
|
# 3. Use the model "huggyllama/llama-7b".
|
||||||
MAX_TOKENS = [64]
|
MAX_TOKENS = [64]
|
||||||
BEAM_WIDTHS = [4]
|
BEAM_WIDTHS = [4]
|
||||||
|
MM_BEAM_WIDTHS = [2]
|
||||||
MODELS = ["TinyLlama/TinyLlama-1.1B-Chat-v1.0"]
|
MODELS = ["TinyLlama/TinyLlama-1.1B-Chat-v1.0"]
|
||||||
|
|
||||||
|
|
||||||
@ -48,15 +52,90 @@ def test_beam_search_single_input(
|
|||||||
for i in range(len(example_prompts)):
|
for i in range(len(example_prompts)):
|
||||||
hf_output_ids, hf_output_texts = hf_outputs[i]
|
hf_output_ids, hf_output_texts = hf_outputs[i]
|
||||||
vllm_output_ids, vllm_output_texts = vllm_outputs[i]
|
vllm_output_ids, vllm_output_texts = vllm_outputs[i]
|
||||||
for i, (hf_text,
|
for j, (hf_text,
|
||||||
vllm_text) in enumerate(zip(hf_output_texts,
|
vllm_text) in enumerate(zip(hf_output_texts,
|
||||||
vllm_output_texts)):
|
vllm_output_texts)):
|
||||||
print(f">>>{i}-th hf output:")
|
print(f">>>{j}-th hf output:")
|
||||||
print(hf_text)
|
print(hf_text)
|
||||||
print(f">>>{i}-th vllm output:")
|
print(f">>>{j}-th vllm output:")
|
||||||
print(vllm_text)
|
print(vllm_text)
|
||||||
assert len(hf_output_ids) == len(vllm_output_ids)
|
assert len(hf_output_ids) == len(vllm_output_ids)
|
||||||
for j in range(len(hf_output_ids)):
|
for j in range(len(hf_output_ids)):
|
||||||
assert hf_output_ids[j] == vllm_output_ids[j], (
|
assert hf_output_ids[j] == vllm_output_ids[j], (
|
||||||
f"Test{i} output{j}:\nHF: {hf_output_ids}\n"
|
f"Test{i} output{j}:\nHF: {hf_output_ids}\n"
|
||||||
f"vLLM: {vllm_output_ids}")
|
f"vLLM: {vllm_output_ids}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
|
@pytest.mark.parametrize("max_tokens", MAX_TOKENS)
|
||||||
|
@pytest.mark.parametrize("beam_width", MM_BEAM_WIDTHS)
|
||||||
|
def test_beam_search_passes_multimodal_data(
|
||||||
|
hf_runner,
|
||||||
|
vllm_runner,
|
||||||
|
dtype: str,
|
||||||
|
max_tokens: int,
|
||||||
|
beam_width: int,
|
||||||
|
) -> None:
|
||||||
|
"""Ensure that beam search passes multimodal data through correctly."""
|
||||||
|
# NOTE - this test is primarily to check that mm data is passed to beams
|
||||||
|
# correctly. As such, we just need to check one extra modality to make
|
||||||
|
# sure things pass through properly.
|
||||||
|
audios = [AudioAsset("mary_had_lamb").audio_and_sample_rate]
|
||||||
|
model = "Qwen/Qwen2-Audio-7B-Instruct"
|
||||||
|
audio_seq = "<|audio_bos|><|AUDIO|><|audio_eos|>"
|
||||||
|
prompts = [
|
||||||
|
f"<|im_start|>user\n{audio_seq}Can you transcribe this?<|im_end|>\n<|im_start|>assistant\n" #noqa: E501
|
||||||
|
]
|
||||||
|
|
||||||
|
with hf_runner(model, dtype=dtype,
|
||||||
|
auto_cls=AutoModelForSeq2SeqLM) as hf_model:
|
||||||
|
audio_token_id = hf_model.config.audio_token_index
|
||||||
|
eos_token_id = hf_model.tokenizer.eos_token_id # <|im_end|>
|
||||||
|
hf_outputs = hf_model.generate_beam_search(
|
||||||
|
prompts,
|
||||||
|
beam_width=beam_width,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
audios=audios,
|
||||||
|
)
|
||||||
|
|
||||||
|
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||||
|
vllm_outputs = vllm_model.generate_beam_search(
|
||||||
|
prompts,
|
||||||
|
beam_width=beam_width,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
audios=audios,
|
||||||
|
)
|
||||||
|
|
||||||
|
seq_with_no_audio_toks = lambda seq: [
|
||||||
|
tok for tok in seq if tok != audio_token_id
|
||||||
|
]
|
||||||
|
|
||||||
|
for i in range(len(prompts)):
|
||||||
|
hf_output_ids, hf_output_texts = hf_outputs[i]
|
||||||
|
vllm_output_ids, vllm_output_texts = vllm_outputs[i]
|
||||||
|
|
||||||
|
for j, (hf_text,
|
||||||
|
vllm_text) in enumerate(zip(hf_output_texts,
|
||||||
|
vllm_output_texts)):
|
||||||
|
print(f">>>{j}-th hf output [NOTE: special tokens are filtered]:")
|
||||||
|
print(hf_text)
|
||||||
|
print(f">>>{j}-th vllm output:")
|
||||||
|
print(vllm_text)
|
||||||
|
assert len(hf_output_ids) == len(vllm_output_ids)
|
||||||
|
|
||||||
|
for j in range(len(hf_output_ids)):
|
||||||
|
# Compare everything except for the audio tokens; we do this since
|
||||||
|
# the IDs returned from the transformers helper expands the audio
|
||||||
|
# token to match features, while the vLLM helper maintains the
|
||||||
|
# single audio token in the input text
|
||||||
|
filtered_hf_output_ids = seq_with_no_audio_toks(hf_output_ids[j])
|
||||||
|
filtered_vllm_output_ids = seq_with_no_audio_toks(
|
||||||
|
vllm_output_ids[j])
|
||||||
|
|
||||||
|
# HF output IDs may contain the end of sequence
|
||||||
|
if len(filtered_hf_output_ids
|
||||||
|
) == len(filtered_vllm_output_ids) + 1:
|
||||||
|
assert filtered_hf_output_ids[-1] == eos_token_id
|
||||||
|
filtered_hf_output_ids = filtered_hf_output_ids[:-1]
|
||||||
|
|
||||||
|
assert filtered_hf_output_ids == filtered_vllm_output_ids
|
||||||
|
@ -38,9 +38,18 @@ class BeamSearchOutput:
|
|||||||
|
|
||||||
class BeamSearchInstance:
|
class BeamSearchInstance:
|
||||||
|
|
||||||
def __init__(self, prompt_tokens: list[int]):
|
def __init__(
|
||||||
|
self,
|
||||||
|
prompt_tokens: list[int],
|
||||||
|
logprobs: Optional[list[dict[int, Logprob]]] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
self.beams: list[BeamSearchSequence] = [
|
self.beams: list[BeamSearchSequence] = [
|
||||||
BeamSearchSequence(tokens=prompt_tokens, logprobs=[])
|
BeamSearchSequence(
|
||||||
|
tokens=prompt_tokens,
|
||||||
|
logprobs=[] if logprobs is None else list(logprobs),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
]
|
]
|
||||||
self.completed: list[BeamSearchSequence] = []
|
self.completed: list[BeamSearchSequence] = []
|
||||||
|
|
||||||
|
@ -536,15 +536,18 @@ class LLM:
|
|||||||
tokenizer.eos_token_id,
|
tokenizer.eos_token_id,
|
||||||
length_penalty)
|
length_penalty)
|
||||||
|
|
||||||
# TODO - fix handling of multimodal data for beam search; we pass it
|
def create_tokens_prompt_from_beam(
|
||||||
# through in the async version on the abstract EngineClient, but not
|
beam: BeamSearchSequence) -> TokensPrompt:
|
||||||
# here.
|
token_prompt_kwargs: TokensPrompt = {
|
||||||
if any("multi_modal_data" in prompt
|
"prompt_token_ids": beam.tokens
|
||||||
and prompt["multi_modal_data"] is not None
|
}
|
||||||
for prompt in prompts):
|
if beam.multi_modal_data is not None:
|
||||||
logger.warning(
|
token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data
|
||||||
"Multimodal data appears to have been provided, but is not"
|
|
||||||
" currently being passed through in LLM.beam_search()!")
|
if beam.mm_processor_kwargs is not None:
|
||||||
|
token_prompt_kwargs[
|
||||||
|
"mm_processor_kwargs"] = beam.mm_processor_kwargs
|
||||||
|
return TokensPrompt(**token_prompt_kwargs)
|
||||||
|
|
||||||
tokenizer = self.get_tokenizer()
|
tokenizer = self.get_tokenizer()
|
||||||
# generate 2 * beam_width candidates at each step
|
# generate 2 * beam_width candidates at each step
|
||||||
@ -556,11 +559,20 @@ class LLM:
|
|||||||
instances: list[BeamSearchInstance] = []
|
instances: list[BeamSearchInstance] = []
|
||||||
|
|
||||||
for prompt in prompts:
|
for prompt in prompts:
|
||||||
|
# Add multimodal processor kwargs & data
|
||||||
|
mm_kwargs = {}
|
||||||
|
if "multi_modal_data" in prompt:
|
||||||
|
mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"]
|
||||||
|
if "mm_processor_kwargs" in prompt:
|
||||||
|
mm_kwargs["mm_processor_kwargs"] = prompt[
|
||||||
|
"mm_processor_kwargs"]
|
||||||
|
|
||||||
if is_token_prompt(prompt):
|
if is_token_prompt(prompt):
|
||||||
prompt_tokens = prompt["prompt_token_ids"]
|
prompt_tokens = prompt["prompt_token_ids"]
|
||||||
else:
|
else:
|
||||||
prompt_tokens = tokenizer.encode(prompt["prompt"])
|
prompt_tokens = tokenizer.encode(prompt["prompt"])
|
||||||
instances.append(BeamSearchInstance(prompt_tokens))
|
instances.append(
|
||||||
|
BeamSearchInstance(prompt_tokens, logprobs=None, **mm_kwargs))
|
||||||
|
|
||||||
for _ in range(max_tokens):
|
for _ in range(max_tokens):
|
||||||
all_beams: list[BeamSearchSequence] = list(
|
all_beams: list[BeamSearchSequence] = list(
|
||||||
@ -575,8 +587,7 @@ class LLM:
|
|||||||
break
|
break
|
||||||
|
|
||||||
prompts_batch = [
|
prompts_batch = [
|
||||||
TokensPrompt(prompt_token_ids=beam.tokens)
|
create_tokens_prompt_from_beam(beam) for beam in all_beams
|
||||||
for beam in all_beams
|
|
||||||
]
|
]
|
||||||
|
|
||||||
# only runs for one step
|
# only runs for one step
|
||||||
@ -602,7 +613,10 @@ class LLM:
|
|||||||
tokens=current_beam.tokens + [token_id],
|
tokens=current_beam.tokens + [token_id],
|
||||||
logprobs=current_beam.logprobs + [logprobs],
|
logprobs=current_beam.logprobs + [logprobs],
|
||||||
cum_logprob=current_beam.cum_logprob +
|
cum_logprob=current_beam.cum_logprob +
|
||||||
logprob_obj.logprob)
|
logprob_obj.logprob,
|
||||||
|
multi_modal_data=current_beam.multi_modal_data,
|
||||||
|
mm_processor_kwargs=current_beam.
|
||||||
|
mm_processor_kwargs)
|
||||||
|
|
||||||
if token_id == tokenizer.eos_token_id and \
|
if token_id == tokenizer.eos_token_id and \
|
||||||
not ignore_eos:
|
not ignore_eos:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user