feat(benchmarks): Add Prefix Caching Benchmark to Serving Benchmark (#3277)
This commit is contained in:
parent
1956931436
commit
45b6ef6513
@ -23,8 +23,9 @@ wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/r
|
||||
# wait for server to start, timeout after 600 seconds
|
||||
timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1
|
||||
python3 benchmarks/benchmark_serving.py \
|
||||
--backend openai \
|
||||
--dataset ./ShareGPT_V3_unfiltered_cleaned_split.json \
|
||||
--backend vllm \
|
||||
--dataset-name sharegpt \
|
||||
--dataset-path ./ShareGPT_V3_unfiltered_cleaned_split.json \
|
||||
--model meta-llama/Llama-2-7b-chat-hf \
|
||||
--num-prompts 20 \
|
||||
--endpoint /v1/completions \
|
||||
|
@ -1,8 +1,10 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
import traceback
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional
|
||||
|
||||
import aiohttp
|
||||
from tqdm.asyncio import tqdm
|
||||
@ -26,8 +28,11 @@ class RequestFuncOutput:
|
||||
generated_text: str = ""
|
||||
success: bool = False
|
||||
latency: float = 0
|
||||
ttft: float = 0
|
||||
ttft: float = 0 # Time to first token
|
||||
itl: List[float] = field(
|
||||
default_factory=list) # List of inter-token latencies
|
||||
prompt_len: int = 0
|
||||
error: str = ""
|
||||
|
||||
|
||||
async def async_request_tgi(
|
||||
@ -55,71 +60,38 @@ async def async_request_tgi(
|
||||
|
||||
ttft = 0
|
||||
st = time.perf_counter()
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
async with session.post(url=api_url, json=payload) as response:
|
||||
if response.status == 200:
|
||||
async for data in response.content.iter_any():
|
||||
async for chunk in response.content:
|
||||
chunk = chunk.strip()
|
||||
if not chunk:
|
||||
continue
|
||||
|
||||
chunk = remove_prefix(chunk.decode("utf-8"), "data:")
|
||||
|
||||
data = json.loads(chunk)
|
||||
timestamp = time.perf_counter()
|
||||
# First token
|
||||
if ttft == 0:
|
||||
ttft = time.perf_counter() - st
|
||||
output.ttft = ttft
|
||||
output.latency = time.perf_counter() - st
|
||||
|
||||
body = remove_prefix(data.decode("utf-8"), "data:")
|
||||
output.generated_text = json.loads(body)["generated_text"]
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(timestamp -
|
||||
most_recent_timestamp)
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
|
||||
output.latency = most_recent_timestamp - st
|
||||
output.success = True
|
||||
else:
|
||||
output.success = False
|
||||
except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
|
||||
output.success = False
|
||||
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
return output
|
||||
|
||||
|
||||
async def async_request_vllm(
|
||||
request_func_input: RequestFuncInput,
|
||||
pbar: Optional[tqdm] = None,
|
||||
) -> RequestFuncOutput:
|
||||
api_url = request_func_input.api_url
|
||||
assert api_url.endswith("generate")
|
||||
|
||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||
payload = {
|
||||
"prompt": request_func_input.prompt,
|
||||
"n": 1,
|
||||
"best_of": request_func_input.best_of,
|
||||
"use_beam_search": request_func_input.use_beam_search,
|
||||
"temperature": 0.0 if request_func_input.use_beam_search else 1.0,
|
||||
"top_p": 1.0,
|
||||
"max_tokens": request_func_input.output_len,
|
||||
"ignore_eos": True,
|
||||
"stream": True,
|
||||
}
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
|
||||
ttft = 0
|
||||
st = time.perf_counter()
|
||||
try:
|
||||
async with session.post(url=api_url, json=payload) as response:
|
||||
if response.status == 200:
|
||||
async for data in response.content.iter_any():
|
||||
if ttft == 0:
|
||||
ttft = time.perf_counter() - st
|
||||
output.ttft = ttft
|
||||
output.latency = time.perf_counter() - st
|
||||
|
||||
# When streaming, '\0' is appended to the end of response.
|
||||
body = data.decode("utf-8").strip("\0")
|
||||
output.generated_text = json.loads(
|
||||
body)["text"][0][len(request_func_input.prompt):]
|
||||
output.success = True
|
||||
|
||||
else:
|
||||
output.success = False
|
||||
except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
|
||||
output.generated_text = data["generated_text"]
|
||||
except Exception:
|
||||
output.success = False
|
||||
exc_info = sys.exc_info()
|
||||
output.error = "".join(traceback.format_exception(*exc_info))
|
||||
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
@ -146,26 +118,45 @@ async def async_request_trt_llm(
|
||||
}
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
ttft = 0
|
||||
|
||||
ttft = 0
|
||||
st = time.perf_counter()
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
async with session.post(url=api_url, json=payload) as resp:
|
||||
if resp.status == 200:
|
||||
async for data in resp.content.iter_any():
|
||||
async with session.post(url=api_url, json=payload) as response:
|
||||
if response.status == 200:
|
||||
async for chunk in response.content:
|
||||
chunk = chunk.strip()
|
||||
if not chunk:
|
||||
continue
|
||||
|
||||
chunk = remove_prefix(chunk.decode("utf-8"), "data:")
|
||||
|
||||
data = json.loads(chunk)
|
||||
timestamp = time.perf_counter()
|
||||
# First token
|
||||
if ttft == 0:
|
||||
ttft = time.perf_counter() - st
|
||||
output.ttft = ttft
|
||||
output.latency = time.perf_counter() - st
|
||||
|
||||
body = remove_prefix(data.decode("utf-8"), "data:")
|
||||
output.generated_text = json.loads(body)["text_output"]
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(timestamp -
|
||||
most_recent_timestamp)
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
|
||||
output.latency = most_recent_timestamp - st
|
||||
output.generated_text = json.loads(data)["text_output"]
|
||||
output.success = True
|
||||
|
||||
else:
|
||||
output.error = response.reason
|
||||
output.success = False
|
||||
except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
|
||||
except Exception:
|
||||
output.success = False
|
||||
exc_info = sys.exc_info()
|
||||
output.error = "".join(traceback.format_exception(*exc_info))
|
||||
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
@ -181,35 +172,35 @@ async def async_request_deepspeed_mii(
|
||||
assert not request_func_input.use_beam_search
|
||||
|
||||
payload = {
|
||||
"prompts": request_func_input.prompt,
|
||||
"max_new_tokens": request_func_input.output_len,
|
||||
"ignore_eos": True,
|
||||
"do_sample": True,
|
||||
"temperature":
|
||||
0.01, # deepspeed-mii does not accept 0.0 temperature.
|
||||
"prompt": request_func_input.prompt,
|
||||
"max_tokens": request_func_input.output_len,
|
||||
"temperature": 0.01, # deepspeed-mii does not accept 0.0 temp.
|
||||
"top_p": 1.0,
|
||||
}
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
|
||||
# DeepSpeed-MII doesn't support streaming as of Jan 28 2024,
|
||||
# NOTE: DeepSpeed-MII doesn't support streaming as of Jan 28 2024,
|
||||
# will use 0 as placeholder.
|
||||
# https://github.com/microsoft/DeepSpeed-MII/pull/311
|
||||
# See https://github.com/microsoft/DeepSpeed-MII/pull/311
|
||||
output.ttft = 0
|
||||
|
||||
st = time.perf_counter()
|
||||
try:
|
||||
async with session.post(url=request_func_input.api_url,
|
||||
json=payload) as resp:
|
||||
if resp.status == 200:
|
||||
parsed_resp = await resp.json()
|
||||
json=payload) as response:
|
||||
if response.status == 200:
|
||||
parsed_resp = await response.json()
|
||||
output.latency = time.perf_counter() - st
|
||||
output.generated_text = parsed_resp[0]["generated_text"]
|
||||
output.generated_text = parsed_resp["text"][0]
|
||||
output.success = True
|
||||
else:
|
||||
output.error = response.reason
|
||||
output.success = False
|
||||
except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
|
||||
except Exception:
|
||||
output.success = False
|
||||
exc_info = sys.exc_info()
|
||||
output.error = "".join(traceback.format_exception(*exc_info))
|
||||
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
@ -221,7 +212,9 @@ async def async_request_openai_completions(
|
||||
pbar: Optional[tqdm] = None,
|
||||
) -> RequestFuncOutput:
|
||||
api_url = request_func_input.api_url
|
||||
assert api_url.endswith("v1/completions")
|
||||
assert api_url.endswith(
|
||||
"v1/completions"
|
||||
), "OpenAI Completions API URL must end with 'v1/completions'."
|
||||
|
||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||
assert not request_func_input.use_beam_search
|
||||
@ -243,15 +236,12 @@ async def async_request_openai_completions(
|
||||
generated_text = ""
|
||||
ttft = 0
|
||||
st = time.perf_counter()
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
async with session.post(url=api_url, json=payload,
|
||||
headers=headers) as response:
|
||||
if response.status == 200:
|
||||
async for chunk in response.content:
|
||||
if ttft == 0:
|
||||
ttft = time.perf_counter() - st
|
||||
output.ttft = ttft
|
||||
|
||||
chunk = chunk.strip()
|
||||
if not chunk:
|
||||
continue
|
||||
@ -260,16 +250,33 @@ async def async_request_openai_completions(
|
||||
if chunk == "[DONE]":
|
||||
latency = time.perf_counter() - st
|
||||
else:
|
||||
body = json.loads(chunk)
|
||||
generated_text += body["choices"][0]["text"]
|
||||
data = json.loads(chunk)
|
||||
|
||||
if data["choices"][0]["text"]:
|
||||
timestamp = time.perf_counter()
|
||||
# First token
|
||||
if ttft == 0:
|
||||
ttft = time.perf_counter() - st
|
||||
output.ttft = ttft
|
||||
|
||||
# Decoding phase
|
||||
# NOTE: Some completion API might have a last
|
||||
# usage summary response without a token so we
|
||||
# do not want to include as inter-token-latency
|
||||
elif data.get("usage", None) is None:
|
||||
output.itl.append(timestamp -
|
||||
most_recent_timestamp)
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
generated_text += data["choices"][0]["text"]
|
||||
|
||||
output.generated_text = generated_text
|
||||
output.success = True
|
||||
output.latency = latency
|
||||
else:
|
||||
output.success = False
|
||||
except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
|
||||
except Exception:
|
||||
output.success = False
|
||||
exc_info = sys.exc_info()
|
||||
output.error = "".join(traceback.format_exception(*exc_info))
|
||||
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
@ -283,7 +290,7 @@ async def async_request_openai_chat_completions(
|
||||
api_url = request_func_input.api_url
|
||||
assert api_url.endswith(
|
||||
"v1/chat/completions"
|
||||
), "OpenAI Chat API URL must end with 'v1/chat/completions'."
|
||||
), "OpenAI Chat Completions API URL must end with 'v1/chat/completions'."
|
||||
|
||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||
assert not request_func_input.use_beam_search
|
||||
@ -301,7 +308,7 @@ async def async_request_openai_chat_completions(
|
||||
}
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
}
|
||||
|
||||
output = RequestFuncOutput()
|
||||
@ -310,15 +317,12 @@ async def async_request_openai_chat_completions(
|
||||
generated_text = ""
|
||||
ttft = 0
|
||||
st = time.perf_counter()
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
async with session.post(url=api_url, json=payload,
|
||||
headers=headers) as response:
|
||||
if response.status == 200:
|
||||
async for chunk in response.content:
|
||||
if ttft == 0:
|
||||
ttft = time.perf_counter() - st
|
||||
output.ttft = ttft
|
||||
|
||||
chunk = chunk.strip()
|
||||
if not chunk:
|
||||
continue
|
||||
@ -327,18 +331,35 @@ async def async_request_openai_chat_completions(
|
||||
if chunk == "[DONE]":
|
||||
latency = time.perf_counter() - st
|
||||
else:
|
||||
body = json.loads(chunk)
|
||||
if "content" in body["choices"][0]["delta"]:
|
||||
generated_text += body["choices"][0]["delta"][
|
||||
timestamp = time.perf_counter()
|
||||
data = json.loads(chunk)
|
||||
|
||||
if "content" in data["choices"][0]["delta"]:
|
||||
# First token
|
||||
if ttft == 0:
|
||||
ttft = time.perf_counter() - st
|
||||
output.ttft = ttft
|
||||
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(timestamp -
|
||||
most_recent_timestamp)
|
||||
|
||||
generated_text += data["choices"][0]["delta"][
|
||||
"content"]
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
|
||||
output.generated_text = generated_text
|
||||
output.success = True
|
||||
output.latency = latency
|
||||
else:
|
||||
output.error = response.reason
|
||||
output.success = False
|
||||
except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
|
||||
except Exception:
|
||||
output.success = False
|
||||
exc_info = sys.exc_info()
|
||||
output.error = "".join(traceback.format_exception(*exc_info))
|
||||
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
@ -355,7 +376,8 @@ def remove_prefix(text: str, prefix: str) -> str:
|
||||
|
||||
ASYNC_REQUEST_FUNCS = {
|
||||
"tgi": async_request_tgi,
|
||||
"vllm": async_request_vllm,
|
||||
"vllm": async_request_openai_completions,
|
||||
"lmdeploy": async_request_openai_completions,
|
||||
"deepspeed-mii": async_request_deepspeed_mii,
|
||||
"openai": async_request_openai_completions,
|
||||
"openai-chat": async_request_openai_chat_completions,
|
||||
|
@ -1,8 +1,8 @@
|
||||
"""Benchmark online serving throughput.
|
||||
|
||||
On the server side, run one of the following commands:
|
||||
(vLLM backend)
|
||||
python -m vllm.entrypoints.api_server \
|
||||
vLLM OpenAI API server
|
||||
python -m vllm.entrypoints.openai.api_server \
|
||||
--model <your_model> --swap-space 16 \
|
||||
--disable-log-requests
|
||||
|
||||
@ -12,14 +12,19 @@ On the server side, run one of the following commands:
|
||||
On the client side, run:
|
||||
python benchmarks/benchmark_serving.py \
|
||||
--backend <backend> \
|
||||
--model <your_model> --dataset <target_dataset> \
|
||||
--request-rate <request_rate>
|
||||
--model <your_model> \
|
||||
--dataset-name sharegpt \
|
||||
--dataset-path <path to dataset> \
|
||||
--request-rate <request_rate> \ # By default <request_rate> is inf
|
||||
--num-prompts <num_prompts> # By default <num_prompts> is 1000
|
||||
"""
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import AsyncGenerator, List, Tuple
|
||||
@ -49,7 +54,7 @@ class BenchmarkMetrics:
|
||||
p99_tpot_ms: float
|
||||
|
||||
|
||||
def sample_requests(
|
||||
def sample_sharegpt_requests(
|
||||
dataset_path: str,
|
||||
num_requests: int,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
@ -97,6 +102,73 @@ def sample_requests(
|
||||
return sampled_requests
|
||||
|
||||
|
||||
def sample_sonnet_requests(
|
||||
dataset_path: str,
|
||||
num_requests: int,
|
||||
input_len: int,
|
||||
output_len: int,
|
||||
prefix_len: int,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
) -> List[Tuple[str, str, int, int]]:
|
||||
assert input_len > prefix_len, "input_len must be greater than prefix_len."
|
||||
|
||||
# Load the dataset.
|
||||
with open(dataset_path) as f:
|
||||
poem_lines = f.readlines()
|
||||
|
||||
# Tokenize the poem lines.
|
||||
poem_token_ids = tokenizer(poem_lines).input_ids
|
||||
average_poem_len = sum(
|
||||
len(token_ids) for token_ids in poem_token_ids) / len(poem_token_ids)
|
||||
|
||||
# Base prefix for all requests.
|
||||
base_prompt = "Pick as many lines as you can from these poem lines:\n"
|
||||
base_message = [{
|
||||
"role": "user",
|
||||
"content": base_prompt,
|
||||
}]
|
||||
base_prompt_formatted = tokenizer.apply_chat_template(
|
||||
base_message, add_generation_prompt=True, tokenize=False)
|
||||
base_prompt_offset = len(tokenizer(base_prompt_formatted).input_ids)
|
||||
|
||||
assert (input_len > base_prompt_offset
|
||||
), f"Please set 'args.input-len' higher than {base_prompt_offset}."
|
||||
num_input_lines = round(
|
||||
(input_len - base_prompt_offset) / average_poem_len)
|
||||
|
||||
# First approximately `prefix_len` number of tokens in the
|
||||
# prompt are fixed poem lines.
|
||||
assert (
|
||||
prefix_len > base_prompt_offset
|
||||
), f"Please set 'args.prefix-len' higher than {base_prompt_offset}."
|
||||
|
||||
num_prefix_lines = round(
|
||||
(prefix_len - base_prompt_offset) / average_poem_len)
|
||||
prefix_lines = poem_lines[:num_prefix_lines]
|
||||
|
||||
# Sample the rest of lines per request.
|
||||
sampled_requests: List[Tuple[str, int, int]] = []
|
||||
for _ in range(num_requests):
|
||||
sampled_lines = "".join(
|
||||
prefix_lines +
|
||||
random.sample(poem_lines, num_input_lines - num_prefix_lines))
|
||||
|
||||
prompt = f"{base_prompt}{sampled_lines}"
|
||||
message = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
},
|
||||
]
|
||||
prompt_formatted = tokenizer.apply_chat_template(
|
||||
message, add_generation_prompt=True, tokenize=False)
|
||||
prompt_len = len(tokenizer(prompt_formatted).input_ids)
|
||||
sampled_requests.append(
|
||||
(prompt, prompt_formatted, prompt_len, output_len))
|
||||
|
||||
return sampled_requests
|
||||
|
||||
|
||||
async def get_request(
|
||||
input_requests: List[Tuple[str, int, int]],
|
||||
request_rate: float,
|
||||
@ -119,37 +191,42 @@ def calculate_metrics(
|
||||
outputs: List[RequestFuncOutput],
|
||||
dur_s: float,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
) -> BenchmarkMetrics:
|
||||
total_output = 0
|
||||
) -> Tuple[BenchmarkMetrics, List[int]]:
|
||||
actual_output_lens = []
|
||||
total_input = 0
|
||||
completed = 0
|
||||
per_token_latencies = []
|
||||
tpots = []
|
||||
ttfts = []
|
||||
for i in range(len(outputs)):
|
||||
if outputs[i].success:
|
||||
output_len = len(tokenizer.encode(outputs[i].generated_text))
|
||||
total_output += output_len
|
||||
output_len = len(tokenizer(outputs[i].generated_text).input_ids)
|
||||
actual_output_lens.append(output_len)
|
||||
total_input += input_requests[i][1]
|
||||
per_token_latencies.append(outputs[i].latency / output_len)
|
||||
if output_len > 1:
|
||||
tpots.append(
|
||||
(outputs[i].latency - outputs[i].ttft) / (output_len - 1))
|
||||
ttfts.append(outputs[i].ttft)
|
||||
completed += 1
|
||||
else:
|
||||
actual_output_lens.append(0)
|
||||
|
||||
metrics = BenchmarkMetrics(
|
||||
completed=completed,
|
||||
total_input=total_input,
|
||||
total_output=total_output,
|
||||
total_output=sum(actual_output_lens),
|
||||
request_throughput=completed / dur_s,
|
||||
input_throughput=total_input / dur_s,
|
||||
output_throughput=total_output / dur_s,
|
||||
mean_ttft_ms=np.mean(ttfts) * 1000,
|
||||
median_ttft_ms=np.median(ttfts) * 1000,
|
||||
p99_ttft_ms=np.percentile(ttfts, 99) * 1000,
|
||||
mean_tpot_ms=np.mean(per_token_latencies) * 1000,
|
||||
median_tpot_ms=np.median(per_token_latencies) * 1000,
|
||||
p99_tpot_ms=np.percentile(per_token_latencies, 99) * 1000,
|
||||
output_throughput=sum(actual_output_lens) / dur_s,
|
||||
mean_ttft_ms=np.mean(ttfts or 0) *
|
||||
1000, # ttfts is empty if streaming is not supported by backend
|
||||
median_ttft_ms=np.median(ttfts or 0) * 1000,
|
||||
p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000,
|
||||
mean_tpot_ms=np.mean(tpots) * 1000,
|
||||
median_tpot_ms=np.median(tpots) * 1000,
|
||||
p99_tpot_ms=np.percentile(tpots, 99) * 1000,
|
||||
)
|
||||
|
||||
return metrics
|
||||
return metrics, actual_output_lens
|
||||
|
||||
|
||||
async def benchmark(
|
||||
@ -189,40 +266,53 @@ async def benchmark(
|
||||
asyncio.create_task(
|
||||
request_func(request_func_input=request_func_input,
|
||||
pbar=pbar)))
|
||||
outputs = await asyncio.gather(*tasks)
|
||||
outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)
|
||||
|
||||
if not disable_tqdm:
|
||||
pbar.close()
|
||||
|
||||
benchmark_duration = time.perf_counter() - benchmark_start_time
|
||||
|
||||
metrics = calculate_metrics(
|
||||
metrics, actual_output_lens = calculate_metrics(
|
||||
input_requests=input_requests,
|
||||
outputs=outputs,
|
||||
dur_s=benchmark_duration,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
print(f"Successful requests: {metrics.completed}")
|
||||
print(f"Benchmark duration: {benchmark_duration:2f} s")
|
||||
print(f"Total input tokens: {metrics.total_input}")
|
||||
print(f"Total generated tokens: {metrics.total_output}")
|
||||
print(f"Request throughput: {metrics.request_throughput:.2f} requests/s")
|
||||
print(f"Input token throughput: {metrics.input_throughput:.2f} tokens/s")
|
||||
print(f"Output token throughput: {metrics.output_throughput:.2f} tokens/s")
|
||||
print(f"Mean TTFT: {metrics.mean_ttft_ms:.2f} ms")
|
||||
print(f"Median TTFT: {metrics.median_ttft_ms:.2f} ms")
|
||||
print(f"P99 TTFT: {metrics.p99_ttft_ms:.2f} ms")
|
||||
print(f"Mean TPOT: {metrics.mean_tpot_ms:.2f} ms")
|
||||
print(f"Median TPOT: {metrics.median_tpot_ms:.2f} ms")
|
||||
print(f"P99 TPOT: {metrics.p99_tpot_ms:.2f} ms")
|
||||
print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='='))
|
||||
print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
|
||||
print("{:<40} {:<10.2f}".format("Benchmark duration (s):",
|
||||
benchmark_duration))
|
||||
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
|
||||
print("{:<40} {:<10}".format("Total generated tokens:",
|
||||
metrics.total_output))
|
||||
print("{:<40} {:<10.2f}".format("Request throughput (req/s):",
|
||||
metrics.request_throughput))
|
||||
print("{:<40} {:<10.2f}".format("Input token throughput (tok/s):",
|
||||
metrics.input_throughput))
|
||||
print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):",
|
||||
metrics.output_throughput))
|
||||
print("{s:{c}^{n}}".format(s='Time to First Token', n=50, c='-'))
|
||||
print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms))
|
||||
print("{:<40} {:<10.2f}".format("Median TTFT (ms):",
|
||||
metrics.median_ttft_ms))
|
||||
print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms))
|
||||
print("{s:{c}^{n}}".format(s='Time per Output Token (excl. 1st token)',
|
||||
n=50,
|
||||
c='-'))
|
||||
print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms))
|
||||
print("{:<40} {:<10.2f}".format("Median TPOT (ms):",
|
||||
metrics.median_tpot_ms))
|
||||
print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms))
|
||||
print("=" * 50)
|
||||
|
||||
result = {
|
||||
"duration": benchmark_duration,
|
||||
"completed": metrics.completed,
|
||||
"total_input_tokens": metrics.total_input,
|
||||
"total_output_tokens": metrics.total_output,
|
||||
"request_inthroughput": metrics.request_throughput,
|
||||
"request_throughput": metrics.request_throughput,
|
||||
"input_throughput": metrics.input_throughput,
|
||||
"output_throughput": metrics.output_throughput,
|
||||
"mean_ttft_ms": metrics.mean_ttft_ms,
|
||||
@ -230,7 +320,13 @@ async def benchmark(
|
||||
"p99_ttft_ms": metrics.p99_ttft_ms,
|
||||
"mean_tpot_ms": metrics.mean_tpot_ms,
|
||||
"median_tpot_ms": metrics.median_tpot_ms,
|
||||
"p99_tpot_ms": metrics.p99_tpot_ms
|
||||
"p99_tpot_ms": metrics.p99_tpot_ms,
|
||||
"input_lens": [output.prompt_len for output in outputs],
|
||||
"output_lens": actual_output_lens,
|
||||
"ttfts": [output.ttft for output in outputs],
|
||||
"itls": [output.itl for output in outputs],
|
||||
"generated_texts": [output.generated_text for output in outputs],
|
||||
"errors": [output.error for output in outputs],
|
||||
}
|
||||
return result
|
||||
|
||||
@ -251,7 +347,58 @@ def main(args: argparse.Namespace):
|
||||
|
||||
tokenizer = get_tokenizer(tokenizer_id,
|
||||
trust_remote_code=args.trust_remote_code)
|
||||
input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
|
||||
|
||||
if args.dataset is not None:
|
||||
warnings.warn(
|
||||
"The '--dataset' argument will be deprecated in the next "
|
||||
"release. Please use '--dataset-name' and "
|
||||
"'--dataset-path' in the future runs.",
|
||||
stacklevel=2)
|
||||
input_requests = sample_sharegpt_requests(
|
||||
dataset_path=args.dataset,
|
||||
num_requests=args.num_prompts,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
elif args.dataset_name == "sharegpt":
|
||||
input_requests = sample_sharegpt_requests(
|
||||
dataset_path=args.dataset_path,
|
||||
num_requests=args.num_prompts,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
elif args.dataset_name == "sonnet":
|
||||
# Do not format the prompt, pass to message directly
|
||||
if args.backend == "openai-chat":
|
||||
input_requests = sample_sonnet_requests(
|
||||
dataset_path=args.dataset_path,
|
||||
num_requests=args.num_prompts,
|
||||
input_len=args.input_len,
|
||||
output_len=args.output_len,
|
||||
prefix_len=args.prefix_len,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
input_requests = [(prompt, prompt_len, output_len)
|
||||
for prompt, prompt_formatted, prompt_len,
|
||||
output_len in input_requests]
|
||||
else:
|
||||
assert (
|
||||
tokenizer.chat_template or tokenizer.default_chat_template
|
||||
), "Tokenizer/model must have chat template for sonnet dataset."
|
||||
input_requests = sample_sonnet_requests(
|
||||
dataset_path=args.dataset_path,
|
||||
num_requests=args.num_prompts,
|
||||
input_len=args.input_len,
|
||||
output_len=args.output_len,
|
||||
prefix_len=args.prefix_len,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
input_requests = [(prompt_formatted, prompt_len, output_len)
|
||||
for prompt, prompt_formatted, prompt_len,
|
||||
output_len in input_requests]
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset: {args.dataset_name}")
|
||||
|
||||
benchmark_result = asyncio.run(
|
||||
benchmark(
|
||||
@ -274,13 +421,23 @@ def main(args: argparse.Namespace):
|
||||
current_dt = datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||
result_json["date"] = current_dt
|
||||
result_json["backend"] = backend
|
||||
result_json["version"] = args.version
|
||||
result_json["model_id"] = model_id
|
||||
result_json["tokenizer_id"] = tokenizer_id
|
||||
result_json["best_of"] = args.best_of
|
||||
result_json["use_beam_search"] = args.use_beam_search
|
||||
result_json["num_prompts"] = args.num_prompts
|
||||
|
||||
# Metadata
|
||||
if args.metadata:
|
||||
for item in args.metadata:
|
||||
if "=" in item:
|
||||
kvstring = item.split("=")
|
||||
result_json[kvstring[0].strip()] = kvstring[1].strip()
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid metadata format. Please use KEY=VALUE format."
|
||||
)
|
||||
|
||||
# Traffic
|
||||
result_json["request_rate"] = (
|
||||
args.request_rate if args.request_rate < float("inf") else "inf")
|
||||
@ -290,9 +447,9 @@ def main(args: argparse.Namespace):
|
||||
|
||||
# Save to file
|
||||
base_model_id = model_id.split("/")[-1]
|
||||
file_name = (
|
||||
f"{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json"
|
||||
)
|
||||
file_name = f"{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json" #noqa
|
||||
if args.result_dir:
|
||||
file_name = os.path.join(args.result_dir, file_name)
|
||||
with open(file_name, "w") as outfile:
|
||||
json.dump(result_json, outfile)
|
||||
|
||||
@ -306,12 +463,6 @@ if __name__ == "__main__":
|
||||
default="vllm",
|
||||
choices=list(ASYNC_REQUEST_FUNCS.keys()),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--version",
|
||||
type=str,
|
||||
default="N/A",
|
||||
help="Version of the serving backend/engine.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base-url",
|
||||
type=str,
|
||||
@ -323,12 +474,26 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--endpoint",
|
||||
type=str,
|
||||
default="/generate",
|
||||
default="/v1/completions",
|
||||
help="API endpoint.",
|
||||
)
|
||||
parser.add_argument("--dataset",
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the ShareGPT dataset, will be deprecated in the "
|
||||
"next release.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-name",
|
||||
type=str,
|
||||
default="sharegpt",
|
||||
choices=["sharegpt", "sonnet"],
|
||||
help="Name of the dataset to benchmark on.",
|
||||
)
|
||||
parser.add_argument("--dataset-path",
|
||||
type=str,
|
||||
required=True,
|
||||
default=None,
|
||||
help="Path to the dataset.")
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
@ -356,6 +521,27 @@ if __name__ == "__main__":
|
||||
default=1000,
|
||||
help="Number of prompts to process.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sonnet-input-len",
|
||||
type=int,
|
||||
default=550,
|
||||
help=
|
||||
"Number of input tokens per request, used only for sonnet dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sonnet-output-len",
|
||||
type=int,
|
||||
default=150,
|
||||
help=
|
||||
"Number of output tokens per request, used only for sonnet dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sonnet-prefix-len",
|
||||
type=int,
|
||||
default=200,
|
||||
help=
|
||||
"Number of prefix tokens per request, used only for sonnet dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--request-rate",
|
||||
type=float,
|
||||
@ -381,6 +567,21 @@ if __name__ == "__main__":
|
||||
action="store_true",
|
||||
help="Specify to save benchmark results to a json file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--metadata",
|
||||
metavar="KEY=VALUE",
|
||||
nargs="*",
|
||||
help="Key-value pairs (e.g, --metadata version=0.3.3 tp=1) "
|
||||
"for metadata of this run to be saved in the result JSON file "
|
||||
"for record keeping purposes.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--result-dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Specify directory to save benchmark json results."
|
||||
"If not specified, results are saved in the current directory.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
518
benchmarks/sonnet.txt
Normal file
518
benchmarks/sonnet.txt
Normal file
@ -0,0 +1,518 @@
|
||||
FROM fairest creatures we desire increase,
|
||||
That thereby beauty's rose might never die,
|
||||
But as the riper should by time decease,
|
||||
His tender heir might bear his memory:
|
||||
But thou, contracted to thine own bright eyes,
|
||||
Feed'st thy light'st flame with self-substantial fuel,
|
||||
Making a famine where abundance lies,
|
||||
Thyself thy foe, to thy sweet self too cruel.
|
||||
Thou that art now the world's fresh ornament
|
||||
And only herald to the gaudy spring,
|
||||
Within thine own bud buriest thy content
|
||||
And, tender churl, makest waste in niggarding.
|
||||
Pity the world, or else this glutton be,
|
||||
To eat the world's due, by the grave and thee.
|
||||
When forty winters shall beseige thy brow,
|
||||
And dig deep trenches in thy beauty's field,
|
||||
Thy youth's proud livery, so gazed on now,
|
||||
Will be a tatter'd weed, of small worth held:
|
||||
Then being ask'd where all thy beauty lies,
|
||||
Where all the treasure of thy lusty days,
|
||||
To say, within thine own deep-sunken eyes,
|
||||
Were an all-eating shame and thriftless praise.
|
||||
How much more praise deserved thy beauty's use,
|
||||
If thou couldst answer 'This fair child of mine
|
||||
Shall sum my count and make my old excuse,'
|
||||
Proving his beauty by succession thine!
|
||||
This were to be new made when thou art old,
|
||||
And see thy blood warm when thou feel'st it cold.
|
||||
Look in thy glass, and tell the face thou viewest
|
||||
Now is the time that face should form another;
|
||||
Whose fresh repair if now thou not renewest,
|
||||
Thou dost beguile the world, unbless some mother.
|
||||
For where is she so fair whose unear'd womb
|
||||
Disdains the tillage of thy husbandry?
|
||||
Or who is he so fond will be the tomb
|
||||
Of his self-love, to stop posterity?
|
||||
Thou art thy mother's glass, and she in thee
|
||||
Calls back the lovely April of her prime:
|
||||
So thou through windows of thine age shall see
|
||||
Despite of wrinkles this thy golden time.
|
||||
But if thou live, remember'd not to be,
|
||||
Die single, and thine image dies with thee.
|
||||
Unthrifty loveliness, why dost thou spend
|
||||
Upon thyself thy beauty's legacy?
|
||||
Nature's bequest gives nothing but doth lend,
|
||||
And being frank she lends to those are free.
|
||||
Then, beauteous niggard, why dost thou abuse
|
||||
The bounteous largess given thee to give?
|
||||
Profitless usurer, why dost thou use
|
||||
So great a sum of sums, yet canst not live?
|
||||
For having traffic with thyself alone,
|
||||
Thou of thyself thy sweet self dost deceive.
|
||||
Then how, when nature calls thee to be gone,
|
||||
What acceptable audit canst thou leave?
|
||||
Thy unused beauty must be tomb'd with thee,
|
||||
Which, used, lives th' executor to be.
|
||||
Those hours, that with gentle work did frame
|
||||
The lovely gaze where every eye doth dwell,
|
||||
Will play the tyrants to the very same
|
||||
And that unfair which fairly doth excel:
|
||||
For never-resting time leads summer on
|
||||
To hideous winter and confounds him there;
|
||||
Sap cheque'd with frost and lusty leaves quite gone,
|
||||
Beauty o'ersnow'd and bareness every where:
|
||||
Then, were not summer's distillation left,
|
||||
A liquid prisoner pent in walls of glass,
|
||||
Beauty's effect with beauty were bereft,
|
||||
Nor it nor no remembrance what it was:
|
||||
But flowers distill'd though they with winter meet,
|
||||
Leese but their show; their substance still lives sweet.
|
||||
Then let not winter's ragged hand deface
|
||||
In thee thy summer, ere thou be distill'd:
|
||||
Make sweet some vial; treasure thou some place
|
||||
With beauty's treasure, ere it be self-kill'd.
|
||||
That use is not forbidden usury,
|
||||
Which happies those that pay the willing loan;
|
||||
That's for thyself to breed another thee,
|
||||
Or ten times happier, be it ten for one;
|
||||
Ten times thyself were happier than thou art,
|
||||
If ten of thine ten times refigured thee:
|
||||
Then what could death do, if thou shouldst depart,
|
||||
Leaving thee living in posterity?
|
||||
Be not self-will'd, for thou art much too fair
|
||||
To be death's conquest and make worms thine heir.
|
||||
Lo! in the orient when the gracious light
|
||||
Lifts up his burning head, each under eye
|
||||
Doth homage to his new-appearing sight,
|
||||
Serving with looks his sacred majesty;
|
||||
And having climb'd the steep-up heavenly hill,
|
||||
Resembling strong youth in his middle age,
|
||||
yet mortal looks adore his beauty still,
|
||||
Attending on his golden pilgrimage;
|
||||
But when from highmost pitch, with weary car,
|
||||
Like feeble age, he reeleth from the day,
|
||||
The eyes, 'fore duteous, now converted are
|
||||
From his low tract and look another way:
|
||||
So thou, thyself out-going in thy noon,
|
||||
Unlook'd on diest, unless thou get a son.
|
||||
Music to hear, why hear'st thou music sadly?
|
||||
Sweets with sweets war not, joy delights in joy.
|
||||
Why lovest thou that which thou receivest not gladly,
|
||||
Or else receivest with pleasure thine annoy?
|
||||
If the true concord of well-tuned sounds,
|
||||
By unions married, do offend thine ear,
|
||||
They do but sweetly chide thee, who confounds
|
||||
In singleness the parts that thou shouldst bear.
|
||||
Mark how one string, sweet husband to another,
|
||||
Strikes each in each by mutual ordering,
|
||||
Resembling sire and child and happy mother
|
||||
Who all in one, one pleasing note do sing:
|
||||
Whose speechless song, being many, seeming one,
|
||||
Sings this to thee: 'thou single wilt prove none.'
|
||||
Is it for fear to wet a widow's eye
|
||||
That thou consumest thyself in single life?
|
||||
Ah! if thou issueless shalt hap to die.
|
||||
The world will wail thee, like a makeless wife;
|
||||
The world will be thy widow and still weep
|
||||
That thou no form of thee hast left behind,
|
||||
When every private widow well may keep
|
||||
By children's eyes her husband's shape in mind.
|
||||
Look, what an unthrift in the world doth spend
|
||||
Shifts but his place, for still the world enjoys it;
|
||||
But beauty's waste hath in the world an end,
|
||||
And kept unused, the user so destroys it.
|
||||
No love toward others in that bosom sits
|
||||
That on himself such murderous shame commits.
|
||||
For shame! deny that thou bear'st love to any,
|
||||
Who for thyself art so unprovident.
|
||||
Grant, if thou wilt, thou art beloved of many,
|
||||
But that thou none lovest is most evident;
|
||||
For thou art so possess'd with murderous hate
|
||||
That 'gainst thyself thou stick'st not to conspire.
|
||||
Seeking that beauteous roof to ruinate
|
||||
Which to repair should be thy chief desire.
|
||||
O, change thy thought, that I may change my mind!
|
||||
Shall hate be fairer lodged than gentle love?
|
||||
Be, as thy presence is, gracious and kind,
|
||||
Or to thyself at least kind-hearted prove:
|
||||
Make thee another self, for love of me,
|
||||
That beauty still may live in thine or thee.
|
||||
As fast as thou shalt wane, so fast thou growest
|
||||
In one of thine, from that which thou departest;
|
||||
And that fresh blood which youngly thou bestowest
|
||||
Thou mayst call thine when thou from youth convertest.
|
||||
Herein lives wisdom, beauty and increase:
|
||||
Without this, folly, age and cold decay:
|
||||
If all were minded so, the times should cease
|
||||
And threescore year would make the world away.
|
||||
Let those whom Nature hath not made for store,
|
||||
Harsh featureless and rude, barrenly perish:
|
||||
Look, whom she best endow'd she gave the more;
|
||||
Which bounteous gift thou shouldst in bounty cherish:
|
||||
She carved thee for her seal, and meant thereby
|
||||
Thou shouldst print more, not let that copy die.
|
||||
When I do count the clock that tells the time,
|
||||
And see the brave day sunk in hideous night;
|
||||
When I behold the violet past prime,
|
||||
And sable curls all silver'd o'er with white;
|
||||
When lofty trees I see barren of leaves
|
||||
Which erst from heat did canopy the herd,
|
||||
And summer's green all girded up in sheaves
|
||||
Borne on the bier with white and bristly beard,
|
||||
Then of thy beauty do I question make,
|
||||
That thou among the wastes of time must go,
|
||||
Since sweets and beauties do themselves forsake
|
||||
And die as fast as they see others grow;
|
||||
And nothing 'gainst Time's scythe can make defence
|
||||
Save breed, to brave him when he takes thee hence.
|
||||
O, that you were yourself! but, love, you are
|
||||
No longer yours than you yourself here live:
|
||||
Against this coming end you should prepare,
|
||||
And your sweet semblance to some other give.
|
||||
So should that beauty which you hold in lease
|
||||
Find no determination: then you were
|
||||
Yourself again after yourself's decease,
|
||||
When your sweet issue your sweet form should bear.
|
||||
Who lets so fair a house fall to decay,
|
||||
Which husbandry in honour might uphold
|
||||
Against the stormy gusts of winter's day
|
||||
And barren rage of death's eternal cold?
|
||||
O, none but unthrifts! Dear my love, you know
|
||||
You had a father: let your son say so.
|
||||
Not from the stars do I my judgment pluck;
|
||||
And yet methinks I have astronomy,
|
||||
But not to tell of good or evil luck,
|
||||
Of plagues, of dearths, or seasons' quality;
|
||||
Nor can I fortune to brief minutes tell,
|
||||
Pointing to each his thunder, rain and wind,
|
||||
Or say with princes if it shall go well,
|
||||
By oft predict that I in heaven find:
|
||||
But from thine eyes my knowledge I derive,
|
||||
And, constant stars, in them I read such art
|
||||
As truth and beauty shall together thrive,
|
||||
If from thyself to store thou wouldst convert;
|
||||
Or else of thee this I prognosticate:
|
||||
Thy end is truth's and beauty's doom and date.
|
||||
When I consider every thing that grows
|
||||
Holds in perfection but a little moment,
|
||||
That this huge stage presenteth nought but shows
|
||||
Whereon the stars in secret influence comment;
|
||||
When I perceive that men as plants increase,
|
||||
Cheered and cheque'd even by the self-same sky,
|
||||
Vaunt in their youthful sap, at height decrease,
|
||||
And wear their brave state out of memory;
|
||||
Then the conceit of this inconstant stay
|
||||
Sets you most rich in youth before my sight,
|
||||
Where wasteful Time debateth with Decay,
|
||||
To change your day of youth to sullied night;
|
||||
And all in war with Time for love of you,
|
||||
As he takes from you, I engraft you new.
|
||||
But wherefore do not you a mightier way
|
||||
Make war upon this bloody tyrant, Time?
|
||||
And fortify yourself in your decay
|
||||
With means more blessed than my barren rhyme?
|
||||
Now stand you on the top of happy hours,
|
||||
And many maiden gardens yet unset
|
||||
With virtuous wish would bear your living flowers,
|
||||
Much liker than your painted counterfeit:
|
||||
So should the lines of life that life repair,
|
||||
Which this, Time's pencil, or my pupil pen,
|
||||
Neither in inward worth nor outward fair,
|
||||
Can make you live yourself in eyes of men.
|
||||
To give away yourself keeps yourself still,
|
||||
And you must live, drawn by your own sweet skill.
|
||||
Who will believe my verse in time to come,
|
||||
If it were fill'd with your most high deserts?
|
||||
Though yet, heaven knows, it is but as a tomb
|
||||
Which hides your life and shows not half your parts.
|
||||
If I could write the beauty of your eyes
|
||||
And in fresh numbers number all your graces,
|
||||
The age to come would say 'This poet lies:
|
||||
Such heavenly touches ne'er touch'd earthly faces.'
|
||||
So should my papers yellow'd with their age
|
||||
Be scorn'd like old men of less truth than tongue,
|
||||
And your true rights be term'd a poet's rage
|
||||
And stretched metre of an antique song:
|
||||
But were some child of yours alive that time,
|
||||
You should live twice; in it and in my rhyme.
|
||||
Shall I compare thee to a summer's day?
|
||||
Thou art more lovely and more temperate:
|
||||
Rough winds do shake the darling buds of May,
|
||||
And summer's lease hath all too short a date:
|
||||
Sometime too hot the eye of heaven shines,
|
||||
And often is his gold complexion dimm'd;
|
||||
And every fair from fair sometime declines,
|
||||
By chance or nature's changing course untrimm'd;
|
||||
But thy eternal summer shall not fade
|
||||
Nor lose possession of that fair thou owest;
|
||||
Nor shall Death brag thou wander'st in his shade,
|
||||
When in eternal lines to time thou growest:
|
||||
So long as men can breathe or eyes can see,
|
||||
So long lives this and this gives life to thee.
|
||||
Devouring Time, blunt thou the lion's paws,
|
||||
And make the earth devour her own sweet brood;
|
||||
Pluck the keen teeth from the fierce tiger's jaws,
|
||||
And burn the long-lived phoenix in her blood;
|
||||
Make glad and sorry seasons as thou fleets,
|
||||
And do whate'er thou wilt, swift-footed Time,
|
||||
To the wide world and all her fading sweets;
|
||||
But I forbid thee one most heinous crime:
|
||||
O, carve not with thy hours my love's fair brow,
|
||||
Nor draw no lines there with thine antique pen;
|
||||
Him in thy course untainted do allow
|
||||
For beauty's pattern to succeeding men.
|
||||
Yet, do thy worst, old Time: despite thy wrong,
|
||||
My love shall in my verse ever live young.
|
||||
A woman's face with Nature's own hand painted
|
||||
Hast thou, the master-mistress of my passion;
|
||||
A woman's gentle heart, but not acquainted
|
||||
With shifting change, as is false women's fashion;
|
||||
An eye more bright than theirs, less false in rolling,
|
||||
Gilding the object whereupon it gazeth;
|
||||
A man in hue, all 'hues' in his controlling,
|
||||
Much steals men's eyes and women's souls amazeth.
|
||||
And for a woman wert thou first created;
|
||||
Till Nature, as she wrought thee, fell a-doting,
|
||||
And by addition me of thee defeated,
|
||||
By adding one thing to my purpose nothing.
|
||||
But since she prick'd thee out for women's pleasure,
|
||||
Mine be thy love and thy love's use their treasure.
|
||||
So is it not with me as with that Muse
|
||||
Stirr'd by a painted beauty to his verse,
|
||||
Who heaven itself for ornament doth use
|
||||
And every fair with his fair doth rehearse
|
||||
Making a couplement of proud compare,
|
||||
With sun and moon, with earth and sea's rich gems,
|
||||
With April's first-born flowers, and all things rare
|
||||
That heaven's air in this huge rondure hems.
|
||||
O' let me, true in love, but truly write,
|
||||
And then believe me, my love is as fair
|
||||
As any mother's child, though not so bright
|
||||
As those gold candles fix'd in heaven's air:
|
||||
Let them say more than like of hearsay well;
|
||||
I will not praise that purpose not to sell.
|
||||
My glass shall not persuade me I am old,
|
||||
So long as youth and thou are of one date;
|
||||
But when in thee time's furrows I behold,
|
||||
Then look I death my days should expiate.
|
||||
For all that beauty that doth cover thee
|
||||
Is but the seemly raiment of my heart,
|
||||
Which in thy breast doth live, as thine in me:
|
||||
How can I then be elder than thou art?
|
||||
O, therefore, love, be of thyself so wary
|
||||
As I, not for myself, but for thee will;
|
||||
Bearing thy heart, which I will keep so chary
|
||||
As tender nurse her babe from faring ill.
|
||||
Presume not on thy heart when mine is slain;
|
||||
Thou gavest me thine, not to give back again.
|
||||
As an unperfect actor on the stage
|
||||
Who with his fear is put besides his part,
|
||||
Or some fierce thing replete with too much rage,
|
||||
Whose strength's abundance weakens his own heart.
|
||||
So I, for fear of trust, forget to say
|
||||
The perfect ceremony of love's rite,
|
||||
And in mine own love's strength seem to decay,
|
||||
O'ercharged with burden of mine own love's might.
|
||||
O, let my books be then the eloquence
|
||||
And dumb presagers of my speaking breast,
|
||||
Who plead for love and look for recompense
|
||||
More than that tongue that more hath more express'd.
|
||||
O, learn to read what silent love hath writ:
|
||||
To hear with eyes belongs to love's fine wit.
|
||||
Mine eye hath play'd the painter and hath stell'd
|
||||
Thy beauty's form in table of my heart;
|
||||
My body is the frame wherein 'tis held,
|
||||
And perspective it is the painter's art.
|
||||
For through the painter must you see his skill,
|
||||
To find where your true image pictured lies;
|
||||
Which in my bosom's shop is hanging still,
|
||||
That hath his windows glazed with thine eyes.
|
||||
Now see what good turns eyes for eyes have done:
|
||||
Mine eyes have drawn thy shape, and thine for me
|
||||
Are windows to my breast, where-through the sun
|
||||
Delights to peep, to gaze therein on thee;
|
||||
Yet eyes this cunning want to grace their art;
|
||||
They draw but what they see, know not the heart.
|
||||
Let those who are in favour with their stars
|
||||
Of public honour and proud titles boast,
|
||||
Whilst I, whom fortune of such triumph bars,
|
||||
Unlook'd for joy in that I honour most.
|
||||
Great princes' favourites their fair leaves spread
|
||||
But as the marigold at the sun's eye,
|
||||
And in themselves their pride lies buried,
|
||||
For at a frown they in their glory die.
|
||||
The painful warrior famoused for fight,
|
||||
After a thousand victories once foil'd,
|
||||
Is from the book of honour razed quite,
|
||||
And all the rest forgot for which he toil'd:
|
||||
Then happy I, that love and am beloved
|
||||
Where I may not remove nor be removed.
|
||||
Lord of my love, to whom in vassalage
|
||||
Thy merit hath my duty strongly knit,
|
||||
To thee I send this written embassage,
|
||||
To witness duty, not to show my wit:
|
||||
Duty so great, which wit so poor as mine
|
||||
May make seem bare, in wanting words to show it,
|
||||
But that I hope some good conceit of thine
|
||||
In thy soul's thought, all naked, will bestow it;
|
||||
Till whatsoever star that guides my moving
|
||||
Points on me graciously with fair aspect
|
||||
And puts apparel on my tatter'd loving,
|
||||
To show me worthy of thy sweet respect:
|
||||
Then may I dare to boast how I do love thee;
|
||||
Till then not show my head where thou mayst prove me.
|
||||
Weary with toil, I haste me to my bed,
|
||||
The dear repose for limbs with travel tired;
|
||||
But then begins a journey in my head,
|
||||
To work my mind, when body's work's expired:
|
||||
For then my thoughts, from far where I abide,
|
||||
Intend a zealous pilgrimage to thee,
|
||||
And keep my drooping eyelids open wide,
|
||||
Looking on darkness which the blind do see
|
||||
Save that my soul's imaginary sight
|
||||
Presents thy shadow to my sightless view,
|
||||
Which, like a jewel hung in ghastly night,
|
||||
Makes black night beauteous and her old face new.
|
||||
Lo! thus, by day my limbs, by night my mind,
|
||||
For thee and for myself no quiet find.
|
||||
How can I then return in happy plight,
|
||||
That am debarr'd the benefit of rest?
|
||||
When day's oppression is not eased by night,
|
||||
But day by night, and night by day, oppress'd?
|
||||
And each, though enemies to either's reign,
|
||||
Do in consent shake hands to torture me;
|
||||
The one by toil, the other to complain
|
||||
How far I toil, still farther off from thee.
|
||||
I tell the day, to please them thou art bright
|
||||
And dost him grace when clouds do blot the heaven:
|
||||
So flatter I the swart-complexion'd night,
|
||||
When sparkling stars twire not thou gild'st the even.
|
||||
But day doth daily draw my sorrows longer
|
||||
And night doth nightly make grief's strength seem stronger.
|
||||
When, in disgrace with fortune and men's eyes,
|
||||
I all alone beweep my outcast state
|
||||
And trouble deal heaven with my bootless cries
|
||||
And look upon myself and curse my fate,
|
||||
Wishing me like to one more rich in hope,
|
||||
Featured like him, like him with friends possess'd,
|
||||
Desiring this man's art and that man's scope,
|
||||
With what I most enjoy contented least;
|
||||
Yet in these thoughts myself almost despising,
|
||||
Haply I think on thee, and then my state,
|
||||
Like to the lark at break of day arising
|
||||
From sullen earth, sings hymns at heaven's gate;
|
||||
For thy sweet love remember'd such wealth brings
|
||||
That then I scorn to change my state with kings.
|
||||
When to the sessions of sweet silent thought
|
||||
I summon up remembrance of things past,
|
||||
I sigh the lack of many a thing I sought,
|
||||
And with old woes new wail my dear time's waste:
|
||||
Then can I drown an eye, unused to flow,
|
||||
For precious friends hid in death's dateless night,
|
||||
And weep afresh love's long since cancell'd woe,
|
||||
And moan the expense of many a vanish'd sight:
|
||||
Then can I grieve at grievances foregone,
|
||||
And heavily from woe to woe tell o'er
|
||||
The sad account of fore-bemoaned moan,
|
||||
Which I new pay as if not paid before.
|
||||
But if the while I think on thee, dear friend,
|
||||
All losses are restored and sorrows end.
|
||||
Thy bosom is endeared with all hearts,
|
||||
Which I by lacking have supposed dead,
|
||||
And there reigns love and all love's loving parts,
|
||||
And all those friends which I thought buried.
|
||||
How many a holy and obsequious tear
|
||||
Hath dear religious love stol'n from mine eye
|
||||
As interest of the dead, which now appear
|
||||
But things removed that hidden in thee lie!
|
||||
Thou art the grave where buried love doth live,
|
||||
Hung with the trophies of my lovers gone,
|
||||
Who all their parts of me to thee did give;
|
||||
That due of many now is thine alone:
|
||||
Their images I loved I view in thee,
|
||||
And thou, all they, hast all the all of me.
|
||||
If thou survive my well-contented day,
|
||||
When that churl Death my bones with dust shall cover,
|
||||
And shalt by fortune once more re-survey
|
||||
These poor rude lines of thy deceased lover,
|
||||
Compare them with the bettering of the time,
|
||||
And though they be outstripp'd by every pen,
|
||||
Reserve them for my love, not for their rhyme,
|
||||
Exceeded by the height of happier men.
|
||||
O, then vouchsafe me but this loving thought:
|
||||
'Had my friend's Muse grown with this growing age,
|
||||
A dearer birth than this his love had brought,
|
||||
To march in ranks of better equipage:
|
||||
But since he died and poets better prove,
|
||||
Theirs for their style I'll read, his for his love.'
|
||||
Full many a glorious morning have I seen
|
||||
Flatter the mountain-tops with sovereign eye,
|
||||
Kissing with golden face the meadows green,
|
||||
Gilding pale streams with heavenly alchemy;
|
||||
Anon permit the basest clouds to ride
|
||||
With ugly rack on his celestial face,
|
||||
And from the forlorn world his visage hide,
|
||||
Stealing unseen to west with this disgrace:
|
||||
Even so my sun one early morn did shine
|
||||
With all triumphant splendor on my brow;
|
||||
But out, alack! he was but one hour mine;
|
||||
The region cloud hath mask'd him from me now.
|
||||
Yet him for this my love no whit disdaineth;
|
||||
Suns of the world may stain when heaven's sun staineth.
|
||||
Why didst thou promise such a beauteous day,
|
||||
And make me travel forth without my cloak,
|
||||
To let base clouds o'ertake me in my way,
|
||||
Hiding thy bravery in their rotten smoke?
|
||||
'Tis not enough that through the cloud thou break,
|
||||
To dry the rain on my storm-beaten face,
|
||||
For no man well of such a salve can speak
|
||||
That heals the wound and cures not the disgrace:
|
||||
Nor can thy shame give physic to my grief;
|
||||
Though thou repent, yet I have still the loss:
|
||||
The offender's sorrow lends but weak relief
|
||||
To him that bears the strong offence's cross.
|
||||
Ah! but those tears are pearl which thy love sheds,
|
||||
And they are rich and ransom all ill deeds.
|
||||
No more be grieved at that which thou hast done:
|
||||
Roses have thorns, and silver fountains mud;
|
||||
Clouds and eclipses stain both moon and sun,
|
||||
And loathsome canker lives in sweetest bud.
|
||||
All men make faults, and even I in this,
|
||||
Authorizing thy trespass with compare,
|
||||
Myself corrupting, salving thy amiss,
|
||||
Excusing thy sins more than thy sins are;
|
||||
For to thy sensual fault I bring in sense--
|
||||
Thy adverse party is thy advocate--
|
||||
And 'gainst myself a lawful plea commence:
|
||||
Such civil war is in my love and hate
|
||||
That I an accessary needs must be
|
||||
To that sweet thief which sourly robs from me.
|
||||
Let me confess that we two must be twain,
|
||||
Although our undivided loves are one:
|
||||
So shall those blots that do with me remain
|
||||
Without thy help by me be borne alone.
|
||||
In our two loves there is but one respect,
|
||||
Though in our lives a separable spite,
|
||||
Which though it alter not love's sole effect,
|
||||
Yet doth it steal sweet hours from love's delight.
|
||||
I may not evermore acknowledge thee,
|
||||
Lest my bewailed guilt should do thee shame,
|
||||
Nor thou with public kindness honour me,
|
||||
Unless thou take that honour from thy name:
|
||||
But do not so; I love thee in such sort
|
||||
As, thou being mine, mine is thy good report.
|
||||
As a decrepit father takes delight
|
||||
To see his active child do deeds of youth,
|
||||
So I, made lame by fortune's dearest spite,
|
||||
Take all my comfort of thy worth and truth.
|
||||
For whether beauty, birth, or wealth, or wit,
|
||||
Or any of these all, or all, or more,
|
||||
Entitled in thy parts do crowned sit,
|
||||
I make my love engrafted to this store:
|
||||
So then I am not lame, poor, nor despised,
|
||||
Whilst that this shadow doth such substance give
|
||||
That I in thy abundance am sufficed
|
||||
And by a part of all thy glory live.
|
||||
Look, what is best, that best I wish in thee:
|
||||
This wish I have; then ten times happy me!
|
@ -50,7 +50,7 @@ exclude = "vllm/model_executor/parallel_utils/|vllm/model_executor/models/"
|
||||
|
||||
[tool.codespell]
|
||||
ignore-words-list = "dout, te, indicies"
|
||||
skip = "./tests/prompts"
|
||||
skip = "./tests/prompts,./benchmarks/sonnet.txt"
|
||||
|
||||
[tool.isort]
|
||||
use_parentheses = true
|
||||
|
@ -36,8 +36,8 @@ def test_contexted_kv_attention(
|
||||
torch.cuda.manual_seed(0)
|
||||
torch.set_default_device(device)
|
||||
|
||||
# Need this, otherwise when we capture the graph the process for GPU 1 would
|
||||
# run on both GPU0 and GPU1 and things would hang
|
||||
# Need this, otherwise when we capture the graph the process
|
||||
# for GPU 1 would run on both GPU0 and GPU1 and things would hang
|
||||
#
|
||||
# see also similar issue: https://github.com/Dao-AILab/flash-attention/issues/523
|
||||
torch.cuda.set_device(device)
|
||||
|
Loading…
x
Reference in New Issue
Block a user