[Bugfix] fix beam search input errors and latency benchmark script (#11875)

Signed-off-by: Ye Qi <yeq@meta.com>
Co-authored-by: yeq <yeq@devgpu004.lla3.facebook.com>
This commit is contained in:
Ye (Charlotte) Qi 2025-01-09 01:36:39 -08:00 committed by GitHub
parent 0bd1ff4346
commit 1d967acb45
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 23 additions and 10 deletions

View File

@ -13,6 +13,7 @@ from tqdm import tqdm
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.inputs import PromptType from vllm.inputs import PromptType
from vllm.sampling_params import BeamSearchParams
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
@ -40,6 +41,20 @@ def main(args: argparse.Namespace):
"prompt_token_ids": batch "prompt_token_ids": batch
} for batch in dummy_prompt_token_ids.tolist()] } for batch in dummy_prompt_token_ids.tolist()]
def llm_generate():
if not args.use_beam_search:
llm.generate(dummy_prompts,
sampling_params=sampling_params,
use_tqdm=False)
else:
llm.beam_search(
dummy_prompts,
BeamSearchParams(
beam_width=args.n,
max_tokens=args.output_len,
ignore_eos=True,
))
def run_to_completion(profile_dir: Optional[str] = None): def run_to_completion(profile_dir: Optional[str] = None):
if profile_dir: if profile_dir:
with torch.profiler.profile( with torch.profiler.profile(
@ -49,15 +64,11 @@ def main(args: argparse.Namespace):
], ],
on_trace_ready=torch.profiler.tensorboard_trace_handler( on_trace_ready=torch.profiler.tensorboard_trace_handler(
str(profile_dir))) as p: str(profile_dir))) as p:
llm.generate(dummy_prompts, llm_generate()
sampling_params=sampling_params,
use_tqdm=False)
print(p.key_averages().table(sort_by="self_cuda_time_total")) print(p.key_averages().table(sort_by="self_cuda_time_total"))
else: else:
start_time = time.perf_counter() start_time = time.perf_counter()
llm.generate(dummy_prompts, llm_generate()
sampling_params=sampling_params,
use_tqdm=False)
end_time = time.perf_counter() end_time = time.perf_counter()
latency = end_time - start_time latency = end_time - start_time
return latency return latency

View File

@ -21,7 +21,7 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
parse_chat_messages, parse_chat_messages,
resolve_chat_template_content_format) resolve_chat_template_content_format)
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
from vllm.inputs.parse import parse_and_batch_prompt from vllm.inputs.parse import is_token_prompt, parse_and_batch_prompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding.guided_fields import ( from vllm.model_executor.guided_decoding.guided_fields import (
@ -457,7 +457,7 @@ class LLM:
def beam_search( def beam_search(
self, self,
prompts: List[Union[str, List[int]]], prompts: List[Union[TokensPrompt, TextPrompt]],
params: BeamSearchParams, params: BeamSearchParams,
) -> List[BeamSearchOutput]: ) -> List[BeamSearchOutput]:
""" """
@ -493,8 +493,10 @@ class LLM:
instances: List[BeamSearchInstance] = [] instances: List[BeamSearchInstance] = []
for prompt in prompts: for prompt in prompts:
prompt_tokens = prompt if isinstance( if is_token_prompt(prompt):
prompt, list) else tokenizer.encode(prompt) prompt_tokens = prompt["prompt_token_ids"]
else:
prompt_tokens = tokenizer.encode(prompt["prompt"])
instances.append(BeamSearchInstance(prompt_tokens)) instances.append(BeamSearchInstance(prompt_tokens))
for _ in range(max_tokens): for _ in range(max_tokens):