[BUGFIX] Skip tokenization support for throughput benchmark (#12712)

Signed-off-by: root <root@banff-cyxtera-s73-5.ctr.dcgpu>
Signed-off-by: Aleksandr Malyshev <maleksan@amd.com>
Co-authored-by: root <root@banff-cyxtera-s73-5.ctr.dcgpu>
Co-authored-by: Aleksandr Malyshev <maleksan@amd.com>
This commit is contained in:
Aleksandr Malyshev 2025-03-07 02:51:47 -08:00 committed by GitHub
parent cc10281498
commit 0ca3b8e01c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 34 additions and 21 deletions

View File

@ -7,7 +7,7 @@ import os
import random import random
import time import time
from functools import cache from functools import cache
from typing import Any, Optional from typing import Any, Optional, Union
import torch import torch
import uvloop import uvloop
@ -20,7 +20,7 @@ from transformers import (AutoModelForCausalLM, AutoTokenizer,
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.entrypoints.openai.api_server import ( from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args) build_async_engine_client_from_engine_args)
from vllm.inputs import TextPrompt from vllm.inputs import TextPrompt, TokensPrompt
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.lora.utils import get_adapter_absolute_path from vllm.lora.utils import get_adapter_absolute_path
from vllm.multimodal import MultiModalDataDict from vllm.multimodal import MultiModalDataDict
@ -178,10 +178,13 @@ def run_vllm(
"Please ensure that max_model_len is greater than the sum of" "Please ensure that max_model_len is greater than the sum of"
" prompt_len and expected_output_len for all requests.") " prompt_len and expected_output_len for all requests.")
# Add the requests to the engine. # Add the requests to the engine.
prompts: list[TextPrompt] = [] prompts: list[Union[TextPrompt, TokensPrompt]] = []
sampling_params: list[SamplingParams] = [] sampling_params: list[SamplingParams] = []
for request in requests: for request in requests:
prompts.append( prompts.append(
TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"],
multi_modal_data=request.multi_modal_data)
if "prompt_token_ids" in request.prompt else \
TextPrompt(prompt=request.prompt, TextPrompt(prompt=request.prompt,
multi_modal_data=request.multi_modal_data)) multi_modal_data=request.multi_modal_data))
sampling_params.append( sampling_params.append(
@ -242,11 +245,14 @@ async def run_vllm_async(
" prompt_len and expected_output_len for all requests.") " prompt_len and expected_output_len for all requests.")
# Add the requests to the engine. # Add the requests to the engine.
prompts: list[TextPrompt] = [] prompts: list[Union[TextPrompt, TokensPrompt]] = []
sampling_params: list[SamplingParams] = [] sampling_params: list[SamplingParams] = []
lora_requests: list[Optional[LoRARequest]] = [] lora_requests: list[Optional[LoRARequest]] = []
for request in requests: for request in requests:
prompts.append( prompts.append(
TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"],
multi_modal_data=request.multi_modal_data)
if "prompt_token_ids" in request.prompt else \
TextPrompt(prompt=request.prompt, TextPrompt(prompt=request.prompt,
multi_modal_data=request.multi_modal_data)) multi_modal_data=request.multi_modal_data))
sampling_params.append( sampling_params.append(
@ -393,11 +399,16 @@ def main(args: argparse.Namespace):
random.randint(0, vocab_size - 1) random.randint(0, vocab_size - 1)
for _ in range(args.input_len) for _ in range(args.input_len)
] ]
# As tokenizer may add additional tokens like BOS, we need to try
# different lengths to get the desired input length. candidate_prompt = {"prompt_token_ids": candidate_ids}
if not args.skip_tokenizer_init:
# As tokenizer may add additional tokens like BOS, we need
# to try different lengths to get the desired input length.
for _ in range(5): # Max attempts to correct for _ in range(5): # Max attempts to correct
candidate_prompt = request_tokenizer.decode(candidate_ids) candidate_prompt = request_tokenizer.decode(candidate_ids)
tokenized_len = len(request_tokenizer.encode(candidate_prompt)) tokenized_len = len(
request_tokenizer.encode(candidate_prompt))
if tokenized_len == args.input_len: if tokenized_len == args.input_len:
break break

View File

@ -276,7 +276,9 @@ class EngineArgs:
parser.add_argument( parser.add_argument(
'--skip-tokenizer-init', '--skip-tokenizer-init',
action='store_true', action='store_true',
help='Skip initialization of tokenizer and detokenizer.') help='Skip initialization of tokenizer and detokenizer. '
'Expects valid prompt_token_ids and None for prompt from '
'the input. The generated output will contain token ids.')
parser.add_argument( parser.add_argument(
'--revision', '--revision',
type=nullable_str, type=nullable_str,