[BUGFIX] Skip tokenization support for throughput benchmark (#12712)

Signed-off-by: root <root@banff-cyxtera-s73-5.ctr.dcgpu>
Signed-off-by: Aleksandr Malyshev <maleksan@amd.com>
Co-authored-by: root <root@banff-cyxtera-s73-5.ctr.dcgpu>
Co-authored-by: Aleksandr Malyshev <maleksan@amd.com>
This commit is contained in:
Aleksandr Malyshev 2025-03-07 02:51:47 -08:00 committed by GitHub
parent cc10281498
commit 0ca3b8e01c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 34 additions and 21 deletions

View File

@ -7,7 +7,7 @@ import os
import random
import time
from functools import cache
from typing import Any, Optional
from typing import Any, Optional, Union
import torch
import uvloop
@ -20,7 +20,7 @@ from transformers import (AutoModelForCausalLM, AutoTokenizer,
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args)
from vllm.inputs import TextPrompt
from vllm.inputs import TextPrompt, TokensPrompt
from vllm.lora.request import LoRARequest
from vllm.lora.utils import get_adapter_absolute_path
from vllm.multimodal import MultiModalDataDict
@ -178,10 +178,13 @@ def run_vllm(
"Please ensure that max_model_len is greater than the sum of"
" prompt_len and expected_output_len for all requests.")
# Add the requests to the engine.
prompts: list[TextPrompt] = []
prompts: list[Union[TextPrompt, TokensPrompt]] = []
sampling_params: list[SamplingParams] = []
for request in requests:
prompts.append(
TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"],
multi_modal_data=request.multi_modal_data)
if "prompt_token_ids" in request.prompt else \
TextPrompt(prompt=request.prompt,
multi_modal_data=request.multi_modal_data))
sampling_params.append(
@ -242,11 +245,14 @@ async def run_vllm_async(
" prompt_len and expected_output_len for all requests.")
# Add the requests to the engine.
prompts: list[TextPrompt] = []
prompts: list[Union[TextPrompt, TokensPrompt]] = []
sampling_params: list[SamplingParams] = []
lora_requests: list[Optional[LoRARequest]] = []
for request in requests:
prompts.append(
TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"],
multi_modal_data=request.multi_modal_data)
if "prompt_token_ids" in request.prompt else \
TextPrompt(prompt=request.prompt,
multi_modal_data=request.multi_modal_data))
sampling_params.append(
@ -393,24 +399,29 @@ def main(args: argparse.Namespace):
random.randint(0, vocab_size - 1)
for _ in range(args.input_len)
]
# As tokenizer may add additional tokens like BOS, we need to try
# different lengths to get the desired input length.
for _ in range(5): # Max attempts to correct
candidate_prompt = request_tokenizer.decode(candidate_ids)
tokenized_len = len(request_tokenizer.encode(candidate_prompt))
if tokenized_len == args.input_len:
break
candidate_prompt = {"prompt_token_ids": candidate_ids}
# Adjust length based on difference
diff = args.input_len - tokenized_len
if diff > 0:
candidate_ids.extend([
random.randint(100, vocab_size - 100)
for _ in range(diff)
])
else:
candidate_ids = candidate_ids[:diff]
if not args.skip_tokenizer_init:
# As tokenizer may add additional tokens like BOS, we need
# to try different lengths to get the desired input length.
for _ in range(5): # Max attempts to correct
candidate_prompt = request_tokenizer.decode(candidate_ids)
tokenized_len = len(
request_tokenizer.encode(candidate_prompt))
if tokenized_len == args.input_len:
break
# Adjust length based on difference
diff = args.input_len - tokenized_len
if diff > 0:
candidate_ids.extend([
random.randint(100, vocab_size - 100)
for _ in range(diff)
])
else:
candidate_ids = candidate_ids[:diff]
requests.append(
SampleRequest(prompt=candidate_prompt,
prompt_len=args.input_len,

View File

@ -276,7 +276,9 @@ class EngineArgs:
parser.add_argument(
'--skip-tokenizer-init',
action='store_true',
help='Skip initialization of tokenizer and detokenizer.')
help='Skip initialization of tokenizer and detokenizer. '
'Expects valid prompt_token_ids and None for prompt from '
'the input. The generated output will contain token ids.')
parser.add_argument(
'--revision',
type=nullable_str,