re-implement beam search on top of vllm core (#8726)

Co-authored-by: Brendan Wong <bjwpokemon@gmail.com>
This commit is contained in:
youkaichao 2024-09-23 22:08:12 -07:00 committed by GitHub
parent 88577ac928
commit 0250dd68c5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 171 additions and 9 deletions

View File

@ -90,6 +90,7 @@ def run_vllm(
download_dir: Optional[str] = None,
load_format: str = EngineArgs.load_format,
disable_async_output_proc: bool = False,
use_new_beam_search_impl: bool = False,
) -> float:
from vllm import LLM, SamplingParams
llm = LLM(
@ -132,9 +133,23 @@ def run_vllm(
max_tokens=output_len,
))
start = time.perf_counter()
llm.generate(prompts, sampling_params, use_tqdm=True)
end = time.perf_counter()
if not use_new_beam_search_impl:
start = time.perf_counter()
llm.generate(prompts, sampling_params, use_tqdm=True)
end = time.perf_counter()
else:
assert use_beam_search
prompts = [prompt for prompt, _, _ in requests]
# output_len should be the same for all requests.
output_len = requests[0][2]
for prompt, input_len, _output_len in requests:
assert _output_len == output_len
start = time.perf_counter()
llm.beam_search(prompts,
beam_width=n,
max_tokens=output_len,
ignore_eos=True)
end = time.perf_counter()
return end - start
@ -336,7 +351,7 @@ def main(args: argparse.Namespace):
run_args.append(args.disable_frontend_multiprocessing)
elapsed_time = uvloop.run(run_vllm_async(*run_args))
else:
elapsed_time = run_vllm(*run_args)
elapsed_time = run_vllm(*run_args, args.use_new_beam_search_impl)
elif args.backend == "hf":
assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
@ -396,6 +411,7 @@ if __name__ == "__main__":
default=1,
help="Number of generated sequences per prompt.")
parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument("--use-new-beam-search-impl", action="store_true")
parser.add_argument("--num-prompts",
type=int,
default=1000,

View File

@ -798,6 +798,20 @@ class VllmRunner:
outputs = self.generate(prompts, beam_search_params)
return outputs
def generate_beam_search_new(
self,
prompts: Union[List[str], List[List[int]]],
beam_width: int,
max_tokens: int,
) -> List[Tuple[List[List[int]], List[str]]]:
outputs = self.model.beam_search(prompts, beam_width, max_tokens)
returned_outputs = []
for output in outputs:
token_ids = [x.tokens for x in output.sequences]
texts = [x.text for x in output.sequences]
returned_outputs.append((token_ids, texts))
return returned_outputs
def encode(self, prompts: List[str]) -> List[List[float]]:
req_outputs = self.model.encode(prompts)
outputs = []

View File

@ -9,7 +9,7 @@ import pytest
# 1. Increase max_tokens to 256.
# 2. Increase beam_width to 8.
# 3. Use the model "huggyllama/llama-7b".
MAX_TOKENS = [128]
MAX_TOKENS = [64]
BEAM_WIDTHS = [4]
MODELS = ["TinyLlama/TinyLlama-1.1B-Chat-v1.0"]
@ -33,8 +33,8 @@ def test_beam_search_single_input(
max_tokens)
with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.generate_beam_search(example_prompts,
beam_width, max_tokens)
vllm_outputs = vllm_model.generate_beam_search_new(
example_prompts, beam_width, max_tokens)
for i in range(len(example_prompts)):
hf_output_ids, hf_output_texts = hf_outputs[i]

View File

@ -1,6 +1,8 @@
import itertools
from contextlib import contextmanager
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Union, cast,
overload)
from dataclasses import dataclass
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple,
Union, cast, overload)
from tqdm import tqdm
@ -30,6 +32,37 @@ from vllm.utils import Counter, deprecate_kwargs, is_list_of
logger = init_logger(__name__)
@dataclass
class BeamSearchSequence:
"""A sequence for beam search.
It keeps track of the tokens and the log probability of the sequence.
The text field is optional and will only be filled when the sequence is
about to be returned to the user.
"""
# The tokens includes the prompt.
tokens: List[int]
cum_logprob: float = 0.0
text: Optional[str] = None
@dataclass
class BeamSearchOutput:
"""The output of beam search.
It contains the list of the best beam search sequences.
The length of the list is equal to the beam width.
"""
sequences: List[BeamSearchSequence]
class BeamSearchInstance:
def __init__(self, prompt_tokens: List[int]):
self.beams: List[BeamSearchSequence] = [
BeamSearchSequence(tokens=prompt_tokens)
]
self.completed: List[BeamSearchSequence] = []
class LLM:
"""An LLM for generating texts from given prompts and sampling parameters.
@ -354,6 +387,105 @@ class LLM:
outputs = self._run_engine(use_tqdm=use_tqdm)
return LLMEngine.validate_outputs(outputs, RequestOutput)
def beam_search(
self,
prompts: List[Union[str, List[int]]],
beam_width: int,
max_tokens: int,
ignore_eos: bool = False,
) -> List[BeamSearchOutput]:
"""
Generate sequences using beam search.
Args:
prompts: A list of prompts. Each prompt can be a string or a list
of token IDs.
beam_width: The number of beams to keep at each step.
max_tokens: The max number of tokens to generate for each prompt.
TODO: how does beam search work together with length penalty, frequency
penalty, and stopping criteria, etc.?
"""
tokenizer = self.get_tokenizer()
# generate 2 * beam_width candidates at each step
# following the huggingface transformers implementation
# at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
beam_search_params = SamplingParams(logprobs=2 * beam_width,
max_tokens=1,
temperature=0.0)
instances: List[BeamSearchInstance] = []
for prompt in prompts:
prompt_tokens = prompt if isinstance(
prompt, list) else tokenizer.encode(prompt)
instances.append(BeamSearchInstance(prompt_tokens))
for _ in range(max_tokens):
all_beams: List[BeamSearchSequence] = list(
sum((instance.beams for instance in instances), []))
pos = [0] + list(
itertools.accumulate(
len(instance.beams) for instance in instances))
instance_start_and_end: List[Tuple[int, int]] = list(
zip(pos[:-1], pos[1:]))
if len(all_beams) == 0:
break
prompts_batch = [
TokensPrompt(prompt_token_ids=beam.tokens)
for beam in all_beams
]
# only runs for one step
# we don't need to use tqdm here
output = self.generate(prompts_batch,
sampling_params=beam_search_params,
use_tqdm=False)
for (start, end), instance in zip(instance_start_and_end,
instances):
instance_new_beams = []
for i in range(start, end):
current_beam = all_beams[i]
result = output[i]
if result.outputs[0].logprobs is not None:
# if `result.outputs[0].logprobs` is None, it means
# the sequence is completed because of the max-model-len
# or abortion. we don't need to add it to the new beams.
logprobs = result.outputs[0].logprobs[0]
for token_id, logprob_obj in logprobs.items():
new_beam = BeamSearchSequence(
tokens=current_beam.tokens + [token_id],
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob)
if token_id == tokenizer.eos_token_id and \
not ignore_eos:
instance.completed.append(new_beam)
else:
instance_new_beams.append(new_beam)
sorted_beams = sorted(instance_new_beams,
key=lambda x: x.cum_logprob,
reverse=True)
instance.beams = sorted_beams[:beam_width]
outputs = []
for instance in instances:
instance.completed.extend(instance.beams)
sorted_completed = sorted(instance.completed,
key=lambda x: x.cum_logprob,
reverse=True)
best_beams = sorted_completed[:beam_width]
for beam in best_beams:
beam.text = tokenizer.decode(beam.tokens)
outputs.append(BeamSearchOutput(sequences=best_beams))
return outputs
def chat(
self,
messages: List[ChatCompletionMessageParam],