2024-02-12 22:53:00 -08:00
|
|
|
import json
|
|
|
|
import os
|
|
|
|
import time
|
|
|
|
from dataclasses import dataclass
|
|
|
|
from typing import Optional
|
|
|
|
|
|
|
|
import aiohttp
|
|
|
|
from tqdm.asyncio import tqdm
|
|
|
|
|
|
|
|
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class RequestFuncInput:
|
|
|
|
prompt: str
|
|
|
|
api_url: str
|
|
|
|
prompt_len: int
|
|
|
|
output_len: int
|
|
|
|
model: str
|
|
|
|
best_of: int = 1
|
|
|
|
use_beam_search: bool = False
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class RequestFuncOutput:
|
|
|
|
generated_text: str = ""
|
|
|
|
success: bool = False
|
|
|
|
latency: float = 0
|
|
|
|
ttft: float = 0
|
|
|
|
prompt_len: int = 0
|
|
|
|
|
|
|
|
|
|
|
|
async def async_request_tgi(
|
|
|
|
request_func_input: RequestFuncInput,
|
|
|
|
pbar: Optional[tqdm] = None,
|
|
|
|
) -> RequestFuncOutput:
|
|
|
|
api_url = request_func_input.api_url
|
|
|
|
assert api_url.endswith("generate_stream")
|
|
|
|
|
|
|
|
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
|
|
|
assert not request_func_input.use_beam_search
|
|
|
|
params = {
|
|
|
|
"best_of": request_func_input.best_of,
|
|
|
|
"max_new_tokens": request_func_input.output_len,
|
|
|
|
"do_sample": True,
|
|
|
|
"temperature": 0.01, # TGI does not accept 0.0 temperature.
|
|
|
|
"top_p": 0.99, # TGI does not accept 1.0 top_p.
|
|
|
|
}
|
|
|
|
payload = {
|
|
|
|
"inputs": request_func_input.prompt,
|
|
|
|
"parameters": params,
|
|
|
|
}
|
|
|
|
output = RequestFuncOutput()
|
|
|
|
output.prompt_len = request_func_input.prompt_len
|
|
|
|
|
|
|
|
ttft = 0
|
|
|
|
st = time.perf_counter()
|
|
|
|
try:
|
|
|
|
async with session.post(url=api_url, json=payload) as response:
|
|
|
|
if response.status == 200:
|
|
|
|
async for data in response.content.iter_any():
|
|
|
|
if ttft == 0:
|
|
|
|
ttft = time.perf_counter() - st
|
|
|
|
output.ttft = ttft
|
|
|
|
output.latency = time.perf_counter() - st
|
|
|
|
|
2024-03-16 06:01:30 +02:00
|
|
|
body = remove_prefix(data.decode("utf-8"), "data:")
|
2024-02-12 22:53:00 -08:00
|
|
|
output.generated_text = json.loads(body)["generated_text"]
|
|
|
|
output.success = True
|
|
|
|
else:
|
|
|
|
output.success = False
|
|
|
|
except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
|
|
|
|
output.success = False
|
|
|
|
|
|
|
|
if pbar:
|
|
|
|
pbar.update(1)
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
|
|
async def async_request_vllm(
|
|
|
|
request_func_input: RequestFuncInput,
|
|
|
|
pbar: Optional[tqdm] = None,
|
|
|
|
) -> RequestFuncOutput:
|
|
|
|
api_url = request_func_input.api_url
|
|
|
|
assert api_url.endswith("generate")
|
|
|
|
|
|
|
|
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
|
|
|
payload = {
|
|
|
|
"prompt": request_func_input.prompt,
|
|
|
|
"n": 1,
|
|
|
|
"best_of": request_func_input.best_of,
|
|
|
|
"use_beam_search": request_func_input.use_beam_search,
|
|
|
|
"temperature": 0.0 if request_func_input.use_beam_search else 1.0,
|
|
|
|
"top_p": 1.0,
|
|
|
|
"max_tokens": request_func_input.output_len,
|
|
|
|
"ignore_eos": True,
|
|
|
|
"stream": True,
|
|
|
|
}
|
|
|
|
output = RequestFuncOutput()
|
|
|
|
output.prompt_len = request_func_input.prompt_len
|
|
|
|
|
|
|
|
ttft = 0
|
|
|
|
st = time.perf_counter()
|
|
|
|
try:
|
|
|
|
async with session.post(url=api_url, json=payload) as response:
|
|
|
|
if response.status == 200:
|
|
|
|
async for data in response.content.iter_any():
|
|
|
|
if ttft == 0:
|
|
|
|
ttft = time.perf_counter() - st
|
|
|
|
output.ttft = ttft
|
|
|
|
output.latency = time.perf_counter() - st
|
|
|
|
|
2024-03-16 00:36:29 -07:00
|
|
|
# When streaming, '\0' is appended to the end of response.
|
2024-02-12 22:53:00 -08:00
|
|
|
body = data.decode("utf-8").strip("\0")
|
|
|
|
output.generated_text = json.loads(
|
|
|
|
body)["text"][0][len(request_func_input.prompt):]
|
|
|
|
output.success = True
|
|
|
|
|
|
|
|
else:
|
|
|
|
output.success = False
|
|
|
|
except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
|
|
|
|
output.success = False
|
|
|
|
|
|
|
|
if pbar:
|
|
|
|
pbar.update(1)
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
|
|
async def async_request_trt_llm(
|
|
|
|
request_func_input: RequestFuncInput,
|
|
|
|
pbar: Optional[tqdm] = None,
|
|
|
|
) -> RequestFuncOutput:
|
|
|
|
api_url = request_func_input.api_url
|
|
|
|
assert api_url.endswith("generate_stream")
|
|
|
|
|
|
|
|
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
|
|
|
assert not request_func_input.use_beam_search
|
|
|
|
assert request_func_input.best_of == 1
|
|
|
|
payload = {
|
|
|
|
"accumulate_tokens": True,
|
|
|
|
"text_input": request_func_input.prompt,
|
|
|
|
"temperature": 0.0,
|
|
|
|
"top_p": 1.0,
|
|
|
|
"max_tokens": request_func_input.output_len,
|
|
|
|
"stream": True,
|
|
|
|
}
|
|
|
|
output = RequestFuncOutput()
|
|
|
|
output.prompt_len = request_func_input.prompt_len
|
|
|
|
ttft = 0
|
|
|
|
|
|
|
|
st = time.perf_counter()
|
|
|
|
try:
|
|
|
|
async with session.post(url=api_url, json=payload) as resp:
|
|
|
|
if resp.status == 200:
|
|
|
|
async for data in resp.content.iter_any():
|
|
|
|
if ttft == 0:
|
|
|
|
ttft = time.perf_counter() - st
|
|
|
|
output.ttft = ttft
|
|
|
|
output.latency = time.perf_counter() - st
|
|
|
|
|
2024-03-16 06:01:30 +02:00
|
|
|
body = remove_prefix(data.decode("utf-8"), "data:")
|
2024-02-12 22:53:00 -08:00
|
|
|
output.generated_text = json.loads(body)["text_output"]
|
|
|
|
output.success = True
|
|
|
|
|
|
|
|
else:
|
|
|
|
output.success = False
|
|
|
|
except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
|
|
|
|
output.success = False
|
|
|
|
|
|
|
|
if pbar:
|
|
|
|
pbar.update(1)
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
|
|
async def async_request_deepspeed_mii(
|
|
|
|
request_func_input: RequestFuncInput,
|
|
|
|
pbar: Optional[tqdm] = None,
|
|
|
|
) -> RequestFuncOutput:
|
|
|
|
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
|
|
|
assert request_func_input.best_of == 1
|
|
|
|
assert not request_func_input.use_beam_search
|
|
|
|
|
|
|
|
payload = {
|
|
|
|
"prompts": request_func_input.prompt,
|
|
|
|
"max_new_tokens": request_func_input.output_len,
|
|
|
|
"ignore_eos": True,
|
|
|
|
"do_sample": True,
|
|
|
|
"temperature":
|
|
|
|
0.01, # deepspeed-mii does not accept 0.0 temperature.
|
|
|
|
"top_p": 1.0,
|
|
|
|
}
|
|
|
|
output = RequestFuncOutput()
|
|
|
|
output.prompt_len = request_func_input.prompt_len
|
|
|
|
|
2024-03-16 00:36:29 -07:00
|
|
|
# DeepSpeed-MII doesn't support streaming as of Jan 28 2024,
|
|
|
|
# will use 0 as placeholder.
|
2024-02-12 22:53:00 -08:00
|
|
|
# https://github.com/microsoft/DeepSpeed-MII/pull/311
|
|
|
|
output.ttft = 0
|
|
|
|
|
|
|
|
st = time.perf_counter()
|
|
|
|
try:
|
|
|
|
async with session.post(url=request_func_input.api_url,
|
|
|
|
json=payload) as resp:
|
|
|
|
if resp.status == 200:
|
|
|
|
parsed_resp = await resp.json()
|
|
|
|
output.latency = time.perf_counter() - st
|
|
|
|
output.generated_text = parsed_resp[0]["generated_text"]
|
|
|
|
output.success = True
|
|
|
|
else:
|
|
|
|
output.success = False
|
|
|
|
except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
|
|
|
|
output.success = False
|
|
|
|
|
|
|
|
if pbar:
|
|
|
|
pbar.update(1)
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
|
|
async def async_request_openai_completions(
|
|
|
|
request_func_input: RequestFuncInput,
|
|
|
|
pbar: Optional[tqdm] = None,
|
|
|
|
) -> RequestFuncOutput:
|
|
|
|
api_url = request_func_input.api_url
|
|
|
|
assert api_url.endswith("v1/completions")
|
|
|
|
|
|
|
|
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
|
|
|
assert not request_func_input.use_beam_search
|
|
|
|
payload = {
|
|
|
|
"model": request_func_input.model,
|
|
|
|
"prompt": request_func_input.prompt,
|
|
|
|
"temperature": 0.0,
|
|
|
|
"best_of": request_func_input.best_of,
|
|
|
|
"max_tokens": request_func_input.output_len,
|
|
|
|
"stream": True,
|
|
|
|
}
|
|
|
|
headers = {
|
|
|
|
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
|
|
|
|
}
|
|
|
|
|
|
|
|
output = RequestFuncOutput()
|
|
|
|
output.prompt_len = request_func_input.prompt_len
|
|
|
|
|
|
|
|
generated_text = ""
|
|
|
|
ttft = 0
|
|
|
|
st = time.perf_counter()
|
|
|
|
try:
|
|
|
|
async with session.post(url=api_url, json=payload,
|
|
|
|
headers=headers) as response:
|
|
|
|
if response.status == 200:
|
|
|
|
async for chunk in response.content:
|
|
|
|
if ttft == 0:
|
|
|
|
ttft = time.perf_counter() - st
|
|
|
|
output.ttft = ttft
|
|
|
|
|
|
|
|
chunk = chunk.strip()
|
|
|
|
if not chunk:
|
|
|
|
continue
|
|
|
|
|
2024-03-16 06:01:30 +02:00
|
|
|
chunk = remove_prefix(chunk.decode("utf-8"), "data: ")
|
2024-02-12 22:53:00 -08:00
|
|
|
if chunk == "[DONE]":
|
|
|
|
latency = time.perf_counter() - st
|
|
|
|
else:
|
|
|
|
body = json.loads(chunk)
|
|
|
|
generated_text += body["choices"][0]["text"]
|
|
|
|
|
|
|
|
output.generated_text = generated_text
|
|
|
|
output.success = True
|
|
|
|
output.latency = latency
|
|
|
|
else:
|
|
|
|
output.success = False
|
|
|
|
except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
|
|
|
|
output.success = False
|
|
|
|
|
|
|
|
if pbar:
|
|
|
|
pbar.update(1)
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
2024-03-04 18:51:56 -05:00
|
|
|
async def async_request_openai_chat_completions(
|
|
|
|
request_func_input: RequestFuncInput,
|
|
|
|
pbar: Optional[tqdm] = None,
|
|
|
|
) -> RequestFuncOutput:
|
|
|
|
api_url = request_func_input.api_url
|
|
|
|
assert api_url.endswith(
|
|
|
|
"v1/chat/completions"
|
|
|
|
), "OpenAI Chat API URL must end with 'v1/chat/completions'."
|
|
|
|
|
|
|
|
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
|
|
|
assert not request_func_input.use_beam_search
|
|
|
|
payload = {
|
|
|
|
"model": request_func_input.model,
|
|
|
|
"messages": [
|
|
|
|
{
|
|
|
|
"role": "user",
|
|
|
|
"content": request_func_input.prompt,
|
|
|
|
},
|
|
|
|
],
|
|
|
|
"temperature": 0.0,
|
|
|
|
"max_tokens": request_func_input.output_len,
|
|
|
|
"stream": True,
|
|
|
|
}
|
|
|
|
headers = {
|
|
|
|
"Content-Type": "application/json",
|
|
|
|
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
|
|
|
|
}
|
|
|
|
|
|
|
|
output = RequestFuncOutput()
|
|
|
|
output.prompt_len = request_func_input.prompt_len
|
|
|
|
|
|
|
|
generated_text = ""
|
|
|
|
ttft = 0
|
|
|
|
st = time.perf_counter()
|
|
|
|
try:
|
|
|
|
async with session.post(url=api_url, json=payload,
|
|
|
|
headers=headers) as response:
|
|
|
|
if response.status == 200:
|
|
|
|
async for chunk in response.content:
|
|
|
|
if ttft == 0:
|
|
|
|
ttft = time.perf_counter() - st
|
|
|
|
output.ttft = ttft
|
|
|
|
|
|
|
|
chunk = chunk.strip()
|
|
|
|
if not chunk:
|
|
|
|
continue
|
|
|
|
|
2024-03-16 06:01:30 +02:00
|
|
|
chunk = remove_prefix(chunk.decode("utf-8"), "data: ")
|
2024-03-04 18:51:56 -05:00
|
|
|
if chunk == "[DONE]":
|
|
|
|
latency = time.perf_counter() - st
|
|
|
|
else:
|
|
|
|
body = json.loads(chunk)
|
|
|
|
if "content" in body["choices"][0]["delta"]:
|
|
|
|
generated_text += body["choices"][0]["delta"][
|
|
|
|
"content"]
|
|
|
|
|
|
|
|
output.generated_text = generated_text
|
|
|
|
output.success = True
|
|
|
|
output.latency = latency
|
|
|
|
else:
|
|
|
|
output.success = False
|
|
|
|
except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
|
|
|
|
output.success = False
|
|
|
|
|
|
|
|
if pbar:
|
|
|
|
pbar.update(1)
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
2024-03-16 00:36:29 -07:00
|
|
|
# Since vllm must support Python 3.8, we can't use str.removeprefix(prefix)
|
|
|
|
# introduced in Python 3.9
|
2024-03-16 06:01:30 +02:00
|
|
|
def remove_prefix(text: str, prefix: str) -> str:
|
|
|
|
if text.startswith(prefix):
|
|
|
|
return text[len(prefix):]
|
|
|
|
return text
|
|
|
|
|
|
|
|
|
2024-02-12 22:53:00 -08:00
|
|
|
ASYNC_REQUEST_FUNCS = {
|
|
|
|
"tgi": async_request_tgi,
|
|
|
|
"vllm": async_request_vllm,
|
|
|
|
"deepspeed-mii": async_request_deepspeed_mii,
|
|
|
|
"openai": async_request_openai_completions,
|
2024-03-04 18:51:56 -05:00
|
|
|
"openai-chat": async_request_openai_chat_completions,
|
2024-02-12 22:53:00 -08:00
|
|
|
"tensorrt-llm": async_request_trt_llm,
|
|
|
|
}
|