Improve configs - ParallelConfig (#16332)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-04-10 18:34:37 +01:00 committed by GitHub
parent c1b57855ec
commit 0c54fc7273
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 182 additions and 85 deletions

View File

@ -4,13 +4,16 @@ import ast
import copy import copy
import enum import enum
import hashlib import hashlib
import inspect
import json import json
import sys import sys
import textwrap
import warnings import warnings
from collections import Counter from collections import Counter
from collections.abc import Mapping from collections.abc import Mapping
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass, field, replace from dataclasses import (MISSING, dataclass, field, fields, is_dataclass,
replace)
from importlib.util import find_spec from importlib.util import find_spec
from pathlib import Path from pathlib import Path
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal, from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal,
@ -104,6 +107,77 @@ class ModelImpl(str, enum.Enum):
TRANSFORMERS = "transformers" TRANSFORMERS = "transformers"
def get_attr_docs(cls: type[Any]) -> dict[str, str]:
"""
Get any docstrings placed after attribute assignments in a class body.
https://davidism.com/mit-license/
"""
def pairwise(iterable):
"""
Manually implement https://docs.python.org/3/library/itertools.html#itertools.pairwise
Can be removed when Python 3.9 support is dropped.
"""
iterator = iter(iterable)
a = next(iterator, None)
for b in iterator:
yield a, b
a = b
cls_node = ast.parse(textwrap.dedent(inspect.getsource(cls))).body[0]
if not isinstance(cls_node, ast.ClassDef):
raise TypeError("Given object was not a class.")
out = {}
# Consider each pair of nodes.
for a, b in pairwise(cls_node.body):
# Must be an assignment then a constant string.
if (not isinstance(a, (ast.Assign, ast.AnnAssign))
or not isinstance(b, ast.Expr)
or not isinstance(b.value, ast.Constant)
or not isinstance(b.value.value, str)):
continue
doc = inspect.cleandoc(b.value.value)
# An assignment can have multiple targets (a = b = v), but an
# annotated assignment only has one target.
targets = a.targets if isinstance(a, ast.Assign) else [a.target]
for target in targets:
# Must be assigning to a plain name.
if not isinstance(target, ast.Name):
continue
out[target.id] = doc
return out
def config(cls: type[Any]) -> type[Any]:
"""
A decorator that ensures all fields in a dataclass have default values
and that each field has a docstring.
"""
if not is_dataclass(cls):
raise TypeError("The decorated class must be a dataclass.")
attr_docs = get_attr_docs(cls)
for f in fields(cls):
if f.init and f.default is MISSING and f.default_factory is MISSING:
raise ValueError(
f"Field '{f.name}' in {cls.__name__} must have a default value."
)
if f.name not in attr_docs:
raise ValueError(
f"Field '{f.name}' in {cls.__name__} must have a docstring.")
return cls
class ModelConfig: class ModelConfig:
"""Configuration for the model. """Configuration for the model.
@ -1432,61 +1506,77 @@ class LoadConfig:
self.ignore_patterns = ["original/**/*"] self.ignore_patterns = ["original/**/*"]
@config
@dataclass @dataclass
class ParallelConfig: class ParallelConfig:
"""Configuration for the distributed execution.""" """Configuration for the distributed execution."""
pipeline_parallel_size: int = 1 # Number of pipeline parallel groups. pipeline_parallel_size: int = 1
tensor_parallel_size: int = 1 # Number of tensor parallel groups. """Number of pipeline parallel groups."""
data_parallel_size: int = 1 # Number of data parallel groups. tensor_parallel_size: int = 1
data_parallel_rank: int = 0 # Rank of the data parallel group. """Number of tensor parallel groups."""
# Local rank of the data parallel group, defaults to global rank. data_parallel_size: int = 1
"""Number of data parallel groups. MoE layers will be sharded according to
the product of the tensor parallel size and data parallel size."""
data_parallel_rank: int = 0
"""Rank of the data parallel group."""
data_parallel_rank_local: Optional[int] = None data_parallel_rank_local: Optional[int] = None
# IP of the data parallel master. """Local rank of the data parallel group, defaults to global rank."""
data_parallel_master_ip: str = "127.0.0.1" data_parallel_master_ip: str = "127.0.0.1"
data_parallel_master_port: int = 29500 # Port of the data parallel master. """IP of the data parallel master."""
enable_expert_parallel: bool = False # Use EP instead of TP for MoE layers. data_parallel_master_port: int = 29500
"""Port of the data parallel master."""
enable_expert_parallel: bool = False
"""Use expert parallelism instead of tensor parallelism for MoE layers."""
# Maximum number of multiple batches
# when load model sequentially. To avoid RAM OOM when using tensor
# parallel and large models.
max_parallel_loading_workers: Optional[int] = None max_parallel_loading_workers: Optional[int] = None
"""Maximum number of parallal loading workers when loading model
sequentially in multiple batches. To avoid RAM OOM when using tensor
parallel and large models."""
# Disable the custom all-reduce kernel and fall back to NCCL.
disable_custom_all_reduce: bool = False disable_custom_all_reduce: bool = False
"""Disable the custom all-reduce kernel and fall back to NCCL."""
# Config for the tokenizer pool. If None, will use synchronous tokenization.
tokenizer_pool_config: Optional[TokenizerPoolConfig] = None tokenizer_pool_config: Optional[TokenizerPoolConfig] = None
"""Config for the tokenizer pool. If None, will use synchronous
tokenization."""
# Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.
ray_workers_use_nsight: bool = False ray_workers_use_nsight: bool = False
"""Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler."""
# ray distributed model workers placement group.
placement_group: Optional["PlacementGroup"] = None placement_group: Optional["PlacementGroup"] = None
"""ray distributed model workers placement group."""
# Backend to use for distributed model
# workers, either "ray" or "mp" (multiprocessing). If the product
# of pipeline_parallel_size and tensor_parallel_size is less than
# or equal to the number of GPUs available, "mp" will be used to
# keep processing on a single host. Otherwise, this will default
# to "ray" if Ray is installed and fail otherwise. Note that tpu
# and hpu only support Ray for distributed inference.
distributed_executor_backend: Optional[Union[str, distributed_executor_backend: Optional[Union[str,
type["ExecutorBase"]]] = None type["ExecutorBase"]]] = None
"""Backend to use for distributed model
workers, either "ray" or "mp" (multiprocessing). If the product
of pipeline_parallel_size and tensor_parallel_size is less than
or equal to the number of GPUs available, "mp" will be used to
keep processing on a single host. Otherwise, this will default
to "ray" if Ray is installed and fail otherwise. Note that tpu
and hpu only support Ray for distributed inference."""
# the full name of the worker class to use. If "auto", the worker class
# will be determined based on the platform.
worker_cls: str = "auto" worker_cls: str = "auto"
"""The full name of the worker class to use. If "auto", the worker class
will be determined based on the platform."""
sd_worker_cls: str = "auto" sd_worker_cls: str = "auto"
"""The full name of the worker class to use for speculative decofing.
If "auto", the worker class will be determined based on the platform."""
worker_extension_cls: str = "" worker_extension_cls: str = ""
"""The full name of the worker extension class to use. The worker extension
class is dynamically inherited by the worker class. This is used to inject
new attributes and methods to the worker class for use in collective_rpc
calls."""
# world_size is TPxPP, it affects the number of workers we create.
world_size: int = field(init=False) world_size: int = field(init=False)
# world_size_across_dp is TPxPPxDP, it is the size of the world """world_size is TPxPP, it affects the number of workers we create."""
# including data parallelism.
world_size_across_dp: int = field(init=False) world_size_across_dp: int = field(init=False)
"""world_size_across_dp is TPxPPxDP, it is the size of the world
including data parallelism."""
rank: int = 0 rank: int = 0
"""Global rank in distributed setup."""
def get_next_dp_init_port(self) -> int: def get_next_dp_init_port(self) -> int:
""" """

View File

@ -5,9 +5,9 @@ import dataclasses
import json import json
import re import re
import threading import threading
from dataclasses import dataclass from dataclasses import MISSING, dataclass, fields
from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional, from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional,
Tuple, Type, Union, cast, get_args) Tuple, Type, Union, cast, get_args, get_origin)
import torch import torch
@ -19,7 +19,7 @@ from vllm.config import (CacheConfig, CompilationConfig, ConfigFormat,
ModelConfig, ModelImpl, ObservabilityConfig, ModelConfig, ModelImpl, ObservabilityConfig,
ParallelConfig, PoolerConfig, PromptAdapterConfig, ParallelConfig, PoolerConfig, PromptAdapterConfig,
SchedulerConfig, SpeculativeConfig, TaskOption, SchedulerConfig, SpeculativeConfig, TaskOption,
TokenizerPoolConfig, VllmConfig) TokenizerPoolConfig, VllmConfig, get_attr_docs)
from vllm.executor.executor_base import ExecutorBase from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
@ -111,14 +111,15 @@ class EngineArgs:
# Note: Specifying a custom executor backend by passing a class # Note: Specifying a custom executor backend by passing a class
# is intended for expert use only. The API may change without # is intended for expert use only. The API may change without
# notice. # notice.
distributed_executor_backend: Optional[Union[str, distributed_executor_backend: Optional[Union[
Type[ExecutorBase]]] = None str, Type[ExecutorBase]]] = ParallelConfig.distributed_executor_backend
# number of P/D disaggregation (or other disaggregation) workers # number of P/D disaggregation (or other disaggregation) workers
pipeline_parallel_size: int = 1 pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
tensor_parallel_size: int = 1 tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
data_parallel_size: int = 1 data_parallel_size: int = ParallelConfig.data_parallel_size
enable_expert_parallel: bool = False enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
max_parallel_loading_workers: Optional[int] = None max_parallel_loading_workers: Optional[
int] = ParallelConfig.max_parallel_loading_workers
block_size: Optional[int] = None block_size: Optional[int] = None
enable_prefix_caching: Optional[bool] = None enable_prefix_caching: Optional[bool] = None
prefix_caching_hash_algo: str = "builtin" prefix_caching_hash_algo: str = "builtin"
@ -145,7 +146,7 @@ class EngineArgs:
quantization: Optional[str] = None quantization: Optional[str] = None
enforce_eager: Optional[bool] = None enforce_eager: Optional[bool] = None
max_seq_len_to_capture: int = 8192 max_seq_len_to_capture: int = 8192
disable_custom_all_reduce: bool = False disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
tokenizer_pool_size: int = 0 tokenizer_pool_size: int = 0
# Note: Specifying a tokenizer pool by passing a class # Note: Specifying a tokenizer pool by passing a class
# is intended for expert use only. The API may change without # is intended for expert use only. The API may change without
@ -170,7 +171,7 @@ class EngineArgs:
device: str = 'auto' device: str = 'auto'
num_scheduler_steps: int = 1 num_scheduler_steps: int = 1
multi_step_stream_outputs: bool = True multi_step_stream_outputs: bool = True
ray_workers_use_nsight: bool = False ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight
num_gpu_blocks_override: Optional[int] = None num_gpu_blocks_override: Optional[int] = None
num_lookahead_slots: int = 0 num_lookahead_slots: int = 0
model_loader_extra_config: Optional[dict] = None model_loader_extra_config: Optional[dict] = None
@ -197,8 +198,8 @@ class EngineArgs:
override_neuron_config: Optional[Dict[str, Any]] = None override_neuron_config: Optional[Dict[str, Any]] = None
override_pooler_config: Optional[PoolerConfig] = None override_pooler_config: Optional[PoolerConfig] = None
compilation_config: Optional[CompilationConfig] = None compilation_config: Optional[CompilationConfig] = None
worker_cls: str = "auto" worker_cls: str = ParallelConfig.worker_cls
worker_extension_cls: str = "" worker_extension_cls: str = ParallelConfig.worker_extension_cls
kv_transfer_config: Optional[KVTransferConfig] = None kv_transfer_config: Optional[KVTransferConfig] = None
@ -232,6 +233,31 @@ class EngineArgs:
@staticmethod @staticmethod
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"""Shared CLI arguments for vLLM engine.""" """Shared CLI arguments for vLLM engine."""
def is_optional(cls: type[Any]) -> bool:
"""Check if the class is an optional type."""
return get_origin(cls) is Union and type(None) in get_args(cls)
def get_kwargs(cls: type[Any]) -> Dict[str, Any]:
cls_docs = get_attr_docs(cls)
kwargs = {}
for field in fields(cls):
name = field.name
# One of these will always be present
default = (field.default_factory
if field.default is MISSING else field.default)
kwargs[name] = {"default": default, "help": cls_docs[name]}
# When using action="store_true"
# add_argument doesn't accept type
if field.type is bool:
continue
# Handle optional fields
if is_optional(field.type):
kwargs[name]["type"] = nullable_str
continue
kwargs[name]["type"] = field.type
return kwargs
# Model arguments # Model arguments
parser.add_argument( parser.add_argument(
'--model', '--model',
@ -411,52 +437,37 @@ class EngineArgs:
'* "transformers" will use the Transformers model ' '* "transformers" will use the Transformers model '
'implementation.\n') 'implementation.\n')
# Parallel arguments # Parallel arguments
parser.add_argument( parallel_kwargs = get_kwargs(ParallelConfig)
parallel_group = parser.add_argument_group(
title="ParallelConfig",
description=ParallelConfig.__doc__,
)
parallel_group.add_argument(
'--distributed-executor-backend', '--distributed-executor-backend',
choices=['ray', 'mp', 'uni', 'external_launcher'], choices=['ray', 'mp', 'uni', 'external_launcher'],
default=EngineArgs.distributed_executor_backend, **parallel_kwargs["distributed_executor_backend"])
help='Backend to use for distributed model ' parallel_group.add_argument(
'workers, either "ray" or "mp" (multiprocessing). If the product ' '--pipeline-parallel-size', '-pp',
'of pipeline_parallel_size and tensor_parallel_size is less than ' **parallel_kwargs["pipeline_parallel_size"])
'or equal to the number of GPUs available, "mp" will be used to ' parallel_group.add_argument('--tensor-parallel-size', '-tp',
'keep processing on a single host. Otherwise, this will default ' **parallel_kwargs["tensor_parallel_size"])
'to "ray" if Ray is installed and fail otherwise. Note that tpu ' parallel_group.add_argument('--data-parallel-size', '-dp',
'only supports Ray for distributed inference.') **parallel_kwargs["data_parallel_size"])
parallel_group.add_argument(
parser.add_argument('--pipeline-parallel-size',
'-pp',
type=int,
default=EngineArgs.pipeline_parallel_size,
help='Number of pipeline stages.')
parser.add_argument('--tensor-parallel-size',
'-tp',
type=int,
default=EngineArgs.tensor_parallel_size,
help='Number of tensor parallel replicas.')
parser.add_argument('--data-parallel-size',
'-dp',
type=int,
default=EngineArgs.data_parallel_size,
help='Number of data parallel replicas. '
'MoE layers will be sharded according to the '
'product of the tensor-parallel-size and '
'data-parallel-size.')
parser.add_argument(
'--enable-expert-parallel', '--enable-expert-parallel',
action='store_true', action='store_true',
help='Use expert parallelism instead of tensor parallelism ' **parallel_kwargs["enable_expert_parallel"])
'for MoE layers.') parallel_group.add_argument(
parser.add_argument(
'--max-parallel-loading-workers', '--max-parallel-loading-workers',
type=int, **parallel_kwargs["max_parallel_loading_workers"])
default=EngineArgs.max_parallel_loading_workers, parallel_group.add_argument(
help='Load model sequentially in multiple batches, '
'to avoid RAM OOM when using tensor '
'parallel and large models.')
parser.add_argument(
'--ray-workers-use-nsight', '--ray-workers-use-nsight',
action='store_true', action='store_true',
help='If specified, use nsight to profile Ray workers.') **parallel_kwargs["ray_workers_use_nsight"])
parallel_group.add_argument(
'--disable-custom-all-reduce',
action='store_true',
**parallel_kwargs["disable_custom_all_reduce"])
# KV cache arguments # KV cache arguments
parser.add_argument('--block-size', parser.add_argument('--block-size',
type=int, type=int,
@ -639,10 +650,6 @@ class EngineArgs:
'Additionally for encoder-decoder models, if the ' 'Additionally for encoder-decoder models, if the '
'sequence length of the encoder input is larger ' 'sequence length of the encoder input is larger '
'than this, we fall back to the eager mode.') 'than this, we fall back to the eager mode.')
parser.add_argument('--disable-custom-all-reduce',
action='store_true',
default=EngineArgs.disable_custom_all_reduce,
help='See ParallelConfig.')
parser.add_argument('--tokenizer-pool-size', parser.add_argument('--tokenizer-pool-size',
type=int, type=int,
default=EngineArgs.tokenizer_pool_size, default=EngineArgs.tokenizer_pool_size,