vllm/benchmarks/backend_request_func.py

476 lines
18 KiB
Python
Raw Normal View History

2024-02-12 22:53:00 -08:00
import json
import os
import sys
2024-02-12 22:53:00 -08:00
import time
import traceback
from dataclasses import dataclass, field
from typing import List, Optional, Union
2024-02-12 22:53:00 -08:00
import aiohttp
import huggingface_hub.constants
2024-02-12 22:53:00 -08:00
from tqdm.asyncio import tqdm
from transformers import (AutoTokenizer, PreTrainedTokenizer,
PreTrainedTokenizerFast)
2024-02-12 22:53:00 -08:00
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
@dataclass
class RequestFuncInput:
prompt: str
api_url: str
prompt_len: int
output_len: int
model: str
model_name: Optional[str] = None
2024-02-12 22:53:00 -08:00
best_of: int = 1
logprobs: Optional[int] = None
extra_body: Optional[dict] = None
multi_modal_content: Optional[dict] = None
ignore_eos: bool = False
2024-02-12 22:53:00 -08:00
@dataclass
class RequestFuncOutput:
generated_text: str = ""
success: bool = False
latency: float = 0.0
output_tokens: int = 0
ttft: float = 0.0 # Time to first token
itl: List[float] = field(
default_factory=list) # List of inter-token latencies
tpot: float = 0.0 # avg next-token latencies
2024-02-12 22:53:00 -08:00
prompt_len: int = 0
error: str = ""
2024-02-12 22:53:00 -08:00
async def async_request_tgi(
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
assert api_url.endswith("generate_stream")
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
params = {
"best_of": request_func_input.best_of,
"max_new_tokens": request_func_input.output_len,
"do_sample": True,
"temperature": 0.01, # TGI does not accept 0.0 temperature.
"top_p": 0.99, # TGI does not accept 1.0 top_p.
"truncate": request_func_input.prompt_len,
# TGI does not accept ignore_eos flag.
2024-02-12 22:53:00 -08:00
}
payload = {
"inputs": request_func_input.prompt,
"parameters": params,
}
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
ttft = 0.0
2024-02-12 22:53:00 -08:00
st = time.perf_counter()
most_recent_timestamp = st
2024-02-12 22:53:00 -08:00
try:
async with session.post(url=api_url, json=payload) as response:
if response.status == 200:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
chunk_bytes = chunk_bytes.decode("utf-8")
2024-02-12 22:53:00 -08:00
# NOTE: Sometimes TGI returns a ping response without
# any data, we should skip it.
if chunk_bytes.startswith(":"):
continue
chunk = chunk_bytes.removeprefix("data:")
2024-02-12 22:53:00 -08:00
data = json.loads(chunk)
timestamp = time.perf_counter()
# First token
if ttft == 0.0:
2024-02-12 22:53:00 -08:00
ttft = time.perf_counter() - st
output.ttft = ttft
# Decoding phase
else:
output.itl.append(timestamp -
most_recent_timestamp)
2024-02-12 22:53:00 -08:00
most_recent_timestamp = timestamp
output.latency = most_recent_timestamp - st
output.success = True
output.generated_text = data["generated_text"]
else:
output.error = response.reason or ""
output.success = False
except Exception:
2024-02-12 22:53:00 -08:00
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
2024-02-12 22:53:00 -08:00
if pbar:
pbar.update(1)
return output
async def async_request_trt_llm(
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
assert api_url.endswith("generate_stream")
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
assert request_func_input.best_of == 1
payload = {
"accumulate_tokens": True,
"text_input": request_func_input.prompt,
"temperature": 0.0,
"top_p": 1.0,
"max_tokens": request_func_input.output_len,
"stream": True,
}
if request_func_input.ignore_eos:
payload["min_length"] = request_func_input.output_len
2024-02-12 22:53:00 -08:00
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
ttft = 0.0
2024-02-12 22:53:00 -08:00
st = time.perf_counter()
most_recent_timestamp = st
2024-02-12 22:53:00 -08:00
try:
async with session.post(url=api_url, json=payload) as response:
if response.status == 200:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
chunk = chunk_bytes.decode("utf-8").removeprefix(
"data:")
data = json.loads(chunk)
output.generated_text += data["text_output"]
timestamp = time.perf_counter()
# First token
if ttft == 0.0:
ttft = timestamp - st
2024-02-12 22:53:00 -08:00
output.ttft = ttft
# Decoding phase
else:
output.itl.append(timestamp -
most_recent_timestamp)
most_recent_timestamp = timestamp
output.latency = most_recent_timestamp - st
2024-02-12 22:53:00 -08:00
output.success = True
else:
output.error = response.reason or ""
2024-02-12 22:53:00 -08:00
output.success = False
except Exception:
2024-02-12 22:53:00 -08:00
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
2024-02-12 22:53:00 -08:00
if pbar:
pbar.update(1)
return output
async def async_request_deepspeed_mii(
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
assert request_func_input.best_of == 1
payload = {
"prompt": request_func_input.prompt,
"max_tokens": request_func_input.output_len,
"temperature": 0.01, # deepspeed-mii does not accept 0.0 temp.
2024-02-12 22:53:00 -08:00
"top_p": 1.0,
}
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
# NOTE: DeepSpeed-MII doesn't support streaming as of Jan 28 2024,
# will use 0 as placeholder.
# See https://github.com/microsoft/DeepSpeed-MII/pull/311
2024-02-12 22:53:00 -08:00
output.ttft = 0
st = time.perf_counter()
try:
async with session.post(url=request_func_input.api_url,
json=payload) as response:
if response.status == 200:
parsed_resp = await response.json()
2024-02-12 22:53:00 -08:00
output.latency = time.perf_counter() - st
output.generated_text = parsed_resp["text"][0]
2024-02-12 22:53:00 -08:00
output.success = True
else:
output.error = response.reason or ""
2024-02-12 22:53:00 -08:00
output.success = False
except Exception:
2024-02-12 22:53:00 -08:00
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
2024-02-12 22:53:00 -08:00
if pbar:
pbar.update(1)
return output
async def async_request_openai_completions(
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
assert api_url.endswith(
("completions", "profile")
), "OpenAI Completions API URL must end with 'completions' or 'profile'."
2024-02-12 22:53:00 -08:00
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
payload = {
"model": request_func_input.model_name \
if request_func_input.model_name else request_func_input.model,
2024-02-12 22:53:00 -08:00
"prompt": request_func_input.prompt,
"temperature": 0.0,
"best_of": request_func_input.best_of,
"max_tokens": request_func_input.output_len,
"logprobs": request_func_input.logprobs,
2024-02-12 22:53:00 -08:00
"stream": True,
"ignore_eos": request_func_input.ignore_eos,
"stream_options": {
"include_usage": True,
},
2024-02-12 22:53:00 -08:00
}
if request_func_input.extra_body:
payload.update(request_func_input.extra_body)
2024-02-12 22:53:00 -08:00
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
}
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
generated_text = ""
st = time.perf_counter()
most_recent_timestamp = st
2024-02-12 22:53:00 -08:00
try:
async with session.post(url=api_url, json=payload,
headers=headers) as response:
if response.status == 200:
first_chunk_received = False
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
2024-02-12 22:53:00 -08:00
continue
chunk = chunk_bytes.decode("utf-8").removeprefix(
"data: ")
if chunk != "[DONE]":
data = json.loads(chunk)
# NOTE: Some completion API might have a last
# usage summary response without a token so we
# want to check a token was generated
if choices := data.get("choices"):
# Note that text could be empty here
# e.g. for special tokens
text = choices[0].get("text")
timestamp = time.perf_counter()
# First token
if not first_chunk_received:
first_chunk_received = True
ttft = time.perf_counter() - st
output.ttft = ttft
# Decoding phase
else:
output.itl.append(timestamp -
most_recent_timestamp)
most_recent_timestamp = timestamp
generated_text += text
elif usage := data.get("usage"):
output.output_tokens = usage.get(
"completion_tokens")
if first_chunk_received:
output.success = True
else:
output.success = False
output.error = (
"Never received a valid chunk to calculate TTFT."
"This response will be marked as failed!")
2024-02-12 22:53:00 -08:00
output.generated_text = generated_text
output.latency = most_recent_timestamp - st
else:
output.error = response.reason or ""
output.success = False
except Exception:
2024-02-12 22:53:00 -08:00
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
2024-02-12 22:53:00 -08:00
if pbar:
pbar.update(1)
return output
async def async_request_openai_chat_completions(
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
assert api_url.endswith(
"chat/completions"
), "OpenAI Chat Completions API URL must end with 'chat/completions'."
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
content = [{"type": "text", "text": request_func_input.prompt}]
if request_func_input.multi_modal_content:
content.append(request_func_input.multi_modal_content)
payload = {
"model": request_func_input.model_name \
if request_func_input.model_name else request_func_input.model,
"messages": [
{
"role": "user",
"content": content
},
],
"temperature": 0.0,
"max_completion_tokens": request_func_input.output_len,
"stream": True,
"ignore_eos": request_func_input.ignore_eos,
"stream_options": {
"include_usage": True,
},
}
if request_func_input.extra_body:
payload.update(request_func_input.extra_body)
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
}
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
generated_text = ""
ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(url=api_url, json=payload,
headers=headers) as response:
if response.status == 200:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
chunk = chunk_bytes.decode("utf-8").removeprefix(
"data: ")
if chunk != "[DONE]":
timestamp = time.perf_counter()
data = json.loads(chunk)
if choices := data.get("choices"):
content = choices[0]["delta"].get("content")
# First token
if ttft == 0.0:
ttft = timestamp - st
output.ttft = ttft
# Decoding phase
else:
output.itl.append(timestamp -
most_recent_timestamp)
generated_text += content
elif usage := data.get("usage"):
output.output_tokens = usage.get(
"completion_tokens")
most_recent_timestamp = timestamp
output.generated_text = generated_text
output.success = True
output.latency = most_recent_timestamp - st
else:
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
if pbar:
pbar.update(1)
return output
def get_model(pretrained_model_name_or_path: str) -> str:
if os.getenv('VLLM_USE_MODELSCOPE', 'False').lower() == 'true':
from modelscope import snapshot_download
model_path = snapshot_download(
model_id=pretrained_model_name_or_path,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"])
return model_path
return pretrained_model_name_or_path
def get_tokenizer(
pretrained_model_name_or_path: str,
tokenizer_mode: str = "auto",
trust_remote_code: bool = False,
**kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
if pretrained_model_name_or_path is not None and not os.path.exists(
pretrained_model_name_or_path):
pretrained_model_name_or_path = get_model(
pretrained_model_name_or_path)
if tokenizer_mode == "slow":
if kwargs.get("use_fast", False):
raise ValueError(
"Cannot use the fast tokenizer in slow tokenizer mode.")
kwargs["use_fast"] = False
if tokenizer_mode == "mistral":
try:
from vllm.transformers_utils.tokenizer import MistralTokenizer
except ImportError as e:
raise ImportError("MistralTokenizer requires vllm package.\n"
"Please install it with `pip install vllm` "
"to use mistral tokenizer mode.") from e
return MistralTokenizer.from_pretrained(
str(pretrained_model_name_or_path))
else:
return AutoTokenizer.from_pretrained(
pretrained_model_name_or_path,
trust_remote_code=trust_remote_code,
**kwargs,
)
2024-02-12 22:53:00 -08:00
ASYNC_REQUEST_FUNCS = {
"tgi": async_request_tgi,
"vllm": async_request_openai_completions,
"lmdeploy": async_request_openai_completions,
2024-02-12 22:53:00 -08:00
"deepspeed-mii": async_request_deepspeed_mii,
"openai": async_request_openai_completions,
"openai-chat": async_request_openai_chat_completions,
2024-02-12 22:53:00 -08:00
"tensorrt-llm": async_request_trt_llm,
"scalellm": async_request_openai_completions,
"sglang": async_request_openai_completions,
2024-02-12 22:53:00 -08:00
}