re-implement beam search on top of vllm core (#8726)
Co-authored-by: Brendan Wong <bjwpokemon@gmail.com>
This commit is contained in:
parent
88577ac928
commit
0250dd68c5
@ -90,6 +90,7 @@ def run_vllm(
|
|||||||
download_dir: Optional[str] = None,
|
download_dir: Optional[str] = None,
|
||||||
load_format: str = EngineArgs.load_format,
|
load_format: str = EngineArgs.load_format,
|
||||||
disable_async_output_proc: bool = False,
|
disable_async_output_proc: bool = False,
|
||||||
|
use_new_beam_search_impl: bool = False,
|
||||||
) -> float:
|
) -> float:
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
llm = LLM(
|
llm = LLM(
|
||||||
@ -132,9 +133,23 @@ def run_vllm(
|
|||||||
max_tokens=output_len,
|
max_tokens=output_len,
|
||||||
))
|
))
|
||||||
|
|
||||||
start = time.perf_counter()
|
if not use_new_beam_search_impl:
|
||||||
llm.generate(prompts, sampling_params, use_tqdm=True)
|
start = time.perf_counter()
|
||||||
end = time.perf_counter()
|
llm.generate(prompts, sampling_params, use_tqdm=True)
|
||||||
|
end = time.perf_counter()
|
||||||
|
else:
|
||||||
|
assert use_beam_search
|
||||||
|
prompts = [prompt for prompt, _, _ in requests]
|
||||||
|
# output_len should be the same for all requests.
|
||||||
|
output_len = requests[0][2]
|
||||||
|
for prompt, input_len, _output_len in requests:
|
||||||
|
assert _output_len == output_len
|
||||||
|
start = time.perf_counter()
|
||||||
|
llm.beam_search(prompts,
|
||||||
|
beam_width=n,
|
||||||
|
max_tokens=output_len,
|
||||||
|
ignore_eos=True)
|
||||||
|
end = time.perf_counter()
|
||||||
return end - start
|
return end - start
|
||||||
|
|
||||||
|
|
||||||
@ -336,7 +351,7 @@ def main(args: argparse.Namespace):
|
|||||||
run_args.append(args.disable_frontend_multiprocessing)
|
run_args.append(args.disable_frontend_multiprocessing)
|
||||||
elapsed_time = uvloop.run(run_vllm_async(*run_args))
|
elapsed_time = uvloop.run(run_vllm_async(*run_args))
|
||||||
else:
|
else:
|
||||||
elapsed_time = run_vllm(*run_args)
|
elapsed_time = run_vllm(*run_args, args.use_new_beam_search_impl)
|
||||||
elif args.backend == "hf":
|
elif args.backend == "hf":
|
||||||
assert args.tensor_parallel_size == 1
|
assert args.tensor_parallel_size == 1
|
||||||
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
||||||
@ -396,6 +411,7 @@ if __name__ == "__main__":
|
|||||||
default=1,
|
default=1,
|
||||||
help="Number of generated sequences per prompt.")
|
help="Number of generated sequences per prompt.")
|
||||||
parser.add_argument("--use-beam-search", action="store_true")
|
parser.add_argument("--use-beam-search", action="store_true")
|
||||||
|
parser.add_argument("--use-new-beam-search-impl", action="store_true")
|
||||||
parser.add_argument("--num-prompts",
|
parser.add_argument("--num-prompts",
|
||||||
type=int,
|
type=int,
|
||||||
default=1000,
|
default=1000,
|
||||||
|
@ -798,6 +798,20 @@ class VllmRunner:
|
|||||||
outputs = self.generate(prompts, beam_search_params)
|
outputs = self.generate(prompts, beam_search_params)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
def generate_beam_search_new(
|
||||||
|
self,
|
||||||
|
prompts: Union[List[str], List[List[int]]],
|
||||||
|
beam_width: int,
|
||||||
|
max_tokens: int,
|
||||||
|
) -> List[Tuple[List[List[int]], List[str]]]:
|
||||||
|
outputs = self.model.beam_search(prompts, beam_width, max_tokens)
|
||||||
|
returned_outputs = []
|
||||||
|
for output in outputs:
|
||||||
|
token_ids = [x.tokens for x in output.sequences]
|
||||||
|
texts = [x.text for x in output.sequences]
|
||||||
|
returned_outputs.append((token_ids, texts))
|
||||||
|
return returned_outputs
|
||||||
|
|
||||||
def encode(self, prompts: List[str]) -> List[List[float]]:
|
def encode(self, prompts: List[str]) -> List[List[float]]:
|
||||||
req_outputs = self.model.encode(prompts)
|
req_outputs = self.model.encode(prompts)
|
||||||
outputs = []
|
outputs = []
|
||||||
|
@ -9,7 +9,7 @@ import pytest
|
|||||||
# 1. Increase max_tokens to 256.
|
# 1. Increase max_tokens to 256.
|
||||||
# 2. Increase beam_width to 8.
|
# 2. Increase beam_width to 8.
|
||||||
# 3. Use the model "huggyllama/llama-7b".
|
# 3. Use the model "huggyllama/llama-7b".
|
||||||
MAX_TOKENS = [128]
|
MAX_TOKENS = [64]
|
||||||
BEAM_WIDTHS = [4]
|
BEAM_WIDTHS = [4]
|
||||||
MODELS = ["TinyLlama/TinyLlama-1.1B-Chat-v1.0"]
|
MODELS = ["TinyLlama/TinyLlama-1.1B-Chat-v1.0"]
|
||||||
|
|
||||||
@ -33,8 +33,8 @@ def test_beam_search_single_input(
|
|||||||
max_tokens)
|
max_tokens)
|
||||||
|
|
||||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||||
vllm_outputs = vllm_model.generate_beam_search(example_prompts,
|
vllm_outputs = vllm_model.generate_beam_search_new(
|
||||||
beam_width, max_tokens)
|
example_prompts, beam_width, max_tokens)
|
||||||
|
|
||||||
for i in range(len(example_prompts)):
|
for i in range(len(example_prompts)):
|
||||||
hf_output_ids, hf_output_texts = hf_outputs[i]
|
hf_output_ids, hf_output_texts = hf_outputs[i]
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
|
import itertools
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Union, cast,
|
from dataclasses import dataclass
|
||||||
overload)
|
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple,
|
||||||
|
Union, cast, overload)
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
@ -30,6 +32,37 @@ from vllm.utils import Counter, deprecate_kwargs, is_list_of
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BeamSearchSequence:
|
||||||
|
"""A sequence for beam search.
|
||||||
|
It keeps track of the tokens and the log probability of the sequence.
|
||||||
|
The text field is optional and will only be filled when the sequence is
|
||||||
|
about to be returned to the user.
|
||||||
|
"""
|
||||||
|
# The tokens includes the prompt.
|
||||||
|
tokens: List[int]
|
||||||
|
cum_logprob: float = 0.0
|
||||||
|
text: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BeamSearchOutput:
|
||||||
|
"""The output of beam search.
|
||||||
|
It contains the list of the best beam search sequences.
|
||||||
|
The length of the list is equal to the beam width.
|
||||||
|
"""
|
||||||
|
sequences: List[BeamSearchSequence]
|
||||||
|
|
||||||
|
|
||||||
|
class BeamSearchInstance:
|
||||||
|
|
||||||
|
def __init__(self, prompt_tokens: List[int]):
|
||||||
|
self.beams: List[BeamSearchSequence] = [
|
||||||
|
BeamSearchSequence(tokens=prompt_tokens)
|
||||||
|
]
|
||||||
|
self.completed: List[BeamSearchSequence] = []
|
||||||
|
|
||||||
|
|
||||||
class LLM:
|
class LLM:
|
||||||
"""An LLM for generating texts from given prompts and sampling parameters.
|
"""An LLM for generating texts from given prompts and sampling parameters.
|
||||||
|
|
||||||
@ -354,6 +387,105 @@ class LLM:
|
|||||||
outputs = self._run_engine(use_tqdm=use_tqdm)
|
outputs = self._run_engine(use_tqdm=use_tqdm)
|
||||||
return LLMEngine.validate_outputs(outputs, RequestOutput)
|
return LLMEngine.validate_outputs(outputs, RequestOutput)
|
||||||
|
|
||||||
|
def beam_search(
|
||||||
|
self,
|
||||||
|
prompts: List[Union[str, List[int]]],
|
||||||
|
beam_width: int,
|
||||||
|
max_tokens: int,
|
||||||
|
ignore_eos: bool = False,
|
||||||
|
) -> List[BeamSearchOutput]:
|
||||||
|
"""
|
||||||
|
Generate sequences using beam search.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompts: A list of prompts. Each prompt can be a string or a list
|
||||||
|
of token IDs.
|
||||||
|
beam_width: The number of beams to keep at each step.
|
||||||
|
max_tokens: The max number of tokens to generate for each prompt.
|
||||||
|
|
||||||
|
TODO: how does beam search work together with length penalty, frequency
|
||||||
|
penalty, and stopping criteria, etc.?
|
||||||
|
"""
|
||||||
|
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
# generate 2 * beam_width candidates at each step
|
||||||
|
# following the huggingface transformers implementation
|
||||||
|
# at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
|
||||||
|
beam_search_params = SamplingParams(logprobs=2 * beam_width,
|
||||||
|
max_tokens=1,
|
||||||
|
temperature=0.0)
|
||||||
|
instances: List[BeamSearchInstance] = []
|
||||||
|
|
||||||
|
for prompt in prompts:
|
||||||
|
prompt_tokens = prompt if isinstance(
|
||||||
|
prompt, list) else tokenizer.encode(prompt)
|
||||||
|
instances.append(BeamSearchInstance(prompt_tokens))
|
||||||
|
|
||||||
|
for _ in range(max_tokens):
|
||||||
|
all_beams: List[BeamSearchSequence] = list(
|
||||||
|
sum((instance.beams for instance in instances), []))
|
||||||
|
pos = [0] + list(
|
||||||
|
itertools.accumulate(
|
||||||
|
len(instance.beams) for instance in instances))
|
||||||
|
instance_start_and_end: List[Tuple[int, int]] = list(
|
||||||
|
zip(pos[:-1], pos[1:]))
|
||||||
|
|
||||||
|
if len(all_beams) == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
prompts_batch = [
|
||||||
|
TokensPrompt(prompt_token_ids=beam.tokens)
|
||||||
|
for beam in all_beams
|
||||||
|
]
|
||||||
|
|
||||||
|
# only runs for one step
|
||||||
|
# we don't need to use tqdm here
|
||||||
|
output = self.generate(prompts_batch,
|
||||||
|
sampling_params=beam_search_params,
|
||||||
|
use_tqdm=False)
|
||||||
|
|
||||||
|
for (start, end), instance in zip(instance_start_and_end,
|
||||||
|
instances):
|
||||||
|
instance_new_beams = []
|
||||||
|
for i in range(start, end):
|
||||||
|
current_beam = all_beams[i]
|
||||||
|
result = output[i]
|
||||||
|
|
||||||
|
if result.outputs[0].logprobs is not None:
|
||||||
|
# if `result.outputs[0].logprobs` is None, it means
|
||||||
|
# the sequence is completed because of the max-model-len
|
||||||
|
# or abortion. we don't need to add it to the new beams.
|
||||||
|
logprobs = result.outputs[0].logprobs[0]
|
||||||
|
for token_id, logprob_obj in logprobs.items():
|
||||||
|
new_beam = BeamSearchSequence(
|
||||||
|
tokens=current_beam.tokens + [token_id],
|
||||||
|
cum_logprob=current_beam.cum_logprob +
|
||||||
|
logprob_obj.logprob)
|
||||||
|
|
||||||
|
if token_id == tokenizer.eos_token_id and \
|
||||||
|
not ignore_eos:
|
||||||
|
instance.completed.append(new_beam)
|
||||||
|
else:
|
||||||
|
instance_new_beams.append(new_beam)
|
||||||
|
sorted_beams = sorted(instance_new_beams,
|
||||||
|
key=lambda x: x.cum_logprob,
|
||||||
|
reverse=True)
|
||||||
|
instance.beams = sorted_beams[:beam_width]
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
for instance in instances:
|
||||||
|
instance.completed.extend(instance.beams)
|
||||||
|
sorted_completed = sorted(instance.completed,
|
||||||
|
key=lambda x: x.cum_logprob,
|
||||||
|
reverse=True)
|
||||||
|
best_beams = sorted_completed[:beam_width]
|
||||||
|
|
||||||
|
for beam in best_beams:
|
||||||
|
beam.text = tokenizer.decode(beam.tokens)
|
||||||
|
outputs.append(BeamSearchOutput(sequences=best_beams))
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
def chat(
|
def chat(
|
||||||
self,
|
self,
|
||||||
messages: List[ChatCompletionMessageParam],
|
messages: List[ChatCompletionMessageParam],
|
||||||
|
Loading…
x
Reference in New Issue
Block a user