diff --git a/tests/test_config.py b/tests/test_config.py index 06264c5b..53db91e8 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,14 +1,36 @@ # SPDX-License-Identifier: Apache-2.0 -from dataclasses import asdict +from dataclasses import MISSING, Field, asdict, dataclass, field import pytest -from vllm.config import ModelConfig, PoolerConfig +from vllm.config import ModelConfig, PoolerConfig, get_field from vllm.model_executor.layers.pooler import PoolingType from vllm.platforms import current_platform +def test_get_field(): + + @dataclass + class TestConfig: + a: int + b: dict = field(default_factory=dict) + c: str = "default" + + with pytest.raises(ValueError): + get_field(TestConfig, "a") + + b = get_field(TestConfig, "b") + assert isinstance(b, Field) + assert b.default is MISSING + assert b.default_factory is dict + + c = get_field(TestConfig, "c") + assert isinstance(c, Field) + assert c.default == "default" + assert c.default_factory is MISSING + + @pytest.mark.parametrize( ("model_id", "expected_runner_type", "expected_task"), [ diff --git a/vllm/config.py b/vllm/config.py index cca725c7..7e2869e4 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -182,6 +182,23 @@ def config(cls: type[Config]) -> type[Config]: return cls +def get_field(cls: type[Config], name: str) -> Field: + """Get the default factory field of a dataclass by name. Used for getting + default factory fields in `EngineArgs`.""" + if not is_dataclass(cls): + raise TypeError("The given class is not a dataclass.") + cls_fields = {f.name: f for f in fields(cls)} + if name not in cls_fields: + raise ValueError(f"Field '{name}' not found in {cls.__name__}.") + named_field: Field = cls_fields.get(name) + if (default_factory := named_field.default_factory) is not MISSING: + return field(default_factory=default_factory) + if (default := named_field.default) is not MISSING: + return field(default=default) + raise ValueError( + f"{cls.__name__}.{name} must have a default value or default factory.") + + class ModelConfig: """Configuration for the model. @@ -1364,20 +1381,26 @@ class CacheConfig: logger.warning("Possibly too large swap space. %s", msg) +PoolType = Literal["ray"] + + +@config @dataclass class TokenizerPoolConfig: - """Configuration for the tokenizer pool. + """Configuration for the tokenizer pool.""" - Args: - pool_size: Number of tokenizer workers in the pool. - pool_type: Type of the pool. - extra_config: Additional config for the pool. - The way the config will be used depends on the - pool type. - """ - pool_size: int - pool_type: Union[str, type["BaseTokenizerGroup"]] - extra_config: dict + pool_size: int = 0 + """Number of tokenizer workers in the pool to use for asynchronous + tokenization. If 0, will use synchronous tokenization.""" + + pool_type: Union[PoolType, type["BaseTokenizerGroup"]] = "ray" + """Type of tokenizer pool to use for asynchronous tokenization. Ignored if + tokenizer_pool_size is 0.""" + + extra_config: dict = field(default_factory=dict) + """Additional config for the pool. The way the config will be used depends + on the pool type. This should be a JSON string that will be parsed into a + dictionary. Ignored if tokenizer_pool_size is 0.""" def compute_hash(self) -> str: """ @@ -1408,7 +1431,7 @@ class TokenizerPoolConfig: @classmethod def create_config( cls, tokenizer_pool_size: int, - tokenizer_pool_type: Union[str, type["BaseTokenizerGroup"]], + tokenizer_pool_type: Union[PoolType, type["BaseTokenizerGroup"]], tokenizer_pool_extra_config: Optional[Union[str, dict]] ) -> Optional["TokenizerPoolConfig"]: """Create a TokenizerPoolConfig from the given parameters. @@ -1483,7 +1506,7 @@ class LoadConfig: download_dir: Optional[str] = None """Directory to download and load the weights, default to the default cache directory of Hugging Face.""" - model_loader_extra_config: Optional[Union[str, dict]] = None + model_loader_extra_config: dict = field(default_factory=dict) """Extra config for model loader. This will be passed to the model loader corresponding to the chosen load_format. This should be a JSON string that will be parsed into a dictionary.""" @@ -1514,10 +1537,6 @@ class LoadConfig: return hash_str def __post_init__(self): - model_loader_extra_config = self.model_loader_extra_config or {} - if isinstance(model_loader_extra_config, str): - self.model_loader_extra_config = json.loads( - model_loader_extra_config) if isinstance(self.load_format, str): load_format = self.load_format.lower() self.load_format = LoadFormat(load_format) @@ -2029,9 +2048,19 @@ class SchedulerConfig: return self.num_scheduler_steps > 1 +Device = Literal["auto", "cuda", "neuron", "cpu", "tpu", "xpu", "hpu"] + + +@config +@dataclass class DeviceConfig: - device: Optional[torch.device] - device_type: str + """Configuration for the device to use for vLLM execution.""" + + device: Union[Device, torch.device] = "auto" + """Device type for vLLM execution.""" + device_type: str = field(init=False) + """Device type from the current platform. This is set in + `__post_init__`.""" def compute_hash(self) -> str: """ @@ -2053,8 +2082,8 @@ class DeviceConfig: usedforsecurity=False).hexdigest() return hash_str - def __init__(self, device: str = "auto") -> None: - if device == "auto": + def __post_init__(self): + if self.device == "auto": # Automated device type detection from vllm.platforms import current_platform self.device_type = current_platform.device_type @@ -2065,7 +2094,7 @@ class DeviceConfig: "to turn on verbose logging to help debug the issue.") else: # Device type is assigned explicitly - self.device_type = device + self.device_type = self.device # Some device types require processing inputs on CPU if self.device_type in ["neuron"]: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 32cb2e90..85b3ddfc 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -16,15 +16,15 @@ from typing_extensions import TypeIs import vllm.envs as envs from vllm import version -from vllm.config import (CacheConfig, CompilationConfig, ConfigFormat, - DecodingConfig, DeviceConfig, +from vllm.config import (CacheConfig, CompilationConfig, Config, ConfigFormat, + DecodingConfig, Device, DeviceConfig, DistributedExecutorBackend, HfOverrides, KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig, ModelConfig, ModelImpl, ObservabilityConfig, - ParallelConfig, PoolerConfig, PromptAdapterConfig, - SchedulerConfig, SchedulerPolicy, SpeculativeConfig, - TaskOption, TokenizerPoolConfig, VllmConfig, - get_attr_docs) + ParallelConfig, PoolerConfig, PoolType, + PromptAdapterConfig, SchedulerConfig, SchedulerPolicy, + SpeculativeConfig, TaskOption, TokenizerPoolConfig, + VllmConfig, get_attr_docs, get_field) from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS @@ -44,27 +44,17 @@ logger = init_logger(__name__) ALLOWED_DETAILED_TRACE_MODULES = ["model", "worker", "all"] -DEVICE_OPTIONS = [ - "auto", - "cuda", - "neuron", - "cpu", - "tpu", - "xpu", - "hpu", -] - # object is used to allow for special typing forms T = TypeVar("T") TypeHint = Union[type[Any], object] TypeHintT = Union[type[T], object] -def optional_arg(val: str, return_type: type[T]) -> Optional[T]: +def optional_arg(val: str, return_type: Callable[[str], T]) -> Optional[T]: if val == "" or val == "None": return None try: - return cast(Callable, return_type)(val) + return return_type(val) except ValueError as e: raise argparse.ArgumentTypeError( f"Value {val} cannot be converted to {return_type}.") from e @@ -82,8 +72,11 @@ def optional_float(val: str) -> Optional[float]: return optional_arg(val, float) -def nullable_kvs(val: str) -> Optional[Mapping[str, int]]: - """Parses a string containing comma separate key [str] to value [int] +def nullable_kvs(val: str) -> Optional[dict[str, int]]: + """NOTE: This function is deprecated, args should be passed as JSON + strings instead. + + Parses a string containing comma separate key [str] to value [int] pairs into a dictionary. Args: @@ -117,6 +110,17 @@ def nullable_kvs(val: str) -> Optional[Mapping[str, int]]: return out_dict +def optional_dict(val: str) -> Optional[dict[str, int]]: + try: + return optional_arg(val, json.loads) + except ValueError: + logger.warning( + "Failed to parse JSON string. Attempting to parse as " + "comma-separated key=value pairs. This will be deprecated in a " + "future release.") + return nullable_kvs(val) + + @dataclass class EngineArgs: """Arguments for vLLM engine.""" @@ -178,12 +182,14 @@ class EngineArgs: enforce_eager: Optional[bool] = None max_seq_len_to_capture: int = 8192 disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce - tokenizer_pool_size: int = 0 + tokenizer_pool_size: int = TokenizerPoolConfig.pool_size # Note: Specifying a tokenizer pool by passing a class # is intended for expert use only. The API may change without # notice. - tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray" - tokenizer_pool_extra_config: Optional[Dict[str, Any]] = None + tokenizer_pool_type: Union[PoolType, Type["BaseTokenizerGroup"]] = \ + TokenizerPoolConfig.pool_type + tokenizer_pool_extra_config: dict[str, Any] = \ + get_field(TokenizerPoolConfig, "extra_config") limit_mm_per_prompt: Optional[Mapping[str, int]] = None mm_processor_kwargs: Optional[Dict[str, Any]] = None disable_mm_preprocessor_cache: bool = False @@ -199,14 +205,14 @@ class EngineArgs: long_lora_scaling_factors: Optional[Tuple[float]] = None lora_dtype: Optional[Union[str, torch.dtype]] = 'auto' max_cpu_loras: Optional[int] = None - device: str = 'auto' + device: Device = DeviceConfig.device num_scheduler_steps: int = SchedulerConfig.num_scheduler_steps multi_step_stream_outputs: bool = SchedulerConfig.multi_step_stream_outputs ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight num_gpu_blocks_override: Optional[int] = None num_lookahead_slots: int = SchedulerConfig.num_lookahead_slots - model_loader_extra_config: Optional[ - dict] = LoadConfig.model_loader_extra_config + model_loader_extra_config: dict = \ + get_field(LoadConfig, "model_loader_extra_config") ignore_patterns: Optional[Union[str, List[str]]] = LoadConfig.ignore_patterns preemption_mode: Optional[str] = SchedulerConfig.preemption_mode @@ -294,14 +300,15 @@ class EngineArgs: """Check if the class is a custom type.""" return cls.__module__ != "builtins" - def get_kwargs(cls: type[Any]) -> dict[str, Any]: + def get_kwargs(cls: type[Config]) -> dict[str, Any]: cls_docs = get_attr_docs(cls) kwargs = {} for field in fields(cls): name = field.name - # One of these will always be present - default = (field.default_factory - if field.default is MISSING else field.default) + default = field.default + # This will only be True if default is MISSING + if field.default_factory is not MISSING: + default = field.default_factory() kwargs[name] = {"default": default, "help": cls_docs[name]} # Make note of if the field is optional and get the actual @@ -331,8 +338,9 @@ class EngineArgs: elif can_be_type(field_type, float): kwargs[name][ "type"] = optional_float if optional else float + elif can_be_type(field_type, dict): + kwargs[name]["type"] = optional_dict elif (can_be_type(field_type, str) - or can_be_type(field_type, dict) or is_custom_type(field_type)): kwargs[name]["type"] = optional_str if optional else str else: @@ -674,25 +682,19 @@ class EngineArgs: 'Additionally for encoder-decoder models, if the ' 'sequence length of the encoder input is larger ' 'than this, we fall back to the eager mode.') - parser.add_argument('--tokenizer-pool-size', - type=int, - default=EngineArgs.tokenizer_pool_size, - help='Size of tokenizer pool to use for ' - 'asynchronous tokenization. If 0, will ' - 'use synchronous tokenization.') - parser.add_argument('--tokenizer-pool-type', - type=str, - default=EngineArgs.tokenizer_pool_type, - help='Type of tokenizer pool to use for ' - 'asynchronous tokenization. Ignored ' - 'if tokenizer_pool_size is 0.') - parser.add_argument('--tokenizer-pool-extra-config', - type=optional_str, - default=EngineArgs.tokenizer_pool_extra_config, - help='Extra config for tokenizer pool. ' - 'This should be a JSON string that will be ' - 'parsed into a dictionary. Ignored if ' - 'tokenizer_pool_size is 0.') + + # Tokenizer arguments + tokenizer_kwargs = get_kwargs(TokenizerPoolConfig) + tokenizer_group = parser.add_argument_group( + title="TokenizerPoolConfig", + description=TokenizerPoolConfig.__doc__, + ) + tokenizer_group.add_argument('--tokenizer-pool-size', + **tokenizer_kwargs["pool_size"]) + tokenizer_group.add_argument('--tokenizer-pool-type', + **tokenizer_kwargs["pool_type"]) + tokenizer_group.add_argument('--tokenizer-pool-extra-config', + **tokenizer_kwargs["extra_config"]) # Multimodal related configs parser.add_argument( @@ -784,11 +786,15 @@ class EngineArgs: type=int, default=EngineArgs.max_prompt_adapter_token, help='Max number of PromptAdapters tokens') - parser.add_argument("--device", - type=str, - default=EngineArgs.device, - choices=DEVICE_OPTIONS, - help='Device type for vLLM execution.') + + # Device arguments + device_kwargs = get_kwargs(DeviceConfig) + device_group = parser.add_argument_group( + title="DeviceConfig", + description=DeviceConfig.__doc__, + ) + device_group.add_argument("--device", **device_kwargs["device"]) + parser.add_argument('--num-scheduler-steps', type=int, default=1, @@ -1302,8 +1308,6 @@ class EngineArgs: if self.qlora_adapter_name_or_path is not None and \ self.qlora_adapter_name_or_path != "": - if self.model_loader_extra_config is None: - self.model_loader_extra_config = {} self.model_loader_extra_config[ "qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path