Improve configs - ParallelConfig
(#16332)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
c1b57855ec
commit
0c54fc7273
146
vllm/config.py
146
vllm/config.py
@ -4,13 +4,16 @@ import ast
|
|||||||
import copy
|
import copy
|
||||||
import enum
|
import enum
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import inspect
|
||||||
import json
|
import json
|
||||||
import sys
|
import sys
|
||||||
|
import textwrap
|
||||||
import warnings
|
import warnings
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass, field, replace
|
from dataclasses import (MISSING, dataclass, field, fields, is_dataclass,
|
||||||
|
replace)
|
||||||
from importlib.util import find_spec
|
from importlib.util import find_spec
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal,
|
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal,
|
||||||
@ -104,6 +107,77 @@ class ModelImpl(str, enum.Enum):
|
|||||||
TRANSFORMERS = "transformers"
|
TRANSFORMERS = "transformers"
|
||||||
|
|
||||||
|
|
||||||
|
def get_attr_docs(cls: type[Any]) -> dict[str, str]:
|
||||||
|
"""
|
||||||
|
Get any docstrings placed after attribute assignments in a class body.
|
||||||
|
|
||||||
|
https://davidism.com/mit-license/
|
||||||
|
"""
|
||||||
|
|
||||||
|
def pairwise(iterable):
|
||||||
|
"""
|
||||||
|
Manually implement https://docs.python.org/3/library/itertools.html#itertools.pairwise
|
||||||
|
|
||||||
|
Can be removed when Python 3.9 support is dropped.
|
||||||
|
"""
|
||||||
|
iterator = iter(iterable)
|
||||||
|
a = next(iterator, None)
|
||||||
|
|
||||||
|
for b in iterator:
|
||||||
|
yield a, b
|
||||||
|
a = b
|
||||||
|
|
||||||
|
cls_node = ast.parse(textwrap.dedent(inspect.getsource(cls))).body[0]
|
||||||
|
|
||||||
|
if not isinstance(cls_node, ast.ClassDef):
|
||||||
|
raise TypeError("Given object was not a class.")
|
||||||
|
|
||||||
|
out = {}
|
||||||
|
|
||||||
|
# Consider each pair of nodes.
|
||||||
|
for a, b in pairwise(cls_node.body):
|
||||||
|
# Must be an assignment then a constant string.
|
||||||
|
if (not isinstance(a, (ast.Assign, ast.AnnAssign))
|
||||||
|
or not isinstance(b, ast.Expr)
|
||||||
|
or not isinstance(b.value, ast.Constant)
|
||||||
|
or not isinstance(b.value.value, str)):
|
||||||
|
continue
|
||||||
|
|
||||||
|
doc = inspect.cleandoc(b.value.value)
|
||||||
|
|
||||||
|
# An assignment can have multiple targets (a = b = v), but an
|
||||||
|
# annotated assignment only has one target.
|
||||||
|
targets = a.targets if isinstance(a, ast.Assign) else [a.target]
|
||||||
|
|
||||||
|
for target in targets:
|
||||||
|
# Must be assigning to a plain name.
|
||||||
|
if not isinstance(target, ast.Name):
|
||||||
|
continue
|
||||||
|
|
||||||
|
out[target.id] = doc
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def config(cls: type[Any]) -> type[Any]:
|
||||||
|
"""
|
||||||
|
A decorator that ensures all fields in a dataclass have default values
|
||||||
|
and that each field has a docstring.
|
||||||
|
"""
|
||||||
|
if not is_dataclass(cls):
|
||||||
|
raise TypeError("The decorated class must be a dataclass.")
|
||||||
|
attr_docs = get_attr_docs(cls)
|
||||||
|
for f in fields(cls):
|
||||||
|
if f.init and f.default is MISSING and f.default_factory is MISSING:
|
||||||
|
raise ValueError(
|
||||||
|
f"Field '{f.name}' in {cls.__name__} must have a default value."
|
||||||
|
)
|
||||||
|
if f.name not in attr_docs:
|
||||||
|
raise ValueError(
|
||||||
|
f"Field '{f.name}' in {cls.__name__} must have a docstring.")
|
||||||
|
return cls
|
||||||
|
|
||||||
|
|
||||||
class ModelConfig:
|
class ModelConfig:
|
||||||
"""Configuration for the model.
|
"""Configuration for the model.
|
||||||
|
|
||||||
@ -1432,61 +1506,77 @@ class LoadConfig:
|
|||||||
self.ignore_patterns = ["original/**/*"]
|
self.ignore_patterns = ["original/**/*"]
|
||||||
|
|
||||||
|
|
||||||
|
@config
|
||||||
@dataclass
|
@dataclass
|
||||||
class ParallelConfig:
|
class ParallelConfig:
|
||||||
"""Configuration for the distributed execution."""
|
"""Configuration for the distributed execution."""
|
||||||
|
|
||||||
pipeline_parallel_size: int = 1 # Number of pipeline parallel groups.
|
pipeline_parallel_size: int = 1
|
||||||
tensor_parallel_size: int = 1 # Number of tensor parallel groups.
|
"""Number of pipeline parallel groups."""
|
||||||
data_parallel_size: int = 1 # Number of data parallel groups.
|
tensor_parallel_size: int = 1
|
||||||
data_parallel_rank: int = 0 # Rank of the data parallel group.
|
"""Number of tensor parallel groups."""
|
||||||
# Local rank of the data parallel group, defaults to global rank.
|
data_parallel_size: int = 1
|
||||||
|
"""Number of data parallel groups. MoE layers will be sharded according to
|
||||||
|
the product of the tensor parallel size and data parallel size."""
|
||||||
|
data_parallel_rank: int = 0
|
||||||
|
"""Rank of the data parallel group."""
|
||||||
data_parallel_rank_local: Optional[int] = None
|
data_parallel_rank_local: Optional[int] = None
|
||||||
# IP of the data parallel master.
|
"""Local rank of the data parallel group, defaults to global rank."""
|
||||||
data_parallel_master_ip: str = "127.0.0.1"
|
data_parallel_master_ip: str = "127.0.0.1"
|
||||||
data_parallel_master_port: int = 29500 # Port of the data parallel master.
|
"""IP of the data parallel master."""
|
||||||
enable_expert_parallel: bool = False # Use EP instead of TP for MoE layers.
|
data_parallel_master_port: int = 29500
|
||||||
|
"""Port of the data parallel master."""
|
||||||
|
enable_expert_parallel: bool = False
|
||||||
|
"""Use expert parallelism instead of tensor parallelism for MoE layers."""
|
||||||
|
|
||||||
# Maximum number of multiple batches
|
|
||||||
# when load model sequentially. To avoid RAM OOM when using tensor
|
|
||||||
# parallel and large models.
|
|
||||||
max_parallel_loading_workers: Optional[int] = None
|
max_parallel_loading_workers: Optional[int] = None
|
||||||
|
"""Maximum number of parallal loading workers when loading model
|
||||||
|
sequentially in multiple batches. To avoid RAM OOM when using tensor
|
||||||
|
parallel and large models."""
|
||||||
|
|
||||||
# Disable the custom all-reduce kernel and fall back to NCCL.
|
|
||||||
disable_custom_all_reduce: bool = False
|
disable_custom_all_reduce: bool = False
|
||||||
|
"""Disable the custom all-reduce kernel and fall back to NCCL."""
|
||||||
|
|
||||||
# Config for the tokenizer pool. If None, will use synchronous tokenization.
|
|
||||||
tokenizer_pool_config: Optional[TokenizerPoolConfig] = None
|
tokenizer_pool_config: Optional[TokenizerPoolConfig] = None
|
||||||
|
"""Config for the tokenizer pool. If None, will use synchronous
|
||||||
|
tokenization."""
|
||||||
|
|
||||||
# Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.
|
|
||||||
ray_workers_use_nsight: bool = False
|
ray_workers_use_nsight: bool = False
|
||||||
|
"""Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler."""
|
||||||
|
|
||||||
# ray distributed model workers placement group.
|
|
||||||
placement_group: Optional["PlacementGroup"] = None
|
placement_group: Optional["PlacementGroup"] = None
|
||||||
|
"""ray distributed model workers placement group."""
|
||||||
|
|
||||||
# Backend to use for distributed model
|
|
||||||
# workers, either "ray" or "mp" (multiprocessing). If the product
|
|
||||||
# of pipeline_parallel_size and tensor_parallel_size is less than
|
|
||||||
# or equal to the number of GPUs available, "mp" will be used to
|
|
||||||
# keep processing on a single host. Otherwise, this will default
|
|
||||||
# to "ray" if Ray is installed and fail otherwise. Note that tpu
|
|
||||||
# and hpu only support Ray for distributed inference.
|
|
||||||
distributed_executor_backend: Optional[Union[str,
|
distributed_executor_backend: Optional[Union[str,
|
||||||
type["ExecutorBase"]]] = None
|
type["ExecutorBase"]]] = None
|
||||||
|
"""Backend to use for distributed model
|
||||||
|
workers, either "ray" or "mp" (multiprocessing). If the product
|
||||||
|
of pipeline_parallel_size and tensor_parallel_size is less than
|
||||||
|
or equal to the number of GPUs available, "mp" will be used to
|
||||||
|
keep processing on a single host. Otherwise, this will default
|
||||||
|
to "ray" if Ray is installed and fail otherwise. Note that tpu
|
||||||
|
and hpu only support Ray for distributed inference."""
|
||||||
|
|
||||||
# the full name of the worker class to use. If "auto", the worker class
|
|
||||||
# will be determined based on the platform.
|
|
||||||
worker_cls: str = "auto"
|
worker_cls: str = "auto"
|
||||||
|
"""The full name of the worker class to use. If "auto", the worker class
|
||||||
|
will be determined based on the platform."""
|
||||||
sd_worker_cls: str = "auto"
|
sd_worker_cls: str = "auto"
|
||||||
|
"""The full name of the worker class to use for speculative decofing.
|
||||||
|
If "auto", the worker class will be determined based on the platform."""
|
||||||
worker_extension_cls: str = ""
|
worker_extension_cls: str = ""
|
||||||
|
"""The full name of the worker extension class to use. The worker extension
|
||||||
|
class is dynamically inherited by the worker class. This is used to inject
|
||||||
|
new attributes and methods to the worker class for use in collective_rpc
|
||||||
|
calls."""
|
||||||
|
|
||||||
# world_size is TPxPP, it affects the number of workers we create.
|
|
||||||
world_size: int = field(init=False)
|
world_size: int = field(init=False)
|
||||||
# world_size_across_dp is TPxPPxDP, it is the size of the world
|
"""world_size is TPxPP, it affects the number of workers we create."""
|
||||||
# including data parallelism.
|
|
||||||
world_size_across_dp: int = field(init=False)
|
world_size_across_dp: int = field(init=False)
|
||||||
|
"""world_size_across_dp is TPxPPxDP, it is the size of the world
|
||||||
|
including data parallelism."""
|
||||||
|
|
||||||
rank: int = 0
|
rank: int = 0
|
||||||
|
"""Global rank in distributed setup."""
|
||||||
|
|
||||||
def get_next_dp_init_port(self) -> int:
|
def get_next_dp_init_port(self) -> int:
|
||||||
"""
|
"""
|
||||||
|
@ -5,9 +5,9 @@ import dataclasses
|
|||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
import threading
|
import threading
|
||||||
from dataclasses import dataclass
|
from dataclasses import MISSING, dataclass, fields
|
||||||
from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional,
|
from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional,
|
||||||
Tuple, Type, Union, cast, get_args)
|
Tuple, Type, Union, cast, get_args, get_origin)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -19,7 +19,7 @@ from vllm.config import (CacheConfig, CompilationConfig, ConfigFormat,
|
|||||||
ModelConfig, ModelImpl, ObservabilityConfig,
|
ModelConfig, ModelImpl, ObservabilityConfig,
|
||||||
ParallelConfig, PoolerConfig, PromptAdapterConfig,
|
ParallelConfig, PoolerConfig, PromptAdapterConfig,
|
||||||
SchedulerConfig, SpeculativeConfig, TaskOption,
|
SchedulerConfig, SpeculativeConfig, TaskOption,
|
||||||
TokenizerPoolConfig, VllmConfig)
|
TokenizerPoolConfig, VllmConfig, get_attr_docs)
|
||||||
from vllm.executor.executor_base import ExecutorBase
|
from vllm.executor.executor_base import ExecutorBase
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||||
@ -111,14 +111,15 @@ class EngineArgs:
|
|||||||
# Note: Specifying a custom executor backend by passing a class
|
# Note: Specifying a custom executor backend by passing a class
|
||||||
# is intended for expert use only. The API may change without
|
# is intended for expert use only. The API may change without
|
||||||
# notice.
|
# notice.
|
||||||
distributed_executor_backend: Optional[Union[str,
|
distributed_executor_backend: Optional[Union[
|
||||||
Type[ExecutorBase]]] = None
|
str, Type[ExecutorBase]]] = ParallelConfig.distributed_executor_backend
|
||||||
# number of P/D disaggregation (or other disaggregation) workers
|
# number of P/D disaggregation (or other disaggregation) workers
|
||||||
pipeline_parallel_size: int = 1
|
pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
|
||||||
tensor_parallel_size: int = 1
|
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
|
||||||
data_parallel_size: int = 1
|
data_parallel_size: int = ParallelConfig.data_parallel_size
|
||||||
enable_expert_parallel: bool = False
|
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
|
||||||
max_parallel_loading_workers: Optional[int] = None
|
max_parallel_loading_workers: Optional[
|
||||||
|
int] = ParallelConfig.max_parallel_loading_workers
|
||||||
block_size: Optional[int] = None
|
block_size: Optional[int] = None
|
||||||
enable_prefix_caching: Optional[bool] = None
|
enable_prefix_caching: Optional[bool] = None
|
||||||
prefix_caching_hash_algo: str = "builtin"
|
prefix_caching_hash_algo: str = "builtin"
|
||||||
@ -145,7 +146,7 @@ class EngineArgs:
|
|||||||
quantization: Optional[str] = None
|
quantization: Optional[str] = None
|
||||||
enforce_eager: Optional[bool] = None
|
enforce_eager: Optional[bool] = None
|
||||||
max_seq_len_to_capture: int = 8192
|
max_seq_len_to_capture: int = 8192
|
||||||
disable_custom_all_reduce: bool = False
|
disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
|
||||||
tokenizer_pool_size: int = 0
|
tokenizer_pool_size: int = 0
|
||||||
# Note: Specifying a tokenizer pool by passing a class
|
# Note: Specifying a tokenizer pool by passing a class
|
||||||
# is intended for expert use only. The API may change without
|
# is intended for expert use only. The API may change without
|
||||||
@ -170,7 +171,7 @@ class EngineArgs:
|
|||||||
device: str = 'auto'
|
device: str = 'auto'
|
||||||
num_scheduler_steps: int = 1
|
num_scheduler_steps: int = 1
|
||||||
multi_step_stream_outputs: bool = True
|
multi_step_stream_outputs: bool = True
|
||||||
ray_workers_use_nsight: bool = False
|
ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight
|
||||||
num_gpu_blocks_override: Optional[int] = None
|
num_gpu_blocks_override: Optional[int] = None
|
||||||
num_lookahead_slots: int = 0
|
num_lookahead_slots: int = 0
|
||||||
model_loader_extra_config: Optional[dict] = None
|
model_loader_extra_config: Optional[dict] = None
|
||||||
@ -197,8 +198,8 @@ class EngineArgs:
|
|||||||
override_neuron_config: Optional[Dict[str, Any]] = None
|
override_neuron_config: Optional[Dict[str, Any]] = None
|
||||||
override_pooler_config: Optional[PoolerConfig] = None
|
override_pooler_config: Optional[PoolerConfig] = None
|
||||||
compilation_config: Optional[CompilationConfig] = None
|
compilation_config: Optional[CompilationConfig] = None
|
||||||
worker_cls: str = "auto"
|
worker_cls: str = ParallelConfig.worker_cls
|
||||||
worker_extension_cls: str = ""
|
worker_extension_cls: str = ParallelConfig.worker_extension_cls
|
||||||
|
|
||||||
kv_transfer_config: Optional[KVTransferConfig] = None
|
kv_transfer_config: Optional[KVTransferConfig] = None
|
||||||
|
|
||||||
@ -232,6 +233,31 @@ class EngineArgs:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||||
"""Shared CLI arguments for vLLM engine."""
|
"""Shared CLI arguments for vLLM engine."""
|
||||||
|
|
||||||
|
def is_optional(cls: type[Any]) -> bool:
|
||||||
|
"""Check if the class is an optional type."""
|
||||||
|
return get_origin(cls) is Union and type(None) in get_args(cls)
|
||||||
|
|
||||||
|
def get_kwargs(cls: type[Any]) -> Dict[str, Any]:
|
||||||
|
cls_docs = get_attr_docs(cls)
|
||||||
|
kwargs = {}
|
||||||
|
for field in fields(cls):
|
||||||
|
name = field.name
|
||||||
|
# One of these will always be present
|
||||||
|
default = (field.default_factory
|
||||||
|
if field.default is MISSING else field.default)
|
||||||
|
kwargs[name] = {"default": default, "help": cls_docs[name]}
|
||||||
|
# When using action="store_true"
|
||||||
|
# add_argument doesn't accept type
|
||||||
|
if field.type is bool:
|
||||||
|
continue
|
||||||
|
# Handle optional fields
|
||||||
|
if is_optional(field.type):
|
||||||
|
kwargs[name]["type"] = nullable_str
|
||||||
|
continue
|
||||||
|
kwargs[name]["type"] = field.type
|
||||||
|
return kwargs
|
||||||
|
|
||||||
# Model arguments
|
# Model arguments
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--model',
|
'--model',
|
||||||
@ -411,52 +437,37 @@ class EngineArgs:
|
|||||||
'* "transformers" will use the Transformers model '
|
'* "transformers" will use the Transformers model '
|
||||||
'implementation.\n')
|
'implementation.\n')
|
||||||
# Parallel arguments
|
# Parallel arguments
|
||||||
parser.add_argument(
|
parallel_kwargs = get_kwargs(ParallelConfig)
|
||||||
|
parallel_group = parser.add_argument_group(
|
||||||
|
title="ParallelConfig",
|
||||||
|
description=ParallelConfig.__doc__,
|
||||||
|
)
|
||||||
|
parallel_group.add_argument(
|
||||||
'--distributed-executor-backend',
|
'--distributed-executor-backend',
|
||||||
choices=['ray', 'mp', 'uni', 'external_launcher'],
|
choices=['ray', 'mp', 'uni', 'external_launcher'],
|
||||||
default=EngineArgs.distributed_executor_backend,
|
**parallel_kwargs["distributed_executor_backend"])
|
||||||
help='Backend to use for distributed model '
|
parallel_group.add_argument(
|
||||||
'workers, either "ray" or "mp" (multiprocessing). If the product '
|
'--pipeline-parallel-size', '-pp',
|
||||||
'of pipeline_parallel_size and tensor_parallel_size is less than '
|
**parallel_kwargs["pipeline_parallel_size"])
|
||||||
'or equal to the number of GPUs available, "mp" will be used to '
|
parallel_group.add_argument('--tensor-parallel-size', '-tp',
|
||||||
'keep processing on a single host. Otherwise, this will default '
|
**parallel_kwargs["tensor_parallel_size"])
|
||||||
'to "ray" if Ray is installed and fail otherwise. Note that tpu '
|
parallel_group.add_argument('--data-parallel-size', '-dp',
|
||||||
'only supports Ray for distributed inference.')
|
**parallel_kwargs["data_parallel_size"])
|
||||||
|
parallel_group.add_argument(
|
||||||
parser.add_argument('--pipeline-parallel-size',
|
|
||||||
'-pp',
|
|
||||||
type=int,
|
|
||||||
default=EngineArgs.pipeline_parallel_size,
|
|
||||||
help='Number of pipeline stages.')
|
|
||||||
parser.add_argument('--tensor-parallel-size',
|
|
||||||
'-tp',
|
|
||||||
type=int,
|
|
||||||
default=EngineArgs.tensor_parallel_size,
|
|
||||||
help='Number of tensor parallel replicas.')
|
|
||||||
parser.add_argument('--data-parallel-size',
|
|
||||||
'-dp',
|
|
||||||
type=int,
|
|
||||||
default=EngineArgs.data_parallel_size,
|
|
||||||
help='Number of data parallel replicas. '
|
|
||||||
'MoE layers will be sharded according to the '
|
|
||||||
'product of the tensor-parallel-size and '
|
|
||||||
'data-parallel-size.')
|
|
||||||
parser.add_argument(
|
|
||||||
'--enable-expert-parallel',
|
'--enable-expert-parallel',
|
||||||
action='store_true',
|
action='store_true',
|
||||||
help='Use expert parallelism instead of tensor parallelism '
|
**parallel_kwargs["enable_expert_parallel"])
|
||||||
'for MoE layers.')
|
parallel_group.add_argument(
|
||||||
parser.add_argument(
|
|
||||||
'--max-parallel-loading-workers',
|
'--max-parallel-loading-workers',
|
||||||
type=int,
|
**parallel_kwargs["max_parallel_loading_workers"])
|
||||||
default=EngineArgs.max_parallel_loading_workers,
|
parallel_group.add_argument(
|
||||||
help='Load model sequentially in multiple batches, '
|
|
||||||
'to avoid RAM OOM when using tensor '
|
|
||||||
'parallel and large models.')
|
|
||||||
parser.add_argument(
|
|
||||||
'--ray-workers-use-nsight',
|
'--ray-workers-use-nsight',
|
||||||
action='store_true',
|
action='store_true',
|
||||||
help='If specified, use nsight to profile Ray workers.')
|
**parallel_kwargs["ray_workers_use_nsight"])
|
||||||
|
parallel_group.add_argument(
|
||||||
|
'--disable-custom-all-reduce',
|
||||||
|
action='store_true',
|
||||||
|
**parallel_kwargs["disable_custom_all_reduce"])
|
||||||
# KV cache arguments
|
# KV cache arguments
|
||||||
parser.add_argument('--block-size',
|
parser.add_argument('--block-size',
|
||||||
type=int,
|
type=int,
|
||||||
@ -639,10 +650,6 @@ class EngineArgs:
|
|||||||
'Additionally for encoder-decoder models, if the '
|
'Additionally for encoder-decoder models, if the '
|
||||||
'sequence length of the encoder input is larger '
|
'sequence length of the encoder input is larger '
|
||||||
'than this, we fall back to the eager mode.')
|
'than this, we fall back to the eager mode.')
|
||||||
parser.add_argument('--disable-custom-all-reduce',
|
|
||||||
action='store_true',
|
|
||||||
default=EngineArgs.disable_custom_all_reduce,
|
|
||||||
help='See ParallelConfig.')
|
|
||||||
parser.add_argument('--tokenizer-pool-size',
|
parser.add_argument('--tokenizer-pool-size',
|
||||||
type=int,
|
type=int,
|
||||||
default=EngineArgs.tokenizer_pool_size,
|
default=EngineArgs.tokenizer_pool_size,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user