[Misc] Human-readable max-model-len
cli arg (#16181)
Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
parent
ad434d4cfe
commit
090c856d76
@ -1,6 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from argparse import ArgumentTypeError
|
from argparse import ArgumentError, ArgumentTypeError
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -142,3 +142,39 @@ def test_composite_arg_parser(arg, expected, option):
|
|||||||
else:
|
else:
|
||||||
args = parser.parse_args([f"--{option}", arg])
|
args = parser.parse_args([f"--{option}", arg])
|
||||||
assert getattr(args, option.replace("-", "_")) == expected
|
assert getattr(args, option.replace("-", "_")) == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_human_readable_model_len():
|
||||||
|
# `exit_on_error` disabled to test invalid values below
|
||||||
|
parser = EngineArgs.add_cli_args(
|
||||||
|
FlexibleArgumentParser(exit_on_error=False))
|
||||||
|
|
||||||
|
args = parser.parse_args([])
|
||||||
|
assert args.max_model_len is None
|
||||||
|
|
||||||
|
args = parser.parse_args(["--max-model-len", "1024"])
|
||||||
|
assert args.max_model_len == 1024
|
||||||
|
|
||||||
|
# Lower
|
||||||
|
args = parser.parse_args(["--max-model-len", "1m"])
|
||||||
|
assert args.max_model_len == 1_000_000
|
||||||
|
args = parser.parse_args(["--max-model-len", "10k"])
|
||||||
|
assert args.max_model_len == 10_000
|
||||||
|
|
||||||
|
# Capital
|
||||||
|
args = parser.parse_args(["--max-model-len", "3K"])
|
||||||
|
assert args.max_model_len == 1024 * 3
|
||||||
|
args = parser.parse_args(["--max-model-len", "10M"])
|
||||||
|
assert args.max_model_len == 2**20 * 10
|
||||||
|
|
||||||
|
# Decimal values
|
||||||
|
args = parser.parse_args(["--max-model-len", "10.2k"])
|
||||||
|
assert args.max_model_len == 10200
|
||||||
|
# ..truncated to the nearest int
|
||||||
|
args = parser.parse_args(["--max-model-len", "10.212345k"])
|
||||||
|
assert args.max_model_len == 10212
|
||||||
|
|
||||||
|
# Invalid (do not allow decimals with binary multipliers)
|
||||||
|
for invalid in ["1a", "pwd", "10.24", "1.23M"]:
|
||||||
|
with pytest.raises(ArgumentError):
|
||||||
|
args = parser.parse_args(["--max-model-len", invalid])
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import json
|
import json
|
||||||
|
import re
|
||||||
import threading
|
import threading
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional,
|
from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional,
|
||||||
@ -368,10 +369,14 @@ class EngineArgs:
|
|||||||
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
|
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
|
||||||
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
|
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
|
||||||
parser.add_argument('--max-model-len',
|
parser.add_argument('--max-model-len',
|
||||||
type=int,
|
type=human_readable_int,
|
||||||
default=EngineArgs.max_model_len,
|
default=EngineArgs.max_model_len,
|
||||||
help='Model context length. If unspecified, will '
|
help='Model context length. If unspecified, will '
|
||||||
'be automatically derived from the model config.')
|
'be automatically derived from the model config. '
|
||||||
|
'Supports k/m/g/K/M/G in human-readable format.\n'
|
||||||
|
'Examples:\n'
|
||||||
|
'- 1k → 1000\n'
|
||||||
|
'- 1K → 1024\n')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--guided-decoding-backend',
|
'--guided-decoding-backend',
|
||||||
type=str,
|
type=str,
|
||||||
@ -1740,6 +1745,47 @@ def _warn_or_fallback(feature_name: str) -> bool:
|
|||||||
return should_exit
|
return should_exit
|
||||||
|
|
||||||
|
|
||||||
|
def human_readable_int(value):
|
||||||
|
"""Parse human-readable integers like '1k', '2M', etc.
|
||||||
|
Including decimal values with decimal multipliers.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
- '1k' -> 1,000
|
||||||
|
- '1K' -> 1,024
|
||||||
|
- '25.6k' -> 25,600
|
||||||
|
"""
|
||||||
|
value = value.strip()
|
||||||
|
match = re.fullmatch(r'(\d+(?:\.\d+)?)([kKmMgGtT])', value)
|
||||||
|
if match:
|
||||||
|
decimal_multiplier = {
|
||||||
|
'k': 10**3,
|
||||||
|
'm': 10**6,
|
||||||
|
'g': 10**9,
|
||||||
|
}
|
||||||
|
binary_multiplier = {
|
||||||
|
'K': 2**10,
|
||||||
|
'M': 2**20,
|
||||||
|
'G': 2**30,
|
||||||
|
}
|
||||||
|
|
||||||
|
number, suffix = match.groups()
|
||||||
|
if suffix in decimal_multiplier:
|
||||||
|
mult = decimal_multiplier[suffix]
|
||||||
|
return int(float(number) * mult)
|
||||||
|
elif suffix in binary_multiplier:
|
||||||
|
mult = binary_multiplier[suffix]
|
||||||
|
# Do not allow decimals with binary multipliers
|
||||||
|
try:
|
||||||
|
return int(number) * mult
|
||||||
|
except ValueError as e:
|
||||||
|
raise argparse.ArgumentTypeError("Decimals are not allowed " \
|
||||||
|
f"with binary suffixes like {suffix}. Did you mean to use " \
|
||||||
|
f"{number}{suffix.lower()} instead?") from e
|
||||||
|
|
||||||
|
# Regular plain number.
|
||||||
|
return int(value)
|
||||||
|
|
||||||
|
|
||||||
# These functions are used by sphinx to build the documentation
|
# These functions are used by sphinx to build the documentation
|
||||||
def _engine_args_parser():
|
def _engine_args_parser():
|
||||||
return EngineArgs.add_cli_args(FlexibleArgumentParser())
|
return EngineArgs.add_cli_args(FlexibleArgumentParser())
|
||||||
|
Loading…
x
Reference in New Issue
Block a user