[Misc] Make benchmarks use EngineArgs (#9529)
This commit is contained in:
parent
23b899a8e6
commit
cb6fdaa0a0
@ -1,5 +1,6 @@
|
|||||||
"""Benchmark the latency of processing a single batch of requests."""
|
"""Benchmark the latency of processing a single batch of requests."""
|
||||||
import argparse
|
import argparse
|
||||||
|
import dataclasses
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -10,43 +11,19 @@ import torch
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.engine.arg_utils import DEVICE_OPTIONS, EngineArgs
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
from vllm.inputs import PromptType
|
from vllm.inputs import PromptType
|
||||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
|
||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
|
|
||||||
def main(args: argparse.Namespace):
|
def main(args: argparse.Namespace):
|
||||||
print(args)
|
print(args)
|
||||||
|
|
||||||
|
engine_args = EngineArgs.from_cli_args(args)
|
||||||
|
|
||||||
# NOTE(woosuk): If the request cannot be processed in a single batch,
|
# NOTE(woosuk): If the request cannot be processed in a single batch,
|
||||||
# the engine will automatically process the request in multiple batches.
|
# the engine will automatically process the request in multiple batches.
|
||||||
llm = LLM(
|
llm = LLM(**dataclasses.asdict(engine_args))
|
||||||
model=args.model,
|
|
||||||
speculative_model=args.speculative_model,
|
|
||||||
num_speculative_tokens=args.num_speculative_tokens,
|
|
||||||
speculative_draft_tensor_parallel_size=\
|
|
||||||
args.speculative_draft_tensor_parallel_size,
|
|
||||||
tokenizer=args.tokenizer,
|
|
||||||
quantization=args.quantization,
|
|
||||||
tensor_parallel_size=args.tensor_parallel_size,
|
|
||||||
trust_remote_code=args.trust_remote_code,
|
|
||||||
dtype=args.dtype,
|
|
||||||
max_model_len=args.max_model_len,
|
|
||||||
enforce_eager=args.enforce_eager,
|
|
||||||
kv_cache_dtype=args.kv_cache_dtype,
|
|
||||||
quantization_param_path=args.quantization_param_path,
|
|
||||||
device=args.device,
|
|
||||||
ray_workers_use_nsight=args.ray_workers_use_nsight,
|
|
||||||
enable_chunked_prefill=args.enable_chunked_prefill,
|
|
||||||
download_dir=args.download_dir,
|
|
||||||
block_size=args.block_size,
|
|
||||||
gpu_memory_utilization=args.gpu_memory_utilization,
|
|
||||||
load_format=args.load_format,
|
|
||||||
distributed_executor_backend=args.distributed_executor_backend,
|
|
||||||
otlp_traces_endpoint=args.otlp_traces_endpoint,
|
|
||||||
enable_prefix_caching=args.enable_prefix_caching,
|
|
||||||
)
|
|
||||||
|
|
||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
n=args.n,
|
n=args.n,
|
||||||
@ -125,19 +102,6 @@ if __name__ == '__main__':
|
|||||||
parser = FlexibleArgumentParser(
|
parser = FlexibleArgumentParser(
|
||||||
description='Benchmark the latency of processing a single batch of '
|
description='Benchmark the latency of processing a single batch of '
|
||||||
'requests till completion.')
|
'requests till completion.')
|
||||||
parser.add_argument('--model', type=str, default='facebook/opt-125m')
|
|
||||||
parser.add_argument('--speculative-model', type=str, default=None)
|
|
||||||
parser.add_argument('--num-speculative-tokens', type=int, default=None)
|
|
||||||
parser.add_argument('--speculative-draft-tensor-parallel-size',
|
|
||||||
'-spec-draft-tp',
|
|
||||||
type=int,
|
|
||||||
default=None)
|
|
||||||
parser.add_argument('--tokenizer', type=str, default=None)
|
|
||||||
parser.add_argument('--quantization',
|
|
||||||
'-q',
|
|
||||||
choices=[*QUANTIZATION_METHODS, None],
|
|
||||||
default=None)
|
|
||||||
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
|
|
||||||
parser.add_argument('--input-len', type=int, default=32)
|
parser.add_argument('--input-len', type=int, default=32)
|
||||||
parser.add_argument('--output-len', type=int, default=128)
|
parser.add_argument('--output-len', type=int, default=128)
|
||||||
parser.add_argument('--batch-size', type=int, default=8)
|
parser.add_argument('--batch-size', type=int, default=8)
|
||||||
@ -154,45 +118,6 @@ if __name__ == '__main__':
|
|||||||
type=int,
|
type=int,
|
||||||
default=30,
|
default=30,
|
||||||
help='Number of iterations to run.')
|
help='Number of iterations to run.')
|
||||||
parser.add_argument('--trust-remote-code',
|
|
||||||
action='store_true',
|
|
||||||
help='trust remote code from huggingface')
|
|
||||||
parser.add_argument(
|
|
||||||
'--max-model-len',
|
|
||||||
type=int,
|
|
||||||
default=None,
|
|
||||||
help='Maximum length of a sequence (including prompt and output). '
|
|
||||||
'If None, will be derived from the model.')
|
|
||||||
parser.add_argument(
|
|
||||||
'--dtype',
|
|
||||||
type=str,
|
|
||||||
default='auto',
|
|
||||||
choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
|
|
||||||
help='data type for model weights and activations. '
|
|
||||||
'The "auto" option will use FP16 precision '
|
|
||||||
'for FP32 and FP16 models, and BF16 precision '
|
|
||||||
'for BF16 models.')
|
|
||||||
parser.add_argument('--enforce-eager',
|
|
||||||
action='store_true',
|
|
||||||
help='enforce eager mode and disable CUDA graph')
|
|
||||||
parser.add_argument(
|
|
||||||
'--kv-cache-dtype',
|
|
||||||
type=str,
|
|
||||||
choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
|
|
||||||
default="auto",
|
|
||||||
help='Data type for kv cache storage. If "auto", will use model '
|
|
||||||
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
|
|
||||||
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
|
|
||||||
parser.add_argument(
|
|
||||||
'--quantization-param-path',
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help='Path to the JSON file containing the KV cache scaling factors. '
|
|
||||||
'This should generally be supplied, when KV cache dtype is FP8. '
|
|
||||||
'Otherwise, KV cache scaling factors default to 1.0, which may cause '
|
|
||||||
'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
|
|
||||||
'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
|
|
||||||
'instead supported for common inference criteria.')
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--profile',
|
'--profile',
|
||||||
action='store_true',
|
action='store_true',
|
||||||
@ -203,78 +128,12 @@ if __name__ == '__main__':
|
|||||||
default=None,
|
default=None,
|
||||||
help=('path to save the pytorch profiler output. Can be visualized '
|
help=('path to save the pytorch profiler output. Can be visualized '
|
||||||
'with ui.perfetto.dev or Tensorboard.'))
|
'with ui.perfetto.dev or Tensorboard.'))
|
||||||
parser.add_argument("--device",
|
|
||||||
type=str,
|
|
||||||
default="auto",
|
|
||||||
choices=DEVICE_OPTIONS,
|
|
||||||
help='device type for vLLM execution')
|
|
||||||
parser.add_argument('--block-size',
|
|
||||||
type=int,
|
|
||||||
default=16,
|
|
||||||
help='block size of key/value cache')
|
|
||||||
parser.add_argument(
|
|
||||||
'--enable-chunked-prefill',
|
|
||||||
action='store_true',
|
|
||||||
help='If True, the prefill requests can be chunked based on the '
|
|
||||||
'max_num_batched_tokens')
|
|
||||||
parser.add_argument("--enable-prefix-caching",
|
|
||||||
action='store_true',
|
|
||||||
help="Enable automatic prefix caching")
|
|
||||||
parser.add_argument(
|
|
||||||
"--ray-workers-use-nsight",
|
|
||||||
action='store_true',
|
|
||||||
help="If specified, use nsight to profile ray workers",
|
|
||||||
)
|
|
||||||
parser.add_argument('--download-dir',
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help='directory to download and load the weights, '
|
|
||||||
'default to the default cache dir of huggingface')
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--output-json',
|
'--output-json',
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help='Path to save the latency results in JSON format.')
|
help='Path to save the latency results in JSON format.')
|
||||||
parser.add_argument('--gpu-memory-utilization',
|
|
||||||
type=float,
|
parser = EngineArgs.add_cli_args(parser)
|
||||||
default=0.9,
|
|
||||||
help='the fraction of GPU memory to be used for '
|
|
||||||
'the model executor, which can range from 0 to 1.'
|
|
||||||
'If unspecified, will use the default value of 0.9.')
|
|
||||||
parser.add_argument(
|
|
||||||
'--load-format',
|
|
||||||
type=str,
|
|
||||||
default=EngineArgs.load_format,
|
|
||||||
choices=[
|
|
||||||
'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer',
|
|
||||||
'bitsandbytes'
|
|
||||||
],
|
|
||||||
help='The format of the model weights to load.\n\n'
|
|
||||||
'* "auto" will try to load the weights in the safetensors format '
|
|
||||||
'and fall back to the pytorch bin format if safetensors format '
|
|
||||||
'is not available.\n'
|
|
||||||
'* "pt" will load the weights in the pytorch bin format.\n'
|
|
||||||
'* "safetensors" will load the weights in the safetensors format.\n'
|
|
||||||
'* "npcache" will load the weights in pytorch format and store '
|
|
||||||
'a numpy cache to speed up the loading.\n'
|
|
||||||
'* "dummy" will initialize the weights with random values, '
|
|
||||||
'which is mainly for profiling.\n'
|
|
||||||
'* "tensorizer" will load the weights using tensorizer from '
|
|
||||||
'CoreWeave. See the Tensorize vLLM Model script in the Examples'
|
|
||||||
'section for more information.\n'
|
|
||||||
'* "bitsandbytes" will load the weights using bitsandbytes '
|
|
||||||
'quantization.\n')
|
|
||||||
parser.add_argument(
|
|
||||||
'--distributed-executor-backend',
|
|
||||||
choices=['ray', 'mp'],
|
|
||||||
default=None,
|
|
||||||
help='Backend to use for distributed serving. When more than 1 GPU '
|
|
||||||
'is used, will be automatically set to "ray" if installed '
|
|
||||||
'or "mp" (multiprocessing) otherwise.')
|
|
||||||
parser.add_argument(
|
|
||||||
'--otlp-traces-endpoint',
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help='Target URL to which OpenTelemetry traces will be sent.')
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main(args)
|
main(args)
|
||||||
|
@ -25,6 +25,7 @@ ShareGPT example usage:
|
|||||||
--input-length-range 128:256
|
--input-length-range 128:256
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
import json
|
import json
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
@ -33,6 +34,7 @@ from typing import List, Optional, Tuple
|
|||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -129,12 +131,9 @@ def main(args):
|
|||||||
filtered_datasets = [(PROMPT, prompt_len, args.output_len)
|
filtered_datasets = [(PROMPT, prompt_len, args.output_len)
|
||||||
] * args.num_prompts
|
] * args.num_prompts
|
||||||
|
|
||||||
llm = LLM(model=args.model,
|
engine_args = EngineArgs.from_cli_args(args)
|
||||||
tokenizer_mode='auto',
|
|
||||||
trust_remote_code=True,
|
llm = LLM(**dataclasses.asdict(engine_args))
|
||||||
enforce_eager=True,
|
|
||||||
tensor_parallel_size=args.tensor_parallel_size,
|
|
||||||
enable_prefix_caching=args.enable_prefix_caching)
|
|
||||||
|
|
||||||
sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len)
|
sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len)
|
||||||
|
|
||||||
@ -162,18 +161,11 @@ if __name__ == "__main__":
|
|||||||
parser = FlexibleArgumentParser(
|
parser = FlexibleArgumentParser(
|
||||||
description=
|
description=
|
||||||
'Benchmark the performance with or without automatic prefix caching.')
|
'Benchmark the performance with or without automatic prefix caching.')
|
||||||
parser.add_argument('--model',
|
|
||||||
type=str,
|
|
||||||
default='baichuan-inc/Baichuan2-13B-Chat')
|
|
||||||
parser.add_argument("--dataset-path",
|
parser.add_argument("--dataset-path",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help="Path to the dataset.")
|
help="Path to the dataset.")
|
||||||
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
|
|
||||||
parser.add_argument('--output-len', type=int, default=10)
|
parser.add_argument('--output-len', type=int, default=10)
|
||||||
parser.add_argument('--enable-prefix-caching',
|
|
||||||
action='store_true',
|
|
||||||
help='enable prefix caching')
|
|
||||||
parser.add_argument('--num-prompts',
|
parser.add_argument('--num-prompts',
|
||||||
type=int,
|
type=int,
|
||||||
default=1,
|
default=1,
|
||||||
@ -190,9 +182,7 @@ if __name__ == "__main__":
|
|||||||
default='128:256',
|
default='128:256',
|
||||||
help='Range of input lengths for sampling prompts,'
|
help='Range of input lengths for sampling prompts,'
|
||||||
'specified as "min:max" (e.g., "128:256").')
|
'specified as "min:max" (e.g., "128:256").')
|
||||||
parser.add_argument("--seed",
|
|
||||||
type=int,
|
parser = EngineArgs.add_cli_args(parser)
|
||||||
default=0,
|
|
||||||
help='Random seed for reproducibility')
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main(args)
|
main(args)
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
"""Benchmark offline prioritization."""
|
"""Benchmark offline prioritization."""
|
||||||
import argparse
|
import argparse
|
||||||
|
import dataclasses
|
||||||
import json
|
import json
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
@ -7,7 +8,8 @@ from typing import List, Optional, Tuple
|
|||||||
|
|
||||||
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
||||||
|
|
||||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
|
|
||||||
def sample_requests(
|
def sample_requests(
|
||||||
@ -62,46 +64,11 @@ def sample_requests(
|
|||||||
|
|
||||||
def run_vllm(
|
def run_vllm(
|
||||||
requests: List[Tuple[str, int, int]],
|
requests: List[Tuple[str, int, int]],
|
||||||
model: str,
|
|
||||||
tokenizer: str,
|
|
||||||
quantization: Optional[str],
|
|
||||||
tensor_parallel_size: int,
|
|
||||||
seed: int,
|
|
||||||
n: int,
|
n: int,
|
||||||
trust_remote_code: bool,
|
engine_args: EngineArgs,
|
||||||
dtype: str,
|
|
||||||
max_model_len: Optional[int],
|
|
||||||
enforce_eager: bool,
|
|
||||||
kv_cache_dtype: str,
|
|
||||||
quantization_param_path: Optional[str],
|
|
||||||
device: str,
|
|
||||||
enable_prefix_caching: bool,
|
|
||||||
enable_chunked_prefill: bool,
|
|
||||||
max_num_batched_tokens: int,
|
|
||||||
gpu_memory_utilization: float = 0.9,
|
|
||||||
download_dir: Optional[str] = None,
|
|
||||||
) -> float:
|
) -> float:
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
llm = LLM(
|
llm = LLM(**dataclasses.asdict(engine_args))
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
quantization=quantization,
|
|
||||||
tensor_parallel_size=tensor_parallel_size,
|
|
||||||
seed=seed,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
dtype=dtype,
|
|
||||||
max_model_len=max_model_len,
|
|
||||||
gpu_memory_utilization=gpu_memory_utilization,
|
|
||||||
enforce_eager=enforce_eager,
|
|
||||||
kv_cache_dtype=kv_cache_dtype,
|
|
||||||
quantization_param_path=quantization_param_path,
|
|
||||||
device=device,
|
|
||||||
enable_prefix_caching=enable_prefix_caching,
|
|
||||||
download_dir=download_dir,
|
|
||||||
enable_chunked_prefill=enable_chunked_prefill,
|
|
||||||
max_num_batched_tokens=max_num_batched_tokens,
|
|
||||||
disable_log_stats=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add the requests to the engine.
|
# Add the requests to the engine.
|
||||||
prompts = []
|
prompts = []
|
||||||
@ -142,16 +109,8 @@ def main(args: argparse.Namespace):
|
|||||||
args.output_len)
|
args.output_len)
|
||||||
|
|
||||||
if args.backend == "vllm":
|
if args.backend == "vllm":
|
||||||
elapsed_time = run_vllm(requests, args.model, args.tokenizer,
|
elapsed_time = run_vllm(requests, args.n,
|
||||||
args.quantization, args.tensor_parallel_size,
|
EngineArgs.from_cli_args(args))
|
||||||
args.seed, args.n, args.trust_remote_code,
|
|
||||||
args.dtype, args.max_model_len,
|
|
||||||
args.enforce_eager, args.kv_cache_dtype,
|
|
||||||
args.quantization_param_path, args.device,
|
|
||||||
args.enable_prefix_caching,
|
|
||||||
args.enable_chunked_prefill,
|
|
||||||
args.max_num_batched_tokens,
|
|
||||||
args.gpu_memory_utilization, args.download_dir)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown backend: {args.backend}")
|
raise ValueError(f"Unknown backend: {args.backend}")
|
||||||
total_num_tokens = sum(prompt_len + output_len
|
total_num_tokens = sum(prompt_len + output_len
|
||||||
@ -173,7 +132,7 @@ def main(args: argparse.Namespace):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Benchmark the throughput.")
|
parser = FlexibleArgumentParser(description="Benchmark the throughput.")
|
||||||
parser.add_argument("--backend",
|
parser.add_argument("--backend",
|
||||||
type=str,
|
type=str,
|
||||||
choices=["vllm", "hf", "mii"],
|
choices=["vllm", "hf", "mii"],
|
||||||
@ -191,13 +150,6 @@ if __name__ == "__main__":
|
|||||||
default=None,
|
default=None,
|
||||||
help="Output length for each request. Overrides the "
|
help="Output length for each request. Overrides the "
|
||||||
"output length from the dataset.")
|
"output length from the dataset.")
|
||||||
parser.add_argument("--model", type=str, default="facebook/opt-125m")
|
|
||||||
parser.add_argument("--tokenizer", type=str, default=None)
|
|
||||||
parser.add_argument('--quantization',
|
|
||||||
'-q',
|
|
||||||
choices=[*QUANTIZATION_METHODS, None],
|
|
||||||
default=None)
|
|
||||||
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
|
|
||||||
parser.add_argument("--n",
|
parser.add_argument("--n",
|
||||||
type=int,
|
type=int,
|
||||||
default=1,
|
default=1,
|
||||||
@ -206,81 +158,13 @@ if __name__ == "__main__":
|
|||||||
type=int,
|
type=int,
|
||||||
default=200,
|
default=200,
|
||||||
help="Number of prompts to process.")
|
help="Number of prompts to process.")
|
||||||
parser.add_argument("--seed", type=int, default=0)
|
|
||||||
parser.add_argument('--trust-remote-code',
|
|
||||||
action='store_true',
|
|
||||||
help='trust remote code from huggingface')
|
|
||||||
parser.add_argument(
|
|
||||||
'--max-model-len',
|
|
||||||
type=int,
|
|
||||||
default=None,
|
|
||||||
help='Maximum length of a sequence (including prompt and output). '
|
|
||||||
'If None, will be derived from the model.')
|
|
||||||
parser.add_argument(
|
|
||||||
'--dtype',
|
|
||||||
type=str,
|
|
||||||
default='auto',
|
|
||||||
choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
|
|
||||||
help='data type for model weights and activations. '
|
|
||||||
'The "auto" option will use FP16 precision '
|
|
||||||
'for FP32 and FP16 models, and BF16 precision '
|
|
||||||
'for BF16 models.')
|
|
||||||
parser.add_argument('--gpu-memory-utilization',
|
|
||||||
type=float,
|
|
||||||
default=0.9,
|
|
||||||
help='the fraction of GPU memory to be used for '
|
|
||||||
'the model executor, which can range from 0 to 1.'
|
|
||||||
'If unspecified, will use the default value of 0.9.')
|
|
||||||
parser.add_argument("--enforce-eager",
|
|
||||||
action="store_true",
|
|
||||||
help="enforce eager execution")
|
|
||||||
parser.add_argument(
|
|
||||||
'--kv-cache-dtype',
|
|
||||||
type=str,
|
|
||||||
choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
|
|
||||||
default="auto",
|
|
||||||
help='Data type for kv cache storage. If "auto", will use model '
|
|
||||||
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
|
|
||||||
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
|
|
||||||
parser.add_argument(
|
|
||||||
'--quantization-param-path',
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help='Path to the JSON file containing the KV cache scaling factors. '
|
|
||||||
'This should generally be supplied, when KV cache dtype is FP8. '
|
|
||||||
'Otherwise, KV cache scaling factors default to 1.0, which may cause '
|
|
||||||
'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
|
|
||||||
'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
|
|
||||||
'instead supported for common inference criteria.')
|
|
||||||
parser.add_argument(
|
|
||||||
"--device",
|
|
||||||
type=str,
|
|
||||||
default="cuda",
|
|
||||||
choices=["cuda", "cpu"],
|
|
||||||
help='device type for vLLM execution, supporting CUDA and CPU.')
|
|
||||||
parser.add_argument(
|
|
||||||
"--enable-prefix-caching",
|
|
||||||
action='store_true',
|
|
||||||
help="enable automatic prefix caching for vLLM backend.")
|
|
||||||
parser.add_argument("--enable-chunked-prefill",
|
|
||||||
action='store_true',
|
|
||||||
help="enable chunked prefill for vLLM backend.")
|
|
||||||
parser.add_argument('--max-num-batched-tokens',
|
|
||||||
type=int,
|
|
||||||
default=None,
|
|
||||||
help='maximum number of batched tokens per '
|
|
||||||
'iteration')
|
|
||||||
parser.add_argument('--download-dir',
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help='directory to download and load the weights, '
|
|
||||||
'default to the default cache dir of huggingface')
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--output-json',
|
'--output-json',
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help='Path to save the throughput results in JSON format.')
|
help='Path to save the throughput results in JSON format.')
|
||||||
|
|
||||||
|
parser = EngineArgs.add_cli_args(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
if args.tokenizer is None:
|
if args.tokenizer is None:
|
||||||
args.tokenizer = args.model
|
args.tokenizer = args.model
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
"""Benchmark offline inference throughput."""
|
"""Benchmark offline inference throughput."""
|
||||||
import argparse
|
import argparse
|
||||||
|
import dataclasses
|
||||||
import json
|
import json
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
@ -11,10 +12,9 @@ from tqdm import tqdm
|
|||||||
from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
||||||
PreTrainedTokenizerBase)
|
PreTrainedTokenizerBase)
|
||||||
|
|
||||||
from vllm.engine.arg_utils import DEVICE_OPTIONS, AsyncEngineArgs, EngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
||||||
from vllm.entrypoints.openai.api_server import (
|
from vllm.entrypoints.openai.api_server import (
|
||||||
build_async_engine_client_from_engine_args)
|
build_async_engine_client_from_engine_args)
|
||||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
|
||||||
from vllm.sampling_params import BeamSearchParams
|
from vllm.sampling_params import BeamSearchParams
|
||||||
from vllm.utils import FlexibleArgumentParser, merge_async_iterators
|
from vllm.utils import FlexibleArgumentParser, merge_async_iterators
|
||||||
|
|
||||||
@ -67,53 +67,11 @@ def sample_requests(
|
|||||||
|
|
||||||
def run_vllm(
|
def run_vllm(
|
||||||
requests: List[Tuple[str, int, int]],
|
requests: List[Tuple[str, int, int]],
|
||||||
model: str,
|
|
||||||
tokenizer: str,
|
|
||||||
quantization: Optional[str],
|
|
||||||
tensor_parallel_size: int,
|
|
||||||
seed: int,
|
|
||||||
n: int,
|
n: int,
|
||||||
trust_remote_code: bool,
|
engine_args: EngineArgs,
|
||||||
dtype: str,
|
|
||||||
max_model_len: Optional[int],
|
|
||||||
enforce_eager: bool,
|
|
||||||
kv_cache_dtype: str,
|
|
||||||
quantization_param_path: Optional[str],
|
|
||||||
device: str,
|
|
||||||
enable_prefix_caching: bool,
|
|
||||||
enable_chunked_prefill: bool,
|
|
||||||
max_num_batched_tokens: int,
|
|
||||||
distributed_executor_backend: Optional[str],
|
|
||||||
gpu_memory_utilization: float = 0.9,
|
|
||||||
num_scheduler_steps: int = 1,
|
|
||||||
download_dir: Optional[str] = None,
|
|
||||||
load_format: str = EngineArgs.load_format,
|
|
||||||
disable_async_output_proc: bool = False,
|
|
||||||
) -> float:
|
) -> float:
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
llm = LLM(
|
llm = LLM(**dataclasses.asdict(engine_args))
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
quantization=quantization,
|
|
||||||
tensor_parallel_size=tensor_parallel_size,
|
|
||||||
seed=seed,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
dtype=dtype,
|
|
||||||
max_model_len=max_model_len,
|
|
||||||
gpu_memory_utilization=gpu_memory_utilization,
|
|
||||||
enforce_eager=enforce_eager,
|
|
||||||
kv_cache_dtype=kv_cache_dtype,
|
|
||||||
quantization_param_path=quantization_param_path,
|
|
||||||
device=device,
|
|
||||||
enable_prefix_caching=enable_prefix_caching,
|
|
||||||
download_dir=download_dir,
|
|
||||||
enable_chunked_prefill=enable_chunked_prefill,
|
|
||||||
max_num_batched_tokens=max_num_batched_tokens,
|
|
||||||
distributed_executor_backend=distributed_executor_backend,
|
|
||||||
load_format=load_format,
|
|
||||||
num_scheduler_steps=num_scheduler_steps,
|
|
||||||
disable_async_output_proc=disable_async_output_proc,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add the requests to the engine.
|
# Add the requests to the engine.
|
||||||
prompts: List[str] = []
|
prompts: List[str] = []
|
||||||
@ -155,56 +113,11 @@ def run_vllm(
|
|||||||
|
|
||||||
async def run_vllm_async(
|
async def run_vllm_async(
|
||||||
requests: List[Tuple[str, int, int]],
|
requests: List[Tuple[str, int, int]],
|
||||||
model: str,
|
|
||||||
tokenizer: str,
|
|
||||||
quantization: Optional[str],
|
|
||||||
tensor_parallel_size: int,
|
|
||||||
seed: int,
|
|
||||||
n: int,
|
n: int,
|
||||||
trust_remote_code: bool,
|
engine_args: AsyncEngineArgs,
|
||||||
dtype: str,
|
|
||||||
max_model_len: Optional[int],
|
|
||||||
enforce_eager: bool,
|
|
||||||
kv_cache_dtype: str,
|
|
||||||
quantization_param_path: Optional[str],
|
|
||||||
device: str,
|
|
||||||
enable_prefix_caching: bool,
|
|
||||||
enable_chunked_prefill: bool,
|
|
||||||
max_num_batched_tokens: int,
|
|
||||||
distributed_executor_backend: Optional[str],
|
|
||||||
gpu_memory_utilization: float = 0.9,
|
|
||||||
num_scheduler_steps: int = 1,
|
|
||||||
download_dir: Optional[str] = None,
|
|
||||||
load_format: str = EngineArgs.load_format,
|
|
||||||
disable_async_output_proc: bool = False,
|
|
||||||
disable_frontend_multiprocessing: bool = False,
|
disable_frontend_multiprocessing: bool = False,
|
||||||
) -> float:
|
) -> float:
|
||||||
from vllm import SamplingParams
|
from vllm import SamplingParams
|
||||||
engine_args = AsyncEngineArgs(
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
quantization=quantization,
|
|
||||||
tensor_parallel_size=tensor_parallel_size,
|
|
||||||
seed=seed,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
dtype=dtype,
|
|
||||||
max_model_len=max_model_len,
|
|
||||||
gpu_memory_utilization=gpu_memory_utilization,
|
|
||||||
enforce_eager=enforce_eager,
|
|
||||||
kv_cache_dtype=kv_cache_dtype,
|
|
||||||
quantization_param_path=quantization_param_path,
|
|
||||||
device=device,
|
|
||||||
enable_prefix_caching=enable_prefix_caching,
|
|
||||||
download_dir=download_dir,
|
|
||||||
enable_chunked_prefill=enable_chunked_prefill,
|
|
||||||
max_num_batched_tokens=max_num_batched_tokens,
|
|
||||||
distributed_executor_backend=distributed_executor_backend,
|
|
||||||
load_format=load_format,
|
|
||||||
num_scheduler_steps=num_scheduler_steps,
|
|
||||||
disable_async_output_proc=disable_async_output_proc,
|
|
||||||
worker_use_ray=False,
|
|
||||||
disable_log_requests=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
async with build_async_engine_client_from_engine_args(
|
async with build_async_engine_client_from_engine_args(
|
||||||
engine_args, disable_frontend_multiprocessing) as llm:
|
engine_args, disable_frontend_multiprocessing) as llm:
|
||||||
@ -328,23 +241,17 @@ def main(args: argparse.Namespace):
|
|||||||
args.output_len)
|
args.output_len)
|
||||||
|
|
||||||
if args.backend == "vllm":
|
if args.backend == "vllm":
|
||||||
run_args = [
|
|
||||||
requests, args.model, args.tokenizer, args.quantization,
|
|
||||||
args.tensor_parallel_size, args.seed, args.n,
|
|
||||||
args.trust_remote_code, args.dtype, args.max_model_len,
|
|
||||||
args.enforce_eager, args.kv_cache_dtype,
|
|
||||||
args.quantization_param_path, args.device,
|
|
||||||
args.enable_prefix_caching, args.enable_chunked_prefill,
|
|
||||||
args.max_num_batched_tokens, args.distributed_executor_backend,
|
|
||||||
args.gpu_memory_utilization, args.num_scheduler_steps,
|
|
||||||
args.download_dir, args.load_format, args.disable_async_output_proc
|
|
||||||
]
|
|
||||||
|
|
||||||
if args.async_engine:
|
if args.async_engine:
|
||||||
run_args.append(args.disable_frontend_multiprocessing)
|
elapsed_time = uvloop.run(
|
||||||
elapsed_time = uvloop.run(run_vllm_async(*run_args))
|
run_vllm_async(
|
||||||
|
requests,
|
||||||
|
args.n,
|
||||||
|
AsyncEngineArgs.from_cli_args(args),
|
||||||
|
args.disable_frontend_multiprocessing,
|
||||||
|
))
|
||||||
else:
|
else:
|
||||||
elapsed_time = run_vllm(*run_args)
|
elapsed_time = run_vllm(requests, args.n,
|
||||||
|
EngineArgs.from_cli_args(args))
|
||||||
elif args.backend == "hf":
|
elif args.backend == "hf":
|
||||||
assert args.tensor_parallel_size == 1
|
assert args.tensor_parallel_size == 1
|
||||||
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
||||||
@ -391,13 +298,6 @@ if __name__ == "__main__":
|
|||||||
default=None,
|
default=None,
|
||||||
help="Output length for each request. Overrides the "
|
help="Output length for each request. Overrides the "
|
||||||
"output length from the dataset.")
|
"output length from the dataset.")
|
||||||
parser.add_argument("--model", type=str, default="facebook/opt-125m")
|
|
||||||
parser.add_argument("--tokenizer", type=str, default=None)
|
|
||||||
parser.add_argument('--quantization',
|
|
||||||
'-q',
|
|
||||||
choices=[*QUANTIZATION_METHODS, None],
|
|
||||||
default=None)
|
|
||||||
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
|
|
||||||
parser.add_argument("--n",
|
parser.add_argument("--n",
|
||||||
type=int,
|
type=int,
|
||||||
default=1,
|
default=1,
|
||||||
@ -406,123 +306,15 @@ if __name__ == "__main__":
|
|||||||
type=int,
|
type=int,
|
||||||
default=1000,
|
default=1000,
|
||||||
help="Number of prompts to process.")
|
help="Number of prompts to process.")
|
||||||
parser.add_argument("--seed", type=int, default=0)
|
|
||||||
parser.add_argument("--hf-max-batch-size",
|
parser.add_argument("--hf-max-batch-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=None,
|
||||||
help="Maximum batch size for HF backend.")
|
help="Maximum batch size for HF backend.")
|
||||||
parser.add_argument('--trust-remote-code',
|
|
||||||
action='store_true',
|
|
||||||
help='trust remote code from huggingface')
|
|
||||||
parser.add_argument(
|
|
||||||
'--max-model-len',
|
|
||||||
type=int,
|
|
||||||
default=None,
|
|
||||||
help='Maximum length of a sequence (including prompt and output). '
|
|
||||||
'If None, will be derived from the model.')
|
|
||||||
parser.add_argument(
|
|
||||||
'--dtype',
|
|
||||||
type=str,
|
|
||||||
default='auto',
|
|
||||||
choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
|
|
||||||
help='data type for model weights and activations. '
|
|
||||||
'The "auto" option will use FP16 precision '
|
|
||||||
'for FP32 and FP16 models, and BF16 precision '
|
|
||||||
'for BF16 models.')
|
|
||||||
parser.add_argument('--gpu-memory-utilization',
|
|
||||||
type=float,
|
|
||||||
default=0.9,
|
|
||||||
help='the fraction of GPU memory to be used for '
|
|
||||||
'the model executor, which can range from 0 to 1.'
|
|
||||||
'If unspecified, will use the default value of 0.9.')
|
|
||||||
parser.add_argument("--enforce-eager",
|
|
||||||
action="store_true",
|
|
||||||
help="enforce eager execution")
|
|
||||||
parser.add_argument(
|
|
||||||
'--kv-cache-dtype',
|
|
||||||
type=str,
|
|
||||||
choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
|
|
||||||
default="auto",
|
|
||||||
help='Data type for kv cache storage. If "auto", will use model '
|
|
||||||
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
|
|
||||||
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
|
|
||||||
parser.add_argument(
|
|
||||||
'--quantization-param-path',
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help='Path to the JSON file containing the KV cache scaling factors. '
|
|
||||||
'This should generally be supplied, when KV cache dtype is FP8. '
|
|
||||||
'Otherwise, KV cache scaling factors default to 1.0, which may cause '
|
|
||||||
'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
|
|
||||||
'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
|
|
||||||
'instead supported for common inference criteria.')
|
|
||||||
parser.add_argument("--device",
|
|
||||||
type=str,
|
|
||||||
default="auto",
|
|
||||||
choices=DEVICE_OPTIONS,
|
|
||||||
help='device type for vLLM execution')
|
|
||||||
parser.add_argument(
|
|
||||||
"--num-scheduler-steps",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="Maximum number of forward steps per scheduler call.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--enable-prefix-caching",
|
|
||||||
action='store_true',
|
|
||||||
help="Enable automatic prefix caching for vLLM backend.")
|
|
||||||
parser.add_argument("--enable-chunked-prefill",
|
|
||||||
action='store_true',
|
|
||||||
help="enable chunked prefill for vLLM backend.")
|
|
||||||
parser.add_argument('--max-num-batched-tokens',
|
|
||||||
type=int,
|
|
||||||
default=None,
|
|
||||||
help='maximum number of batched tokens per '
|
|
||||||
'iteration')
|
|
||||||
parser.add_argument('--download-dir',
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help='directory to download and load the weights, '
|
|
||||||
'default to the default cache dir of huggingface')
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--output-json',
|
'--output-json',
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help='Path to save the throughput results in JSON format.')
|
help='Path to save the throughput results in JSON format.')
|
||||||
parser.add_argument(
|
|
||||||
'--distributed-executor-backend',
|
|
||||||
choices=['ray', 'mp'],
|
|
||||||
default=None,
|
|
||||||
help='Backend to use for distributed serving. When more than 1 GPU '
|
|
||||||
'is used, will be automatically set to "ray" if installed '
|
|
||||||
'or "mp" (multiprocessing) otherwise.')
|
|
||||||
parser.add_argument(
|
|
||||||
'--load-format',
|
|
||||||
type=str,
|
|
||||||
default=EngineArgs.load_format,
|
|
||||||
choices=[
|
|
||||||
'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer',
|
|
||||||
'bitsandbytes'
|
|
||||||
],
|
|
||||||
help='The format of the model weights to load.\n\n'
|
|
||||||
'* "auto" will try to load the weights in the safetensors format '
|
|
||||||
'and fall back to the pytorch bin format if safetensors format '
|
|
||||||
'is not available.\n'
|
|
||||||
'* "pt" will load the weights in the pytorch bin format.\n'
|
|
||||||
'* "safetensors" will load the weights in the safetensors format.\n'
|
|
||||||
'* "npcache" will load the weights in pytorch format and store '
|
|
||||||
'a numpy cache to speed up the loading.\n'
|
|
||||||
'* "dummy" will initialize the weights with random values, '
|
|
||||||
'which is mainly for profiling.\n'
|
|
||||||
'* "tensorizer" will load the weights using tensorizer from '
|
|
||||||
'CoreWeave. See the Tensorize vLLM Model script in the Examples'
|
|
||||||
'section for more information.\n'
|
|
||||||
'* "bitsandbytes" will load the weights using bitsandbytes '
|
|
||||||
'quantization.\n')
|
|
||||||
parser.add_argument(
|
|
||||||
"--disable-async-output-proc",
|
|
||||||
action='store_true',
|
|
||||||
default=False,
|
|
||||||
help="Disable async output processor for vLLM backend.")
|
|
||||||
parser.add_argument("--async-engine",
|
parser.add_argument("--async-engine",
|
||||||
action='store_true',
|
action='store_true',
|
||||||
default=False,
|
default=False,
|
||||||
@ -531,6 +323,7 @@ if __name__ == "__main__":
|
|||||||
action='store_true',
|
action='store_true',
|
||||||
default=False,
|
default=False,
|
||||||
help="Disable decoupled async engine frontend.")
|
help="Disable decoupled async engine frontend.")
|
||||||
|
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
if args.tokenizer is None:
|
if args.tokenizer is None:
|
||||||
args.tokenizer = args.model
|
args.tokenizer = args.model
|
||||||
|
Loading…
x
Reference in New Issue
Block a user