Improve configs - LoadConfig (#16422)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-04-11 21:27:27 +01:00 committed by GitHub
parent 71b9cde010
commit cd77382ac1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 95 additions and 96 deletions

View File

@ -17,7 +17,7 @@ from dataclasses import (MISSING, dataclass, field, fields, is_dataclass,
from importlib.util import find_spec
from pathlib import Path
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal,
Optional, Protocol, Union)
Optional, Protocol, TypeVar, Union)
import torch
from pydantic import BaseModel, Field, PrivateAttr
@ -45,6 +45,7 @@ from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless,
random_uuid, resolve_obj_by_qualname)
if TYPE_CHECKING:
from _typeshed import DataclassInstance
from ray.util.placement_group import PlacementGroup
from vllm.executor.executor_base import ExecutorBase
@ -53,8 +54,11 @@ if TYPE_CHECKING:
from vllm.model_executor.model_loader.loader import BaseModelLoader
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
BaseTokenizerGroup)
Config = TypeVar("Config", bound=DataclassInstance)
else:
QuantizationConfig = None
Config = TypeVar("Config")
logger = init_logger(__name__)
@ -159,7 +163,7 @@ def get_attr_docs(cls: type[Any]) -> dict[str, str]:
return out
def config(cls: type[Any]) -> type[Any]:
def config(cls: type[Config]) -> type[Config]:
"""
A decorator that ensures all fields in a dataclass have default values
and that each field has a docstring.
@ -1431,44 +1435,47 @@ class LoadFormat(str, enum.Enum):
FASTSAFETENSORS = "fastsafetensors"
@config
@dataclass
class LoadConfig:
"""
download_dir: Directory to download and load the weights, default to the
default cache directory of huggingface.
load_format: The format of the model weights to load:
"auto" will try to load the weights in the safetensors format and
fall back to the pytorch bin format if safetensors format is
not available.
"pt" will load the weights in the pytorch bin format.
"safetensors" will load the weights in the safetensors format.
"npcache" will load the weights in pytorch format and store
a numpy cache to speed up the loading.
"dummy" will initialize the weights with random values, which is
mainly for profiling.
"tensorizer" will use CoreWeave's tensorizer library for
fast weight loading.
"bitsandbytes" will load nf4 type weights.
"sharded_state" will load weights from pre-sharded checkpoint files,
supporting efficient loading of tensor-parallel models.
"gguf" will load weights from GGUF format files.
"mistral" will load weights from consolidated safetensors files used
by Mistral models.
"runai_streamer" will load weights from RunAI streamer format files.
model_loader_extra_config: The extra config for the model loader.
ignore_patterns: The list of patterns to ignore when loading the model.
Default to "original/**/*" to avoid repeated loading of llama's
checkpoints.
use_tqdm_on_load: Whether to enable tqdm for showing progress bar during
loading. Default to True
"""
"""Configuration for loading the model weights."""
load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
load_format: Union[str, LoadFormat,
"BaseModelLoader"] = LoadFormat.AUTO.value
"""The format of the model weights to load:\n
- "auto" will try to load the weights in the safetensors format and fall
back to the pytorch bin format if safetensors format is not available.\n
- "pt" will load the weights in the pytorch bin format.\n
- "safetensors" will load the weights in the safetensors format.\n
- "npcache" will load the weights in pytorch format and store a numpy cache
to speed up the loading.\n
- "dummy" will initialize the weights with random values, which is mainly
for profiling.\n
- "tensorizer" will use CoreWeave's tensorizer library for fast weight
loading. See the Tensorize vLLM Model script in the Examples section for
more information.\n
- "runai_streamer" will load the Safetensors weights using Run:ai Model
Streamer.\n
- "bitsandbytes" will load the weights using bitsandbytes quantization.\n
- "sharded_state" will load weights from pre-sharded checkpoint files,
supporting efficient loading of tensor-parallel models.\n
- "gguf" will load weights from GGUF format files (details specified in
https://github.com/ggml-org/ggml/blob/master/docs/gguf.md).\n
- "mistral" will load weights from consolidated safetensors files used by
Mistral models."""
download_dir: Optional[str] = None
model_loader_extra_config: Optional[Union[str, dict]] = field(
default_factory=dict)
"""Directory to download and load the weights, default to the default
cache directory of Hugging Face."""
model_loader_extra_config: Optional[Union[str, dict]] = None
"""Extra config for model loader. This will be passed to the model loader
corresponding to the chosen load_format. This should be a JSON string that
will be parsed into a dictionary."""
ignore_patterns: Optional[Union[list[str], str]] = None
"""The list of patterns to ignore when loading the model. Default to
"original/**/*" to avoid repeated loading of llama's checkpoints."""
use_tqdm_on_load: bool = True
"""Whether to enable tqdm for showing progress bar when loading model
weights."""
def compute_hash(self) -> str:
"""

View File

@ -101,8 +101,8 @@ class EngineArgs:
tokenizer_mode: str = 'auto'
trust_remote_code: bool = False
allowed_local_media_path: str = ""
download_dir: Optional[str] = None
load_format: str = 'auto'
download_dir: Optional[str] = LoadConfig.download_dir
load_format: str = LoadConfig.load_format
config_format: ConfigFormat = ConfigFormat.AUTO
dtype: str = 'auto'
kv_cache_dtype: str = 'auto'
@ -174,8 +174,10 @@ class EngineArgs:
ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight
num_gpu_blocks_override: Optional[int] = None
num_lookahead_slots: int = 0
model_loader_extra_config: Optional[dict] = None
ignore_patterns: Optional[Union[str, List[str]]] = None
model_loader_extra_config: Optional[
dict] = LoadConfig.model_loader_extra_config
ignore_patterns: Optional[Union[str,
List[str]]] = LoadConfig.ignore_patterns
preemption_mode: Optional[str] = None
scheduler_delay_factor: float = 0.0
@ -213,7 +215,7 @@ class EngineArgs:
additional_config: Optional[Dict[str, Any]] = None
enable_reasoning: Optional[bool] = None
reasoning_parser: Optional[str] = None
use_tqdm_on_load: bool = True
use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
def __post_init__(self):
if not self.tokenizer:
@ -234,9 +236,13 @@ class EngineArgs:
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"""Shared CLI arguments for vLLM engine."""
def is_type_in_union(cls: type[Any], type: type[Any]) -> bool:
"""Check if the class is a type in a union type."""
return get_origin(cls) is Union and type in get_args(cls)
def is_optional(cls: type[Any]) -> bool:
"""Check if the class is an optional type."""
return get_origin(cls) is Union and type(None) in get_args(cls)
return is_type_in_union(cls, type(None))
def get_kwargs(cls: type[Any]) -> Dict[str, Any]:
cls_docs = get_attr_docs(cls)
@ -255,6 +261,10 @@ class EngineArgs:
if is_optional(field.type):
kwargs[name]["type"] = nullable_str
continue
# Handle str in union fields
if is_type_in_union(field.type, str):
kwargs[name]["type"] = str
continue
kwargs[name]["type"] = field.type
return kwargs
@ -333,38 +343,23 @@ class EngineArgs:
"from directories specified by the server file system. "
"This is a security risk. "
"Should only be enabled in trusted environments.")
parser.add_argument('--download-dir',
type=nullable_str,
default=EngineArgs.download_dir,
help='Directory to download and load the weights.')
parser.add_argument(
'--load-format',
type=str,
default=EngineArgs.load_format,
choices=[f.value for f in LoadFormat],
help='The format of the model weights to load.\n\n'
'* "auto" will try to load the weights in the safetensors format '
'and fall back to the pytorch bin format if safetensors format '
'is not available.\n'
'* "pt" will load the weights in the pytorch bin format.\n'
'* "safetensors" will load the weights in the safetensors format.\n'
'* "npcache" will load the weights in pytorch format and store '
'a numpy cache to speed up the loading.\n'
'* "dummy" will initialize the weights with random values, '
'which is mainly for profiling.\n'
'* "tensorizer" will load the weights using tensorizer from '
'CoreWeave. See the Tensorize vLLM Model script in the Examples '
'section for more information.\n'
'* "runai_streamer" will load the Safetensors weights using Run:ai'
'Model Streamer.\n'
'* "bitsandbytes" will load the weights using bitsandbytes '
'quantization.\n'
'* "sharded_state" will load weights from pre-sharded checkpoint '
'files, supporting efficient loading of tensor-parallel models\n'
'* "gguf" will load weights from GGUF format files (details '
'specified in https://github.com/ggml-org/ggml/blob/master/docs/gguf.md).\n'
'* "mistral" will load weights from consolidated safetensors files '
'used by Mistral models.\n')
# Model loading arguments
load_kwargs = get_kwargs(LoadConfig)
load_group = parser.add_argument_group(
title="LoadConfig",
description=LoadConfig.__doc__,
)
load_group.add_argument('--load-format',
choices=[f.value for f in LoadFormat],
**load_kwargs["load_format"])
load_group.add_argument('--download-dir',
**load_kwargs["download_dir"])
load_group.add_argument('--model-loader-extra-config',
**load_kwargs["model_loader_extra_config"])
load_group.add_argument('--use-tqdm-on-load',
action=argparse.BooleanOptionalAction,
**load_kwargs["use_tqdm_on_load"])
parser.add_argument(
'--config-format',
default=EngineArgs.config_format,
@ -770,14 +765,6 @@ class EngineArgs:
default=1,
help=('Maximum number of forward steps per '
'scheduler call.'))
parser.add_argument(
'--use-tqdm-on-load',
dest='use_tqdm_on_load',
action=argparse.BooleanOptionalAction,
default=EngineArgs.use_tqdm_on_load,
help='Whether to enable/disable progress bar '
'when loading model weights.',
)
parser.add_argument(
'--multi-step-stream-outputs',
@ -806,15 +793,6 @@ class EngineArgs:
default=None,
help='The configurations for speculative decoding.'
' Should be a JSON string.')
parser.add_argument('--model-loader-extra-config',
type=nullable_str,
default=EngineArgs.model_loader_extra_config,
help='Extra config for model loader. '
'This will be passed to the model loader '
'corresponding to the chosen load_format. '
'This should be a JSON string that will be '
'parsed into a dictionary.')
parser.add_argument(
'--ignore-patterns',
action="append",

View File

@ -2,7 +2,6 @@
from __future__ import annotations
import argparse
import asyncio
import concurrent
import contextlib
@ -25,6 +24,7 @@ import socket
import subprocess
import sys
import tempfile
import textwrap
import threading
import time
import traceback
@ -32,6 +32,8 @@ import types
import uuid
import warnings
import weakref
from argparse import (Action, ArgumentDefaultsHelpFormatter, ArgumentParser,
ArgumentTypeError)
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task
from collections import UserDict, defaultdict
from collections.abc import (AsyncGenerator, Awaitable, Generator, Hashable,
@ -1209,7 +1211,7 @@ def run_once(f: Callable[P, None]) -> Callable[P, None]:
return wrapper
class StoreBoolean(argparse.Action):
class StoreBoolean(Action):
def __call__(self, parser, namespace, values, option_string=None):
if values.lower() == "true":
@ -1221,15 +1223,28 @@ class StoreBoolean(argparse.Action):
"Expected 'true' or 'false'.")
class SortedHelpFormatter(argparse.ArgumentDefaultsHelpFormatter):
class SortedHelpFormatter(ArgumentDefaultsHelpFormatter):
"""SortedHelpFormatter that sorts arguments by their option strings."""
def _split_lines(self, text, width):
"""
1. Sentences split across lines have their single newlines removed.
2. Paragraphs and explicit newlines are split into separate lines.
3. Each line is wrapped to the specified width (width of terminal).
"""
# The patterns also include whitespace after the newline
single_newline = re.compile("(?<!\n)\n(?!\n)\s*")
multiple_newlines = re.compile("\n{2,}\s*")
text = single_newline.sub(' ', text)
lines = re.split(multiple_newlines, text)
return sum([textwrap.wrap(line, width) for line in lines], [])
def add_arguments(self, actions):
actions = sorted(actions, key=lambda x: x.option_strings)
super().add_arguments(actions)
class FlexibleArgumentParser(argparse.ArgumentParser):
class FlexibleArgumentParser(ArgumentParser):
"""ArgumentParser that allows both underscore and dash in names."""
def __init__(self, *args, **kwargs):
@ -1280,11 +1295,10 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
value = int(value)
except ValueError:
msg = "Port must be an integer"
raise argparse.ArgumentTypeError(msg) from None
raise ArgumentTypeError(msg) from None
if not (1024 <= value <= 65535):
raise argparse.ArgumentTypeError(
"Port must be between 1024 and 65535")
raise ArgumentTypeError("Port must be between 1024 and 65535")
return value