Improve configs - ParallelConfig (#16332)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-04-10 18:34:37 +01:00 committed by GitHub
parent c1b57855ec
commit 0c54fc7273
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 182 additions and 85 deletions

View File

@ -4,13 +4,16 @@ import ast
import copy
import enum
import hashlib
import inspect
import json
import sys
import textwrap
import warnings
from collections import Counter
from collections.abc import Mapping
from contextlib import contextmanager
from dataclasses import dataclass, field, replace
from dataclasses import (MISSING, dataclass, field, fields, is_dataclass,
replace)
from importlib.util import find_spec
from pathlib import Path
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal,
@ -104,6 +107,77 @@ class ModelImpl(str, enum.Enum):
TRANSFORMERS = "transformers"
def get_attr_docs(cls: type[Any]) -> dict[str, str]:
"""
Get any docstrings placed after attribute assignments in a class body.
https://davidism.com/mit-license/
"""
def pairwise(iterable):
"""
Manually implement https://docs.python.org/3/library/itertools.html#itertools.pairwise
Can be removed when Python 3.9 support is dropped.
"""
iterator = iter(iterable)
a = next(iterator, None)
for b in iterator:
yield a, b
a = b
cls_node = ast.parse(textwrap.dedent(inspect.getsource(cls))).body[0]
if not isinstance(cls_node, ast.ClassDef):
raise TypeError("Given object was not a class.")
out = {}
# Consider each pair of nodes.
for a, b in pairwise(cls_node.body):
# Must be an assignment then a constant string.
if (not isinstance(a, (ast.Assign, ast.AnnAssign))
or not isinstance(b, ast.Expr)
or not isinstance(b.value, ast.Constant)
or not isinstance(b.value.value, str)):
continue
doc = inspect.cleandoc(b.value.value)
# An assignment can have multiple targets (a = b = v), but an
# annotated assignment only has one target.
targets = a.targets if isinstance(a, ast.Assign) else [a.target]
for target in targets:
# Must be assigning to a plain name.
if not isinstance(target, ast.Name):
continue
out[target.id] = doc
return out
def config(cls: type[Any]) -> type[Any]:
"""
A decorator that ensures all fields in a dataclass have default values
and that each field has a docstring.
"""
if not is_dataclass(cls):
raise TypeError("The decorated class must be a dataclass.")
attr_docs = get_attr_docs(cls)
for f in fields(cls):
if f.init and f.default is MISSING and f.default_factory is MISSING:
raise ValueError(
f"Field '{f.name}' in {cls.__name__} must have a default value."
)
if f.name not in attr_docs:
raise ValueError(
f"Field '{f.name}' in {cls.__name__} must have a docstring.")
return cls
class ModelConfig:
"""Configuration for the model.
@ -1432,61 +1506,77 @@ class LoadConfig:
self.ignore_patterns = ["original/**/*"]
@config
@dataclass
class ParallelConfig:
"""Configuration for the distributed execution."""
pipeline_parallel_size: int = 1 # Number of pipeline parallel groups.
tensor_parallel_size: int = 1 # Number of tensor parallel groups.
data_parallel_size: int = 1 # Number of data parallel groups.
data_parallel_rank: int = 0 # Rank of the data parallel group.
# Local rank of the data parallel group, defaults to global rank.
pipeline_parallel_size: int = 1
"""Number of pipeline parallel groups."""
tensor_parallel_size: int = 1
"""Number of tensor parallel groups."""
data_parallel_size: int = 1
"""Number of data parallel groups. MoE layers will be sharded according to
the product of the tensor parallel size and data parallel size."""
data_parallel_rank: int = 0
"""Rank of the data parallel group."""
data_parallel_rank_local: Optional[int] = None
# IP of the data parallel master.
"""Local rank of the data parallel group, defaults to global rank."""
data_parallel_master_ip: str = "127.0.0.1"
data_parallel_master_port: int = 29500 # Port of the data parallel master.
enable_expert_parallel: bool = False # Use EP instead of TP for MoE layers.
"""IP of the data parallel master."""
data_parallel_master_port: int = 29500
"""Port of the data parallel master."""
enable_expert_parallel: bool = False
"""Use expert parallelism instead of tensor parallelism for MoE layers."""
# Maximum number of multiple batches
# when load model sequentially. To avoid RAM OOM when using tensor
# parallel and large models.
max_parallel_loading_workers: Optional[int] = None
"""Maximum number of parallal loading workers when loading model
sequentially in multiple batches. To avoid RAM OOM when using tensor
parallel and large models."""
# Disable the custom all-reduce kernel and fall back to NCCL.
disable_custom_all_reduce: bool = False
"""Disable the custom all-reduce kernel and fall back to NCCL."""
# Config for the tokenizer pool. If None, will use synchronous tokenization.
tokenizer_pool_config: Optional[TokenizerPoolConfig] = None
"""Config for the tokenizer pool. If None, will use synchronous
tokenization."""
# Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.
ray_workers_use_nsight: bool = False
"""Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler."""
# ray distributed model workers placement group.
placement_group: Optional["PlacementGroup"] = None
"""ray distributed model workers placement group."""
# Backend to use for distributed model
# workers, either "ray" or "mp" (multiprocessing). If the product
# of pipeline_parallel_size and tensor_parallel_size is less than
# or equal to the number of GPUs available, "mp" will be used to
# keep processing on a single host. Otherwise, this will default
# to "ray" if Ray is installed and fail otherwise. Note that tpu
# and hpu only support Ray for distributed inference.
distributed_executor_backend: Optional[Union[str,
type["ExecutorBase"]]] = None
"""Backend to use for distributed model
workers, either "ray" or "mp" (multiprocessing). If the product
of pipeline_parallel_size and tensor_parallel_size is less than
or equal to the number of GPUs available, "mp" will be used to
keep processing on a single host. Otherwise, this will default
to "ray" if Ray is installed and fail otherwise. Note that tpu
and hpu only support Ray for distributed inference."""
# the full name of the worker class to use. If "auto", the worker class
# will be determined based on the platform.
worker_cls: str = "auto"
"""The full name of the worker class to use. If "auto", the worker class
will be determined based on the platform."""
sd_worker_cls: str = "auto"
"""The full name of the worker class to use for speculative decofing.
If "auto", the worker class will be determined based on the platform."""
worker_extension_cls: str = ""
"""The full name of the worker extension class to use. The worker extension
class is dynamically inherited by the worker class. This is used to inject
new attributes and methods to the worker class for use in collective_rpc
calls."""
# world_size is TPxPP, it affects the number of workers we create.
world_size: int = field(init=False)
# world_size_across_dp is TPxPPxDP, it is the size of the world
# including data parallelism.
"""world_size is TPxPP, it affects the number of workers we create."""
world_size_across_dp: int = field(init=False)
"""world_size_across_dp is TPxPPxDP, it is the size of the world
including data parallelism."""
rank: int = 0
"""Global rank in distributed setup."""
def get_next_dp_init_port(self) -> int:
"""

View File

@ -5,9 +5,9 @@ import dataclasses
import json
import re
import threading
from dataclasses import dataclass
from dataclasses import MISSING, dataclass, fields
from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional,
Tuple, Type, Union, cast, get_args)
Tuple, Type, Union, cast, get_args, get_origin)
import torch
@ -19,7 +19,7 @@ from vllm.config import (CacheConfig, CompilationConfig, ConfigFormat,
ModelConfig, ModelImpl, ObservabilityConfig,
ParallelConfig, PoolerConfig, PromptAdapterConfig,
SchedulerConfig, SpeculativeConfig, TaskOption,
TokenizerPoolConfig, VllmConfig)
TokenizerPoolConfig, VllmConfig, get_attr_docs)
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
@ -111,14 +111,15 @@ class EngineArgs:
# Note: Specifying a custom executor backend by passing a class
# is intended for expert use only. The API may change without
# notice.
distributed_executor_backend: Optional[Union[str,
Type[ExecutorBase]]] = None
distributed_executor_backend: Optional[Union[
str, Type[ExecutorBase]]] = ParallelConfig.distributed_executor_backend
# number of P/D disaggregation (or other disaggregation) workers
pipeline_parallel_size: int = 1
tensor_parallel_size: int = 1
data_parallel_size: int = 1
enable_expert_parallel: bool = False
max_parallel_loading_workers: Optional[int] = None
pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
data_parallel_size: int = ParallelConfig.data_parallel_size
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
max_parallel_loading_workers: Optional[
int] = ParallelConfig.max_parallel_loading_workers
block_size: Optional[int] = None
enable_prefix_caching: Optional[bool] = None
prefix_caching_hash_algo: str = "builtin"
@ -145,7 +146,7 @@ class EngineArgs:
quantization: Optional[str] = None
enforce_eager: Optional[bool] = None
max_seq_len_to_capture: int = 8192
disable_custom_all_reduce: bool = False
disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
tokenizer_pool_size: int = 0
# Note: Specifying a tokenizer pool by passing a class
# is intended for expert use only. The API may change without
@ -170,7 +171,7 @@ class EngineArgs:
device: str = 'auto'
num_scheduler_steps: int = 1
multi_step_stream_outputs: bool = True
ray_workers_use_nsight: bool = False
ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight
num_gpu_blocks_override: Optional[int] = None
num_lookahead_slots: int = 0
model_loader_extra_config: Optional[dict] = None
@ -197,8 +198,8 @@ class EngineArgs:
override_neuron_config: Optional[Dict[str, Any]] = None
override_pooler_config: Optional[PoolerConfig] = None
compilation_config: Optional[CompilationConfig] = None
worker_cls: str = "auto"
worker_extension_cls: str = ""
worker_cls: str = ParallelConfig.worker_cls
worker_extension_cls: str = ParallelConfig.worker_extension_cls
kv_transfer_config: Optional[KVTransferConfig] = None
@ -232,6 +233,31 @@ class EngineArgs:
@staticmethod
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"""Shared CLI arguments for vLLM engine."""
def is_optional(cls: type[Any]) -> bool:
"""Check if the class is an optional type."""
return get_origin(cls) is Union and type(None) in get_args(cls)
def get_kwargs(cls: type[Any]) -> Dict[str, Any]:
cls_docs = get_attr_docs(cls)
kwargs = {}
for field in fields(cls):
name = field.name
# One of these will always be present
default = (field.default_factory
if field.default is MISSING else field.default)
kwargs[name] = {"default": default, "help": cls_docs[name]}
# When using action="store_true"
# add_argument doesn't accept type
if field.type is bool:
continue
# Handle optional fields
if is_optional(field.type):
kwargs[name]["type"] = nullable_str
continue
kwargs[name]["type"] = field.type
return kwargs
# Model arguments
parser.add_argument(
'--model',
@ -411,52 +437,37 @@ class EngineArgs:
'* "transformers" will use the Transformers model '
'implementation.\n')
# Parallel arguments
parser.add_argument(
parallel_kwargs = get_kwargs(ParallelConfig)
parallel_group = parser.add_argument_group(
title="ParallelConfig",
description=ParallelConfig.__doc__,
)
parallel_group.add_argument(
'--distributed-executor-backend',
choices=['ray', 'mp', 'uni', 'external_launcher'],
default=EngineArgs.distributed_executor_backend,
help='Backend to use for distributed model '
'workers, either "ray" or "mp" (multiprocessing). If the product '
'of pipeline_parallel_size and tensor_parallel_size is less than '
'or equal to the number of GPUs available, "mp" will be used to '
'keep processing on a single host. Otherwise, this will default '
'to "ray" if Ray is installed and fail otherwise. Note that tpu '
'only supports Ray for distributed inference.')
parser.add_argument('--pipeline-parallel-size',
'-pp',
type=int,
default=EngineArgs.pipeline_parallel_size,
help='Number of pipeline stages.')
parser.add_argument('--tensor-parallel-size',
'-tp',
type=int,
default=EngineArgs.tensor_parallel_size,
help='Number of tensor parallel replicas.')
parser.add_argument('--data-parallel-size',
'-dp',
type=int,
default=EngineArgs.data_parallel_size,
help='Number of data parallel replicas. '
'MoE layers will be sharded according to the '
'product of the tensor-parallel-size and '
'data-parallel-size.')
parser.add_argument(
**parallel_kwargs["distributed_executor_backend"])
parallel_group.add_argument(
'--pipeline-parallel-size', '-pp',
**parallel_kwargs["pipeline_parallel_size"])
parallel_group.add_argument('--tensor-parallel-size', '-tp',
**parallel_kwargs["tensor_parallel_size"])
parallel_group.add_argument('--data-parallel-size', '-dp',
**parallel_kwargs["data_parallel_size"])
parallel_group.add_argument(
'--enable-expert-parallel',
action='store_true',
help='Use expert parallelism instead of tensor parallelism '
'for MoE layers.')
parser.add_argument(
**parallel_kwargs["enable_expert_parallel"])
parallel_group.add_argument(
'--max-parallel-loading-workers',
type=int,
default=EngineArgs.max_parallel_loading_workers,
help='Load model sequentially in multiple batches, '
'to avoid RAM OOM when using tensor '
'parallel and large models.')
parser.add_argument(
**parallel_kwargs["max_parallel_loading_workers"])
parallel_group.add_argument(
'--ray-workers-use-nsight',
action='store_true',
help='If specified, use nsight to profile Ray workers.')
**parallel_kwargs["ray_workers_use_nsight"])
parallel_group.add_argument(
'--disable-custom-all-reduce',
action='store_true',
**parallel_kwargs["disable_custom_all_reduce"])
# KV cache arguments
parser.add_argument('--block-size',
type=int,
@ -639,10 +650,6 @@ class EngineArgs:
'Additionally for encoder-decoder models, if the '
'sequence length of the encoder input is larger '
'than this, we fall back to the eager mode.')
parser.add_argument('--disable-custom-all-reduce',
action='store_true',
default=EngineArgs.disable_custom_all_reduce,
help='See ParallelConfig.')
parser.add_argument('--tokenizer-pool-size',
type=int,
default=EngineArgs.tokenizer_pool_size,