Improve configs - LoadConfig
(#16422)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
71b9cde010
commit
cd77382ac1
@ -17,7 +17,7 @@ from dataclasses import (MISSING, dataclass, field, fields, is_dataclass,
|
|||||||
from importlib.util import find_spec
|
from importlib.util import find_spec
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal,
|
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal,
|
||||||
Optional, Protocol, Union)
|
Optional, Protocol, TypeVar, Union)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pydantic import BaseModel, Field, PrivateAttr
|
from pydantic import BaseModel, Field, PrivateAttr
|
||||||
@ -45,6 +45,7 @@ from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless,
|
|||||||
random_uuid, resolve_obj_by_qualname)
|
random_uuid, resolve_obj_by_qualname)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from _typeshed import DataclassInstance
|
||||||
from ray.util.placement_group import PlacementGroup
|
from ray.util.placement_group import PlacementGroup
|
||||||
|
|
||||||
from vllm.executor.executor_base import ExecutorBase
|
from vllm.executor.executor_base import ExecutorBase
|
||||||
@ -53,8 +54,11 @@ if TYPE_CHECKING:
|
|||||||
from vllm.model_executor.model_loader.loader import BaseModelLoader
|
from vllm.model_executor.model_loader.loader import BaseModelLoader
|
||||||
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
|
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
|
||||||
BaseTokenizerGroup)
|
BaseTokenizerGroup)
|
||||||
|
|
||||||
|
Config = TypeVar("Config", bound=DataclassInstance)
|
||||||
else:
|
else:
|
||||||
QuantizationConfig = None
|
QuantizationConfig = None
|
||||||
|
Config = TypeVar("Config")
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -159,7 +163,7 @@ def get_attr_docs(cls: type[Any]) -> dict[str, str]:
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def config(cls: type[Any]) -> type[Any]:
|
def config(cls: type[Config]) -> type[Config]:
|
||||||
"""
|
"""
|
||||||
A decorator that ensures all fields in a dataclass have default values
|
A decorator that ensures all fields in a dataclass have default values
|
||||||
and that each field has a docstring.
|
and that each field has a docstring.
|
||||||
@ -1431,44 +1435,47 @@ class LoadFormat(str, enum.Enum):
|
|||||||
FASTSAFETENSORS = "fastsafetensors"
|
FASTSAFETENSORS = "fastsafetensors"
|
||||||
|
|
||||||
|
|
||||||
|
@config
|
||||||
@dataclass
|
@dataclass
|
||||||
class LoadConfig:
|
class LoadConfig:
|
||||||
"""
|
"""Configuration for loading the model weights."""
|
||||||
download_dir: Directory to download and load the weights, default to the
|
|
||||||
default cache directory of huggingface.
|
|
||||||
load_format: The format of the model weights to load:
|
|
||||||
"auto" will try to load the weights in the safetensors format and
|
|
||||||
fall back to the pytorch bin format if safetensors format is
|
|
||||||
not available.
|
|
||||||
"pt" will load the weights in the pytorch bin format.
|
|
||||||
"safetensors" will load the weights in the safetensors format.
|
|
||||||
"npcache" will load the weights in pytorch format and store
|
|
||||||
a numpy cache to speed up the loading.
|
|
||||||
"dummy" will initialize the weights with random values, which is
|
|
||||||
mainly for profiling.
|
|
||||||
"tensorizer" will use CoreWeave's tensorizer library for
|
|
||||||
fast weight loading.
|
|
||||||
"bitsandbytes" will load nf4 type weights.
|
|
||||||
"sharded_state" will load weights from pre-sharded checkpoint files,
|
|
||||||
supporting efficient loading of tensor-parallel models.
|
|
||||||
"gguf" will load weights from GGUF format files.
|
|
||||||
"mistral" will load weights from consolidated safetensors files used
|
|
||||||
by Mistral models.
|
|
||||||
"runai_streamer" will load weights from RunAI streamer format files.
|
|
||||||
model_loader_extra_config: The extra config for the model loader.
|
|
||||||
ignore_patterns: The list of patterns to ignore when loading the model.
|
|
||||||
Default to "original/**/*" to avoid repeated loading of llama's
|
|
||||||
checkpoints.
|
|
||||||
use_tqdm_on_load: Whether to enable tqdm for showing progress bar during
|
|
||||||
loading. Default to True
|
|
||||||
"""
|
|
||||||
|
|
||||||
load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
|
load_format: Union[str, LoadFormat,
|
||||||
|
"BaseModelLoader"] = LoadFormat.AUTO.value
|
||||||
|
"""The format of the model weights to load:\n
|
||||||
|
- "auto" will try to load the weights in the safetensors format and fall
|
||||||
|
back to the pytorch bin format if safetensors format is not available.\n
|
||||||
|
- "pt" will load the weights in the pytorch bin format.\n
|
||||||
|
- "safetensors" will load the weights in the safetensors format.\n
|
||||||
|
- "npcache" will load the weights in pytorch format and store a numpy cache
|
||||||
|
to speed up the loading.\n
|
||||||
|
- "dummy" will initialize the weights with random values, which is mainly
|
||||||
|
for profiling.\n
|
||||||
|
- "tensorizer" will use CoreWeave's tensorizer library for fast weight
|
||||||
|
loading. See the Tensorize vLLM Model script in the Examples section for
|
||||||
|
more information.\n
|
||||||
|
- "runai_streamer" will load the Safetensors weights using Run:ai Model
|
||||||
|
Streamer.\n
|
||||||
|
- "bitsandbytes" will load the weights using bitsandbytes quantization.\n
|
||||||
|
- "sharded_state" will load weights from pre-sharded checkpoint files,
|
||||||
|
supporting efficient loading of tensor-parallel models.\n
|
||||||
|
- "gguf" will load weights from GGUF format files (details specified in
|
||||||
|
https://github.com/ggml-org/ggml/blob/master/docs/gguf.md).\n
|
||||||
|
- "mistral" will load weights from consolidated safetensors files used by
|
||||||
|
Mistral models."""
|
||||||
download_dir: Optional[str] = None
|
download_dir: Optional[str] = None
|
||||||
model_loader_extra_config: Optional[Union[str, dict]] = field(
|
"""Directory to download and load the weights, default to the default
|
||||||
default_factory=dict)
|
cache directory of Hugging Face."""
|
||||||
|
model_loader_extra_config: Optional[Union[str, dict]] = None
|
||||||
|
"""Extra config for model loader. This will be passed to the model loader
|
||||||
|
corresponding to the chosen load_format. This should be a JSON string that
|
||||||
|
will be parsed into a dictionary."""
|
||||||
ignore_patterns: Optional[Union[list[str], str]] = None
|
ignore_patterns: Optional[Union[list[str], str]] = None
|
||||||
|
"""The list of patterns to ignore when loading the model. Default to
|
||||||
|
"original/**/*" to avoid repeated loading of llama's checkpoints."""
|
||||||
use_tqdm_on_load: bool = True
|
use_tqdm_on_load: bool = True
|
||||||
|
"""Whether to enable tqdm for showing progress bar when loading model
|
||||||
|
weights."""
|
||||||
|
|
||||||
def compute_hash(self) -> str:
|
def compute_hash(self) -> str:
|
||||||
"""
|
"""
|
||||||
|
@ -101,8 +101,8 @@ class EngineArgs:
|
|||||||
tokenizer_mode: str = 'auto'
|
tokenizer_mode: str = 'auto'
|
||||||
trust_remote_code: bool = False
|
trust_remote_code: bool = False
|
||||||
allowed_local_media_path: str = ""
|
allowed_local_media_path: str = ""
|
||||||
download_dir: Optional[str] = None
|
download_dir: Optional[str] = LoadConfig.download_dir
|
||||||
load_format: str = 'auto'
|
load_format: str = LoadConfig.load_format
|
||||||
config_format: ConfigFormat = ConfigFormat.AUTO
|
config_format: ConfigFormat = ConfigFormat.AUTO
|
||||||
dtype: str = 'auto'
|
dtype: str = 'auto'
|
||||||
kv_cache_dtype: str = 'auto'
|
kv_cache_dtype: str = 'auto'
|
||||||
@ -174,8 +174,10 @@ class EngineArgs:
|
|||||||
ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight
|
ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight
|
||||||
num_gpu_blocks_override: Optional[int] = None
|
num_gpu_blocks_override: Optional[int] = None
|
||||||
num_lookahead_slots: int = 0
|
num_lookahead_slots: int = 0
|
||||||
model_loader_extra_config: Optional[dict] = None
|
model_loader_extra_config: Optional[
|
||||||
ignore_patterns: Optional[Union[str, List[str]]] = None
|
dict] = LoadConfig.model_loader_extra_config
|
||||||
|
ignore_patterns: Optional[Union[str,
|
||||||
|
List[str]]] = LoadConfig.ignore_patterns
|
||||||
preemption_mode: Optional[str] = None
|
preemption_mode: Optional[str] = None
|
||||||
|
|
||||||
scheduler_delay_factor: float = 0.0
|
scheduler_delay_factor: float = 0.0
|
||||||
@ -213,7 +215,7 @@ class EngineArgs:
|
|||||||
additional_config: Optional[Dict[str, Any]] = None
|
additional_config: Optional[Dict[str, Any]] = None
|
||||||
enable_reasoning: Optional[bool] = None
|
enable_reasoning: Optional[bool] = None
|
||||||
reasoning_parser: Optional[str] = None
|
reasoning_parser: Optional[str] = None
|
||||||
use_tqdm_on_load: bool = True
|
use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if not self.tokenizer:
|
if not self.tokenizer:
|
||||||
@ -234,9 +236,13 @@ class EngineArgs:
|
|||||||
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||||
"""Shared CLI arguments for vLLM engine."""
|
"""Shared CLI arguments for vLLM engine."""
|
||||||
|
|
||||||
|
def is_type_in_union(cls: type[Any], type: type[Any]) -> bool:
|
||||||
|
"""Check if the class is a type in a union type."""
|
||||||
|
return get_origin(cls) is Union and type in get_args(cls)
|
||||||
|
|
||||||
def is_optional(cls: type[Any]) -> bool:
|
def is_optional(cls: type[Any]) -> bool:
|
||||||
"""Check if the class is an optional type."""
|
"""Check if the class is an optional type."""
|
||||||
return get_origin(cls) is Union and type(None) in get_args(cls)
|
return is_type_in_union(cls, type(None))
|
||||||
|
|
||||||
def get_kwargs(cls: type[Any]) -> Dict[str, Any]:
|
def get_kwargs(cls: type[Any]) -> Dict[str, Any]:
|
||||||
cls_docs = get_attr_docs(cls)
|
cls_docs = get_attr_docs(cls)
|
||||||
@ -255,6 +261,10 @@ class EngineArgs:
|
|||||||
if is_optional(field.type):
|
if is_optional(field.type):
|
||||||
kwargs[name]["type"] = nullable_str
|
kwargs[name]["type"] = nullable_str
|
||||||
continue
|
continue
|
||||||
|
# Handle str in union fields
|
||||||
|
if is_type_in_union(field.type, str):
|
||||||
|
kwargs[name]["type"] = str
|
||||||
|
continue
|
||||||
kwargs[name]["type"] = field.type
|
kwargs[name]["type"] = field.type
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
@ -333,38 +343,23 @@ class EngineArgs:
|
|||||||
"from directories specified by the server file system. "
|
"from directories specified by the server file system. "
|
||||||
"This is a security risk. "
|
"This is a security risk. "
|
||||||
"Should only be enabled in trusted environments.")
|
"Should only be enabled in trusted environments.")
|
||||||
parser.add_argument('--download-dir',
|
# Model loading arguments
|
||||||
type=nullable_str,
|
load_kwargs = get_kwargs(LoadConfig)
|
||||||
default=EngineArgs.download_dir,
|
load_group = parser.add_argument_group(
|
||||||
help='Directory to download and load the weights.')
|
title="LoadConfig",
|
||||||
parser.add_argument(
|
description=LoadConfig.__doc__,
|
||||||
'--load-format',
|
)
|
||||||
type=str,
|
load_group.add_argument('--load-format',
|
||||||
default=EngineArgs.load_format,
|
choices=[f.value for f in LoadFormat],
|
||||||
choices=[f.value for f in LoadFormat],
|
**load_kwargs["load_format"])
|
||||||
help='The format of the model weights to load.\n\n'
|
load_group.add_argument('--download-dir',
|
||||||
'* "auto" will try to load the weights in the safetensors format '
|
**load_kwargs["download_dir"])
|
||||||
'and fall back to the pytorch bin format if safetensors format '
|
load_group.add_argument('--model-loader-extra-config',
|
||||||
'is not available.\n'
|
**load_kwargs["model_loader_extra_config"])
|
||||||
'* "pt" will load the weights in the pytorch bin format.\n'
|
load_group.add_argument('--use-tqdm-on-load',
|
||||||
'* "safetensors" will load the weights in the safetensors format.\n'
|
action=argparse.BooleanOptionalAction,
|
||||||
'* "npcache" will load the weights in pytorch format and store '
|
**load_kwargs["use_tqdm_on_load"])
|
||||||
'a numpy cache to speed up the loading.\n'
|
|
||||||
'* "dummy" will initialize the weights with random values, '
|
|
||||||
'which is mainly for profiling.\n'
|
|
||||||
'* "tensorizer" will load the weights using tensorizer from '
|
|
||||||
'CoreWeave. See the Tensorize vLLM Model script in the Examples '
|
|
||||||
'section for more information.\n'
|
|
||||||
'* "runai_streamer" will load the Safetensors weights using Run:ai'
|
|
||||||
'Model Streamer.\n'
|
|
||||||
'* "bitsandbytes" will load the weights using bitsandbytes '
|
|
||||||
'quantization.\n'
|
|
||||||
'* "sharded_state" will load weights from pre-sharded checkpoint '
|
|
||||||
'files, supporting efficient loading of tensor-parallel models\n'
|
|
||||||
'* "gguf" will load weights from GGUF format files (details '
|
|
||||||
'specified in https://github.com/ggml-org/ggml/blob/master/docs/gguf.md).\n'
|
|
||||||
'* "mistral" will load weights from consolidated safetensors files '
|
|
||||||
'used by Mistral models.\n')
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--config-format',
|
'--config-format',
|
||||||
default=EngineArgs.config_format,
|
default=EngineArgs.config_format,
|
||||||
@ -770,14 +765,6 @@ class EngineArgs:
|
|||||||
default=1,
|
default=1,
|
||||||
help=('Maximum number of forward steps per '
|
help=('Maximum number of forward steps per '
|
||||||
'scheduler call.'))
|
'scheduler call.'))
|
||||||
parser.add_argument(
|
|
||||||
'--use-tqdm-on-load',
|
|
||||||
dest='use_tqdm_on_load',
|
|
||||||
action=argparse.BooleanOptionalAction,
|
|
||||||
default=EngineArgs.use_tqdm_on_load,
|
|
||||||
help='Whether to enable/disable progress bar '
|
|
||||||
'when loading model weights.',
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--multi-step-stream-outputs',
|
'--multi-step-stream-outputs',
|
||||||
@ -806,15 +793,6 @@ class EngineArgs:
|
|||||||
default=None,
|
default=None,
|
||||||
help='The configurations for speculative decoding.'
|
help='The configurations for speculative decoding.'
|
||||||
' Should be a JSON string.')
|
' Should be a JSON string.')
|
||||||
|
|
||||||
parser.add_argument('--model-loader-extra-config',
|
|
||||||
type=nullable_str,
|
|
||||||
default=EngineArgs.model_loader_extra_config,
|
|
||||||
help='Extra config for model loader. '
|
|
||||||
'This will be passed to the model loader '
|
|
||||||
'corresponding to the chosen load_format. '
|
|
||||||
'This should be a JSON string that will be '
|
|
||||||
'parsed into a dictionary.')
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--ignore-patterns',
|
'--ignore-patterns',
|
||||||
action="append",
|
action="append",
|
||||||
|
@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import concurrent
|
import concurrent
|
||||||
import contextlib
|
import contextlib
|
||||||
@ -25,6 +24,7 @@ import socket
|
|||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import textwrap
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
@ -32,6 +32,8 @@ import types
|
|||||||
import uuid
|
import uuid
|
||||||
import warnings
|
import warnings
|
||||||
import weakref
|
import weakref
|
||||||
|
from argparse import (Action, ArgumentDefaultsHelpFormatter, ArgumentParser,
|
||||||
|
ArgumentTypeError)
|
||||||
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task
|
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task
|
||||||
from collections import UserDict, defaultdict
|
from collections import UserDict, defaultdict
|
||||||
from collections.abc import (AsyncGenerator, Awaitable, Generator, Hashable,
|
from collections.abc import (AsyncGenerator, Awaitable, Generator, Hashable,
|
||||||
@ -1209,7 +1211,7 @@ def run_once(f: Callable[P, None]) -> Callable[P, None]:
|
|||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
class StoreBoolean(argparse.Action):
|
class StoreBoolean(Action):
|
||||||
|
|
||||||
def __call__(self, parser, namespace, values, option_string=None):
|
def __call__(self, parser, namespace, values, option_string=None):
|
||||||
if values.lower() == "true":
|
if values.lower() == "true":
|
||||||
@ -1221,15 +1223,28 @@ class StoreBoolean(argparse.Action):
|
|||||||
"Expected 'true' or 'false'.")
|
"Expected 'true' or 'false'.")
|
||||||
|
|
||||||
|
|
||||||
class SortedHelpFormatter(argparse.ArgumentDefaultsHelpFormatter):
|
class SortedHelpFormatter(ArgumentDefaultsHelpFormatter):
|
||||||
"""SortedHelpFormatter that sorts arguments by their option strings."""
|
"""SortedHelpFormatter that sorts arguments by their option strings."""
|
||||||
|
|
||||||
|
def _split_lines(self, text, width):
|
||||||
|
"""
|
||||||
|
1. Sentences split across lines have their single newlines removed.
|
||||||
|
2. Paragraphs and explicit newlines are split into separate lines.
|
||||||
|
3. Each line is wrapped to the specified width (width of terminal).
|
||||||
|
"""
|
||||||
|
# The patterns also include whitespace after the newline
|
||||||
|
single_newline = re.compile("(?<!\n)\n(?!\n)\s*")
|
||||||
|
multiple_newlines = re.compile("\n{2,}\s*")
|
||||||
|
text = single_newline.sub(' ', text)
|
||||||
|
lines = re.split(multiple_newlines, text)
|
||||||
|
return sum([textwrap.wrap(line, width) for line in lines], [])
|
||||||
|
|
||||||
def add_arguments(self, actions):
|
def add_arguments(self, actions):
|
||||||
actions = sorted(actions, key=lambda x: x.option_strings)
|
actions = sorted(actions, key=lambda x: x.option_strings)
|
||||||
super().add_arguments(actions)
|
super().add_arguments(actions)
|
||||||
|
|
||||||
|
|
||||||
class FlexibleArgumentParser(argparse.ArgumentParser):
|
class FlexibleArgumentParser(ArgumentParser):
|
||||||
"""ArgumentParser that allows both underscore and dash in names."""
|
"""ArgumentParser that allows both underscore and dash in names."""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
@ -1280,11 +1295,10 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
|
|||||||
value = int(value)
|
value = int(value)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
msg = "Port must be an integer"
|
msg = "Port must be an integer"
|
||||||
raise argparse.ArgumentTypeError(msg) from None
|
raise ArgumentTypeError(msg) from None
|
||||||
|
|
||||||
if not (1024 <= value <= 65535):
|
if not (1024 <= value <= 65535):
|
||||||
raise argparse.ArgumentTypeError(
|
raise ArgumentTypeError("Port must be between 1024 and 65535")
|
||||||
"Port must be between 1024 and 65535")
|
|
||||||
|
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user