[Frontend] Add FlexibleArgumentParser to support both underscore and dash in names (#5718)
This commit is contained in:
parent
3f3b6b2150
commit
8065a7e220
@ -13,6 +13,7 @@ from vllm import LLM, SamplingParams
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.inputs import PromptStrictInputs
|
||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
@ -120,7 +121,7 @@ def main(args: argparse.Namespace):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(
|
||||
parser = FlexibleArgumentParser(
|
||||
description='Benchmark the latency of processing a single batch of '
|
||||
'requests till completion.')
|
||||
parser.add_argument('--model', type=str, default='facebook/opt-125m')
|
||||
|
@ -1,7 +1,7 @@
|
||||
import argparse
|
||||
import time
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
PROMPT = "You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as fellows. You need to answer my question about the table.\n# Table\n|Opening|Opening|Sl. No.|Film|Cast|Director|Music Director|Notes|\n|----|----|----|----|----|----|----|----|\n|J A N|9|1|Agni Pushpam|Jayabharathi, Kamalahasan|Jeassy|M. K. Arjunan||\n|J A N|16|2|Priyamvada|Mohan Sharma, Lakshmi, KPAC Lalitha|K. S. Sethumadhavan|V. Dakshinamoorthy||\n|J A N|23|3|Yakshagaanam|Madhu, Sheela|Sheela|M. S. Viswanathan||\n|J A N|30|4|Paalkkadal|Sheela, Sharada|T. K. Prasad|A. T. Ummer||\n|F E B|5|5|Amma|Madhu, Srividya|M. Krishnan Nair|M. K. Arjunan||\n|F E B|13|6|Appooppan|Thikkurissi Sukumaran Nair, Kamal Haasan|P. Bhaskaran|M. S. Baburaj||\n|F E B|20|7|Srishti|Chowalloor Krishnankutty, Ravi Alummoodu|K. T. Muhammad|M. S. Baburaj||\n|F E B|20|8|Vanadevatha|Prem Nazir, Madhubala|Yusufali Kechery|G. Devarajan||\n|F E B|27|9|Samasya|Madhu, Kamalahaasan|K. Thankappan|Shyam||\n|F E B|27|10|Yudhabhoomi|K. P. Ummer, Vidhubala|Crossbelt Mani|R. K. Shekhar||\n|M A R|5|11|Seemantha Puthran|Prem Nazir, Jayabharathi|A. B. Raj|M. K. Arjunan||\n|M A R|12|12|Swapnadanam|Rani Chandra, Dr. Mohandas|K. G. George|Bhaskar Chandavarkar||\n|M A R|19|13|Thulavarsham|Prem Nazir, sreedevi, Sudheer|N. Sankaran Nair|V. Dakshinamoorthy||\n|M A R|20|14|Aruthu|Kaviyoor Ponnamma, Kamalahasan|Ravi|G. Devarajan||\n|M A R|26|15|Swimming Pool|Kamal Haasan, M. G. Soman|J. Sasikumar|M. K. Arjunan||\n\n# Question\nWhat' s the content in the (1,1) cells\n" # noqa: E501
|
||||
|
||||
@ -44,7 +44,7 @@ def main(args):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
parser = FlexibleArgumentParser(
|
||||
description='Benchmark the performance with or without automatic '
|
||||
'prefix caching.')
|
||||
parser.add_argument('--model',
|
||||
|
@ -44,6 +44,11 @@ try:
|
||||
except ImportError:
|
||||
from backend_request_func import get_tokenizer
|
||||
|
||||
try:
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
except ImportError:
|
||||
from argparse import ArgumentParser as FlexibleArgumentParser
|
||||
|
||||
|
||||
@dataclass
|
||||
class BenchmarkMetrics:
|
||||
@ -511,7 +516,7 @@ def main(args: argparse.Namespace):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Benchmark the online serving throughput.")
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
|
@ -12,6 +12,7 @@ from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
||||
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
def sample_requests(
|
||||
@ -261,7 +262,7 @@ def main(args: argparse.Namespace):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Benchmark the throughput.")
|
||||
parser = FlexibleArgumentParser(description="Benchmark the throughput.")
|
||||
parser.add_argument("--backend",
|
||||
type=str,
|
||||
choices=["vllm", "hf", "mii"],
|
||||
|
@ -11,6 +11,7 @@ from torch.utils.benchmark import Measurement as TMeasurement
|
||||
from weight_shapes import WEIGHT_SHAPES
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())[1:]
|
||||
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
|
||||
@ -293,7 +294,7 @@ if __name__ == '__main__':
|
||||
return torch.float8_e4m3fn
|
||||
raise ValueError("unsupported dtype")
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
parser = FlexibleArgumentParser(
|
||||
description="""
|
||||
Benchmark Cutlass GEMM.
|
||||
|
||||
|
@ -1,4 +1,3 @@
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from typing import Optional
|
||||
@ -10,6 +9,7 @@ from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.aqlm import (
|
||||
dequantize_weight, generic_dequantize_gemm, get_int_dtype,
|
||||
optimized_dequantize_gemm)
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
||||
|
||||
@ -137,7 +137,7 @@ def dequant_test(k: int, parts: torch.Tensor, nbooks: int, bits: int) -> None:
|
||||
|
||||
def main():
|
||||
|
||||
parser = argparse.ArgumentParser(description="Benchmark aqlm performance.")
|
||||
parser = FlexibleArgumentParser(description="Benchmark aqlm performance.")
|
||||
|
||||
# Add arguments
|
||||
parser.add_argument("--nbooks",
|
||||
|
@ -1,4 +1,3 @@
|
||||
import argparse
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
@ -16,6 +15,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
MarlinWorkspace, marlin_24_quantize, marlin_quantize)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
gptq_pack, quantize_weights, sort_weights)
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"]
|
||||
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
|
||||
@ -211,7 +211,7 @@ def main(args):
|
||||
# python benchmark_marlin.py --batch-sizes 1 16 32 --limit-k 4096 --limit-n 4096 --limit-group-size 128 --limit-num-bits 4 --limit-act-order 0 --limit-k-full 1 # noqa E501
|
||||
#
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Benchmark Marlin across specified models/shapes/batches")
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
|
@ -10,6 +10,7 @@ from ray.experimental.tqdm_ray import tqdm
|
||||
from transformers import AutoConfig
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import *
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
class BenchmarkConfig(TypedDict):
|
||||
@ -315,7 +316,7 @@ def main(args: argparse.Namespace):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = FlexibleArgumentParser()
|
||||
parser.add_argument("--model",
|
||||
type=str,
|
||||
default="mistralai/Mixtral-8x7B-Instruct-v0.1")
|
||||
|
@ -1,4 +1,3 @@
|
||||
import argparse
|
||||
import random
|
||||
import time
|
||||
from typing import List, Optional
|
||||
@ -6,7 +5,8 @@ from typing import List, Optional
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser,
|
||||
create_kv_caches_with_random)
|
||||
|
||||
NUM_BLOCKS = 1024
|
||||
PARTITION_SIZE = 512
|
||||
@ -161,7 +161,7 @@ def main(
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Benchmark the paged attention kernel.")
|
||||
parser.add_argument("--version",
|
||||
type=str,
|
||||
|
@ -1,4 +1,3 @@
|
||||
import argparse
|
||||
from itertools import accumulate
|
||||
from typing import List, Optional
|
||||
|
||||
@ -7,6 +6,7 @@ import torch
|
||||
|
||||
from vllm.model_executor.layers.rotary_embedding import (RotaryEmbedding,
|
||||
get_rope)
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
def benchmark_rope_kernels_multi_lora(
|
||||
@ -86,7 +86,7 @@ def benchmark_rope_kernels_multi_lora(
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Benchmark the rotary embedding kernels.")
|
||||
parser.add_argument("--is-neox-style", type=bool, default=True)
|
||||
parser.add_argument("--batch-size", type=int, default=16)
|
||||
|
@ -1,8 +1,8 @@
|
||||
import argparse
|
||||
import cProfile
|
||||
import pstats
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
# A very long prompt, total number of tokens is about 15k.
|
||||
LONG_PROMPT = ["You are an expert in large language models, aren't you?"
|
||||
@ -47,7 +47,7 @@ def main(args):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
parser = FlexibleArgumentParser(
|
||||
description='Benchmark the performance of hashing function in'
|
||||
'automatic prefix caching.')
|
||||
parser.add_argument('--model', type=str, default='lmsys/longchat-7b-16k')
|
||||
|
@ -1,11 +1,10 @@
|
||||
import argparse
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
parser = argparse.ArgumentParser(description='AQLM examples')
|
||||
parser = FlexibleArgumentParser(description='AQLM examples')
|
||||
|
||||
parser.add_argument('--model',
|
||||
'-m',
|
||||
|
@ -2,6 +2,7 @@ import argparse
|
||||
from typing import List, Tuple
|
||||
|
||||
from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
def create_test_prompts() -> List[Tuple[str, SamplingParams]]:
|
||||
@ -55,7 +56,7 @@ def main(args: argparse.Namespace):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(
|
||||
parser = FlexibleArgumentParser(
|
||||
description='Demo on using the LLMEngine class directly')
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
@ -20,15 +20,15 @@ llm = LLM(
|
||||
tensor_parallel_size=8,
|
||||
)
|
||||
"""
|
||||
import argparse
|
||||
import dataclasses
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from vllm import LLM, EngineArgs
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = FlexibleArgumentParser()
|
||||
EngineArgs.add_cli_args(parser)
|
||||
parser.add_argument("--output",
|
||||
"-o",
|
||||
|
@ -9,6 +9,7 @@ from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.model_executor.model_loader.tensorizer import (TensorizerArgs,
|
||||
TensorizerConfig,
|
||||
tensorize_vllm_model)
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
# yapf conflicts with isort for this docstring
|
||||
# yapf: disable
|
||||
@ -96,7 +97,7 @@ deserialization in this example script, although `--tensorizer-uri` and
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
parser = FlexibleArgumentParser(
|
||||
description="An example script that can be used to serialize and "
|
||||
"deserialize vLLM models. These models "
|
||||
"can be loaded using tensorizer directly to the GPU "
|
||||
|
@ -1,5 +1,4 @@
|
||||
"""vllm.entrypoints.api_server with some extra logging for testing."""
|
||||
import argparse
|
||||
from typing import Any, Dict
|
||||
|
||||
import uvicorn
|
||||
@ -8,6 +7,7 @@ from fastapi.responses import JSONResponse, Response
|
||||
import vllm.entrypoints.api_server
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
app = vllm.entrypoints.api_server.app
|
||||
|
||||
@ -33,7 +33,7 @@ def stats() -> Response:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = FlexibleArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
|
@ -11,7 +11,7 @@ from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
|
||||
SpeculativeConfig, TokenizerPoolConfig,
|
||||
VisionLanguageConfig)
|
||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||
from vllm.utils import str_to_int_tuple
|
||||
from vllm.utils import FlexibleArgumentParser, str_to_int_tuple
|
||||
|
||||
|
||||
def nullable_str(val: str):
|
||||
@ -110,7 +110,7 @@ class EngineArgs:
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args_for_vlm(
|
||||
parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
||||
parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
parser.add_argument('--image-input-type',
|
||||
type=nullable_str,
|
||||
default=None,
|
||||
@ -156,8 +156,7 @@ class EngineArgs:
|
||||
return parser
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(
|
||||
parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
||||
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
"""Shared CLI arguments for vLLM engine."""
|
||||
|
||||
# Model arguments
|
||||
@ -800,8 +799,8 @@ class AsyncEngineArgs(EngineArgs):
|
||||
max_log_len: Optional[int] = None
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: argparse.ArgumentParser,
|
||||
async_args_only: bool = False) -> argparse.ArgumentParser:
|
||||
def add_cli_args(parser: FlexibleArgumentParser,
|
||||
async_args_only: bool = False) -> FlexibleArgumentParser:
|
||||
if not async_args_only:
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
parser.add_argument('--engine-use-ray',
|
||||
@ -822,13 +821,13 @@ class AsyncEngineArgs(EngineArgs):
|
||||
|
||||
# These functions are used by sphinx to build the documentation
|
||||
def _engine_args_parser():
|
||||
return EngineArgs.add_cli_args(argparse.ArgumentParser())
|
||||
return EngineArgs.add_cli_args(FlexibleArgumentParser())
|
||||
|
||||
|
||||
def _async_engine_args_parser():
|
||||
return AsyncEngineArgs.add_cli_args(argparse.ArgumentParser(),
|
||||
return AsyncEngineArgs.add_cli_args(FlexibleArgumentParser(),
|
||||
async_args_only=True)
|
||||
|
||||
|
||||
def _vlm_engine_args_parser():
|
||||
return EngineArgs.add_cli_args_for_vlm(argparse.ArgumentParser())
|
||||
return EngineArgs.add_cli_args_for_vlm(FlexibleArgumentParser())
|
||||
|
@ -6,7 +6,6 @@ We are also not going to accept PRs modifying this file, please
|
||||
change `vllm/entrypoints/openai/api_server.py` instead.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import ssl
|
||||
from typing import AsyncGenerator
|
||||
@ -19,7 +18,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.utils import FlexibleArgumentParser, random_uuid
|
||||
|
||||
TIMEOUT_KEEP_ALIVE = 5 # seconds.
|
||||
app = FastAPI()
|
||||
@ -80,7 +79,7 @@ async def generate(request: Request) -> Response:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = FlexibleArgumentParser()
|
||||
parser.add_argument("--host", type=str, default=None)
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser.add_argument("--ssl-keyfile", type=str, default=None)
|
||||
|
@ -10,6 +10,7 @@ import ssl
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
|
||||
from vllm.entrypoints.openai.serving_engine import LoRAModulePath
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
class LoRAParserAction(argparse.Action):
|
||||
@ -23,7 +24,7 @@ class LoRAParserAction(argparse.Action):
|
||||
|
||||
|
||||
def make_arg_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
parser = FlexibleArgumentParser(
|
||||
description="vLLM OpenAI-Compatible RESTful API server.")
|
||||
parser.add_argument("--host",
|
||||
type=nullable_str,
|
||||
|
@ -1,4 +1,3 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import sys
|
||||
from io import StringIO
|
||||
@ -16,14 +15,14 @@ from vllm.entrypoints.openai.protocol import (BatchRequestInput,
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.logger import init_logger
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.utils import FlexibleArgumentParser, random_uuid
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
parser = FlexibleArgumentParser(
|
||||
description="vLLM OpenAI-Compatible batch runner.")
|
||||
parser.add_argument(
|
||||
"-i",
|
||||
|
@ -21,6 +21,7 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
tensorizer_error_msg = None
|
||||
|
||||
@ -177,8 +178,7 @@ class TensorizerArgs:
|
||||
self.deserializer_params['encryption'] = decryption_params
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(
|
||||
parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
||||
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
"""Tensorizer CLI arguments"""
|
||||
|
||||
# Tensorizer options arg group
|
||||
|
@ -1,3 +1,4 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import datetime
|
||||
import enum
|
||||
@ -775,3 +776,21 @@ def run_once(f):
|
||||
|
||||
wrapper.has_run = False # type: ignore[attr-defined]
|
||||
return wrapper
|
||||
|
||||
|
||||
class FlexibleArgumentParser(argparse.ArgumentParser):
|
||||
"""ArgumentParser that allows both underscore and dash in names."""
|
||||
|
||||
def parse_args(self, args=None, namespace=None):
|
||||
if args is None:
|
||||
args = sys.argv[1:]
|
||||
|
||||
# Convert underscores to dashes and vice versa in argument names
|
||||
processed_args = []
|
||||
for arg in args:
|
||||
if arg.startswith('--'):
|
||||
processed_args.append('--' + arg[len('--'):].replace('_', '-'))
|
||||
else:
|
||||
processed_args.append(arg)
|
||||
|
||||
return super().parse_args(processed_args, namespace)
|
||||
|
Loading…
x
Reference in New Issue
Block a user