diff --git a/vllm/config.py b/vllm/config.py index 5fcc5f46..23541a88 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -4,13 +4,16 @@ import ast import copy import enum import hashlib +import inspect import json import sys +import textwrap import warnings from collections import Counter from collections.abc import Mapping from contextlib import contextmanager -from dataclasses import dataclass, field, replace +from dataclasses import (MISSING, dataclass, field, fields, is_dataclass, + replace) from importlib.util import find_spec from pathlib import Path from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal, @@ -104,6 +107,77 @@ class ModelImpl(str, enum.Enum): TRANSFORMERS = "transformers" +def get_attr_docs(cls: type[Any]) -> dict[str, str]: + """ + Get any docstrings placed after attribute assignments in a class body. + + https://davidism.com/mit-license/ + """ + + def pairwise(iterable): + """ + Manually implement https://docs.python.org/3/library/itertools.html#itertools.pairwise + + Can be removed when Python 3.9 support is dropped. + """ + iterator = iter(iterable) + a = next(iterator, None) + + for b in iterator: + yield a, b + a = b + + cls_node = ast.parse(textwrap.dedent(inspect.getsource(cls))).body[0] + + if not isinstance(cls_node, ast.ClassDef): + raise TypeError("Given object was not a class.") + + out = {} + + # Consider each pair of nodes. + for a, b in pairwise(cls_node.body): + # Must be an assignment then a constant string. + if (not isinstance(a, (ast.Assign, ast.AnnAssign)) + or not isinstance(b, ast.Expr) + or not isinstance(b.value, ast.Constant) + or not isinstance(b.value.value, str)): + continue + + doc = inspect.cleandoc(b.value.value) + + # An assignment can have multiple targets (a = b = v), but an + # annotated assignment only has one target. + targets = a.targets if isinstance(a, ast.Assign) else [a.target] + + for target in targets: + # Must be assigning to a plain name. + if not isinstance(target, ast.Name): + continue + + out[target.id] = doc + + return out + + +def config(cls: type[Any]) -> type[Any]: + """ + A decorator that ensures all fields in a dataclass have default values + and that each field has a docstring. + """ + if not is_dataclass(cls): + raise TypeError("The decorated class must be a dataclass.") + attr_docs = get_attr_docs(cls) + for f in fields(cls): + if f.init and f.default is MISSING and f.default_factory is MISSING: + raise ValueError( + f"Field '{f.name}' in {cls.__name__} must have a default value." + ) + if f.name not in attr_docs: + raise ValueError( + f"Field '{f.name}' in {cls.__name__} must have a docstring.") + return cls + + class ModelConfig: """Configuration for the model. @@ -1432,61 +1506,77 @@ class LoadConfig: self.ignore_patterns = ["original/**/*"] +@config @dataclass class ParallelConfig: """Configuration for the distributed execution.""" - pipeline_parallel_size: int = 1 # Number of pipeline parallel groups. - tensor_parallel_size: int = 1 # Number of tensor parallel groups. - data_parallel_size: int = 1 # Number of data parallel groups. - data_parallel_rank: int = 0 # Rank of the data parallel group. - # Local rank of the data parallel group, defaults to global rank. + pipeline_parallel_size: int = 1 + """Number of pipeline parallel groups.""" + tensor_parallel_size: int = 1 + """Number of tensor parallel groups.""" + data_parallel_size: int = 1 + """Number of data parallel groups. MoE layers will be sharded according to + the product of the tensor parallel size and data parallel size.""" + data_parallel_rank: int = 0 + """Rank of the data parallel group.""" data_parallel_rank_local: Optional[int] = None - # IP of the data parallel master. + """Local rank of the data parallel group, defaults to global rank.""" data_parallel_master_ip: str = "127.0.0.1" - data_parallel_master_port: int = 29500 # Port of the data parallel master. - enable_expert_parallel: bool = False # Use EP instead of TP for MoE layers. + """IP of the data parallel master.""" + data_parallel_master_port: int = 29500 + """Port of the data parallel master.""" + enable_expert_parallel: bool = False + """Use expert parallelism instead of tensor parallelism for MoE layers.""" - # Maximum number of multiple batches - # when load model sequentially. To avoid RAM OOM when using tensor - # parallel and large models. max_parallel_loading_workers: Optional[int] = None + """Maximum number of parallal loading workers when loading model + sequentially in multiple batches. To avoid RAM OOM when using tensor + parallel and large models.""" - # Disable the custom all-reduce kernel and fall back to NCCL. disable_custom_all_reduce: bool = False + """Disable the custom all-reduce kernel and fall back to NCCL.""" - # Config for the tokenizer pool. If None, will use synchronous tokenization. tokenizer_pool_config: Optional[TokenizerPoolConfig] = None + """Config for the tokenizer pool. If None, will use synchronous + tokenization.""" - # Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler. ray_workers_use_nsight: bool = False + """Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.""" - # ray distributed model workers placement group. placement_group: Optional["PlacementGroup"] = None + """ray distributed model workers placement group.""" - # Backend to use for distributed model - # workers, either "ray" or "mp" (multiprocessing). If the product - # of pipeline_parallel_size and tensor_parallel_size is less than - # or equal to the number of GPUs available, "mp" will be used to - # keep processing on a single host. Otherwise, this will default - # to "ray" if Ray is installed and fail otherwise. Note that tpu - # and hpu only support Ray for distributed inference. distributed_executor_backend: Optional[Union[str, type["ExecutorBase"]]] = None + """Backend to use for distributed model + workers, either "ray" or "mp" (multiprocessing). If the product + of pipeline_parallel_size and tensor_parallel_size is less than + or equal to the number of GPUs available, "mp" will be used to + keep processing on a single host. Otherwise, this will default + to "ray" if Ray is installed and fail otherwise. Note that tpu + and hpu only support Ray for distributed inference.""" - # the full name of the worker class to use. If "auto", the worker class - # will be determined based on the platform. worker_cls: str = "auto" + """The full name of the worker class to use. If "auto", the worker class + will be determined based on the platform.""" sd_worker_cls: str = "auto" + """The full name of the worker class to use for speculative decofing. + If "auto", the worker class will be determined based on the platform.""" worker_extension_cls: str = "" + """The full name of the worker extension class to use. The worker extension + class is dynamically inherited by the worker class. This is used to inject + new attributes and methods to the worker class for use in collective_rpc + calls.""" - # world_size is TPxPP, it affects the number of workers we create. world_size: int = field(init=False) - # world_size_across_dp is TPxPPxDP, it is the size of the world - # including data parallelism. + """world_size is TPxPP, it affects the number of workers we create.""" world_size_across_dp: int = field(init=False) + """world_size_across_dp is TPxPPxDP, it is the size of the world + including data parallelism.""" rank: int = 0 + """Global rank in distributed setup.""" def get_next_dp_init_port(self) -> int: """ diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 0c81e3ed..ba71a877 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -5,9 +5,9 @@ import dataclasses import json import re import threading -from dataclasses import dataclass +from dataclasses import MISSING, dataclass, fields from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional, - Tuple, Type, Union, cast, get_args) + Tuple, Type, Union, cast, get_args, get_origin) import torch @@ -19,7 +19,7 @@ from vllm.config import (CacheConfig, CompilationConfig, ConfigFormat, ModelConfig, ModelImpl, ObservabilityConfig, ParallelConfig, PoolerConfig, PromptAdapterConfig, SchedulerConfig, SpeculativeConfig, TaskOption, - TokenizerPoolConfig, VllmConfig) + TokenizerPoolConfig, VllmConfig, get_attr_docs) from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS @@ -111,14 +111,15 @@ class EngineArgs: # Note: Specifying a custom executor backend by passing a class # is intended for expert use only. The API may change without # notice. - distributed_executor_backend: Optional[Union[str, - Type[ExecutorBase]]] = None + distributed_executor_backend: Optional[Union[ + str, Type[ExecutorBase]]] = ParallelConfig.distributed_executor_backend # number of P/D disaggregation (or other disaggregation) workers - pipeline_parallel_size: int = 1 - tensor_parallel_size: int = 1 - data_parallel_size: int = 1 - enable_expert_parallel: bool = False - max_parallel_loading_workers: Optional[int] = None + pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size + tensor_parallel_size: int = ParallelConfig.tensor_parallel_size + data_parallel_size: int = ParallelConfig.data_parallel_size + enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel + max_parallel_loading_workers: Optional[ + int] = ParallelConfig.max_parallel_loading_workers block_size: Optional[int] = None enable_prefix_caching: Optional[bool] = None prefix_caching_hash_algo: str = "builtin" @@ -145,7 +146,7 @@ class EngineArgs: quantization: Optional[str] = None enforce_eager: Optional[bool] = None max_seq_len_to_capture: int = 8192 - disable_custom_all_reduce: bool = False + disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce tokenizer_pool_size: int = 0 # Note: Specifying a tokenizer pool by passing a class # is intended for expert use only. The API may change without @@ -170,7 +171,7 @@ class EngineArgs: device: str = 'auto' num_scheduler_steps: int = 1 multi_step_stream_outputs: bool = True - ray_workers_use_nsight: bool = False + ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight num_gpu_blocks_override: Optional[int] = None num_lookahead_slots: int = 0 model_loader_extra_config: Optional[dict] = None @@ -197,8 +198,8 @@ class EngineArgs: override_neuron_config: Optional[Dict[str, Any]] = None override_pooler_config: Optional[PoolerConfig] = None compilation_config: Optional[CompilationConfig] = None - worker_cls: str = "auto" - worker_extension_cls: str = "" + worker_cls: str = ParallelConfig.worker_cls + worker_extension_cls: str = ParallelConfig.worker_extension_cls kv_transfer_config: Optional[KVTransferConfig] = None @@ -232,6 +233,31 @@ class EngineArgs: @staticmethod def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: """Shared CLI arguments for vLLM engine.""" + + def is_optional(cls: type[Any]) -> bool: + """Check if the class is an optional type.""" + return get_origin(cls) is Union and type(None) in get_args(cls) + + def get_kwargs(cls: type[Any]) -> Dict[str, Any]: + cls_docs = get_attr_docs(cls) + kwargs = {} + for field in fields(cls): + name = field.name + # One of these will always be present + default = (field.default_factory + if field.default is MISSING else field.default) + kwargs[name] = {"default": default, "help": cls_docs[name]} + # When using action="store_true" + # add_argument doesn't accept type + if field.type is bool: + continue + # Handle optional fields + if is_optional(field.type): + kwargs[name]["type"] = nullable_str + continue + kwargs[name]["type"] = field.type + return kwargs + # Model arguments parser.add_argument( '--model', @@ -411,52 +437,37 @@ class EngineArgs: '* "transformers" will use the Transformers model ' 'implementation.\n') # Parallel arguments - parser.add_argument( + parallel_kwargs = get_kwargs(ParallelConfig) + parallel_group = parser.add_argument_group( + title="ParallelConfig", + description=ParallelConfig.__doc__, + ) + parallel_group.add_argument( '--distributed-executor-backend', choices=['ray', 'mp', 'uni', 'external_launcher'], - default=EngineArgs.distributed_executor_backend, - help='Backend to use for distributed model ' - 'workers, either "ray" or "mp" (multiprocessing). If the product ' - 'of pipeline_parallel_size and tensor_parallel_size is less than ' - 'or equal to the number of GPUs available, "mp" will be used to ' - 'keep processing on a single host. Otherwise, this will default ' - 'to "ray" if Ray is installed and fail otherwise. Note that tpu ' - 'only supports Ray for distributed inference.') - - parser.add_argument('--pipeline-parallel-size', - '-pp', - type=int, - default=EngineArgs.pipeline_parallel_size, - help='Number of pipeline stages.') - parser.add_argument('--tensor-parallel-size', - '-tp', - type=int, - default=EngineArgs.tensor_parallel_size, - help='Number of tensor parallel replicas.') - parser.add_argument('--data-parallel-size', - '-dp', - type=int, - default=EngineArgs.data_parallel_size, - help='Number of data parallel replicas. ' - 'MoE layers will be sharded according to the ' - 'product of the tensor-parallel-size and ' - 'data-parallel-size.') - parser.add_argument( + **parallel_kwargs["distributed_executor_backend"]) + parallel_group.add_argument( + '--pipeline-parallel-size', '-pp', + **parallel_kwargs["pipeline_parallel_size"]) + parallel_group.add_argument('--tensor-parallel-size', '-tp', + **parallel_kwargs["tensor_parallel_size"]) + parallel_group.add_argument('--data-parallel-size', '-dp', + **parallel_kwargs["data_parallel_size"]) + parallel_group.add_argument( '--enable-expert-parallel', action='store_true', - help='Use expert parallelism instead of tensor parallelism ' - 'for MoE layers.') - parser.add_argument( + **parallel_kwargs["enable_expert_parallel"]) + parallel_group.add_argument( '--max-parallel-loading-workers', - type=int, - default=EngineArgs.max_parallel_loading_workers, - help='Load model sequentially in multiple batches, ' - 'to avoid RAM OOM when using tensor ' - 'parallel and large models.') - parser.add_argument( + **parallel_kwargs["max_parallel_loading_workers"]) + parallel_group.add_argument( '--ray-workers-use-nsight', action='store_true', - help='If specified, use nsight to profile Ray workers.') + **parallel_kwargs["ray_workers_use_nsight"]) + parallel_group.add_argument( + '--disable-custom-all-reduce', + action='store_true', + **parallel_kwargs["disable_custom_all_reduce"]) # KV cache arguments parser.add_argument('--block-size', type=int, @@ -639,10 +650,6 @@ class EngineArgs: 'Additionally for encoder-decoder models, if the ' 'sequence length of the encoder input is larger ' 'than this, we fall back to the eager mode.') - parser.add_argument('--disable-custom-all-reduce', - action='store_true', - default=EngineArgs.disable_custom_all_reduce, - help='See ParallelConfig.') parser.add_argument('--tokenizer-pool-size', type=int, default=EngineArgs.tokenizer_pool_size,