re-implement beam search on top of vllm core (#8726)
Co-authored-by: Brendan Wong <bjwpokemon@gmail.com>
This commit is contained in:
parent
88577ac928
commit
0250dd68c5
@ -90,6 +90,7 @@ def run_vllm(
|
||||
download_dir: Optional[str] = None,
|
||||
load_format: str = EngineArgs.load_format,
|
||||
disable_async_output_proc: bool = False,
|
||||
use_new_beam_search_impl: bool = False,
|
||||
) -> float:
|
||||
from vllm import LLM, SamplingParams
|
||||
llm = LLM(
|
||||
@ -132,9 +133,23 @@ def run_vllm(
|
||||
max_tokens=output_len,
|
||||
))
|
||||
|
||||
start = time.perf_counter()
|
||||
llm.generate(prompts, sampling_params, use_tqdm=True)
|
||||
end = time.perf_counter()
|
||||
if not use_new_beam_search_impl:
|
||||
start = time.perf_counter()
|
||||
llm.generate(prompts, sampling_params, use_tqdm=True)
|
||||
end = time.perf_counter()
|
||||
else:
|
||||
assert use_beam_search
|
||||
prompts = [prompt for prompt, _, _ in requests]
|
||||
# output_len should be the same for all requests.
|
||||
output_len = requests[0][2]
|
||||
for prompt, input_len, _output_len in requests:
|
||||
assert _output_len == output_len
|
||||
start = time.perf_counter()
|
||||
llm.beam_search(prompts,
|
||||
beam_width=n,
|
||||
max_tokens=output_len,
|
||||
ignore_eos=True)
|
||||
end = time.perf_counter()
|
||||
return end - start
|
||||
|
||||
|
||||
@ -336,7 +351,7 @@ def main(args: argparse.Namespace):
|
||||
run_args.append(args.disable_frontend_multiprocessing)
|
||||
elapsed_time = uvloop.run(run_vllm_async(*run_args))
|
||||
else:
|
||||
elapsed_time = run_vllm(*run_args)
|
||||
elapsed_time = run_vllm(*run_args, args.use_new_beam_search_impl)
|
||||
elif args.backend == "hf":
|
||||
assert args.tensor_parallel_size == 1
|
||||
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
||||
@ -396,6 +411,7 @@ if __name__ == "__main__":
|
||||
default=1,
|
||||
help="Number of generated sequences per prompt.")
|
||||
parser.add_argument("--use-beam-search", action="store_true")
|
||||
parser.add_argument("--use-new-beam-search-impl", action="store_true")
|
||||
parser.add_argument("--num-prompts",
|
||||
type=int,
|
||||
default=1000,
|
||||
|
@ -798,6 +798,20 @@ class VllmRunner:
|
||||
outputs = self.generate(prompts, beam_search_params)
|
||||
return outputs
|
||||
|
||||
def generate_beam_search_new(
|
||||
self,
|
||||
prompts: Union[List[str], List[List[int]]],
|
||||
beam_width: int,
|
||||
max_tokens: int,
|
||||
) -> List[Tuple[List[List[int]], List[str]]]:
|
||||
outputs = self.model.beam_search(prompts, beam_width, max_tokens)
|
||||
returned_outputs = []
|
||||
for output in outputs:
|
||||
token_ids = [x.tokens for x in output.sequences]
|
||||
texts = [x.text for x in output.sequences]
|
||||
returned_outputs.append((token_ids, texts))
|
||||
return returned_outputs
|
||||
|
||||
def encode(self, prompts: List[str]) -> List[List[float]]:
|
||||
req_outputs = self.model.encode(prompts)
|
||||
outputs = []
|
||||
|
@ -9,7 +9,7 @@ import pytest
|
||||
# 1. Increase max_tokens to 256.
|
||||
# 2. Increase beam_width to 8.
|
||||
# 3. Use the model "huggyllama/llama-7b".
|
||||
MAX_TOKENS = [128]
|
||||
MAX_TOKENS = [64]
|
||||
BEAM_WIDTHS = [4]
|
||||
MODELS = ["TinyLlama/TinyLlama-1.1B-Chat-v1.0"]
|
||||
|
||||
@ -33,8 +33,8 @@ def test_beam_search_single_input(
|
||||
max_tokens)
|
||||
|
||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_beam_search(example_prompts,
|
||||
beam_width, max_tokens)
|
||||
vllm_outputs = vllm_model.generate_beam_search_new(
|
||||
example_prompts, beam_width, max_tokens)
|
||||
|
||||
for i in range(len(example_prompts)):
|
||||
hf_output_ids, hf_output_texts = hf_outputs[i]
|
||||
|
@ -1,6 +1,8 @@
|
||||
import itertools
|
||||
from contextlib import contextmanager
|
||||
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Union, cast,
|
||||
overload)
|
||||
from dataclasses import dataclass
|
||||
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple,
|
||||
Union, cast, overload)
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
@ -30,6 +32,37 @@ from vllm.utils import Counter, deprecate_kwargs, is_list_of
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BeamSearchSequence:
|
||||
"""A sequence for beam search.
|
||||
It keeps track of the tokens and the log probability of the sequence.
|
||||
The text field is optional and will only be filled when the sequence is
|
||||
about to be returned to the user.
|
||||
"""
|
||||
# The tokens includes the prompt.
|
||||
tokens: List[int]
|
||||
cum_logprob: float = 0.0
|
||||
text: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class BeamSearchOutput:
|
||||
"""The output of beam search.
|
||||
It contains the list of the best beam search sequences.
|
||||
The length of the list is equal to the beam width.
|
||||
"""
|
||||
sequences: List[BeamSearchSequence]
|
||||
|
||||
|
||||
class BeamSearchInstance:
|
||||
|
||||
def __init__(self, prompt_tokens: List[int]):
|
||||
self.beams: List[BeamSearchSequence] = [
|
||||
BeamSearchSequence(tokens=prompt_tokens)
|
||||
]
|
||||
self.completed: List[BeamSearchSequence] = []
|
||||
|
||||
|
||||
class LLM:
|
||||
"""An LLM for generating texts from given prompts and sampling parameters.
|
||||
|
||||
@ -354,6 +387,105 @@ class LLM:
|
||||
outputs = self._run_engine(use_tqdm=use_tqdm)
|
||||
return LLMEngine.validate_outputs(outputs, RequestOutput)
|
||||
|
||||
def beam_search(
|
||||
self,
|
||||
prompts: List[Union[str, List[int]]],
|
||||
beam_width: int,
|
||||
max_tokens: int,
|
||||
ignore_eos: bool = False,
|
||||
) -> List[BeamSearchOutput]:
|
||||
"""
|
||||
Generate sequences using beam search.
|
||||
|
||||
Args:
|
||||
prompts: A list of prompts. Each prompt can be a string or a list
|
||||
of token IDs.
|
||||
beam_width: The number of beams to keep at each step.
|
||||
max_tokens: The max number of tokens to generate for each prompt.
|
||||
|
||||
TODO: how does beam search work together with length penalty, frequency
|
||||
penalty, and stopping criteria, etc.?
|
||||
"""
|
||||
|
||||
tokenizer = self.get_tokenizer()
|
||||
# generate 2 * beam_width candidates at each step
|
||||
# following the huggingface transformers implementation
|
||||
# at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
|
||||
beam_search_params = SamplingParams(logprobs=2 * beam_width,
|
||||
max_tokens=1,
|
||||
temperature=0.0)
|
||||
instances: List[BeamSearchInstance] = []
|
||||
|
||||
for prompt in prompts:
|
||||
prompt_tokens = prompt if isinstance(
|
||||
prompt, list) else tokenizer.encode(prompt)
|
||||
instances.append(BeamSearchInstance(prompt_tokens))
|
||||
|
||||
for _ in range(max_tokens):
|
||||
all_beams: List[BeamSearchSequence] = list(
|
||||
sum((instance.beams for instance in instances), []))
|
||||
pos = [0] + list(
|
||||
itertools.accumulate(
|
||||
len(instance.beams) for instance in instances))
|
||||
instance_start_and_end: List[Tuple[int, int]] = list(
|
||||
zip(pos[:-1], pos[1:]))
|
||||
|
||||
if len(all_beams) == 0:
|
||||
break
|
||||
|
||||
prompts_batch = [
|
||||
TokensPrompt(prompt_token_ids=beam.tokens)
|
||||
for beam in all_beams
|
||||
]
|
||||
|
||||
# only runs for one step
|
||||
# we don't need to use tqdm here
|
||||
output = self.generate(prompts_batch,
|
||||
sampling_params=beam_search_params,
|
||||
use_tqdm=False)
|
||||
|
||||
for (start, end), instance in zip(instance_start_and_end,
|
||||
instances):
|
||||
instance_new_beams = []
|
||||
for i in range(start, end):
|
||||
current_beam = all_beams[i]
|
||||
result = output[i]
|
||||
|
||||
if result.outputs[0].logprobs is not None:
|
||||
# if `result.outputs[0].logprobs` is None, it means
|
||||
# the sequence is completed because of the max-model-len
|
||||
# or abortion. we don't need to add it to the new beams.
|
||||
logprobs = result.outputs[0].logprobs[0]
|
||||
for token_id, logprob_obj in logprobs.items():
|
||||
new_beam = BeamSearchSequence(
|
||||
tokens=current_beam.tokens + [token_id],
|
||||
cum_logprob=current_beam.cum_logprob +
|
||||
logprob_obj.logprob)
|
||||
|
||||
if token_id == tokenizer.eos_token_id and \
|
||||
not ignore_eos:
|
||||
instance.completed.append(new_beam)
|
||||
else:
|
||||
instance_new_beams.append(new_beam)
|
||||
sorted_beams = sorted(instance_new_beams,
|
||||
key=lambda x: x.cum_logprob,
|
||||
reverse=True)
|
||||
instance.beams = sorted_beams[:beam_width]
|
||||
|
||||
outputs = []
|
||||
for instance in instances:
|
||||
instance.completed.extend(instance.beams)
|
||||
sorted_completed = sorted(instance.completed,
|
||||
key=lambda x: x.cum_logprob,
|
||||
reverse=True)
|
||||
best_beams = sorted_completed[:beam_width]
|
||||
|
||||
for beam in best_beams:
|
||||
beam.text = tokenizer.decode(beam.tokens)
|
||||
outputs.append(BeamSearchOutput(sequences=best_beams))
|
||||
|
||||
return outputs
|
||||
|
||||
def chat(
|
||||
self,
|
||||
messages: List[ChatCompletionMessageParam],
|
||||
|
Loading…
x
Reference in New Issue
Block a user