[Bugfix] fix beam search input errors and latency benchmark script (#11875)
Signed-off-by: Ye Qi <yeq@meta.com> Co-authored-by: yeq <yeq@devgpu004.lla3.facebook.com>
This commit is contained in:
parent
0bd1ff4346
commit
1d967acb45
@ -13,6 +13,7 @@ from tqdm import tqdm
|
|||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.engine.arg_utils import EngineArgs
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
from vllm.inputs import PromptType
|
from vllm.inputs import PromptType
|
||||||
|
from vllm.sampling_params import BeamSearchParams
|
||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
|
|
||||||
@ -40,6 +41,20 @@ def main(args: argparse.Namespace):
|
|||||||
"prompt_token_ids": batch
|
"prompt_token_ids": batch
|
||||||
} for batch in dummy_prompt_token_ids.tolist()]
|
} for batch in dummy_prompt_token_ids.tolist()]
|
||||||
|
|
||||||
|
def llm_generate():
|
||||||
|
if not args.use_beam_search:
|
||||||
|
llm.generate(dummy_prompts,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
use_tqdm=False)
|
||||||
|
else:
|
||||||
|
llm.beam_search(
|
||||||
|
dummy_prompts,
|
||||||
|
BeamSearchParams(
|
||||||
|
beam_width=args.n,
|
||||||
|
max_tokens=args.output_len,
|
||||||
|
ignore_eos=True,
|
||||||
|
))
|
||||||
|
|
||||||
def run_to_completion(profile_dir: Optional[str] = None):
|
def run_to_completion(profile_dir: Optional[str] = None):
|
||||||
if profile_dir:
|
if profile_dir:
|
||||||
with torch.profiler.profile(
|
with torch.profiler.profile(
|
||||||
@ -49,15 +64,11 @@ def main(args: argparse.Namespace):
|
|||||||
],
|
],
|
||||||
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
||||||
str(profile_dir))) as p:
|
str(profile_dir))) as p:
|
||||||
llm.generate(dummy_prompts,
|
llm_generate()
|
||||||
sampling_params=sampling_params,
|
|
||||||
use_tqdm=False)
|
|
||||||
print(p.key_averages().table(sort_by="self_cuda_time_total"))
|
print(p.key_averages().table(sort_by="self_cuda_time_total"))
|
||||||
else:
|
else:
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
llm.generate(dummy_prompts,
|
llm_generate()
|
||||||
sampling_params=sampling_params,
|
|
||||||
use_tqdm=False)
|
|
||||||
end_time = time.perf_counter()
|
end_time = time.perf_counter()
|
||||||
latency = end_time - start_time
|
latency = end_time - start_time
|
||||||
return latency
|
return latency
|
||||||
|
@ -21,7 +21,7 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
|
|||||||
parse_chat_messages,
|
parse_chat_messages,
|
||||||
resolve_chat_template_content_format)
|
resolve_chat_template_content_format)
|
||||||
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
|
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
|
||||||
from vllm.inputs.parse import parse_and_batch_prompt
|
from vllm.inputs.parse import is_token_prompt, parse_and_batch_prompt
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.model_executor.guided_decoding.guided_fields import (
|
from vllm.model_executor.guided_decoding.guided_fields import (
|
||||||
@ -457,7 +457,7 @@ class LLM:
|
|||||||
|
|
||||||
def beam_search(
|
def beam_search(
|
||||||
self,
|
self,
|
||||||
prompts: List[Union[str, List[int]]],
|
prompts: List[Union[TokensPrompt, TextPrompt]],
|
||||||
params: BeamSearchParams,
|
params: BeamSearchParams,
|
||||||
) -> List[BeamSearchOutput]:
|
) -> List[BeamSearchOutput]:
|
||||||
"""
|
"""
|
||||||
@ -493,8 +493,10 @@ class LLM:
|
|||||||
instances: List[BeamSearchInstance] = []
|
instances: List[BeamSearchInstance] = []
|
||||||
|
|
||||||
for prompt in prompts:
|
for prompt in prompts:
|
||||||
prompt_tokens = prompt if isinstance(
|
if is_token_prompt(prompt):
|
||||||
prompt, list) else tokenizer.encode(prompt)
|
prompt_tokens = prompt["prompt_token_ids"]
|
||||||
|
else:
|
||||||
|
prompt_tokens = tokenizer.encode(prompt["prompt"])
|
||||||
instances.append(BeamSearchInstance(prompt_tokens))
|
instances.append(BeamSearchInstance(prompt_tokens))
|
||||||
|
|
||||||
for _ in range(max_tokens):
|
for _ in range(max_tokens):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user