[Frontend] Improve Nullable kv Arg Parsing (#8525)
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
This commit is contained in:
parent
546034b466
commit
1c1bb388e0
@ -1,6 +1,8 @@
|
|||||||
|
from argparse import ArgumentTypeError
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm.engine.arg_utils import EngineArgs
|
from vllm.engine.arg_utils import EngineArgs, nullable_kvs
|
||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
|
|
||||||
@ -13,6 +15,10 @@ from vllm.utils import FlexibleArgumentParser
|
|||||||
"image": 16,
|
"image": 16,
|
||||||
"video": 2
|
"video": 2
|
||||||
}),
|
}),
|
||||||
|
("Image=16, Video=2", {
|
||||||
|
"image": 16,
|
||||||
|
"video": 2
|
||||||
|
}),
|
||||||
])
|
])
|
||||||
def test_limit_mm_per_prompt_parser(arg, expected):
|
def test_limit_mm_per_prompt_parser(arg, expected):
|
||||||
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
|
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
|
||||||
@ -22,3 +28,15 @@ def test_limit_mm_per_prompt_parser(arg, expected):
|
|||||||
args = parser.parse_args(["--limit-mm-per-prompt", arg])
|
args = parser.parse_args(["--limit-mm-per-prompt", arg])
|
||||||
|
|
||||||
assert args.limit_mm_per_prompt == expected
|
assert args.limit_mm_per_prompt == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("arg"),
|
||||||
|
[
|
||||||
|
"image", # Missing =
|
||||||
|
"image=4,image=5", # Conflicting values
|
||||||
|
"image=video=4" # Too many = in tokenized arg
|
||||||
|
])
|
||||||
|
def test_bad_nullable_kvs(arg):
|
||||||
|
with pytest.raises(ArgumentTypeError):
|
||||||
|
nullable_kvs(arg)
|
||||||
|
@ -44,22 +44,36 @@ def nullable_str(val: str):
|
|||||||
|
|
||||||
|
|
||||||
def nullable_kvs(val: str) -> Optional[Mapping[str, int]]:
|
def nullable_kvs(val: str) -> Optional[Mapping[str, int]]:
|
||||||
|
"""Parses a string containing comma separate key [str] to value [int]
|
||||||
|
pairs into a dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
val: String value to be parsed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with parsed values.
|
||||||
|
"""
|
||||||
if len(val) == 0:
|
if len(val) == 0:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
out_dict: Dict[str, int] = {}
|
out_dict: Dict[str, int] = {}
|
||||||
for item in val.split(","):
|
for item in val.split(","):
|
||||||
try:
|
kv_parts = [part.lower().strip() for part in item.split("=")]
|
||||||
key, value = item.split("=")
|
if len(kv_parts) != 2:
|
||||||
except TypeError as exc:
|
raise argparse.ArgumentTypeError(
|
||||||
msg = "Each item should be in the form KEY=VALUE"
|
"Each item should be in the form KEY=VALUE")
|
||||||
raise ValueError(msg) from exc
|
key, value = kv_parts
|
||||||
|
|
||||||
try:
|
try:
|
||||||
out_dict[key] = int(value)
|
parsed_value = int(value)
|
||||||
except ValueError as exc:
|
except ValueError as exc:
|
||||||
msg = f"Failed to parse value of item {key}={value}"
|
msg = f"Failed to parse value of item {key}={value}"
|
||||||
raise ValueError(msg) from exc
|
raise argparse.ArgumentTypeError(msg) from exc
|
||||||
|
|
||||||
|
if key in out_dict and out_dict[key] != parsed_value:
|
||||||
|
raise argparse.ArgumentTypeError(
|
||||||
|
f"Conflicting values specified for key: {key}")
|
||||||
|
out_dict[key] = parsed_value
|
||||||
|
|
||||||
return out_dict
|
return out_dict
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user