diff --git a/.buildkite/run-benchmarks.sh b/.buildkite/run-benchmarks.sh index 86506862..6af6c814 100644 --- a/.buildkite/run-benchmarks.sh +++ b/.buildkite/run-benchmarks.sh @@ -23,8 +23,9 @@ wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/r # wait for server to start, timeout after 600 seconds timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1 python3 benchmarks/benchmark_serving.py \ - --backend openai \ - --dataset ./ShareGPT_V3_unfiltered_cleaned_split.json \ + --backend vllm \ + --dataset-name sharegpt \ + --dataset-path ./ShareGPT_V3_unfiltered_cleaned_split.json \ --model meta-llama/Llama-2-7b-chat-hf \ --num-prompts 20 \ --endpoint /v1/completions \ diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index 7e6f3c3e..96a372e5 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -1,8 +1,10 @@ import json import os +import sys import time -from dataclasses import dataclass -from typing import Optional +import traceback +from dataclasses import dataclass, field +from typing import List, Optional import aiohttp from tqdm.asyncio import tqdm @@ -26,8 +28,11 @@ class RequestFuncOutput: generated_text: str = "" success: bool = False latency: float = 0 - ttft: float = 0 + ttft: float = 0 # Time to first token + itl: List[float] = field( + default_factory=list) # List of inter-token latencies prompt_len: int = 0 + error: str = "" async def async_request_tgi( @@ -55,71 +60,38 @@ async def async_request_tgi( ttft = 0 st = time.perf_counter() + most_recent_timestamp = st try: async with session.post(url=api_url, json=payload) as response: if response.status == 200: - async for data in response.content.iter_any(): + async for chunk in response.content: + chunk = chunk.strip() + if not chunk: + continue + + chunk = remove_prefix(chunk.decode("utf-8"), "data:") + + data = json.loads(chunk) + timestamp = time.perf_counter() + # First token if ttft == 0: ttft = time.perf_counter() - st output.ttft = ttft - output.latency = time.perf_counter() - st - body = remove_prefix(data.decode("utf-8"), "data:") - output.generated_text = json.loads(body)["generated_text"] + # Decoding phase + else: + output.itl.append(timestamp - + most_recent_timestamp) + + most_recent_timestamp = timestamp + + output.latency = most_recent_timestamp - st output.success = True - else: - output.success = False - except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError): - output.success = False - - if pbar: - pbar.update(1) - return output - - -async def async_request_vllm( - request_func_input: RequestFuncInput, - pbar: Optional[tqdm] = None, -) -> RequestFuncOutput: - api_url = request_func_input.api_url - assert api_url.endswith("generate") - - async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: - payload = { - "prompt": request_func_input.prompt, - "n": 1, - "best_of": request_func_input.best_of, - "use_beam_search": request_func_input.use_beam_search, - "temperature": 0.0 if request_func_input.use_beam_search else 1.0, - "top_p": 1.0, - "max_tokens": request_func_input.output_len, - "ignore_eos": True, - "stream": True, - } - output = RequestFuncOutput() - output.prompt_len = request_func_input.prompt_len - - ttft = 0 - st = time.perf_counter() - try: - async with session.post(url=api_url, json=payload) as response: - if response.status == 200: - async for data in response.content.iter_any(): - if ttft == 0: - ttft = time.perf_counter() - st - output.ttft = ttft - output.latency = time.perf_counter() - st - - # When streaming, '\0' is appended to the end of response. - body = data.decode("utf-8").strip("\0") - output.generated_text = json.loads( - body)["text"][0][len(request_func_input.prompt):] - output.success = True - - else: - output.success = False - except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError): + output.generated_text = data["generated_text"] + except Exception: output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) if pbar: pbar.update(1) @@ -146,26 +118,45 @@ async def async_request_trt_llm( } output = RequestFuncOutput() output.prompt_len = request_func_input.prompt_len - ttft = 0 + ttft = 0 st = time.perf_counter() + most_recent_timestamp = st try: - async with session.post(url=api_url, json=payload) as resp: - if resp.status == 200: - async for data in resp.content.iter_any(): + async with session.post(url=api_url, json=payload) as response: + if response.status == 200: + async for chunk in response.content: + chunk = chunk.strip() + if not chunk: + continue + + chunk = remove_prefix(chunk.decode("utf-8"), "data:") + + data = json.loads(chunk) + timestamp = time.perf_counter() + # First token if ttft == 0: ttft = time.perf_counter() - st output.ttft = ttft - output.latency = time.perf_counter() - st - body = remove_prefix(data.decode("utf-8"), "data:") - output.generated_text = json.loads(body)["text_output"] + # Decoding phase + else: + output.itl.append(timestamp - + most_recent_timestamp) + + most_recent_timestamp = timestamp + + output.latency = most_recent_timestamp - st + output.generated_text = json.loads(data)["text_output"] output.success = True else: + output.error = response.reason output.success = False - except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError): + except Exception: output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) if pbar: pbar.update(1) @@ -181,35 +172,35 @@ async def async_request_deepspeed_mii( assert not request_func_input.use_beam_search payload = { - "prompts": request_func_input.prompt, - "max_new_tokens": request_func_input.output_len, - "ignore_eos": True, - "do_sample": True, - "temperature": - 0.01, # deepspeed-mii does not accept 0.0 temperature. + "prompt": request_func_input.prompt, + "max_tokens": request_func_input.output_len, + "temperature": 0.01, # deepspeed-mii does not accept 0.0 temp. "top_p": 1.0, } output = RequestFuncOutput() output.prompt_len = request_func_input.prompt_len - # DeepSpeed-MII doesn't support streaming as of Jan 28 2024, + # NOTE: DeepSpeed-MII doesn't support streaming as of Jan 28 2024, # will use 0 as placeholder. - # https://github.com/microsoft/DeepSpeed-MII/pull/311 + # See https://github.com/microsoft/DeepSpeed-MII/pull/311 output.ttft = 0 st = time.perf_counter() try: async with session.post(url=request_func_input.api_url, - json=payload) as resp: - if resp.status == 200: - parsed_resp = await resp.json() + json=payload) as response: + if response.status == 200: + parsed_resp = await response.json() output.latency = time.perf_counter() - st - output.generated_text = parsed_resp[0]["generated_text"] + output.generated_text = parsed_resp["text"][0] output.success = True else: + output.error = response.reason output.success = False - except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError): + except Exception: output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) if pbar: pbar.update(1) @@ -221,7 +212,9 @@ async def async_request_openai_completions( pbar: Optional[tqdm] = None, ) -> RequestFuncOutput: api_url = request_func_input.api_url - assert api_url.endswith("v1/completions") + assert api_url.endswith( + "v1/completions" + ), "OpenAI Completions API URL must end with 'v1/completions'." async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: assert not request_func_input.use_beam_search @@ -243,15 +236,12 @@ async def async_request_openai_completions( generated_text = "" ttft = 0 st = time.perf_counter() + most_recent_timestamp = st try: async with session.post(url=api_url, json=payload, headers=headers) as response: if response.status == 200: async for chunk in response.content: - if ttft == 0: - ttft = time.perf_counter() - st - output.ttft = ttft - chunk = chunk.strip() if not chunk: continue @@ -260,16 +250,33 @@ async def async_request_openai_completions( if chunk == "[DONE]": latency = time.perf_counter() - st else: - body = json.loads(chunk) - generated_text += body["choices"][0]["text"] + data = json.loads(chunk) + + if data["choices"][0]["text"]: + timestamp = time.perf_counter() + # First token + if ttft == 0: + ttft = time.perf_counter() - st + output.ttft = ttft + + # Decoding phase + # NOTE: Some completion API might have a last + # usage summary response without a token so we + # do not want to include as inter-token-latency + elif data.get("usage", None) is None: + output.itl.append(timestamp - + most_recent_timestamp) + + most_recent_timestamp = timestamp + generated_text += data["choices"][0]["text"] output.generated_text = generated_text output.success = True output.latency = latency - else: - output.success = False - except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError): + except Exception: output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) if pbar: pbar.update(1) @@ -283,7 +290,7 @@ async def async_request_openai_chat_completions( api_url = request_func_input.api_url assert api_url.endswith( "v1/chat/completions" - ), "OpenAI Chat API URL must end with 'v1/chat/completions'." + ), "OpenAI Chat Completions API URL must end with 'v1/chat/completions'." async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: assert not request_func_input.use_beam_search @@ -301,7 +308,7 @@ async def async_request_openai_chat_completions( } headers = { "Content-Type": "application/json", - "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}" + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", } output = RequestFuncOutput() @@ -310,15 +317,12 @@ async def async_request_openai_chat_completions( generated_text = "" ttft = 0 st = time.perf_counter() + most_recent_timestamp = st try: async with session.post(url=api_url, json=payload, headers=headers) as response: if response.status == 200: async for chunk in response.content: - if ttft == 0: - ttft = time.perf_counter() - st - output.ttft = ttft - chunk = chunk.strip() if not chunk: continue @@ -327,18 +331,35 @@ async def async_request_openai_chat_completions( if chunk == "[DONE]": latency = time.perf_counter() - st else: - body = json.loads(chunk) - if "content" in body["choices"][0]["delta"]: - generated_text += body["choices"][0]["delta"][ + timestamp = time.perf_counter() + data = json.loads(chunk) + + if "content" in data["choices"][0]["delta"]: + # First token + if ttft == 0: + ttft = time.perf_counter() - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append(timestamp - + most_recent_timestamp) + + generated_text += data["choices"][0]["delta"][ "content"] + most_recent_timestamp = timestamp + output.generated_text = generated_text output.success = True output.latency = latency else: + output.error = response.reason output.success = False - except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError): + except Exception: output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) if pbar: pbar.update(1) @@ -355,7 +376,8 @@ def remove_prefix(text: str, prefix: str) -> str: ASYNC_REQUEST_FUNCS = { "tgi": async_request_tgi, - "vllm": async_request_vllm, + "vllm": async_request_openai_completions, + "lmdeploy": async_request_openai_completions, "deepspeed-mii": async_request_deepspeed_mii, "openai": async_request_openai_completions, "openai-chat": async_request_openai_chat_completions, diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 976cd28b..bc7812ed 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -1,8 +1,8 @@ """Benchmark online serving throughput. On the server side, run one of the following commands: - (vLLM backend) - python -m vllm.entrypoints.api_server \ + vLLM OpenAI API server + python -m vllm.entrypoints.openai.api_server \ --model --swap-space 16 \ --disable-log-requests @@ -12,14 +12,19 @@ On the server side, run one of the following commands: On the client side, run: python benchmarks/benchmark_serving.py \ --backend \ - --model --dataset \ - --request-rate + --model \ + --dataset-name sharegpt \ + --dataset-path \ + --request-rate \ # By default is inf + --num-prompts # By default is 1000 """ import argparse import asyncio import json +import os import random import time +import warnings from dataclasses import dataclass from datetime import datetime from typing import AsyncGenerator, List, Tuple @@ -49,7 +54,7 @@ class BenchmarkMetrics: p99_tpot_ms: float -def sample_requests( +def sample_sharegpt_requests( dataset_path: str, num_requests: int, tokenizer: PreTrainedTokenizerBase, @@ -97,6 +102,73 @@ def sample_requests( return sampled_requests +def sample_sonnet_requests( + dataset_path: str, + num_requests: int, + input_len: int, + output_len: int, + prefix_len: int, + tokenizer: PreTrainedTokenizerBase, +) -> List[Tuple[str, str, int, int]]: + assert input_len > prefix_len, "input_len must be greater than prefix_len." + + # Load the dataset. + with open(dataset_path) as f: + poem_lines = f.readlines() + + # Tokenize the poem lines. + poem_token_ids = tokenizer(poem_lines).input_ids + average_poem_len = sum( + len(token_ids) for token_ids in poem_token_ids) / len(poem_token_ids) + + # Base prefix for all requests. + base_prompt = "Pick as many lines as you can from these poem lines:\n" + base_message = [{ + "role": "user", + "content": base_prompt, + }] + base_prompt_formatted = tokenizer.apply_chat_template( + base_message, add_generation_prompt=True, tokenize=False) + base_prompt_offset = len(tokenizer(base_prompt_formatted).input_ids) + + assert (input_len > base_prompt_offset + ), f"Please set 'args.input-len' higher than {base_prompt_offset}." + num_input_lines = round( + (input_len - base_prompt_offset) / average_poem_len) + + # First approximately `prefix_len` number of tokens in the + # prompt are fixed poem lines. + assert ( + prefix_len > base_prompt_offset + ), f"Please set 'args.prefix-len' higher than {base_prompt_offset}." + + num_prefix_lines = round( + (prefix_len - base_prompt_offset) / average_poem_len) + prefix_lines = poem_lines[:num_prefix_lines] + + # Sample the rest of lines per request. + sampled_requests: List[Tuple[str, int, int]] = [] + for _ in range(num_requests): + sampled_lines = "".join( + prefix_lines + + random.sample(poem_lines, num_input_lines - num_prefix_lines)) + + prompt = f"{base_prompt}{sampled_lines}" + message = [ + { + "role": "user", + "content": prompt, + }, + ] + prompt_formatted = tokenizer.apply_chat_template( + message, add_generation_prompt=True, tokenize=False) + prompt_len = len(tokenizer(prompt_formatted).input_ids) + sampled_requests.append( + (prompt, prompt_formatted, prompt_len, output_len)) + + return sampled_requests + + async def get_request( input_requests: List[Tuple[str, int, int]], request_rate: float, @@ -119,37 +191,42 @@ def calculate_metrics( outputs: List[RequestFuncOutput], dur_s: float, tokenizer: PreTrainedTokenizerBase, -) -> BenchmarkMetrics: - total_output = 0 +) -> Tuple[BenchmarkMetrics, List[int]]: + actual_output_lens = [] total_input = 0 completed = 0 - per_token_latencies = [] + tpots = [] ttfts = [] for i in range(len(outputs)): if outputs[i].success: - output_len = len(tokenizer.encode(outputs[i].generated_text)) - total_output += output_len + output_len = len(tokenizer(outputs[i].generated_text).input_ids) + actual_output_lens.append(output_len) total_input += input_requests[i][1] - per_token_latencies.append(outputs[i].latency / output_len) + if output_len > 1: + tpots.append( + (outputs[i].latency - outputs[i].ttft) / (output_len - 1)) ttfts.append(outputs[i].ttft) completed += 1 + else: + actual_output_lens.append(0) metrics = BenchmarkMetrics( completed=completed, total_input=total_input, - total_output=total_output, + total_output=sum(actual_output_lens), request_throughput=completed / dur_s, input_throughput=total_input / dur_s, - output_throughput=total_output / dur_s, - mean_ttft_ms=np.mean(ttfts) * 1000, - median_ttft_ms=np.median(ttfts) * 1000, - p99_ttft_ms=np.percentile(ttfts, 99) * 1000, - mean_tpot_ms=np.mean(per_token_latencies) * 1000, - median_tpot_ms=np.median(per_token_latencies) * 1000, - p99_tpot_ms=np.percentile(per_token_latencies, 99) * 1000, + output_throughput=sum(actual_output_lens) / dur_s, + mean_ttft_ms=np.mean(ttfts or 0) * + 1000, # ttfts is empty if streaming is not supported by backend + median_ttft_ms=np.median(ttfts or 0) * 1000, + p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000, + mean_tpot_ms=np.mean(tpots) * 1000, + median_tpot_ms=np.median(tpots) * 1000, + p99_tpot_ms=np.percentile(tpots, 99) * 1000, ) - return metrics + return metrics, actual_output_lens async def benchmark( @@ -189,40 +266,53 @@ async def benchmark( asyncio.create_task( request_func(request_func_input=request_func_input, pbar=pbar))) - outputs = await asyncio.gather(*tasks) + outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks) if not disable_tqdm: pbar.close() benchmark_duration = time.perf_counter() - benchmark_start_time - metrics = calculate_metrics( + metrics, actual_output_lens = calculate_metrics( input_requests=input_requests, outputs=outputs, dur_s=benchmark_duration, tokenizer=tokenizer, ) - print(f"Successful requests: {metrics.completed}") - print(f"Benchmark duration: {benchmark_duration:2f} s") - print(f"Total input tokens: {metrics.total_input}") - print(f"Total generated tokens: {metrics.total_output}") - print(f"Request throughput: {metrics.request_throughput:.2f} requests/s") - print(f"Input token throughput: {metrics.input_throughput:.2f} tokens/s") - print(f"Output token throughput: {metrics.output_throughput:.2f} tokens/s") - print(f"Mean TTFT: {metrics.mean_ttft_ms:.2f} ms") - print(f"Median TTFT: {metrics.median_ttft_ms:.2f} ms") - print(f"P99 TTFT: {metrics.p99_ttft_ms:.2f} ms") - print(f"Mean TPOT: {metrics.mean_tpot_ms:.2f} ms") - print(f"Median TPOT: {metrics.median_tpot_ms:.2f} ms") - print(f"P99 TPOT: {metrics.p99_tpot_ms:.2f} ms") + print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='=')) + print("{:<40} {:<10}".format("Successful requests:", metrics.completed)) + print("{:<40} {:<10.2f}".format("Benchmark duration (s):", + benchmark_duration)) + print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input)) + print("{:<40} {:<10}".format("Total generated tokens:", + metrics.total_output)) + print("{:<40} {:<10.2f}".format("Request throughput (req/s):", + metrics.request_throughput)) + print("{:<40} {:<10.2f}".format("Input token throughput (tok/s):", + metrics.input_throughput)) + print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):", + metrics.output_throughput)) + print("{s:{c}^{n}}".format(s='Time to First Token', n=50, c='-')) + print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms)) + print("{:<40} {:<10.2f}".format("Median TTFT (ms):", + metrics.median_ttft_ms)) + print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms)) + print("{s:{c}^{n}}".format(s='Time per Output Token (excl. 1st token)', + n=50, + c='-')) + print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms)) + print("{:<40} {:<10.2f}".format("Median TPOT (ms):", + metrics.median_tpot_ms)) + print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms)) + print("=" * 50) result = { "duration": benchmark_duration, "completed": metrics.completed, "total_input_tokens": metrics.total_input, "total_output_tokens": metrics.total_output, - "request_inthroughput": metrics.request_throughput, + "request_throughput": metrics.request_throughput, "input_throughput": metrics.input_throughput, "output_throughput": metrics.output_throughput, "mean_ttft_ms": metrics.mean_ttft_ms, @@ -230,7 +320,13 @@ async def benchmark( "p99_ttft_ms": metrics.p99_ttft_ms, "mean_tpot_ms": metrics.mean_tpot_ms, "median_tpot_ms": metrics.median_tpot_ms, - "p99_tpot_ms": metrics.p99_tpot_ms + "p99_tpot_ms": metrics.p99_tpot_ms, + "input_lens": [output.prompt_len for output in outputs], + "output_lens": actual_output_lens, + "ttfts": [output.ttft for output in outputs], + "itls": [output.itl for output in outputs], + "generated_texts": [output.generated_text for output in outputs], + "errors": [output.error for output in outputs], } return result @@ -251,7 +347,58 @@ def main(args: argparse.Namespace): tokenizer = get_tokenizer(tokenizer_id, trust_remote_code=args.trust_remote_code) - input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer) + + if args.dataset is not None: + warnings.warn( + "The '--dataset' argument will be deprecated in the next " + "release. Please use '--dataset-name' and " + "'--dataset-path' in the future runs.", + stacklevel=2) + input_requests = sample_sharegpt_requests( + dataset_path=args.dataset, + num_requests=args.num_prompts, + tokenizer=tokenizer, + ) + + elif args.dataset_name == "sharegpt": + input_requests = sample_sharegpt_requests( + dataset_path=args.dataset_path, + num_requests=args.num_prompts, + tokenizer=tokenizer, + ) + + elif args.dataset_name == "sonnet": + # Do not format the prompt, pass to message directly + if args.backend == "openai-chat": + input_requests = sample_sonnet_requests( + dataset_path=args.dataset_path, + num_requests=args.num_prompts, + input_len=args.input_len, + output_len=args.output_len, + prefix_len=args.prefix_len, + tokenizer=tokenizer, + ) + input_requests = [(prompt, prompt_len, output_len) + for prompt, prompt_formatted, prompt_len, + output_len in input_requests] + else: + assert ( + tokenizer.chat_template or tokenizer.default_chat_template + ), "Tokenizer/model must have chat template for sonnet dataset." + input_requests = sample_sonnet_requests( + dataset_path=args.dataset_path, + num_requests=args.num_prompts, + input_len=args.input_len, + output_len=args.output_len, + prefix_len=args.prefix_len, + tokenizer=tokenizer, + ) + input_requests = [(prompt_formatted, prompt_len, output_len) + for prompt, prompt_formatted, prompt_len, + output_len in input_requests] + + else: + raise ValueError(f"Unknown dataset: {args.dataset_name}") benchmark_result = asyncio.run( benchmark( @@ -274,13 +421,23 @@ def main(args: argparse.Namespace): current_dt = datetime.now().strftime("%Y%m%d-%H%M%S") result_json["date"] = current_dt result_json["backend"] = backend - result_json["version"] = args.version result_json["model_id"] = model_id result_json["tokenizer_id"] = tokenizer_id result_json["best_of"] = args.best_of result_json["use_beam_search"] = args.use_beam_search result_json["num_prompts"] = args.num_prompts + # Metadata + if args.metadata: + for item in args.metadata: + if "=" in item: + kvstring = item.split("=") + result_json[kvstring[0].strip()] = kvstring[1].strip() + else: + raise ValueError( + "Invalid metadata format. Please use KEY=VALUE format." + ) + # Traffic result_json["request_rate"] = ( args.request_rate if args.request_rate < float("inf") else "inf") @@ -290,9 +447,9 @@ def main(args: argparse.Namespace): # Save to file base_model_id = model_id.split("/")[-1] - file_name = ( - f"{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json" - ) + file_name = f"{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json" #noqa + if args.result_dir: + file_name = os.path.join(args.result_dir, file_name) with open(file_name, "w") as outfile: json.dump(result_json, outfile) @@ -306,12 +463,6 @@ if __name__ == "__main__": default="vllm", choices=list(ASYNC_REQUEST_FUNCS.keys()), ) - parser.add_argument( - "--version", - type=str, - default="N/A", - help="Version of the serving backend/engine.", - ) parser.add_argument( "--base-url", type=str, @@ -323,12 +474,26 @@ if __name__ == "__main__": parser.add_argument( "--endpoint", type=str, - default="/generate", + default="/v1/completions", help="API endpoint.", ) - parser.add_argument("--dataset", + parser.add_argument( + "--dataset", + type=str, + default=None, + help="Path to the ShareGPT dataset, will be deprecated in the " + "next release.", + ) + parser.add_argument( + "--dataset-name", + type=str, + default="sharegpt", + choices=["sharegpt", "sonnet"], + help="Name of the dataset to benchmark on.", + ) + parser.add_argument("--dataset-path", type=str, - required=True, + default=None, help="Path to the dataset.") parser.add_argument( "--model", @@ -356,6 +521,27 @@ if __name__ == "__main__": default=1000, help="Number of prompts to process.", ) + parser.add_argument( + "--sonnet-input-len", + type=int, + default=550, + help= + "Number of input tokens per request, used only for sonnet dataset.", + ) + parser.add_argument( + "--sonnet-output-len", + type=int, + default=150, + help= + "Number of output tokens per request, used only for sonnet dataset.", + ) + parser.add_argument( + "--sonnet-prefix-len", + type=int, + default=200, + help= + "Number of prefix tokens per request, used only for sonnet dataset.", + ) parser.add_argument( "--request-rate", type=float, @@ -381,6 +567,21 @@ if __name__ == "__main__": action="store_true", help="Specify to save benchmark results to a json file", ) + parser.add_argument( + "--metadata", + metavar="KEY=VALUE", + nargs="*", + help="Key-value pairs (e.g, --metadata version=0.3.3 tp=1) " + "for metadata of this run to be saved in the result JSON file " + "for record keeping purposes.", + ) + parser.add_argument( + "--result-dir", + type=str, + default=None, + help="Specify directory to save benchmark json results." + "If not specified, results are saved in the current directory.", + ) args = parser.parse_args() main(args) diff --git a/benchmarks/sonnet.txt b/benchmarks/sonnet.txt new file mode 100644 index 00000000..34c444e8 --- /dev/null +++ b/benchmarks/sonnet.txt @@ -0,0 +1,518 @@ +FROM fairest creatures we desire increase, +That thereby beauty's rose might never die, +But as the riper should by time decease, +His tender heir might bear his memory: +But thou, contracted to thine own bright eyes, +Feed'st thy light'st flame with self-substantial fuel, +Making a famine where abundance lies, +Thyself thy foe, to thy sweet self too cruel. +Thou that art now the world's fresh ornament +And only herald to the gaudy spring, +Within thine own bud buriest thy content +And, tender churl, makest waste in niggarding. +Pity the world, or else this glutton be, +To eat the world's due, by the grave and thee. +When forty winters shall beseige thy brow, +And dig deep trenches in thy beauty's field, +Thy youth's proud livery, so gazed on now, +Will be a tatter'd weed, of small worth held: +Then being ask'd where all thy beauty lies, +Where all the treasure of thy lusty days, +To say, within thine own deep-sunken eyes, +Were an all-eating shame and thriftless praise. +How much more praise deserved thy beauty's use, +If thou couldst answer 'This fair child of mine +Shall sum my count and make my old excuse,' +Proving his beauty by succession thine! +This were to be new made when thou art old, +And see thy blood warm when thou feel'st it cold. +Look in thy glass, and tell the face thou viewest +Now is the time that face should form another; +Whose fresh repair if now thou not renewest, +Thou dost beguile the world, unbless some mother. +For where is she so fair whose unear'd womb +Disdains the tillage of thy husbandry? +Or who is he so fond will be the tomb +Of his self-love, to stop posterity? +Thou art thy mother's glass, and she in thee +Calls back the lovely April of her prime: +So thou through windows of thine age shall see +Despite of wrinkles this thy golden time. +But if thou live, remember'd not to be, +Die single, and thine image dies with thee. +Unthrifty loveliness, why dost thou spend +Upon thyself thy beauty's legacy? +Nature's bequest gives nothing but doth lend, +And being frank she lends to those are free. +Then, beauteous niggard, why dost thou abuse +The bounteous largess given thee to give? +Profitless usurer, why dost thou use +So great a sum of sums, yet canst not live? +For having traffic with thyself alone, +Thou of thyself thy sweet self dost deceive. +Then how, when nature calls thee to be gone, +What acceptable audit canst thou leave? +Thy unused beauty must be tomb'd with thee, +Which, used, lives th' executor to be. +Those hours, that with gentle work did frame +The lovely gaze where every eye doth dwell, +Will play the tyrants to the very same +And that unfair which fairly doth excel: +For never-resting time leads summer on +To hideous winter and confounds him there; +Sap cheque'd with frost and lusty leaves quite gone, +Beauty o'ersnow'd and bareness every where: +Then, were not summer's distillation left, +A liquid prisoner pent in walls of glass, +Beauty's effect with beauty were bereft, +Nor it nor no remembrance what it was: +But flowers distill'd though they with winter meet, +Leese but their show; their substance still lives sweet. +Then let not winter's ragged hand deface +In thee thy summer, ere thou be distill'd: +Make sweet some vial; treasure thou some place +With beauty's treasure, ere it be self-kill'd. +That use is not forbidden usury, +Which happies those that pay the willing loan; +That's for thyself to breed another thee, +Or ten times happier, be it ten for one; +Ten times thyself were happier than thou art, +If ten of thine ten times refigured thee: +Then what could death do, if thou shouldst depart, +Leaving thee living in posterity? +Be not self-will'd, for thou art much too fair +To be death's conquest and make worms thine heir. +Lo! in the orient when the gracious light +Lifts up his burning head, each under eye +Doth homage to his new-appearing sight, +Serving with looks his sacred majesty; +And having climb'd the steep-up heavenly hill, +Resembling strong youth in his middle age, +yet mortal looks adore his beauty still, +Attending on his golden pilgrimage; +But when from highmost pitch, with weary car, +Like feeble age, he reeleth from the day, +The eyes, 'fore duteous, now converted are +From his low tract and look another way: +So thou, thyself out-going in thy noon, +Unlook'd on diest, unless thou get a son. +Music to hear, why hear'st thou music sadly? +Sweets with sweets war not, joy delights in joy. +Why lovest thou that which thou receivest not gladly, +Or else receivest with pleasure thine annoy? +If the true concord of well-tuned sounds, +By unions married, do offend thine ear, +They do but sweetly chide thee, who confounds +In singleness the parts that thou shouldst bear. +Mark how one string, sweet husband to another, +Strikes each in each by mutual ordering, +Resembling sire and child and happy mother +Who all in one, one pleasing note do sing: +Whose speechless song, being many, seeming one, +Sings this to thee: 'thou single wilt prove none.' +Is it for fear to wet a widow's eye +That thou consumest thyself in single life? +Ah! if thou issueless shalt hap to die. +The world will wail thee, like a makeless wife; +The world will be thy widow and still weep +That thou no form of thee hast left behind, +When every private widow well may keep +By children's eyes her husband's shape in mind. +Look, what an unthrift in the world doth spend +Shifts but his place, for still the world enjoys it; +But beauty's waste hath in the world an end, +And kept unused, the user so destroys it. +No love toward others in that bosom sits +That on himself such murderous shame commits. +For shame! deny that thou bear'st love to any, +Who for thyself art so unprovident. +Grant, if thou wilt, thou art beloved of many, +But that thou none lovest is most evident; +For thou art so possess'd with murderous hate +That 'gainst thyself thou stick'st not to conspire. +Seeking that beauteous roof to ruinate +Which to repair should be thy chief desire. +O, change thy thought, that I may change my mind! +Shall hate be fairer lodged than gentle love? +Be, as thy presence is, gracious and kind, +Or to thyself at least kind-hearted prove: +Make thee another self, for love of me, +That beauty still may live in thine or thee. +As fast as thou shalt wane, so fast thou growest +In one of thine, from that which thou departest; +And that fresh blood which youngly thou bestowest +Thou mayst call thine when thou from youth convertest. +Herein lives wisdom, beauty and increase: +Without this, folly, age and cold decay: +If all were minded so, the times should cease +And threescore year would make the world away. +Let those whom Nature hath not made for store, +Harsh featureless and rude, barrenly perish: +Look, whom she best endow'd she gave the more; +Which bounteous gift thou shouldst in bounty cherish: +She carved thee for her seal, and meant thereby +Thou shouldst print more, not let that copy die. +When I do count the clock that tells the time, +And see the brave day sunk in hideous night; +When I behold the violet past prime, +And sable curls all silver'd o'er with white; +When lofty trees I see barren of leaves +Which erst from heat did canopy the herd, +And summer's green all girded up in sheaves +Borne on the bier with white and bristly beard, +Then of thy beauty do I question make, +That thou among the wastes of time must go, +Since sweets and beauties do themselves forsake +And die as fast as they see others grow; +And nothing 'gainst Time's scythe can make defence +Save breed, to brave him when he takes thee hence. +O, that you were yourself! but, love, you are +No longer yours than you yourself here live: +Against this coming end you should prepare, +And your sweet semblance to some other give. +So should that beauty which you hold in lease +Find no determination: then you were +Yourself again after yourself's decease, +When your sweet issue your sweet form should bear. +Who lets so fair a house fall to decay, +Which husbandry in honour might uphold +Against the stormy gusts of winter's day +And barren rage of death's eternal cold? +O, none but unthrifts! Dear my love, you know +You had a father: let your son say so. +Not from the stars do I my judgment pluck; +And yet methinks I have astronomy, +But not to tell of good or evil luck, +Of plagues, of dearths, or seasons' quality; +Nor can I fortune to brief minutes tell, +Pointing to each his thunder, rain and wind, +Or say with princes if it shall go well, +By oft predict that I in heaven find: +But from thine eyes my knowledge I derive, +And, constant stars, in them I read such art +As truth and beauty shall together thrive, +If from thyself to store thou wouldst convert; +Or else of thee this I prognosticate: +Thy end is truth's and beauty's doom and date. +When I consider every thing that grows +Holds in perfection but a little moment, +That this huge stage presenteth nought but shows +Whereon the stars in secret influence comment; +When I perceive that men as plants increase, +Cheered and cheque'd even by the self-same sky, +Vaunt in their youthful sap, at height decrease, +And wear their brave state out of memory; +Then the conceit of this inconstant stay +Sets you most rich in youth before my sight, +Where wasteful Time debateth with Decay, +To change your day of youth to sullied night; +And all in war with Time for love of you, +As he takes from you, I engraft you new. +But wherefore do not you a mightier way +Make war upon this bloody tyrant, Time? +And fortify yourself in your decay +With means more blessed than my barren rhyme? +Now stand you on the top of happy hours, +And many maiden gardens yet unset +With virtuous wish would bear your living flowers, +Much liker than your painted counterfeit: +So should the lines of life that life repair, +Which this, Time's pencil, or my pupil pen, +Neither in inward worth nor outward fair, +Can make you live yourself in eyes of men. +To give away yourself keeps yourself still, +And you must live, drawn by your own sweet skill. +Who will believe my verse in time to come, +If it were fill'd with your most high deserts? +Though yet, heaven knows, it is but as a tomb +Which hides your life and shows not half your parts. +If I could write the beauty of your eyes +And in fresh numbers number all your graces, +The age to come would say 'This poet lies: +Such heavenly touches ne'er touch'd earthly faces.' +So should my papers yellow'd with their age +Be scorn'd like old men of less truth than tongue, +And your true rights be term'd a poet's rage +And stretched metre of an antique song: +But were some child of yours alive that time, +You should live twice; in it and in my rhyme. +Shall I compare thee to a summer's day? +Thou art more lovely and more temperate: +Rough winds do shake the darling buds of May, +And summer's lease hath all too short a date: +Sometime too hot the eye of heaven shines, +And often is his gold complexion dimm'd; +And every fair from fair sometime declines, +By chance or nature's changing course untrimm'd; +But thy eternal summer shall not fade +Nor lose possession of that fair thou owest; +Nor shall Death brag thou wander'st in his shade, +When in eternal lines to time thou growest: +So long as men can breathe or eyes can see, +So long lives this and this gives life to thee. +Devouring Time, blunt thou the lion's paws, +And make the earth devour her own sweet brood; +Pluck the keen teeth from the fierce tiger's jaws, +And burn the long-lived phoenix in her blood; +Make glad and sorry seasons as thou fleets, +And do whate'er thou wilt, swift-footed Time, +To the wide world and all her fading sweets; +But I forbid thee one most heinous crime: +O, carve not with thy hours my love's fair brow, +Nor draw no lines there with thine antique pen; +Him in thy course untainted do allow +For beauty's pattern to succeeding men. +Yet, do thy worst, old Time: despite thy wrong, +My love shall in my verse ever live young. +A woman's face with Nature's own hand painted +Hast thou, the master-mistress of my passion; +A woman's gentle heart, but not acquainted +With shifting change, as is false women's fashion; +An eye more bright than theirs, less false in rolling, +Gilding the object whereupon it gazeth; +A man in hue, all 'hues' in his controlling, +Much steals men's eyes and women's souls amazeth. +And for a woman wert thou first created; +Till Nature, as she wrought thee, fell a-doting, +And by addition me of thee defeated, +By adding one thing to my purpose nothing. +But since she prick'd thee out for women's pleasure, +Mine be thy love and thy love's use their treasure. +So is it not with me as with that Muse +Stirr'd by a painted beauty to his verse, +Who heaven itself for ornament doth use +And every fair with his fair doth rehearse +Making a couplement of proud compare, +With sun and moon, with earth and sea's rich gems, +With April's first-born flowers, and all things rare +That heaven's air in this huge rondure hems. +O' let me, true in love, but truly write, +And then believe me, my love is as fair +As any mother's child, though not so bright +As those gold candles fix'd in heaven's air: +Let them say more than like of hearsay well; +I will not praise that purpose not to sell. +My glass shall not persuade me I am old, +So long as youth and thou are of one date; +But when in thee time's furrows I behold, +Then look I death my days should expiate. +For all that beauty that doth cover thee +Is but the seemly raiment of my heart, +Which in thy breast doth live, as thine in me: +How can I then be elder than thou art? +O, therefore, love, be of thyself so wary +As I, not for myself, but for thee will; +Bearing thy heart, which I will keep so chary +As tender nurse her babe from faring ill. +Presume not on thy heart when mine is slain; +Thou gavest me thine, not to give back again. +As an unperfect actor on the stage +Who with his fear is put besides his part, +Or some fierce thing replete with too much rage, +Whose strength's abundance weakens his own heart. +So I, for fear of trust, forget to say +The perfect ceremony of love's rite, +And in mine own love's strength seem to decay, +O'ercharged with burden of mine own love's might. +O, let my books be then the eloquence +And dumb presagers of my speaking breast, +Who plead for love and look for recompense +More than that tongue that more hath more express'd. +O, learn to read what silent love hath writ: +To hear with eyes belongs to love's fine wit. +Mine eye hath play'd the painter and hath stell'd +Thy beauty's form in table of my heart; +My body is the frame wherein 'tis held, +And perspective it is the painter's art. +For through the painter must you see his skill, +To find where your true image pictured lies; +Which in my bosom's shop is hanging still, +That hath his windows glazed with thine eyes. +Now see what good turns eyes for eyes have done: +Mine eyes have drawn thy shape, and thine for me +Are windows to my breast, where-through the sun +Delights to peep, to gaze therein on thee; +Yet eyes this cunning want to grace their art; +They draw but what they see, know not the heart. +Let those who are in favour with their stars +Of public honour and proud titles boast, +Whilst I, whom fortune of such triumph bars, +Unlook'd for joy in that I honour most. +Great princes' favourites their fair leaves spread +But as the marigold at the sun's eye, +And in themselves their pride lies buried, +For at a frown they in their glory die. +The painful warrior famoused for fight, +After a thousand victories once foil'd, +Is from the book of honour razed quite, +And all the rest forgot for which he toil'd: +Then happy I, that love and am beloved +Where I may not remove nor be removed. +Lord of my love, to whom in vassalage +Thy merit hath my duty strongly knit, +To thee I send this written embassage, +To witness duty, not to show my wit: +Duty so great, which wit so poor as mine +May make seem bare, in wanting words to show it, +But that I hope some good conceit of thine +In thy soul's thought, all naked, will bestow it; +Till whatsoever star that guides my moving +Points on me graciously with fair aspect +And puts apparel on my tatter'd loving, +To show me worthy of thy sweet respect: +Then may I dare to boast how I do love thee; +Till then not show my head where thou mayst prove me. +Weary with toil, I haste me to my bed, +The dear repose for limbs with travel tired; +But then begins a journey in my head, +To work my mind, when body's work's expired: +For then my thoughts, from far where I abide, +Intend a zealous pilgrimage to thee, +And keep my drooping eyelids open wide, +Looking on darkness which the blind do see +Save that my soul's imaginary sight +Presents thy shadow to my sightless view, +Which, like a jewel hung in ghastly night, +Makes black night beauteous and her old face new. +Lo! thus, by day my limbs, by night my mind, +For thee and for myself no quiet find. +How can I then return in happy plight, +That am debarr'd the benefit of rest? +When day's oppression is not eased by night, +But day by night, and night by day, oppress'd? +And each, though enemies to either's reign, +Do in consent shake hands to torture me; +The one by toil, the other to complain +How far I toil, still farther off from thee. +I tell the day, to please them thou art bright +And dost him grace when clouds do blot the heaven: +So flatter I the swart-complexion'd night, +When sparkling stars twire not thou gild'st the even. +But day doth daily draw my sorrows longer +And night doth nightly make grief's strength seem stronger. +When, in disgrace with fortune and men's eyes, +I all alone beweep my outcast state +And trouble deal heaven with my bootless cries +And look upon myself and curse my fate, +Wishing me like to one more rich in hope, +Featured like him, like him with friends possess'd, +Desiring this man's art and that man's scope, +With what I most enjoy contented least; +Yet in these thoughts myself almost despising, +Haply I think on thee, and then my state, +Like to the lark at break of day arising +From sullen earth, sings hymns at heaven's gate; +For thy sweet love remember'd such wealth brings +That then I scorn to change my state with kings. +When to the sessions of sweet silent thought +I summon up remembrance of things past, +I sigh the lack of many a thing I sought, +And with old woes new wail my dear time's waste: +Then can I drown an eye, unused to flow, +For precious friends hid in death's dateless night, +And weep afresh love's long since cancell'd woe, +And moan the expense of many a vanish'd sight: +Then can I grieve at grievances foregone, +And heavily from woe to woe tell o'er +The sad account of fore-bemoaned moan, +Which I new pay as if not paid before. +But if the while I think on thee, dear friend, +All losses are restored and sorrows end. +Thy bosom is endeared with all hearts, +Which I by lacking have supposed dead, +And there reigns love and all love's loving parts, +And all those friends which I thought buried. +How many a holy and obsequious tear +Hath dear religious love stol'n from mine eye +As interest of the dead, which now appear +But things removed that hidden in thee lie! +Thou art the grave where buried love doth live, +Hung with the trophies of my lovers gone, +Who all their parts of me to thee did give; +That due of many now is thine alone: +Their images I loved I view in thee, +And thou, all they, hast all the all of me. +If thou survive my well-contented day, +When that churl Death my bones with dust shall cover, +And shalt by fortune once more re-survey +These poor rude lines of thy deceased lover, +Compare them with the bettering of the time, +And though they be outstripp'd by every pen, +Reserve them for my love, not for their rhyme, +Exceeded by the height of happier men. +O, then vouchsafe me but this loving thought: +'Had my friend's Muse grown with this growing age, +A dearer birth than this his love had brought, +To march in ranks of better equipage: +But since he died and poets better prove, +Theirs for their style I'll read, his for his love.' +Full many a glorious morning have I seen +Flatter the mountain-tops with sovereign eye, +Kissing with golden face the meadows green, +Gilding pale streams with heavenly alchemy; +Anon permit the basest clouds to ride +With ugly rack on his celestial face, +And from the forlorn world his visage hide, +Stealing unseen to west with this disgrace: +Even so my sun one early morn did shine +With all triumphant splendor on my brow; +But out, alack! he was but one hour mine; +The region cloud hath mask'd him from me now. +Yet him for this my love no whit disdaineth; +Suns of the world may stain when heaven's sun staineth. +Why didst thou promise such a beauteous day, +And make me travel forth without my cloak, +To let base clouds o'ertake me in my way, +Hiding thy bravery in their rotten smoke? +'Tis not enough that through the cloud thou break, +To dry the rain on my storm-beaten face, +For no man well of such a salve can speak +That heals the wound and cures not the disgrace: +Nor can thy shame give physic to my grief; +Though thou repent, yet I have still the loss: +The offender's sorrow lends but weak relief +To him that bears the strong offence's cross. +Ah! but those tears are pearl which thy love sheds, +And they are rich and ransom all ill deeds. +No more be grieved at that which thou hast done: +Roses have thorns, and silver fountains mud; +Clouds and eclipses stain both moon and sun, +And loathsome canker lives in sweetest bud. +All men make faults, and even I in this, +Authorizing thy trespass with compare, +Myself corrupting, salving thy amiss, +Excusing thy sins more than thy sins are; +For to thy sensual fault I bring in sense-- +Thy adverse party is thy advocate-- +And 'gainst myself a lawful plea commence: +Such civil war is in my love and hate +That I an accessary needs must be +To that sweet thief which sourly robs from me. +Let me confess that we two must be twain, +Although our undivided loves are one: +So shall those blots that do with me remain +Without thy help by me be borne alone. +In our two loves there is but one respect, +Though in our lives a separable spite, +Which though it alter not love's sole effect, +Yet doth it steal sweet hours from love's delight. +I may not evermore acknowledge thee, +Lest my bewailed guilt should do thee shame, +Nor thou with public kindness honour me, +Unless thou take that honour from thy name: +But do not so; I love thee in such sort +As, thou being mine, mine is thy good report. +As a decrepit father takes delight +To see his active child do deeds of youth, +So I, made lame by fortune's dearest spite, +Take all my comfort of thy worth and truth. +For whether beauty, birth, or wealth, or wit, +Or any of these all, or all, or more, +Entitled in thy parts do crowned sit, +I make my love engrafted to this store: +So then I am not lame, poor, nor despised, +Whilst that this shadow doth such substance give +That I in thy abundance am sufficed +And by a part of all thy glory live. +Look, what is best, that best I wish in thee: +This wish I have; then ten times happy me! \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 4d6fb5a3..9d042601 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,7 @@ exclude = "vllm/model_executor/parallel_utils/|vllm/model_executor/models/" [tool.codespell] ignore-words-list = "dout, te, indicies" -skip = "./tests/prompts" +skip = "./tests/prompts,./benchmarks/sonnet.txt" [tool.isort] use_parentheses = true diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/test_prefix_prefill.py index eb706c0d..6494fb34 100644 --- a/tests/kernels/test_prefix_prefill.py +++ b/tests/kernels/test_prefix_prefill.py @@ -36,8 +36,8 @@ def test_contexted_kv_attention( torch.cuda.manual_seed(0) torch.set_default_device(device) - # Need this, otherwise when we capture the graph the process for GPU 1 would - # run on both GPU0 and GPU1 and things would hang + # Need this, otherwise when we capture the graph the process + # for GPU 1 would run on both GPU0 and GPU1 and things would hang # # see also similar issue: https://github.com/Dao-AILab/flash-attention/issues/523 torch.cuda.set_device(device)