vllm/benchmarks/backend_request_func.py

390 lines
14 KiB
Python
Raw Normal View History

2024-02-12 22:53:00 -08:00
import json
import os
import sys
2024-02-12 22:53:00 -08:00
import time
import traceback
from dataclasses import dataclass, field
from typing import List, Optional
2024-02-12 22:53:00 -08:00
import aiohttp
from tqdm.asyncio import tqdm
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
@dataclass
class RequestFuncInput:
prompt: str
api_url: str
prompt_len: int
output_len: int
model: str
best_of: int = 1
use_beam_search: bool = False
@dataclass
class RequestFuncOutput:
generated_text: str = ""
success: bool = False
latency: float = 0.0
ttft: float = 0.0 # Time to first token
itl: List[float] = field(
default_factory=list) # List of inter-token latencies
2024-02-12 22:53:00 -08:00
prompt_len: int = 0
error: str = ""
2024-02-12 22:53:00 -08:00
async def async_request_tgi(
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
assert api_url.endswith("generate_stream")
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
assert not request_func_input.use_beam_search
params = {
"best_of": request_func_input.best_of,
"max_new_tokens": request_func_input.output_len,
"do_sample": True,
"temperature": 0.01, # TGI does not accept 0.0 temperature.
"top_p": 0.99, # TGI does not accept 1.0 top_p.
}
payload = {
"inputs": request_func_input.prompt,
"parameters": params,
}
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
ttft = 0.0
2024-02-12 22:53:00 -08:00
st = time.perf_counter()
most_recent_timestamp = st
2024-02-12 22:53:00 -08:00
try:
async with session.post(url=api_url, json=payload) as response:
if response.status == 200:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
2024-02-12 22:53:00 -08:00
chunk = remove_prefix(chunk_bytes.decode("utf-8"),
"data:")
2024-02-12 22:53:00 -08:00
data = json.loads(chunk)
timestamp = time.perf_counter()
# First token
if ttft == 0.0:
2024-02-12 22:53:00 -08:00
ttft = time.perf_counter() - st
output.ttft = ttft
# Decoding phase
else:
output.itl.append(timestamp -
most_recent_timestamp)
2024-02-12 22:53:00 -08:00
most_recent_timestamp = timestamp
output.latency = most_recent_timestamp - st
output.success = True
output.generated_text = data["generated_text"]
except Exception:
2024-02-12 22:53:00 -08:00
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
2024-02-12 22:53:00 -08:00
if pbar:
pbar.update(1)
return output
async def async_request_trt_llm(
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
assert api_url.endswith("generate_stream")
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
assert not request_func_input.use_beam_search
assert request_func_input.best_of == 1
payload = {
"accumulate_tokens": True,
"text_input": request_func_input.prompt,
"temperature": 0.0,
"top_p": 1.0,
"max_tokens": request_func_input.output_len,
"stream": True,
}
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
ttft = 0.0
2024-02-12 22:53:00 -08:00
st = time.perf_counter()
most_recent_timestamp = st
2024-02-12 22:53:00 -08:00
try:
async with session.post(url=api_url, json=payload) as response:
if response.status == 200:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
chunk = remove_prefix(chunk_bytes.decode("utf-8"),
"data:")
data = json.loads(chunk)
timestamp = time.perf_counter()
# First token
if ttft == 0.0:
2024-02-12 22:53:00 -08:00
ttft = time.perf_counter() - st
output.ttft = ttft
# Decoding phase
else:
output.itl.append(timestamp -
most_recent_timestamp)
most_recent_timestamp = timestamp
output.latency = most_recent_timestamp - st
output.generated_text = json.loads(data)["text_output"]
2024-02-12 22:53:00 -08:00
output.success = True
else:
output.error = response.reason or ""
2024-02-12 22:53:00 -08:00
output.success = False
except Exception:
2024-02-12 22:53:00 -08:00
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
2024-02-12 22:53:00 -08:00
if pbar:
pbar.update(1)
return output
async def async_request_deepspeed_mii(
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
assert request_func_input.best_of == 1
assert not request_func_input.use_beam_search
payload = {
"prompt": request_func_input.prompt,
"max_tokens": request_func_input.output_len,
"temperature": 0.01, # deepspeed-mii does not accept 0.0 temp.
2024-02-12 22:53:00 -08:00
"top_p": 1.0,
}
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
# NOTE: DeepSpeed-MII doesn't support streaming as of Jan 28 2024,
# will use 0 as placeholder.
# See https://github.com/microsoft/DeepSpeed-MII/pull/311
2024-02-12 22:53:00 -08:00
output.ttft = 0
st = time.perf_counter()
try:
async with session.post(url=request_func_input.api_url,
json=payload) as response:
if response.status == 200:
parsed_resp = await response.json()
2024-02-12 22:53:00 -08:00
output.latency = time.perf_counter() - st
output.generated_text = parsed_resp["text"][0]
2024-02-12 22:53:00 -08:00
output.success = True
else:
output.error = response.reason or ""
2024-02-12 22:53:00 -08:00
output.success = False
except Exception:
2024-02-12 22:53:00 -08:00
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
2024-02-12 22:53:00 -08:00
if pbar:
pbar.update(1)
return output
async def async_request_openai_completions(
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
assert api_url.endswith(
"v1/completions"
), "OpenAI Completions API URL must end with 'v1/completions'."
2024-02-12 22:53:00 -08:00
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
assert not request_func_input.use_beam_search
payload = {
"model": request_func_input.model,
"prompt": request_func_input.prompt,
"temperature": 0.0,
"best_of": request_func_input.best_of,
"max_tokens": request_func_input.output_len,
"stream": True,
}
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
}
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
generated_text = ""
ttft = 0.0
2024-02-12 22:53:00 -08:00
st = time.perf_counter()
most_recent_timestamp = st
2024-02-12 22:53:00 -08:00
try:
async with session.post(url=api_url, json=payload,
headers=headers) as response:
if response.status == 200:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
2024-02-12 22:53:00 -08:00
continue
chunk = remove_prefix(chunk_bytes.decode("utf-8"),
"data: ")
2024-02-12 22:53:00 -08:00
if chunk == "[DONE]":
latency = time.perf_counter() - st
else:
data = json.loads(chunk)
if data["choices"][0]["text"]:
timestamp = time.perf_counter()
# First token
if ttft == 0.0:
ttft = time.perf_counter() - st
output.ttft = ttft
# Decoding phase
# NOTE: Some completion API might have a last
# usage summary response without a token so we
# do not want to include as inter-token-latency
elif data.get("usage", None) is None:
output.itl.append(timestamp -
most_recent_timestamp)
most_recent_timestamp = timestamp
generated_text += data["choices"][0]["text"]
2024-02-12 22:53:00 -08:00
output.generated_text = generated_text
output.success = True
output.latency = latency
except Exception:
2024-02-12 22:53:00 -08:00
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
2024-02-12 22:53:00 -08:00
if pbar:
pbar.update(1)
return output
async def async_request_openai_chat_completions(
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
assert api_url.endswith(
"v1/chat/completions"
), "OpenAI Chat Completions API URL must end with 'v1/chat/completions'."
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
assert not request_func_input.use_beam_search
payload = {
"model": request_func_input.model,
"messages": [
{
"role": "user",
"content": request_func_input.prompt,
},
],
"temperature": 0.0,
"max_tokens": request_func_input.output_len,
"stream": True,
}
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
}
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
generated_text = ""
ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(url=api_url, json=payload,
headers=headers) as response:
if response.status == 200:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
chunk = remove_prefix(chunk_bytes.decode("utf-8"),
"data: ")
if chunk == "[DONE]":
latency = time.perf_counter() - st
else:
timestamp = time.perf_counter()
data = json.loads(chunk)
delta = data["choices"][0]["delta"]
if delta.get("content", None):
# First token
if ttft == 0.0:
ttft = time.perf_counter() - st
output.ttft = ttft
# Decoding phase
else:
output.itl.append(timestamp -
most_recent_timestamp)
generated_text += delta["content"]
most_recent_timestamp = timestamp
output.generated_text = generated_text
output.success = True
output.latency = latency
else:
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
if pbar:
pbar.update(1)
return output
# Since vllm must support Python 3.8, we can't use str.removeprefix(prefix)
# introduced in Python 3.9
def remove_prefix(text: str, prefix: str) -> str:
if text.startswith(prefix):
return text[len(prefix):]
return text
2024-02-12 22:53:00 -08:00
ASYNC_REQUEST_FUNCS = {
"tgi": async_request_tgi,
"vllm": async_request_openai_completions,
"lmdeploy": async_request_openai_completions,
2024-02-12 22:53:00 -08:00
"deepspeed-mii": async_request_deepspeed_mii,
"openai": async_request_openai_completions,
"openai-chat": async_request_openai_chat_completions,
2024-02-12 22:53:00 -08:00
"tensorrt-llm": async_request_trt_llm,
}