[Frontend] Add FlexibleArgumentParser to support both underscore and dash in names (#5718)

This commit is contained in:
Michael Goin 2024-06-20 19:00:13 -04:00 committed by GitHub
parent 3f3b6b2150
commit 8065a7e220
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 72 additions and 45 deletions

View File

@ -13,6 +13,7 @@ from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import EngineArgs
from vllm.inputs import PromptStrictInputs
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.utils import FlexibleArgumentParser
def main(args: argparse.Namespace):
@ -120,7 +121,7 @@ def main(args: argparse.Namespace):
if __name__ == '__main__':
parser = argparse.ArgumentParser(
parser = FlexibleArgumentParser(
description='Benchmark the latency of processing a single batch of '
'requests till completion.')
parser.add_argument('--model', type=str, default='facebook/opt-125m')

View File

@ -1,7 +1,7 @@
import argparse
import time
from vllm import LLM, SamplingParams
from vllm.utils import FlexibleArgumentParser
PROMPT = "You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as fellows. You need to answer my question about the table.\n# Table\n|Opening|Opening|Sl. No.|Film|Cast|Director|Music Director|Notes|\n|----|----|----|----|----|----|----|----|\n|J A N|9|1|Agni Pushpam|Jayabharathi, Kamalahasan|Jeassy|M. K. Arjunan||\n|J A N|16|2|Priyamvada|Mohan Sharma, Lakshmi, KPAC Lalitha|K. S. Sethumadhavan|V. Dakshinamoorthy||\n|J A N|23|3|Yakshagaanam|Madhu, Sheela|Sheela|M. S. Viswanathan||\n|J A N|30|4|Paalkkadal|Sheela, Sharada|T. K. Prasad|A. T. Ummer||\n|F E B|5|5|Amma|Madhu, Srividya|M. Krishnan Nair|M. K. Arjunan||\n|F E B|13|6|Appooppan|Thikkurissi Sukumaran Nair, Kamal Haasan|P. Bhaskaran|M. S. Baburaj||\n|F E B|20|7|Srishti|Chowalloor Krishnankutty, Ravi Alummoodu|K. T. Muhammad|M. S. Baburaj||\n|F E B|20|8|Vanadevatha|Prem Nazir, Madhubala|Yusufali Kechery|G. Devarajan||\n|F E B|27|9|Samasya|Madhu, Kamalahaasan|K. Thankappan|Shyam||\n|F E B|27|10|Yudhabhoomi|K. P. Ummer, Vidhubala|Crossbelt Mani|R. K. Shekhar||\n|M A R|5|11|Seemantha Puthran|Prem Nazir, Jayabharathi|A. B. Raj|M. K. Arjunan||\n|M A R|12|12|Swapnadanam|Rani Chandra, Dr. Mohandas|K. G. George|Bhaskar Chandavarkar||\n|M A R|19|13|Thulavarsham|Prem Nazir, sreedevi, Sudheer|N. Sankaran Nair|V. Dakshinamoorthy||\n|M A R|20|14|Aruthu|Kaviyoor Ponnamma, Kamalahasan|Ravi|G. Devarajan||\n|M A R|26|15|Swimming Pool|Kamal Haasan, M. G. Soman|J. Sasikumar|M. K. Arjunan||\n\n# Question\nWhat' s the content in the (1,1) cells\n" # noqa: E501
@ -44,7 +44,7 @@ def main(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser(
parser = FlexibleArgumentParser(
description='Benchmark the performance with or without automatic '
'prefix caching.')
parser.add_argument('--model',

View File

@ -44,6 +44,11 @@ try:
except ImportError:
from backend_request_func import get_tokenizer
try:
from vllm.utils import FlexibleArgumentParser
except ImportError:
from argparse import ArgumentParser as FlexibleArgumentParser
@dataclass
class BenchmarkMetrics:
@ -511,7 +516,7 @@ def main(args: argparse.Namespace):
if __name__ == "__main__":
parser = argparse.ArgumentParser(
parser = FlexibleArgumentParser(
description="Benchmark the online serving throughput.")
parser.add_argument(
"--backend",

View File

@ -12,6 +12,7 @@ from transformers import (AutoModelForCausalLM, AutoTokenizer,
from vllm.engine.arg_utils import EngineArgs
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.utils import FlexibleArgumentParser
def sample_requests(
@ -261,7 +262,7 @@ def main(args: argparse.Namespace):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Benchmark the throughput.")
parser = FlexibleArgumentParser(description="Benchmark the throughput.")
parser.add_argument("--backend",
type=str,
choices=["vllm", "hf", "mii"],

View File

@ -11,6 +11,7 @@ from torch.utils.benchmark import Measurement as TMeasurement
from weight_shapes import WEIGHT_SHAPES
from vllm import _custom_ops as ops
from vllm.utils import FlexibleArgumentParser
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())[1:]
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
@ -293,7 +294,7 @@ if __name__ == '__main__':
return torch.float8_e4m3fn
raise ValueError("unsupported dtype")
parser = argparse.ArgumentParser(
parser = FlexibleArgumentParser(
description="""
Benchmark Cutlass GEMM.

View File

@ -1,4 +1,3 @@
import argparse
import os
import sys
from typing import Optional
@ -10,6 +9,7 @@ from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.aqlm import (
dequantize_weight, generic_dequantize_gemm, get_int_dtype,
optimized_dequantize_gemm)
from vllm.utils import FlexibleArgumentParser
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
@ -137,7 +137,7 @@ def dequant_test(k: int, parts: torch.Tensor, nbooks: int, bits: int) -> None:
def main():
parser = argparse.ArgumentParser(description="Benchmark aqlm performance.")
parser = FlexibleArgumentParser(description="Benchmark aqlm performance.")
# Add arguments
parser.add_argument("--nbooks",

View File

@ -1,4 +1,3 @@
import argparse
from typing import List
import torch
@ -16,6 +15,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
MarlinWorkspace, marlin_24_quantize, marlin_quantize)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
gptq_pack, quantize_weights, sort_weights)
from vllm.utils import FlexibleArgumentParser
DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"]
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
@ -211,7 +211,7 @@ def main(args):
# python benchmark_marlin.py --batch-sizes 1 16 32 --limit-k 4096 --limit-n 4096 --limit-group-size 128 --limit-num-bits 4 --limit-act-order 0 --limit-k-full 1 # noqa E501
#
if __name__ == "__main__":
parser = argparse.ArgumentParser(
parser = FlexibleArgumentParser(
description="Benchmark Marlin across specified models/shapes/batches")
parser.add_argument(
"--models",

View File

@ -10,6 +10,7 @@ from ray.experimental.tqdm_ray import tqdm
from transformers import AutoConfig
from vllm.model_executor.layers.fused_moe.fused_moe import *
from vllm.utils import FlexibleArgumentParser
class BenchmarkConfig(TypedDict):
@ -315,7 +316,7 @@ def main(args: argparse.Namespace):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser = FlexibleArgumentParser()
parser.add_argument("--model",
type=str,
default="mistralai/Mixtral-8x7B-Instruct-v0.1")

View File

@ -1,4 +1,3 @@
import argparse
import random
import time
from typing import List, Optional
@ -6,7 +5,8 @@ from typing import List, Optional
import torch
from vllm import _custom_ops as ops
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser,
create_kv_caches_with_random)
NUM_BLOCKS = 1024
PARTITION_SIZE = 512
@ -161,7 +161,7 @@ def main(
if __name__ == '__main__':
parser = argparse.ArgumentParser(
parser = FlexibleArgumentParser(
description="Benchmark the paged attention kernel.")
parser.add_argument("--version",
type=str,

View File

@ -1,4 +1,3 @@
import argparse
from itertools import accumulate
from typing import List, Optional
@ -7,6 +6,7 @@ import torch
from vllm.model_executor.layers.rotary_embedding import (RotaryEmbedding,
get_rope)
from vllm.utils import FlexibleArgumentParser
def benchmark_rope_kernels_multi_lora(
@ -86,7 +86,7 @@ def benchmark_rope_kernels_multi_lora(
if __name__ == '__main__':
parser = argparse.ArgumentParser(
parser = FlexibleArgumentParser(
description="Benchmark the rotary embedding kernels.")
parser.add_argument("--is-neox-style", type=bool, default=True)
parser.add_argument("--batch-size", type=int, default=16)

View File

@ -1,8 +1,8 @@
import argparse
import cProfile
import pstats
from vllm import LLM, SamplingParams
from vllm.utils import FlexibleArgumentParser
# A very long prompt, total number of tokens is about 15k.
LONG_PROMPT = ["You are an expert in large language models, aren't you?"
@ -47,7 +47,7 @@ def main(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser(
parser = FlexibleArgumentParser(
description='Benchmark the performance of hashing function in'
'automatic prefix caching.')
parser.add_argument('--model', type=str, default='lmsys/longchat-7b-16k')

View File

@ -1,11 +1,10 @@
import argparse
from vllm import LLM, SamplingParams
from vllm.utils import FlexibleArgumentParser
def main():
parser = argparse.ArgumentParser(description='AQLM examples')
parser = FlexibleArgumentParser(description='AQLM examples')
parser.add_argument('--model',
'-m',

View File

@ -2,6 +2,7 @@ import argparse
from typing import List, Tuple
from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
from vllm.utils import FlexibleArgumentParser
def create_test_prompts() -> List[Tuple[str, SamplingParams]]:
@ -55,7 +56,7 @@ def main(args: argparse.Namespace):
if __name__ == '__main__':
parser = argparse.ArgumentParser(
parser = FlexibleArgumentParser(
description='Demo on using the LLMEngine class directly')
parser = EngineArgs.add_cli_args(parser)
args = parser.parse_args()

View File

@ -20,15 +20,15 @@ llm = LLM(
tensor_parallel_size=8,
)
"""
import argparse
import dataclasses
import os
import shutil
from pathlib import Path
from vllm import LLM, EngineArgs
from vllm.utils import FlexibleArgumentParser
parser = argparse.ArgumentParser()
parser = FlexibleArgumentParser()
EngineArgs.add_cli_args(parser)
parser.add_argument("--output",
"-o",

View File

@ -9,6 +9,7 @@ from vllm.engine.arg_utils import EngineArgs
from vllm.model_executor.model_loader.tensorizer import (TensorizerArgs,
TensorizerConfig,
tensorize_vllm_model)
from vllm.utils import FlexibleArgumentParser
# yapf conflicts with isort for this docstring
# yapf: disable
@ -96,7 +97,7 @@ deserialization in this example script, although `--tensorizer-uri` and
def parse_args():
parser = argparse.ArgumentParser(
parser = FlexibleArgumentParser(
description="An example script that can be used to serialize and "
"deserialize vLLM models. These models "
"can be loaded using tensorizer directly to the GPU "

View File

@ -1,5 +1,4 @@
"""vllm.entrypoints.api_server with some extra logging for testing."""
import argparse
from typing import Any, Dict
import uvicorn
@ -8,6 +7,7 @@ from fastapi.responses import JSONResponse, Response
import vllm.entrypoints.api_server
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.utils import FlexibleArgumentParser
app = vllm.entrypoints.api_server.app
@ -33,7 +33,7 @@ def stats() -> Response:
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser = FlexibleArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8000)
parser = AsyncEngineArgs.add_cli_args(parser)

View File

@ -11,7 +11,7 @@ from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
SpeculativeConfig, TokenizerPoolConfig,
VisionLanguageConfig)
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.utils import str_to_int_tuple
from vllm.utils import FlexibleArgumentParser, str_to_int_tuple
def nullable_str(val: str):
@ -110,7 +110,7 @@ class EngineArgs:
@staticmethod
def add_cli_args_for_vlm(
parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser.add_argument('--image-input-type',
type=nullable_str,
default=None,
@ -156,8 +156,7 @@ class EngineArgs:
return parser
@staticmethod
def add_cli_args(
parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"""Shared CLI arguments for vLLM engine."""
# Model arguments
@ -800,8 +799,8 @@ class AsyncEngineArgs(EngineArgs):
max_log_len: Optional[int] = None
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser,
async_args_only: bool = False) -> argparse.ArgumentParser:
def add_cli_args(parser: FlexibleArgumentParser,
async_args_only: bool = False) -> FlexibleArgumentParser:
if not async_args_only:
parser = EngineArgs.add_cli_args(parser)
parser.add_argument('--engine-use-ray',
@ -822,13 +821,13 @@ class AsyncEngineArgs(EngineArgs):
# These functions are used by sphinx to build the documentation
def _engine_args_parser():
return EngineArgs.add_cli_args(argparse.ArgumentParser())
return EngineArgs.add_cli_args(FlexibleArgumentParser())
def _async_engine_args_parser():
return AsyncEngineArgs.add_cli_args(argparse.ArgumentParser(),
return AsyncEngineArgs.add_cli_args(FlexibleArgumentParser(),
async_args_only=True)
def _vlm_engine_args_parser():
return EngineArgs.add_cli_args_for_vlm(argparse.ArgumentParser())
return EngineArgs.add_cli_args_for_vlm(FlexibleArgumentParser())

View File

@ -6,7 +6,6 @@ We are also not going to accept PRs modifying this file, please
change `vllm/entrypoints/openai/api_server.py` instead.
"""
import argparse
import json
import ssl
from typing import AsyncGenerator
@ -19,7 +18,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams
from vllm.usage.usage_lib import UsageContext
from vllm.utils import random_uuid
from vllm.utils import FlexibleArgumentParser, random_uuid
TIMEOUT_KEEP_ALIVE = 5 # seconds.
app = FastAPI()
@ -80,7 +79,7 @@ async def generate(request: Request) -> Response:
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser = FlexibleArgumentParser()
parser.add_argument("--host", type=str, default=None)
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--ssl-keyfile", type=str, default=None)

View File

@ -10,6 +10,7 @@ import ssl
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.entrypoints.openai.serving_engine import LoRAModulePath
from vllm.utils import FlexibleArgumentParser
class LoRAParserAction(argparse.Action):
@ -23,7 +24,7 @@ class LoRAParserAction(argparse.Action):
def make_arg_parser():
parser = argparse.ArgumentParser(
parser = FlexibleArgumentParser(
description="vLLM OpenAI-Compatible RESTful API server.")
parser.add_argument("--host",
type=nullable_str,

View File

@ -1,4 +1,3 @@
import argparse
import asyncio
import sys
from io import StringIO
@ -16,14 +15,14 @@ from vllm.entrypoints.openai.protocol import (BatchRequestInput,
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
from vllm.utils import random_uuid
from vllm.utils import FlexibleArgumentParser, random_uuid
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__)
def parse_args():
parser = argparse.ArgumentParser(
parser = FlexibleArgumentParser(
description="vLLM OpenAI-Compatible batch runner.")
parser.add_argument(
"-i",

View File

@ -21,6 +21,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.utils import FlexibleArgumentParser
tensorizer_error_msg = None
@ -177,8 +178,7 @@ class TensorizerArgs:
self.deserializer_params['encryption'] = decryption_params
@staticmethod
def add_cli_args(
parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"""Tensorizer CLI arguments"""
# Tensorizer options arg group

View File

@ -1,3 +1,4 @@
import argparse
import asyncio
import datetime
import enum
@ -775,3 +776,21 @@ def run_once(f):
wrapper.has_run = False # type: ignore[attr-defined]
return wrapper
class FlexibleArgumentParser(argparse.ArgumentParser):
"""ArgumentParser that allows both underscore and dash in names."""
def parse_args(self, args=None, namespace=None):
if args is None:
args = sys.argv[1:]
# Convert underscores to dashes and vice versa in argument names
processed_args = []
for arg in args:
if arg.startswith('--'):
processed_args.append('--' + arg[len('--'):].replace('_', '-'))
else:
processed_args.append(arg)
return super().parse_args(processed_args, namespace)