[Misc] Human-readable max-model-len
cli arg (#16181)
Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
parent
ad434d4cfe
commit
090c856d76
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from argparse import ArgumentTypeError
|
||||
from argparse import ArgumentError, ArgumentTypeError
|
||||
|
||||
import pytest
|
||||
|
||||
@ -142,3 +142,39 @@ def test_composite_arg_parser(arg, expected, option):
|
||||
else:
|
||||
args = parser.parse_args([f"--{option}", arg])
|
||||
assert getattr(args, option.replace("-", "_")) == expected
|
||||
|
||||
|
||||
def test_human_readable_model_len():
|
||||
# `exit_on_error` disabled to test invalid values below
|
||||
parser = EngineArgs.add_cli_args(
|
||||
FlexibleArgumentParser(exit_on_error=False))
|
||||
|
||||
args = parser.parse_args([])
|
||||
assert args.max_model_len is None
|
||||
|
||||
args = parser.parse_args(["--max-model-len", "1024"])
|
||||
assert args.max_model_len == 1024
|
||||
|
||||
# Lower
|
||||
args = parser.parse_args(["--max-model-len", "1m"])
|
||||
assert args.max_model_len == 1_000_000
|
||||
args = parser.parse_args(["--max-model-len", "10k"])
|
||||
assert args.max_model_len == 10_000
|
||||
|
||||
# Capital
|
||||
args = parser.parse_args(["--max-model-len", "3K"])
|
||||
assert args.max_model_len == 1024 * 3
|
||||
args = parser.parse_args(["--max-model-len", "10M"])
|
||||
assert args.max_model_len == 2**20 * 10
|
||||
|
||||
# Decimal values
|
||||
args = parser.parse_args(["--max-model-len", "10.2k"])
|
||||
assert args.max_model_len == 10200
|
||||
# ..truncated to the nearest int
|
||||
args = parser.parse_args(["--max-model-len", "10.212345k"])
|
||||
assert args.max_model_len == 10212
|
||||
|
||||
# Invalid (do not allow decimals with binary multipliers)
|
||||
for invalid in ["1a", "pwd", "10.24", "1.23M"]:
|
||||
with pytest.raises(ArgumentError):
|
||||
args = parser.parse_args(["--max-model-len", invalid])
|
||||
|
@ -3,6 +3,7 @@
|
||||
import argparse
|
||||
import dataclasses
|
||||
import json
|
||||
import re
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional,
|
||||
@ -368,10 +369,14 @@ class EngineArgs:
|
||||
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
|
||||
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
|
||||
parser.add_argument('--max-model-len',
|
||||
type=int,
|
||||
type=human_readable_int,
|
||||
default=EngineArgs.max_model_len,
|
||||
help='Model context length. If unspecified, will '
|
||||
'be automatically derived from the model config.')
|
||||
'be automatically derived from the model config. '
|
||||
'Supports k/m/g/K/M/G in human-readable format.\n'
|
||||
'Examples:\n'
|
||||
'- 1k → 1000\n'
|
||||
'- 1K → 1024\n')
|
||||
parser.add_argument(
|
||||
'--guided-decoding-backend',
|
||||
type=str,
|
||||
@ -1740,6 +1745,47 @@ def _warn_or_fallback(feature_name: str) -> bool:
|
||||
return should_exit
|
||||
|
||||
|
||||
def human_readable_int(value):
|
||||
"""Parse human-readable integers like '1k', '2M', etc.
|
||||
Including decimal values with decimal multipliers.
|
||||
|
||||
Examples:
|
||||
- '1k' -> 1,000
|
||||
- '1K' -> 1,024
|
||||
- '25.6k' -> 25,600
|
||||
"""
|
||||
value = value.strip()
|
||||
match = re.fullmatch(r'(\d+(?:\.\d+)?)([kKmMgGtT])', value)
|
||||
if match:
|
||||
decimal_multiplier = {
|
||||
'k': 10**3,
|
||||
'm': 10**6,
|
||||
'g': 10**9,
|
||||
}
|
||||
binary_multiplier = {
|
||||
'K': 2**10,
|
||||
'M': 2**20,
|
||||
'G': 2**30,
|
||||
}
|
||||
|
||||
number, suffix = match.groups()
|
||||
if suffix in decimal_multiplier:
|
||||
mult = decimal_multiplier[suffix]
|
||||
return int(float(number) * mult)
|
||||
elif suffix in binary_multiplier:
|
||||
mult = binary_multiplier[suffix]
|
||||
# Do not allow decimals with binary multipliers
|
||||
try:
|
||||
return int(number) * mult
|
||||
except ValueError as e:
|
||||
raise argparse.ArgumentTypeError("Decimals are not allowed " \
|
||||
f"with binary suffixes like {suffix}. Did you mean to use " \
|
||||
f"{number}{suffix.lower()} instead?") from e
|
||||
|
||||
# Regular plain number.
|
||||
return int(value)
|
||||
|
||||
|
||||
# These functions are used by sphinx to build the documentation
|
||||
def _engine_args_parser():
|
||||
return EngineArgs.add_cli_args(FlexibleArgumentParser())
|
||||
|
Loading…
x
Reference in New Issue
Block a user