Improve configs - TokenizerPoolConfig + DeviceConfig (#16603)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-04-17 12:19:42 +01:00 committed by GitHub
parent 99ed526101
commit d27ea94034
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 136 additions and 81 deletions

View File

@ -1,14 +1,36 @@
# SPDX-License-Identifier: Apache-2.0
from dataclasses import asdict
from dataclasses import MISSING, Field, asdict, dataclass, field
import pytest
from vllm.config import ModelConfig, PoolerConfig
from vllm.config import ModelConfig, PoolerConfig, get_field
from vllm.model_executor.layers.pooler import PoolingType
from vllm.platforms import current_platform
def test_get_field():
@dataclass
class TestConfig:
a: int
b: dict = field(default_factory=dict)
c: str = "default"
with pytest.raises(ValueError):
get_field(TestConfig, "a")
b = get_field(TestConfig, "b")
assert isinstance(b, Field)
assert b.default is MISSING
assert b.default_factory is dict
c = get_field(TestConfig, "c")
assert isinstance(c, Field)
assert c.default == "default"
assert c.default_factory is MISSING
@pytest.mark.parametrize(
("model_id", "expected_runner_type", "expected_task"),
[

View File

@ -182,6 +182,23 @@ def config(cls: type[Config]) -> type[Config]:
return cls
def get_field(cls: type[Config], name: str) -> Field:
"""Get the default factory field of a dataclass by name. Used for getting
default factory fields in `EngineArgs`."""
if not is_dataclass(cls):
raise TypeError("The given class is not a dataclass.")
cls_fields = {f.name: f for f in fields(cls)}
if name not in cls_fields:
raise ValueError(f"Field '{name}' not found in {cls.__name__}.")
named_field: Field = cls_fields.get(name)
if (default_factory := named_field.default_factory) is not MISSING:
return field(default_factory=default_factory)
if (default := named_field.default) is not MISSING:
return field(default=default)
raise ValueError(
f"{cls.__name__}.{name} must have a default value or default factory.")
class ModelConfig:
"""Configuration for the model.
@ -1364,20 +1381,26 @@ class CacheConfig:
logger.warning("Possibly too large swap space. %s", msg)
PoolType = Literal["ray"]
@config
@dataclass
class TokenizerPoolConfig:
"""Configuration for the tokenizer pool.
"""Configuration for the tokenizer pool."""
Args:
pool_size: Number of tokenizer workers in the pool.
pool_type: Type of the pool.
extra_config: Additional config for the pool.
The way the config will be used depends on the
pool type.
"""
pool_size: int
pool_type: Union[str, type["BaseTokenizerGroup"]]
extra_config: dict
pool_size: int = 0
"""Number of tokenizer workers in the pool to use for asynchronous
tokenization. If 0, will use synchronous tokenization."""
pool_type: Union[PoolType, type["BaseTokenizerGroup"]] = "ray"
"""Type of tokenizer pool to use for asynchronous tokenization. Ignored if
tokenizer_pool_size is 0."""
extra_config: dict = field(default_factory=dict)
"""Additional config for the pool. The way the config will be used depends
on the pool type. This should be a JSON string that will be parsed into a
dictionary. Ignored if tokenizer_pool_size is 0."""
def compute_hash(self) -> str:
"""
@ -1408,7 +1431,7 @@ class TokenizerPoolConfig:
@classmethod
def create_config(
cls, tokenizer_pool_size: int,
tokenizer_pool_type: Union[str, type["BaseTokenizerGroup"]],
tokenizer_pool_type: Union[PoolType, type["BaseTokenizerGroup"]],
tokenizer_pool_extra_config: Optional[Union[str, dict]]
) -> Optional["TokenizerPoolConfig"]:
"""Create a TokenizerPoolConfig from the given parameters.
@ -1483,7 +1506,7 @@ class LoadConfig:
download_dir: Optional[str] = None
"""Directory to download and load the weights, default to the default
cache directory of Hugging Face."""
model_loader_extra_config: Optional[Union[str, dict]] = None
model_loader_extra_config: dict = field(default_factory=dict)
"""Extra config for model loader. This will be passed to the model loader
corresponding to the chosen load_format. This should be a JSON string that
will be parsed into a dictionary."""
@ -1514,10 +1537,6 @@ class LoadConfig:
return hash_str
def __post_init__(self):
model_loader_extra_config = self.model_loader_extra_config or {}
if isinstance(model_loader_extra_config, str):
self.model_loader_extra_config = json.loads(
model_loader_extra_config)
if isinstance(self.load_format, str):
load_format = self.load_format.lower()
self.load_format = LoadFormat(load_format)
@ -2029,9 +2048,19 @@ class SchedulerConfig:
return self.num_scheduler_steps > 1
Device = Literal["auto", "cuda", "neuron", "cpu", "tpu", "xpu", "hpu"]
@config
@dataclass
class DeviceConfig:
device: Optional[torch.device]
device_type: str
"""Configuration for the device to use for vLLM execution."""
device: Union[Device, torch.device] = "auto"
"""Device type for vLLM execution."""
device_type: str = field(init=False)
"""Device type from the current platform. This is set in
`__post_init__`."""
def compute_hash(self) -> str:
"""
@ -2053,8 +2082,8 @@ class DeviceConfig:
usedforsecurity=False).hexdigest()
return hash_str
def __init__(self, device: str = "auto") -> None:
if device == "auto":
def __post_init__(self):
if self.device == "auto":
# Automated device type detection
from vllm.platforms import current_platform
self.device_type = current_platform.device_type
@ -2065,7 +2094,7 @@ class DeviceConfig:
"to turn on verbose logging to help debug the issue.")
else:
# Device type is assigned explicitly
self.device_type = device
self.device_type = self.device
# Some device types require processing inputs on CPU
if self.device_type in ["neuron"]:

View File

@ -16,15 +16,15 @@ from typing_extensions import TypeIs
import vllm.envs as envs
from vllm import version
from vllm.config import (CacheConfig, CompilationConfig, ConfigFormat,
DecodingConfig, DeviceConfig,
from vllm.config import (CacheConfig, CompilationConfig, Config, ConfigFormat,
DecodingConfig, Device, DeviceConfig,
DistributedExecutorBackend, HfOverrides,
KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig,
ModelConfig, ModelImpl, ObservabilityConfig,
ParallelConfig, PoolerConfig, PromptAdapterConfig,
SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
TaskOption, TokenizerPoolConfig, VllmConfig,
get_attr_docs)
ParallelConfig, PoolerConfig, PoolType,
PromptAdapterConfig, SchedulerConfig, SchedulerPolicy,
SpeculativeConfig, TaskOption, TokenizerPoolConfig,
VllmConfig, get_attr_docs, get_field)
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
@ -44,27 +44,17 @@ logger = init_logger(__name__)
ALLOWED_DETAILED_TRACE_MODULES = ["model", "worker", "all"]
DEVICE_OPTIONS = [
"auto",
"cuda",
"neuron",
"cpu",
"tpu",
"xpu",
"hpu",
]
# object is used to allow for special typing forms
T = TypeVar("T")
TypeHint = Union[type[Any], object]
TypeHintT = Union[type[T], object]
def optional_arg(val: str, return_type: type[T]) -> Optional[T]:
def optional_arg(val: str, return_type: Callable[[str], T]) -> Optional[T]:
if val == "" or val == "None":
return None
try:
return cast(Callable, return_type)(val)
return return_type(val)
except ValueError as e:
raise argparse.ArgumentTypeError(
f"Value {val} cannot be converted to {return_type}.") from e
@ -82,8 +72,11 @@ def optional_float(val: str) -> Optional[float]:
return optional_arg(val, float)
def nullable_kvs(val: str) -> Optional[Mapping[str, int]]:
"""Parses a string containing comma separate key [str] to value [int]
def nullable_kvs(val: str) -> Optional[dict[str, int]]:
"""NOTE: This function is deprecated, args should be passed as JSON
strings instead.
Parses a string containing comma separate key [str] to value [int]
pairs into a dictionary.
Args:
@ -117,6 +110,17 @@ def nullable_kvs(val: str) -> Optional[Mapping[str, int]]:
return out_dict
def optional_dict(val: str) -> Optional[dict[str, int]]:
try:
return optional_arg(val, json.loads)
except ValueError:
logger.warning(
"Failed to parse JSON string. Attempting to parse as "
"comma-separated key=value pairs. This will be deprecated in a "
"future release.")
return nullable_kvs(val)
@dataclass
class EngineArgs:
"""Arguments for vLLM engine."""
@ -178,12 +182,14 @@ class EngineArgs:
enforce_eager: Optional[bool] = None
max_seq_len_to_capture: int = 8192
disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
tokenizer_pool_size: int = 0
tokenizer_pool_size: int = TokenizerPoolConfig.pool_size
# Note: Specifying a tokenizer pool by passing a class
# is intended for expert use only. The API may change without
# notice.
tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray"
tokenizer_pool_extra_config: Optional[Dict[str, Any]] = None
tokenizer_pool_type: Union[PoolType, Type["BaseTokenizerGroup"]] = \
TokenizerPoolConfig.pool_type
tokenizer_pool_extra_config: dict[str, Any] = \
get_field(TokenizerPoolConfig, "extra_config")
limit_mm_per_prompt: Optional[Mapping[str, int]] = None
mm_processor_kwargs: Optional[Dict[str, Any]] = None
disable_mm_preprocessor_cache: bool = False
@ -199,14 +205,14 @@ class EngineArgs:
long_lora_scaling_factors: Optional[Tuple[float]] = None
lora_dtype: Optional[Union[str, torch.dtype]] = 'auto'
max_cpu_loras: Optional[int] = None
device: str = 'auto'
device: Device = DeviceConfig.device
num_scheduler_steps: int = SchedulerConfig.num_scheduler_steps
multi_step_stream_outputs: bool = SchedulerConfig.multi_step_stream_outputs
ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight
num_gpu_blocks_override: Optional[int] = None
num_lookahead_slots: int = SchedulerConfig.num_lookahead_slots
model_loader_extra_config: Optional[
dict] = LoadConfig.model_loader_extra_config
model_loader_extra_config: dict = \
get_field(LoadConfig, "model_loader_extra_config")
ignore_patterns: Optional[Union[str,
List[str]]] = LoadConfig.ignore_patterns
preemption_mode: Optional[str] = SchedulerConfig.preemption_mode
@ -294,14 +300,15 @@ class EngineArgs:
"""Check if the class is a custom type."""
return cls.__module__ != "builtins"
def get_kwargs(cls: type[Any]) -> dict[str, Any]:
def get_kwargs(cls: type[Config]) -> dict[str, Any]:
cls_docs = get_attr_docs(cls)
kwargs = {}
for field in fields(cls):
name = field.name
# One of these will always be present
default = (field.default_factory
if field.default is MISSING else field.default)
default = field.default
# This will only be True if default is MISSING
if field.default_factory is not MISSING:
default = field.default_factory()
kwargs[name] = {"default": default, "help": cls_docs[name]}
# Make note of if the field is optional and get the actual
@ -331,8 +338,9 @@ class EngineArgs:
elif can_be_type(field_type, float):
kwargs[name][
"type"] = optional_float if optional else float
elif can_be_type(field_type, dict):
kwargs[name]["type"] = optional_dict
elif (can_be_type(field_type, str)
or can_be_type(field_type, dict)
or is_custom_type(field_type)):
kwargs[name]["type"] = optional_str if optional else str
else:
@ -674,25 +682,19 @@ class EngineArgs:
'Additionally for encoder-decoder models, if the '
'sequence length of the encoder input is larger '
'than this, we fall back to the eager mode.')
parser.add_argument('--tokenizer-pool-size',
type=int,
default=EngineArgs.tokenizer_pool_size,
help='Size of tokenizer pool to use for '
'asynchronous tokenization. If 0, will '
'use synchronous tokenization.')
parser.add_argument('--tokenizer-pool-type',
type=str,
default=EngineArgs.tokenizer_pool_type,
help='Type of tokenizer pool to use for '
'asynchronous tokenization. Ignored '
'if tokenizer_pool_size is 0.')
parser.add_argument('--tokenizer-pool-extra-config',
type=optional_str,
default=EngineArgs.tokenizer_pool_extra_config,
help='Extra config for tokenizer pool. '
'This should be a JSON string that will be '
'parsed into a dictionary. Ignored if '
'tokenizer_pool_size is 0.')
# Tokenizer arguments
tokenizer_kwargs = get_kwargs(TokenizerPoolConfig)
tokenizer_group = parser.add_argument_group(
title="TokenizerPoolConfig",
description=TokenizerPoolConfig.__doc__,
)
tokenizer_group.add_argument('--tokenizer-pool-size',
**tokenizer_kwargs["pool_size"])
tokenizer_group.add_argument('--tokenizer-pool-type',
**tokenizer_kwargs["pool_type"])
tokenizer_group.add_argument('--tokenizer-pool-extra-config',
**tokenizer_kwargs["extra_config"])
# Multimodal related configs
parser.add_argument(
@ -784,11 +786,15 @@ class EngineArgs:
type=int,
default=EngineArgs.max_prompt_adapter_token,
help='Max number of PromptAdapters tokens')
parser.add_argument("--device",
type=str,
default=EngineArgs.device,
choices=DEVICE_OPTIONS,
help='Device type for vLLM execution.')
# Device arguments
device_kwargs = get_kwargs(DeviceConfig)
device_group = parser.add_argument_group(
title="DeviceConfig",
description=DeviceConfig.__doc__,
)
device_group.add_argument("--device", **device_kwargs["device"])
parser.add_argument('--num-scheduler-steps',
type=int,
default=1,
@ -1302,8 +1308,6 @@ class EngineArgs:
if self.qlora_adapter_name_or_path is not None and \
self.qlora_adapter_name_or_path != "":
if self.model_loader_extra_config is None:
self.model_loader_extra_config = {}
self.model_loader_extra_config[
"qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path