[BUGFIX] Skip tokenization support for throughput benchmark (#12712)
Signed-off-by: root <root@banff-cyxtera-s73-5.ctr.dcgpu> Signed-off-by: Aleksandr Malyshev <maleksan@amd.com> Co-authored-by: root <root@banff-cyxtera-s73-5.ctr.dcgpu> Co-authored-by: Aleksandr Malyshev <maleksan@amd.com>
This commit is contained in:
parent
cc10281498
commit
0ca3b8e01c
@ -7,7 +7,7 @@ import os
|
|||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
from functools import cache
|
from functools import cache
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import uvloop
|
import uvloop
|
||||||
@ -20,7 +20,7 @@ from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
|||||||
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
||||||
from vllm.entrypoints.openai.api_server import (
|
from vllm.entrypoints.openai.api_server import (
|
||||||
build_async_engine_client_from_engine_args)
|
build_async_engine_client_from_engine_args)
|
||||||
from vllm.inputs import TextPrompt
|
from vllm.inputs import TextPrompt, TokensPrompt
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.lora.utils import get_adapter_absolute_path
|
from vllm.lora.utils import get_adapter_absolute_path
|
||||||
from vllm.multimodal import MultiModalDataDict
|
from vllm.multimodal import MultiModalDataDict
|
||||||
@ -178,10 +178,13 @@ def run_vllm(
|
|||||||
"Please ensure that max_model_len is greater than the sum of"
|
"Please ensure that max_model_len is greater than the sum of"
|
||||||
" prompt_len and expected_output_len for all requests.")
|
" prompt_len and expected_output_len for all requests.")
|
||||||
# Add the requests to the engine.
|
# Add the requests to the engine.
|
||||||
prompts: list[TextPrompt] = []
|
prompts: list[Union[TextPrompt, TokensPrompt]] = []
|
||||||
sampling_params: list[SamplingParams] = []
|
sampling_params: list[SamplingParams] = []
|
||||||
for request in requests:
|
for request in requests:
|
||||||
prompts.append(
|
prompts.append(
|
||||||
|
TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"],
|
||||||
|
multi_modal_data=request.multi_modal_data)
|
||||||
|
if "prompt_token_ids" in request.prompt else \
|
||||||
TextPrompt(prompt=request.prompt,
|
TextPrompt(prompt=request.prompt,
|
||||||
multi_modal_data=request.multi_modal_data))
|
multi_modal_data=request.multi_modal_data))
|
||||||
sampling_params.append(
|
sampling_params.append(
|
||||||
@ -242,11 +245,14 @@ async def run_vllm_async(
|
|||||||
" prompt_len and expected_output_len for all requests.")
|
" prompt_len and expected_output_len for all requests.")
|
||||||
|
|
||||||
# Add the requests to the engine.
|
# Add the requests to the engine.
|
||||||
prompts: list[TextPrompt] = []
|
prompts: list[Union[TextPrompt, TokensPrompt]] = []
|
||||||
sampling_params: list[SamplingParams] = []
|
sampling_params: list[SamplingParams] = []
|
||||||
lora_requests: list[Optional[LoRARequest]] = []
|
lora_requests: list[Optional[LoRARequest]] = []
|
||||||
for request in requests:
|
for request in requests:
|
||||||
prompts.append(
|
prompts.append(
|
||||||
|
TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"],
|
||||||
|
multi_modal_data=request.multi_modal_data)
|
||||||
|
if "prompt_token_ids" in request.prompt else \
|
||||||
TextPrompt(prompt=request.prompt,
|
TextPrompt(prompt=request.prompt,
|
||||||
multi_modal_data=request.multi_modal_data))
|
multi_modal_data=request.multi_modal_data))
|
||||||
sampling_params.append(
|
sampling_params.append(
|
||||||
@ -393,24 +399,29 @@ def main(args: argparse.Namespace):
|
|||||||
random.randint(0, vocab_size - 1)
|
random.randint(0, vocab_size - 1)
|
||||||
for _ in range(args.input_len)
|
for _ in range(args.input_len)
|
||||||
]
|
]
|
||||||
# As tokenizer may add additional tokens like BOS, we need to try
|
|
||||||
# different lengths to get the desired input length.
|
|
||||||
for _ in range(5): # Max attempts to correct
|
|
||||||
candidate_prompt = request_tokenizer.decode(candidate_ids)
|
|
||||||
tokenized_len = len(request_tokenizer.encode(candidate_prompt))
|
|
||||||
|
|
||||||
if tokenized_len == args.input_len:
|
candidate_prompt = {"prompt_token_ids": candidate_ids}
|
||||||
break
|
|
||||||
|
|
||||||
# Adjust length based on difference
|
if not args.skip_tokenizer_init:
|
||||||
diff = args.input_len - tokenized_len
|
# As tokenizer may add additional tokens like BOS, we need
|
||||||
if diff > 0:
|
# to try different lengths to get the desired input length.
|
||||||
candidate_ids.extend([
|
for _ in range(5): # Max attempts to correct
|
||||||
random.randint(100, vocab_size - 100)
|
candidate_prompt = request_tokenizer.decode(candidate_ids)
|
||||||
for _ in range(diff)
|
tokenized_len = len(
|
||||||
])
|
request_tokenizer.encode(candidate_prompt))
|
||||||
else:
|
|
||||||
candidate_ids = candidate_ids[:diff]
|
if tokenized_len == args.input_len:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Adjust length based on difference
|
||||||
|
diff = args.input_len - tokenized_len
|
||||||
|
if diff > 0:
|
||||||
|
candidate_ids.extend([
|
||||||
|
random.randint(100, vocab_size - 100)
|
||||||
|
for _ in range(diff)
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
candidate_ids = candidate_ids[:diff]
|
||||||
requests.append(
|
requests.append(
|
||||||
SampleRequest(prompt=candidate_prompt,
|
SampleRequest(prompt=candidate_prompt,
|
||||||
prompt_len=args.input_len,
|
prompt_len=args.input_len,
|
||||||
|
@ -276,7 +276,9 @@ class EngineArgs:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--skip-tokenizer-init',
|
'--skip-tokenizer-init',
|
||||||
action='store_true',
|
action='store_true',
|
||||||
help='Skip initialization of tokenizer and detokenizer.')
|
help='Skip initialization of tokenizer and detokenizer. '
|
||||||
|
'Expects valid prompt_token_ids and None for prompt from '
|
||||||
|
'the input. The generated output will contain token ids.')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--revision',
|
'--revision',
|
||||||
type=nullable_str,
|
type=nullable_str,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user