[Misc] Human-readable max-model-len cli arg (#16181)

Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
Nicolò Lucchesi 2025-04-07 20:40:58 +02:00 committed by GitHub
parent ad434d4cfe
commit 090c856d76
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 85 additions and 3 deletions

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from argparse import ArgumentTypeError
from argparse import ArgumentError, ArgumentTypeError
import pytest
@ -142,3 +142,39 @@ def test_composite_arg_parser(arg, expected, option):
else:
args = parser.parse_args([f"--{option}", arg])
assert getattr(args, option.replace("-", "_")) == expected
def test_human_readable_model_len():
# `exit_on_error` disabled to test invalid values below
parser = EngineArgs.add_cli_args(
FlexibleArgumentParser(exit_on_error=False))
args = parser.parse_args([])
assert args.max_model_len is None
args = parser.parse_args(["--max-model-len", "1024"])
assert args.max_model_len == 1024
# Lower
args = parser.parse_args(["--max-model-len", "1m"])
assert args.max_model_len == 1_000_000
args = parser.parse_args(["--max-model-len", "10k"])
assert args.max_model_len == 10_000
# Capital
args = parser.parse_args(["--max-model-len", "3K"])
assert args.max_model_len == 1024 * 3
args = parser.parse_args(["--max-model-len", "10M"])
assert args.max_model_len == 2**20 * 10
# Decimal values
args = parser.parse_args(["--max-model-len", "10.2k"])
assert args.max_model_len == 10200
# ..truncated to the nearest int
args = parser.parse_args(["--max-model-len", "10.212345k"])
assert args.max_model_len == 10212
# Invalid (do not allow decimals with binary multipliers)
for invalid in ["1a", "pwd", "10.24", "1.23M"]:
with pytest.raises(ArgumentError):
args = parser.parse_args(["--max-model-len", invalid])

View File

@ -3,6 +3,7 @@
import argparse
import dataclasses
import json
import re
import threading
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional,
@ -368,10 +369,14 @@ class EngineArgs:
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
parser.add_argument('--max-model-len',
type=int,
type=human_readable_int,
default=EngineArgs.max_model_len,
help='Model context length. If unspecified, will '
'be automatically derived from the model config.')
'be automatically derived from the model config. '
'Supports k/m/g/K/M/G in human-readable format.\n'
'Examples:\n'
'- 1k → 1000\n'
'- 1K → 1024\n')
parser.add_argument(
'--guided-decoding-backend',
type=str,
@ -1740,6 +1745,47 @@ def _warn_or_fallback(feature_name: str) -> bool:
return should_exit
def human_readable_int(value):
"""Parse human-readable integers like '1k', '2M', etc.
Including decimal values with decimal multipliers.
Examples:
- '1k' -> 1,000
- '1K' -> 1,024
- '25.6k' -> 25,600
"""
value = value.strip()
match = re.fullmatch(r'(\d+(?:\.\d+)?)([kKmMgGtT])', value)
if match:
decimal_multiplier = {
'k': 10**3,
'm': 10**6,
'g': 10**9,
}
binary_multiplier = {
'K': 2**10,
'M': 2**20,
'G': 2**30,
}
number, suffix = match.groups()
if suffix in decimal_multiplier:
mult = decimal_multiplier[suffix]
return int(float(number) * mult)
elif suffix in binary_multiplier:
mult = binary_multiplier[suffix]
# Do not allow decimals with binary multipliers
try:
return int(number) * mult
except ValueError as e:
raise argparse.ArgumentTypeError("Decimals are not allowed " \
f"with binary suffixes like {suffix}. Did you mean to use " \
f"{number}{suffix.lower()} instead?") from e
# Regular plain number.
return int(value)
# These functions are used by sphinx to build the documentation
def _engine_args_parser():
return EngineArgs.add_cli_args(FlexibleArgumentParser())