[Misc] add the "download-dir" option to the latency/throughput benchmarks (#3621)

This commit is contained in:
AmadeusChan 2024-03-27 16:39:05 -04:00 committed by GitHub
parent e24336b5a7
commit 1956931436
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 32 additions and 19 deletions

View File

@ -16,8 +16,7 @@ def main(args: argparse.Namespace):
# NOTE(woosuk): If the request cannot be processed in a single batch, # NOTE(woosuk): If the request cannot be processed in a single batch,
# the engine will automatically process the request in multiple batches. # the engine will automatically process the request in multiple batches.
llm = LLM( llm = LLM(model=args.model,
model=args.model,
tokenizer=args.tokenizer, tokenizer=args.tokenizer,
quantization=args.quantization, quantization=args.quantization,
tensor_parallel_size=args.tensor_parallel_size, tensor_parallel_size=args.tensor_parallel_size,
@ -27,7 +26,7 @@ def main(args: argparse.Namespace):
kv_cache_dtype=args.kv_cache_dtype, kv_cache_dtype=args.kv_cache_dtype,
device=args.device, device=args.device,
ray_workers_use_nsight=args.ray_workers_use_nsight, ray_workers_use_nsight=args.ray_workers_use_nsight,
) download_dir=args.download_dir)
sampling_params = SamplingParams( sampling_params = SamplingParams(
n=args.n, n=args.n,
@ -151,5 +150,10 @@ if __name__ == '__main__':
action='store_true', action='store_true',
help="If specified, use nsight to profile ray workers", help="If specified, use nsight to profile ray workers",
) )
parser.add_argument('--download-dir',
type=str,
default=None,
help='directory to download and load the weights, '
'default to the default cache dir of huggingface')
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)

View File

@ -75,6 +75,7 @@ def run_vllm(
device: str, device: str,
enable_prefix_caching: bool, enable_prefix_caching: bool,
gpu_memory_utilization: float = 0.9, gpu_memory_utilization: float = 0.9,
download_dir: Optional[str] = None,
) -> float: ) -> float:
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
llm = LLM(model=model, llm = LLM(model=model,
@ -89,7 +90,8 @@ def run_vllm(
enforce_eager=enforce_eager, enforce_eager=enforce_eager,
kv_cache_dtype=kv_cache_dtype, kv_cache_dtype=kv_cache_dtype,
device=device, device=device,
enable_prefix_caching=enable_prefix_caching) enable_prefix_caching=enable_prefix_caching,
download_dir=download_dir)
# Add the requests to the engine. # Add the requests to the engine.
for prompt, _, output_len in requests: for prompt, _, output_len in requests:
@ -208,12 +210,14 @@ def main(args: argparse.Namespace):
args.output_len) args.output_len)
if args.backend == "vllm": if args.backend == "vllm":
elapsed_time = run_vllm( elapsed_time = run_vllm(requests, args.model, args.tokenizer,
requests, args.model, args.tokenizer, args.quantization, args.quantization, args.tensor_parallel_size,
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search, args.seed, args.n, args.use_beam_search,
args.trust_remote_code, args.dtype, args.max_model_len, args.trust_remote_code, args.dtype,
args.enforce_eager, args.kv_cache_dtype, args.device, args.max_model_len, args.enforce_eager,
args.enable_prefix_caching, args.gpu_memory_utilization) args.kv_cache_dtype, args.device,
args.enable_prefix_caching,
args.gpu_memory_utilization, args.download_dir)
elif args.backend == "hf": elif args.backend == "hf":
assert args.tensor_parallel_size == 1 assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n, elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
@ -314,6 +318,11 @@ if __name__ == "__main__":
"--enable-prefix-caching", "--enable-prefix-caching",
action='store_true', action='store_true',
help="enable automatic prefix caching for vLLM backend.") help="enable automatic prefix caching for vLLM backend.")
parser.add_argument('--download-dir',
type=str,
default=None,
help='directory to download and load the weights, '
'default to the default cache dir of huggingface')
args = parser.parse_args() args = parser.parse_args()
if args.tokenizer is None: if args.tokenizer is None:
args.tokenizer = args.model args.tokenizer = args.model