Improve configs - TokenizerPoolConfig + DeviceConfig (#16603)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-04-17 12:19:42 +01:00 committed by GitHub
parent 99ed526101
commit d27ea94034
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 136 additions and 81 deletions

View File

@ -1,14 +1,36 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from dataclasses import asdict from dataclasses import MISSING, Field, asdict, dataclass, field
import pytest import pytest
from vllm.config import ModelConfig, PoolerConfig from vllm.config import ModelConfig, PoolerConfig, get_field
from vllm.model_executor.layers.pooler import PoolingType from vllm.model_executor.layers.pooler import PoolingType
from vllm.platforms import current_platform from vllm.platforms import current_platform
def test_get_field():
@dataclass
class TestConfig:
a: int
b: dict = field(default_factory=dict)
c: str = "default"
with pytest.raises(ValueError):
get_field(TestConfig, "a")
b = get_field(TestConfig, "b")
assert isinstance(b, Field)
assert b.default is MISSING
assert b.default_factory is dict
c = get_field(TestConfig, "c")
assert isinstance(c, Field)
assert c.default == "default"
assert c.default_factory is MISSING
@pytest.mark.parametrize( @pytest.mark.parametrize(
("model_id", "expected_runner_type", "expected_task"), ("model_id", "expected_runner_type", "expected_task"),
[ [

View File

@ -182,6 +182,23 @@ def config(cls: type[Config]) -> type[Config]:
return cls return cls
def get_field(cls: type[Config], name: str) -> Field:
"""Get the default factory field of a dataclass by name. Used for getting
default factory fields in `EngineArgs`."""
if not is_dataclass(cls):
raise TypeError("The given class is not a dataclass.")
cls_fields = {f.name: f for f in fields(cls)}
if name not in cls_fields:
raise ValueError(f"Field '{name}' not found in {cls.__name__}.")
named_field: Field = cls_fields.get(name)
if (default_factory := named_field.default_factory) is not MISSING:
return field(default_factory=default_factory)
if (default := named_field.default) is not MISSING:
return field(default=default)
raise ValueError(
f"{cls.__name__}.{name} must have a default value or default factory.")
class ModelConfig: class ModelConfig:
"""Configuration for the model. """Configuration for the model.
@ -1364,20 +1381,26 @@ class CacheConfig:
logger.warning("Possibly too large swap space. %s", msg) logger.warning("Possibly too large swap space. %s", msg)
PoolType = Literal["ray"]
@config
@dataclass @dataclass
class TokenizerPoolConfig: class TokenizerPoolConfig:
"""Configuration for the tokenizer pool. """Configuration for the tokenizer pool."""
Args: pool_size: int = 0
pool_size: Number of tokenizer workers in the pool. """Number of tokenizer workers in the pool to use for asynchronous
pool_type: Type of the pool. tokenization. If 0, will use synchronous tokenization."""
extra_config: Additional config for the pool.
The way the config will be used depends on the pool_type: Union[PoolType, type["BaseTokenizerGroup"]] = "ray"
pool type. """Type of tokenizer pool to use for asynchronous tokenization. Ignored if
""" tokenizer_pool_size is 0."""
pool_size: int
pool_type: Union[str, type["BaseTokenizerGroup"]] extra_config: dict = field(default_factory=dict)
extra_config: dict """Additional config for the pool. The way the config will be used depends
on the pool type. This should be a JSON string that will be parsed into a
dictionary. Ignored if tokenizer_pool_size is 0."""
def compute_hash(self) -> str: def compute_hash(self) -> str:
""" """
@ -1408,7 +1431,7 @@ class TokenizerPoolConfig:
@classmethod @classmethod
def create_config( def create_config(
cls, tokenizer_pool_size: int, cls, tokenizer_pool_size: int,
tokenizer_pool_type: Union[str, type["BaseTokenizerGroup"]], tokenizer_pool_type: Union[PoolType, type["BaseTokenizerGroup"]],
tokenizer_pool_extra_config: Optional[Union[str, dict]] tokenizer_pool_extra_config: Optional[Union[str, dict]]
) -> Optional["TokenizerPoolConfig"]: ) -> Optional["TokenizerPoolConfig"]:
"""Create a TokenizerPoolConfig from the given parameters. """Create a TokenizerPoolConfig from the given parameters.
@ -1483,7 +1506,7 @@ class LoadConfig:
download_dir: Optional[str] = None download_dir: Optional[str] = None
"""Directory to download and load the weights, default to the default """Directory to download and load the weights, default to the default
cache directory of Hugging Face.""" cache directory of Hugging Face."""
model_loader_extra_config: Optional[Union[str, dict]] = None model_loader_extra_config: dict = field(default_factory=dict)
"""Extra config for model loader. This will be passed to the model loader """Extra config for model loader. This will be passed to the model loader
corresponding to the chosen load_format. This should be a JSON string that corresponding to the chosen load_format. This should be a JSON string that
will be parsed into a dictionary.""" will be parsed into a dictionary."""
@ -1514,10 +1537,6 @@ class LoadConfig:
return hash_str return hash_str
def __post_init__(self): def __post_init__(self):
model_loader_extra_config = self.model_loader_extra_config or {}
if isinstance(model_loader_extra_config, str):
self.model_loader_extra_config = json.loads(
model_loader_extra_config)
if isinstance(self.load_format, str): if isinstance(self.load_format, str):
load_format = self.load_format.lower() load_format = self.load_format.lower()
self.load_format = LoadFormat(load_format) self.load_format = LoadFormat(load_format)
@ -2029,9 +2048,19 @@ class SchedulerConfig:
return self.num_scheduler_steps > 1 return self.num_scheduler_steps > 1
Device = Literal["auto", "cuda", "neuron", "cpu", "tpu", "xpu", "hpu"]
@config
@dataclass
class DeviceConfig: class DeviceConfig:
device: Optional[torch.device] """Configuration for the device to use for vLLM execution."""
device_type: str
device: Union[Device, torch.device] = "auto"
"""Device type for vLLM execution."""
device_type: str = field(init=False)
"""Device type from the current platform. This is set in
`__post_init__`."""
def compute_hash(self) -> str: def compute_hash(self) -> str:
""" """
@ -2053,8 +2082,8 @@ class DeviceConfig:
usedforsecurity=False).hexdigest() usedforsecurity=False).hexdigest()
return hash_str return hash_str
def __init__(self, device: str = "auto") -> None: def __post_init__(self):
if device == "auto": if self.device == "auto":
# Automated device type detection # Automated device type detection
from vllm.platforms import current_platform from vllm.platforms import current_platform
self.device_type = current_platform.device_type self.device_type = current_platform.device_type
@ -2065,7 +2094,7 @@ class DeviceConfig:
"to turn on verbose logging to help debug the issue.") "to turn on verbose logging to help debug the issue.")
else: else:
# Device type is assigned explicitly # Device type is assigned explicitly
self.device_type = device self.device_type = self.device
# Some device types require processing inputs on CPU # Some device types require processing inputs on CPU
if self.device_type in ["neuron"]: if self.device_type in ["neuron"]:

View File

@ -16,15 +16,15 @@ from typing_extensions import TypeIs
import vllm.envs as envs import vllm.envs as envs
from vllm import version from vllm import version
from vllm.config import (CacheConfig, CompilationConfig, ConfigFormat, from vllm.config import (CacheConfig, CompilationConfig, Config, ConfigFormat,
DecodingConfig, DeviceConfig, DecodingConfig, Device, DeviceConfig,
DistributedExecutorBackend, HfOverrides, DistributedExecutorBackend, HfOverrides,
KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig, KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig,
ModelConfig, ModelImpl, ObservabilityConfig, ModelConfig, ModelImpl, ObservabilityConfig,
ParallelConfig, PoolerConfig, PromptAdapterConfig, ParallelConfig, PoolerConfig, PoolType,
SchedulerConfig, SchedulerPolicy, SpeculativeConfig, PromptAdapterConfig, SchedulerConfig, SchedulerPolicy,
TaskOption, TokenizerPoolConfig, VllmConfig, SpeculativeConfig, TaskOption, TokenizerPoolConfig,
get_attr_docs) VllmConfig, get_attr_docs, get_field)
from vllm.executor.executor_base import ExecutorBase from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
@ -44,27 +44,17 @@ logger = init_logger(__name__)
ALLOWED_DETAILED_TRACE_MODULES = ["model", "worker", "all"] ALLOWED_DETAILED_TRACE_MODULES = ["model", "worker", "all"]
DEVICE_OPTIONS = [
"auto",
"cuda",
"neuron",
"cpu",
"tpu",
"xpu",
"hpu",
]
# object is used to allow for special typing forms # object is used to allow for special typing forms
T = TypeVar("T") T = TypeVar("T")
TypeHint = Union[type[Any], object] TypeHint = Union[type[Any], object]
TypeHintT = Union[type[T], object] TypeHintT = Union[type[T], object]
def optional_arg(val: str, return_type: type[T]) -> Optional[T]: def optional_arg(val: str, return_type: Callable[[str], T]) -> Optional[T]:
if val == "" or val == "None": if val == "" or val == "None":
return None return None
try: try:
return cast(Callable, return_type)(val) return return_type(val)
except ValueError as e: except ValueError as e:
raise argparse.ArgumentTypeError( raise argparse.ArgumentTypeError(
f"Value {val} cannot be converted to {return_type}.") from e f"Value {val} cannot be converted to {return_type}.") from e
@ -82,8 +72,11 @@ def optional_float(val: str) -> Optional[float]:
return optional_arg(val, float) return optional_arg(val, float)
def nullable_kvs(val: str) -> Optional[Mapping[str, int]]: def nullable_kvs(val: str) -> Optional[dict[str, int]]:
"""Parses a string containing comma separate key [str] to value [int] """NOTE: This function is deprecated, args should be passed as JSON
strings instead.
Parses a string containing comma separate key [str] to value [int]
pairs into a dictionary. pairs into a dictionary.
Args: Args:
@ -117,6 +110,17 @@ def nullable_kvs(val: str) -> Optional[Mapping[str, int]]:
return out_dict return out_dict
def optional_dict(val: str) -> Optional[dict[str, int]]:
try:
return optional_arg(val, json.loads)
except ValueError:
logger.warning(
"Failed to parse JSON string. Attempting to parse as "
"comma-separated key=value pairs. This will be deprecated in a "
"future release.")
return nullable_kvs(val)
@dataclass @dataclass
class EngineArgs: class EngineArgs:
"""Arguments for vLLM engine.""" """Arguments for vLLM engine."""
@ -178,12 +182,14 @@ class EngineArgs:
enforce_eager: Optional[bool] = None enforce_eager: Optional[bool] = None
max_seq_len_to_capture: int = 8192 max_seq_len_to_capture: int = 8192
disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
tokenizer_pool_size: int = 0 tokenizer_pool_size: int = TokenizerPoolConfig.pool_size
# Note: Specifying a tokenizer pool by passing a class # Note: Specifying a tokenizer pool by passing a class
# is intended for expert use only. The API may change without # is intended for expert use only. The API may change without
# notice. # notice.
tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray" tokenizer_pool_type: Union[PoolType, Type["BaseTokenizerGroup"]] = \
tokenizer_pool_extra_config: Optional[Dict[str, Any]] = None TokenizerPoolConfig.pool_type
tokenizer_pool_extra_config: dict[str, Any] = \
get_field(TokenizerPoolConfig, "extra_config")
limit_mm_per_prompt: Optional[Mapping[str, int]] = None limit_mm_per_prompt: Optional[Mapping[str, int]] = None
mm_processor_kwargs: Optional[Dict[str, Any]] = None mm_processor_kwargs: Optional[Dict[str, Any]] = None
disable_mm_preprocessor_cache: bool = False disable_mm_preprocessor_cache: bool = False
@ -199,14 +205,14 @@ class EngineArgs:
long_lora_scaling_factors: Optional[Tuple[float]] = None long_lora_scaling_factors: Optional[Tuple[float]] = None
lora_dtype: Optional[Union[str, torch.dtype]] = 'auto' lora_dtype: Optional[Union[str, torch.dtype]] = 'auto'
max_cpu_loras: Optional[int] = None max_cpu_loras: Optional[int] = None
device: str = 'auto' device: Device = DeviceConfig.device
num_scheduler_steps: int = SchedulerConfig.num_scheduler_steps num_scheduler_steps: int = SchedulerConfig.num_scheduler_steps
multi_step_stream_outputs: bool = SchedulerConfig.multi_step_stream_outputs multi_step_stream_outputs: bool = SchedulerConfig.multi_step_stream_outputs
ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight
num_gpu_blocks_override: Optional[int] = None num_gpu_blocks_override: Optional[int] = None
num_lookahead_slots: int = SchedulerConfig.num_lookahead_slots num_lookahead_slots: int = SchedulerConfig.num_lookahead_slots
model_loader_extra_config: Optional[ model_loader_extra_config: dict = \
dict] = LoadConfig.model_loader_extra_config get_field(LoadConfig, "model_loader_extra_config")
ignore_patterns: Optional[Union[str, ignore_patterns: Optional[Union[str,
List[str]]] = LoadConfig.ignore_patterns List[str]]] = LoadConfig.ignore_patterns
preemption_mode: Optional[str] = SchedulerConfig.preemption_mode preemption_mode: Optional[str] = SchedulerConfig.preemption_mode
@ -294,14 +300,15 @@ class EngineArgs:
"""Check if the class is a custom type.""" """Check if the class is a custom type."""
return cls.__module__ != "builtins" return cls.__module__ != "builtins"
def get_kwargs(cls: type[Any]) -> dict[str, Any]: def get_kwargs(cls: type[Config]) -> dict[str, Any]:
cls_docs = get_attr_docs(cls) cls_docs = get_attr_docs(cls)
kwargs = {} kwargs = {}
for field in fields(cls): for field in fields(cls):
name = field.name name = field.name
# One of these will always be present default = field.default
default = (field.default_factory # This will only be True if default is MISSING
if field.default is MISSING else field.default) if field.default_factory is not MISSING:
default = field.default_factory()
kwargs[name] = {"default": default, "help": cls_docs[name]} kwargs[name] = {"default": default, "help": cls_docs[name]}
# Make note of if the field is optional and get the actual # Make note of if the field is optional and get the actual
@ -331,8 +338,9 @@ class EngineArgs:
elif can_be_type(field_type, float): elif can_be_type(field_type, float):
kwargs[name][ kwargs[name][
"type"] = optional_float if optional else float "type"] = optional_float if optional else float
elif can_be_type(field_type, dict):
kwargs[name]["type"] = optional_dict
elif (can_be_type(field_type, str) elif (can_be_type(field_type, str)
or can_be_type(field_type, dict)
or is_custom_type(field_type)): or is_custom_type(field_type)):
kwargs[name]["type"] = optional_str if optional else str kwargs[name]["type"] = optional_str if optional else str
else: else:
@ -674,25 +682,19 @@ class EngineArgs:
'Additionally for encoder-decoder models, if the ' 'Additionally for encoder-decoder models, if the '
'sequence length of the encoder input is larger ' 'sequence length of the encoder input is larger '
'than this, we fall back to the eager mode.') 'than this, we fall back to the eager mode.')
parser.add_argument('--tokenizer-pool-size',
type=int, # Tokenizer arguments
default=EngineArgs.tokenizer_pool_size, tokenizer_kwargs = get_kwargs(TokenizerPoolConfig)
help='Size of tokenizer pool to use for ' tokenizer_group = parser.add_argument_group(
'asynchronous tokenization. If 0, will ' title="TokenizerPoolConfig",
'use synchronous tokenization.') description=TokenizerPoolConfig.__doc__,
parser.add_argument('--tokenizer-pool-type', )
type=str, tokenizer_group.add_argument('--tokenizer-pool-size',
default=EngineArgs.tokenizer_pool_type, **tokenizer_kwargs["pool_size"])
help='Type of tokenizer pool to use for ' tokenizer_group.add_argument('--tokenizer-pool-type',
'asynchronous tokenization. Ignored ' **tokenizer_kwargs["pool_type"])
'if tokenizer_pool_size is 0.') tokenizer_group.add_argument('--tokenizer-pool-extra-config',
parser.add_argument('--tokenizer-pool-extra-config', **tokenizer_kwargs["extra_config"])
type=optional_str,
default=EngineArgs.tokenizer_pool_extra_config,
help='Extra config for tokenizer pool. '
'This should be a JSON string that will be '
'parsed into a dictionary. Ignored if '
'tokenizer_pool_size is 0.')
# Multimodal related configs # Multimodal related configs
parser.add_argument( parser.add_argument(
@ -784,11 +786,15 @@ class EngineArgs:
type=int, type=int,
default=EngineArgs.max_prompt_adapter_token, default=EngineArgs.max_prompt_adapter_token,
help='Max number of PromptAdapters tokens') help='Max number of PromptAdapters tokens')
parser.add_argument("--device",
type=str, # Device arguments
default=EngineArgs.device, device_kwargs = get_kwargs(DeviceConfig)
choices=DEVICE_OPTIONS, device_group = parser.add_argument_group(
help='Device type for vLLM execution.') title="DeviceConfig",
description=DeviceConfig.__doc__,
)
device_group.add_argument("--device", **device_kwargs["device"])
parser.add_argument('--num-scheduler-steps', parser.add_argument('--num-scheduler-steps',
type=int, type=int,
default=1, default=1,
@ -1302,8 +1308,6 @@ class EngineArgs:
if self.qlora_adapter_name_or_path is not None and \ if self.qlora_adapter_name_or_path is not None and \
self.qlora_adapter_name_or_path != "": self.qlora_adapter_name_or_path != "":
if self.model_loader_extra_config is None:
self.model_loader_extra_config = {}
self.model_loader_extra_config[ self.model_loader_extra_config[
"qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path "qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path