Improve configs - ParallelConfig
(#16332)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
c1b57855ec
commit
0c54fc7273
146
vllm/config.py
146
vllm/config.py
@ -4,13 +4,16 @@ import ast
|
||||
import copy
|
||||
import enum
|
||||
import hashlib
|
||||
import inspect
|
||||
import json
|
||||
import sys
|
||||
import textwrap
|
||||
import warnings
|
||||
from collections import Counter
|
||||
from collections.abc import Mapping
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field, replace
|
||||
from dataclasses import (MISSING, dataclass, field, fields, is_dataclass,
|
||||
replace)
|
||||
from importlib.util import find_spec
|
||||
from pathlib import Path
|
||||
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal,
|
||||
@ -104,6 +107,77 @@ class ModelImpl(str, enum.Enum):
|
||||
TRANSFORMERS = "transformers"
|
||||
|
||||
|
||||
def get_attr_docs(cls: type[Any]) -> dict[str, str]:
|
||||
"""
|
||||
Get any docstrings placed after attribute assignments in a class body.
|
||||
|
||||
https://davidism.com/mit-license/
|
||||
"""
|
||||
|
||||
def pairwise(iterable):
|
||||
"""
|
||||
Manually implement https://docs.python.org/3/library/itertools.html#itertools.pairwise
|
||||
|
||||
Can be removed when Python 3.9 support is dropped.
|
||||
"""
|
||||
iterator = iter(iterable)
|
||||
a = next(iterator, None)
|
||||
|
||||
for b in iterator:
|
||||
yield a, b
|
||||
a = b
|
||||
|
||||
cls_node = ast.parse(textwrap.dedent(inspect.getsource(cls))).body[0]
|
||||
|
||||
if not isinstance(cls_node, ast.ClassDef):
|
||||
raise TypeError("Given object was not a class.")
|
||||
|
||||
out = {}
|
||||
|
||||
# Consider each pair of nodes.
|
||||
for a, b in pairwise(cls_node.body):
|
||||
# Must be an assignment then a constant string.
|
||||
if (not isinstance(a, (ast.Assign, ast.AnnAssign))
|
||||
or not isinstance(b, ast.Expr)
|
||||
or not isinstance(b.value, ast.Constant)
|
||||
or not isinstance(b.value.value, str)):
|
||||
continue
|
||||
|
||||
doc = inspect.cleandoc(b.value.value)
|
||||
|
||||
# An assignment can have multiple targets (a = b = v), but an
|
||||
# annotated assignment only has one target.
|
||||
targets = a.targets if isinstance(a, ast.Assign) else [a.target]
|
||||
|
||||
for target in targets:
|
||||
# Must be assigning to a plain name.
|
||||
if not isinstance(target, ast.Name):
|
||||
continue
|
||||
|
||||
out[target.id] = doc
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def config(cls: type[Any]) -> type[Any]:
|
||||
"""
|
||||
A decorator that ensures all fields in a dataclass have default values
|
||||
and that each field has a docstring.
|
||||
"""
|
||||
if not is_dataclass(cls):
|
||||
raise TypeError("The decorated class must be a dataclass.")
|
||||
attr_docs = get_attr_docs(cls)
|
||||
for f in fields(cls):
|
||||
if f.init and f.default is MISSING and f.default_factory is MISSING:
|
||||
raise ValueError(
|
||||
f"Field '{f.name}' in {cls.__name__} must have a default value."
|
||||
)
|
||||
if f.name not in attr_docs:
|
||||
raise ValueError(
|
||||
f"Field '{f.name}' in {cls.__name__} must have a docstring.")
|
||||
return cls
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
"""Configuration for the model.
|
||||
|
||||
@ -1432,61 +1506,77 @@ class LoadConfig:
|
||||
self.ignore_patterns = ["original/**/*"]
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class ParallelConfig:
|
||||
"""Configuration for the distributed execution."""
|
||||
|
||||
pipeline_parallel_size: int = 1 # Number of pipeline parallel groups.
|
||||
tensor_parallel_size: int = 1 # Number of tensor parallel groups.
|
||||
data_parallel_size: int = 1 # Number of data parallel groups.
|
||||
data_parallel_rank: int = 0 # Rank of the data parallel group.
|
||||
# Local rank of the data parallel group, defaults to global rank.
|
||||
pipeline_parallel_size: int = 1
|
||||
"""Number of pipeline parallel groups."""
|
||||
tensor_parallel_size: int = 1
|
||||
"""Number of tensor parallel groups."""
|
||||
data_parallel_size: int = 1
|
||||
"""Number of data parallel groups. MoE layers will be sharded according to
|
||||
the product of the tensor parallel size and data parallel size."""
|
||||
data_parallel_rank: int = 0
|
||||
"""Rank of the data parallel group."""
|
||||
data_parallel_rank_local: Optional[int] = None
|
||||
# IP of the data parallel master.
|
||||
"""Local rank of the data parallel group, defaults to global rank."""
|
||||
data_parallel_master_ip: str = "127.0.0.1"
|
||||
data_parallel_master_port: int = 29500 # Port of the data parallel master.
|
||||
enable_expert_parallel: bool = False # Use EP instead of TP for MoE layers.
|
||||
"""IP of the data parallel master."""
|
||||
data_parallel_master_port: int = 29500
|
||||
"""Port of the data parallel master."""
|
||||
enable_expert_parallel: bool = False
|
||||
"""Use expert parallelism instead of tensor parallelism for MoE layers."""
|
||||
|
||||
# Maximum number of multiple batches
|
||||
# when load model sequentially. To avoid RAM OOM when using tensor
|
||||
# parallel and large models.
|
||||
max_parallel_loading_workers: Optional[int] = None
|
||||
"""Maximum number of parallal loading workers when loading model
|
||||
sequentially in multiple batches. To avoid RAM OOM when using tensor
|
||||
parallel and large models."""
|
||||
|
||||
# Disable the custom all-reduce kernel and fall back to NCCL.
|
||||
disable_custom_all_reduce: bool = False
|
||||
"""Disable the custom all-reduce kernel and fall back to NCCL."""
|
||||
|
||||
# Config for the tokenizer pool. If None, will use synchronous tokenization.
|
||||
tokenizer_pool_config: Optional[TokenizerPoolConfig] = None
|
||||
"""Config for the tokenizer pool. If None, will use synchronous
|
||||
tokenization."""
|
||||
|
||||
# Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.
|
||||
ray_workers_use_nsight: bool = False
|
||||
"""Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler."""
|
||||
|
||||
# ray distributed model workers placement group.
|
||||
placement_group: Optional["PlacementGroup"] = None
|
||||
"""ray distributed model workers placement group."""
|
||||
|
||||
# Backend to use for distributed model
|
||||
# workers, either "ray" or "mp" (multiprocessing). If the product
|
||||
# of pipeline_parallel_size and tensor_parallel_size is less than
|
||||
# or equal to the number of GPUs available, "mp" will be used to
|
||||
# keep processing on a single host. Otherwise, this will default
|
||||
# to "ray" if Ray is installed and fail otherwise. Note that tpu
|
||||
# and hpu only support Ray for distributed inference.
|
||||
distributed_executor_backend: Optional[Union[str,
|
||||
type["ExecutorBase"]]] = None
|
||||
"""Backend to use for distributed model
|
||||
workers, either "ray" or "mp" (multiprocessing). If the product
|
||||
of pipeline_parallel_size and tensor_parallel_size is less than
|
||||
or equal to the number of GPUs available, "mp" will be used to
|
||||
keep processing on a single host. Otherwise, this will default
|
||||
to "ray" if Ray is installed and fail otherwise. Note that tpu
|
||||
and hpu only support Ray for distributed inference."""
|
||||
|
||||
# the full name of the worker class to use. If "auto", the worker class
|
||||
# will be determined based on the platform.
|
||||
worker_cls: str = "auto"
|
||||
"""The full name of the worker class to use. If "auto", the worker class
|
||||
will be determined based on the platform."""
|
||||
sd_worker_cls: str = "auto"
|
||||
"""The full name of the worker class to use for speculative decofing.
|
||||
If "auto", the worker class will be determined based on the platform."""
|
||||
worker_extension_cls: str = ""
|
||||
"""The full name of the worker extension class to use. The worker extension
|
||||
class is dynamically inherited by the worker class. This is used to inject
|
||||
new attributes and methods to the worker class for use in collective_rpc
|
||||
calls."""
|
||||
|
||||
# world_size is TPxPP, it affects the number of workers we create.
|
||||
world_size: int = field(init=False)
|
||||
# world_size_across_dp is TPxPPxDP, it is the size of the world
|
||||
# including data parallelism.
|
||||
"""world_size is TPxPP, it affects the number of workers we create."""
|
||||
world_size_across_dp: int = field(init=False)
|
||||
"""world_size_across_dp is TPxPPxDP, it is the size of the world
|
||||
including data parallelism."""
|
||||
|
||||
rank: int = 0
|
||||
"""Global rank in distributed setup."""
|
||||
|
||||
def get_next_dp_init_port(self) -> int:
|
||||
"""
|
||||
|
@ -5,9 +5,9 @@ import dataclasses
|
||||
import json
|
||||
import re
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import MISSING, dataclass, fields
|
||||
from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional,
|
||||
Tuple, Type, Union, cast, get_args)
|
||||
Tuple, Type, Union, cast, get_args, get_origin)
|
||||
|
||||
import torch
|
||||
|
||||
@ -19,7 +19,7 @@ from vllm.config import (CacheConfig, CompilationConfig, ConfigFormat,
|
||||
ModelConfig, ModelImpl, ObservabilityConfig,
|
||||
ParallelConfig, PoolerConfig, PromptAdapterConfig,
|
||||
SchedulerConfig, SpeculativeConfig, TaskOption,
|
||||
TokenizerPoolConfig, VllmConfig)
|
||||
TokenizerPoolConfig, VllmConfig, get_attr_docs)
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||
@ -111,14 +111,15 @@ class EngineArgs:
|
||||
# Note: Specifying a custom executor backend by passing a class
|
||||
# is intended for expert use only. The API may change without
|
||||
# notice.
|
||||
distributed_executor_backend: Optional[Union[str,
|
||||
Type[ExecutorBase]]] = None
|
||||
distributed_executor_backend: Optional[Union[
|
||||
str, Type[ExecutorBase]]] = ParallelConfig.distributed_executor_backend
|
||||
# number of P/D disaggregation (or other disaggregation) workers
|
||||
pipeline_parallel_size: int = 1
|
||||
tensor_parallel_size: int = 1
|
||||
data_parallel_size: int = 1
|
||||
enable_expert_parallel: bool = False
|
||||
max_parallel_loading_workers: Optional[int] = None
|
||||
pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
|
||||
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
|
||||
data_parallel_size: int = ParallelConfig.data_parallel_size
|
||||
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
|
||||
max_parallel_loading_workers: Optional[
|
||||
int] = ParallelConfig.max_parallel_loading_workers
|
||||
block_size: Optional[int] = None
|
||||
enable_prefix_caching: Optional[bool] = None
|
||||
prefix_caching_hash_algo: str = "builtin"
|
||||
@ -145,7 +146,7 @@ class EngineArgs:
|
||||
quantization: Optional[str] = None
|
||||
enforce_eager: Optional[bool] = None
|
||||
max_seq_len_to_capture: int = 8192
|
||||
disable_custom_all_reduce: bool = False
|
||||
disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
|
||||
tokenizer_pool_size: int = 0
|
||||
# Note: Specifying a tokenizer pool by passing a class
|
||||
# is intended for expert use only. The API may change without
|
||||
@ -170,7 +171,7 @@ class EngineArgs:
|
||||
device: str = 'auto'
|
||||
num_scheduler_steps: int = 1
|
||||
multi_step_stream_outputs: bool = True
|
||||
ray_workers_use_nsight: bool = False
|
||||
ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight
|
||||
num_gpu_blocks_override: Optional[int] = None
|
||||
num_lookahead_slots: int = 0
|
||||
model_loader_extra_config: Optional[dict] = None
|
||||
@ -197,8 +198,8 @@ class EngineArgs:
|
||||
override_neuron_config: Optional[Dict[str, Any]] = None
|
||||
override_pooler_config: Optional[PoolerConfig] = None
|
||||
compilation_config: Optional[CompilationConfig] = None
|
||||
worker_cls: str = "auto"
|
||||
worker_extension_cls: str = ""
|
||||
worker_cls: str = ParallelConfig.worker_cls
|
||||
worker_extension_cls: str = ParallelConfig.worker_extension_cls
|
||||
|
||||
kv_transfer_config: Optional[KVTransferConfig] = None
|
||||
|
||||
@ -232,6 +233,31 @@ class EngineArgs:
|
||||
@staticmethod
|
||||
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
"""Shared CLI arguments for vLLM engine."""
|
||||
|
||||
def is_optional(cls: type[Any]) -> bool:
|
||||
"""Check if the class is an optional type."""
|
||||
return get_origin(cls) is Union and type(None) in get_args(cls)
|
||||
|
||||
def get_kwargs(cls: type[Any]) -> Dict[str, Any]:
|
||||
cls_docs = get_attr_docs(cls)
|
||||
kwargs = {}
|
||||
for field in fields(cls):
|
||||
name = field.name
|
||||
# One of these will always be present
|
||||
default = (field.default_factory
|
||||
if field.default is MISSING else field.default)
|
||||
kwargs[name] = {"default": default, "help": cls_docs[name]}
|
||||
# When using action="store_true"
|
||||
# add_argument doesn't accept type
|
||||
if field.type is bool:
|
||||
continue
|
||||
# Handle optional fields
|
||||
if is_optional(field.type):
|
||||
kwargs[name]["type"] = nullable_str
|
||||
continue
|
||||
kwargs[name]["type"] = field.type
|
||||
return kwargs
|
||||
|
||||
# Model arguments
|
||||
parser.add_argument(
|
||||
'--model',
|
||||
@ -411,52 +437,37 @@ class EngineArgs:
|
||||
'* "transformers" will use the Transformers model '
|
||||
'implementation.\n')
|
||||
# Parallel arguments
|
||||
parser.add_argument(
|
||||
parallel_kwargs = get_kwargs(ParallelConfig)
|
||||
parallel_group = parser.add_argument_group(
|
||||
title="ParallelConfig",
|
||||
description=ParallelConfig.__doc__,
|
||||
)
|
||||
parallel_group.add_argument(
|
||||
'--distributed-executor-backend',
|
||||
choices=['ray', 'mp', 'uni', 'external_launcher'],
|
||||
default=EngineArgs.distributed_executor_backend,
|
||||
help='Backend to use for distributed model '
|
||||
'workers, either "ray" or "mp" (multiprocessing). If the product '
|
||||
'of pipeline_parallel_size and tensor_parallel_size is less than '
|
||||
'or equal to the number of GPUs available, "mp" will be used to '
|
||||
'keep processing on a single host. Otherwise, this will default '
|
||||
'to "ray" if Ray is installed and fail otherwise. Note that tpu '
|
||||
'only supports Ray for distributed inference.')
|
||||
|
||||
parser.add_argument('--pipeline-parallel-size',
|
||||
'-pp',
|
||||
type=int,
|
||||
default=EngineArgs.pipeline_parallel_size,
|
||||
help='Number of pipeline stages.')
|
||||
parser.add_argument('--tensor-parallel-size',
|
||||
'-tp',
|
||||
type=int,
|
||||
default=EngineArgs.tensor_parallel_size,
|
||||
help='Number of tensor parallel replicas.')
|
||||
parser.add_argument('--data-parallel-size',
|
||||
'-dp',
|
||||
type=int,
|
||||
default=EngineArgs.data_parallel_size,
|
||||
help='Number of data parallel replicas. '
|
||||
'MoE layers will be sharded according to the '
|
||||
'product of the tensor-parallel-size and '
|
||||
'data-parallel-size.')
|
||||
parser.add_argument(
|
||||
**parallel_kwargs["distributed_executor_backend"])
|
||||
parallel_group.add_argument(
|
||||
'--pipeline-parallel-size', '-pp',
|
||||
**parallel_kwargs["pipeline_parallel_size"])
|
||||
parallel_group.add_argument('--tensor-parallel-size', '-tp',
|
||||
**parallel_kwargs["tensor_parallel_size"])
|
||||
parallel_group.add_argument('--data-parallel-size', '-dp',
|
||||
**parallel_kwargs["data_parallel_size"])
|
||||
parallel_group.add_argument(
|
||||
'--enable-expert-parallel',
|
||||
action='store_true',
|
||||
help='Use expert parallelism instead of tensor parallelism '
|
||||
'for MoE layers.')
|
||||
parser.add_argument(
|
||||
**parallel_kwargs["enable_expert_parallel"])
|
||||
parallel_group.add_argument(
|
||||
'--max-parallel-loading-workers',
|
||||
type=int,
|
||||
default=EngineArgs.max_parallel_loading_workers,
|
||||
help='Load model sequentially in multiple batches, '
|
||||
'to avoid RAM OOM when using tensor '
|
||||
'parallel and large models.')
|
||||
parser.add_argument(
|
||||
**parallel_kwargs["max_parallel_loading_workers"])
|
||||
parallel_group.add_argument(
|
||||
'--ray-workers-use-nsight',
|
||||
action='store_true',
|
||||
help='If specified, use nsight to profile Ray workers.')
|
||||
**parallel_kwargs["ray_workers_use_nsight"])
|
||||
parallel_group.add_argument(
|
||||
'--disable-custom-all-reduce',
|
||||
action='store_true',
|
||||
**parallel_kwargs["disable_custom_all_reduce"])
|
||||
# KV cache arguments
|
||||
parser.add_argument('--block-size',
|
||||
type=int,
|
||||
@ -639,10 +650,6 @@ class EngineArgs:
|
||||
'Additionally for encoder-decoder models, if the '
|
||||
'sequence length of the encoder input is larger '
|
||||
'than this, we fall back to the eager mode.')
|
||||
parser.add_argument('--disable-custom-all-reduce',
|
||||
action='store_true',
|
||||
default=EngineArgs.disable_custom_all_reduce,
|
||||
help='See ParallelConfig.')
|
||||
parser.add_argument('--tokenizer-pool-size',
|
||||
type=int,
|
||||
default=EngineArgs.tokenizer_pool_size,
|
||||
|
Loading…
x
Reference in New Issue
Block a user