Improve configs - LoadConfig
(#16422)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
71b9cde010
commit
cd77382ac1
@ -17,7 +17,7 @@ from dataclasses import (MISSING, dataclass, field, fields, is_dataclass,
|
||||
from importlib.util import find_spec
|
||||
from pathlib import Path
|
||||
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal,
|
||||
Optional, Protocol, Union)
|
||||
Optional, Protocol, TypeVar, Union)
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
@ -45,6 +45,7 @@ from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless,
|
||||
random_uuid, resolve_obj_by_qualname)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from _typeshed import DataclassInstance
|
||||
from ray.util.placement_group import PlacementGroup
|
||||
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
@ -53,8 +54,11 @@ if TYPE_CHECKING:
|
||||
from vllm.model_executor.model_loader.loader import BaseModelLoader
|
||||
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
|
||||
BaseTokenizerGroup)
|
||||
|
||||
Config = TypeVar("Config", bound=DataclassInstance)
|
||||
else:
|
||||
QuantizationConfig = None
|
||||
Config = TypeVar("Config")
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -159,7 +163,7 @@ def get_attr_docs(cls: type[Any]) -> dict[str, str]:
|
||||
return out
|
||||
|
||||
|
||||
def config(cls: type[Any]) -> type[Any]:
|
||||
def config(cls: type[Config]) -> type[Config]:
|
||||
"""
|
||||
A decorator that ensures all fields in a dataclass have default values
|
||||
and that each field has a docstring.
|
||||
@ -1431,44 +1435,47 @@ class LoadFormat(str, enum.Enum):
|
||||
FASTSAFETENSORS = "fastsafetensors"
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class LoadConfig:
|
||||
"""
|
||||
download_dir: Directory to download and load the weights, default to the
|
||||
default cache directory of huggingface.
|
||||
load_format: The format of the model weights to load:
|
||||
"auto" will try to load the weights in the safetensors format and
|
||||
fall back to the pytorch bin format if safetensors format is
|
||||
not available.
|
||||
"pt" will load the weights in the pytorch bin format.
|
||||
"safetensors" will load the weights in the safetensors format.
|
||||
"npcache" will load the weights in pytorch format and store
|
||||
a numpy cache to speed up the loading.
|
||||
"dummy" will initialize the weights with random values, which is
|
||||
mainly for profiling.
|
||||
"tensorizer" will use CoreWeave's tensorizer library for
|
||||
fast weight loading.
|
||||
"bitsandbytes" will load nf4 type weights.
|
||||
"sharded_state" will load weights from pre-sharded checkpoint files,
|
||||
supporting efficient loading of tensor-parallel models.
|
||||
"gguf" will load weights from GGUF format files.
|
||||
"mistral" will load weights from consolidated safetensors files used
|
||||
by Mistral models.
|
||||
"runai_streamer" will load weights from RunAI streamer format files.
|
||||
model_loader_extra_config: The extra config for the model loader.
|
||||
ignore_patterns: The list of patterns to ignore when loading the model.
|
||||
Default to "original/**/*" to avoid repeated loading of llama's
|
||||
checkpoints.
|
||||
use_tqdm_on_load: Whether to enable tqdm for showing progress bar during
|
||||
loading. Default to True
|
||||
"""
|
||||
"""Configuration for loading the model weights."""
|
||||
|
||||
load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
|
||||
load_format: Union[str, LoadFormat,
|
||||
"BaseModelLoader"] = LoadFormat.AUTO.value
|
||||
"""The format of the model weights to load:\n
|
||||
- "auto" will try to load the weights in the safetensors format and fall
|
||||
back to the pytorch bin format if safetensors format is not available.\n
|
||||
- "pt" will load the weights in the pytorch bin format.\n
|
||||
- "safetensors" will load the weights in the safetensors format.\n
|
||||
- "npcache" will load the weights in pytorch format and store a numpy cache
|
||||
to speed up the loading.\n
|
||||
- "dummy" will initialize the weights with random values, which is mainly
|
||||
for profiling.\n
|
||||
- "tensorizer" will use CoreWeave's tensorizer library for fast weight
|
||||
loading. See the Tensorize vLLM Model script in the Examples section for
|
||||
more information.\n
|
||||
- "runai_streamer" will load the Safetensors weights using Run:ai Model
|
||||
Streamer.\n
|
||||
- "bitsandbytes" will load the weights using bitsandbytes quantization.\n
|
||||
- "sharded_state" will load weights from pre-sharded checkpoint files,
|
||||
supporting efficient loading of tensor-parallel models.\n
|
||||
- "gguf" will load weights from GGUF format files (details specified in
|
||||
https://github.com/ggml-org/ggml/blob/master/docs/gguf.md).\n
|
||||
- "mistral" will load weights from consolidated safetensors files used by
|
||||
Mistral models."""
|
||||
download_dir: Optional[str] = None
|
||||
model_loader_extra_config: Optional[Union[str, dict]] = field(
|
||||
default_factory=dict)
|
||||
"""Directory to download and load the weights, default to the default
|
||||
cache directory of Hugging Face."""
|
||||
model_loader_extra_config: Optional[Union[str, dict]] = None
|
||||
"""Extra config for model loader. This will be passed to the model loader
|
||||
corresponding to the chosen load_format. This should be a JSON string that
|
||||
will be parsed into a dictionary."""
|
||||
ignore_patterns: Optional[Union[list[str], str]] = None
|
||||
"""The list of patterns to ignore when loading the model. Default to
|
||||
"original/**/*" to avoid repeated loading of llama's checkpoints."""
|
||||
use_tqdm_on_load: bool = True
|
||||
"""Whether to enable tqdm for showing progress bar when loading model
|
||||
weights."""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
|
@ -101,8 +101,8 @@ class EngineArgs:
|
||||
tokenizer_mode: str = 'auto'
|
||||
trust_remote_code: bool = False
|
||||
allowed_local_media_path: str = ""
|
||||
download_dir: Optional[str] = None
|
||||
load_format: str = 'auto'
|
||||
download_dir: Optional[str] = LoadConfig.download_dir
|
||||
load_format: str = LoadConfig.load_format
|
||||
config_format: ConfigFormat = ConfigFormat.AUTO
|
||||
dtype: str = 'auto'
|
||||
kv_cache_dtype: str = 'auto'
|
||||
@ -174,8 +174,10 @@ class EngineArgs:
|
||||
ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight
|
||||
num_gpu_blocks_override: Optional[int] = None
|
||||
num_lookahead_slots: int = 0
|
||||
model_loader_extra_config: Optional[dict] = None
|
||||
ignore_patterns: Optional[Union[str, List[str]]] = None
|
||||
model_loader_extra_config: Optional[
|
||||
dict] = LoadConfig.model_loader_extra_config
|
||||
ignore_patterns: Optional[Union[str,
|
||||
List[str]]] = LoadConfig.ignore_patterns
|
||||
preemption_mode: Optional[str] = None
|
||||
|
||||
scheduler_delay_factor: float = 0.0
|
||||
@ -213,7 +215,7 @@ class EngineArgs:
|
||||
additional_config: Optional[Dict[str, Any]] = None
|
||||
enable_reasoning: Optional[bool] = None
|
||||
reasoning_parser: Optional[str] = None
|
||||
use_tqdm_on_load: bool = True
|
||||
use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.tokenizer:
|
||||
@ -234,9 +236,13 @@ class EngineArgs:
|
||||
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
"""Shared CLI arguments for vLLM engine."""
|
||||
|
||||
def is_type_in_union(cls: type[Any], type: type[Any]) -> bool:
|
||||
"""Check if the class is a type in a union type."""
|
||||
return get_origin(cls) is Union and type in get_args(cls)
|
||||
|
||||
def is_optional(cls: type[Any]) -> bool:
|
||||
"""Check if the class is an optional type."""
|
||||
return get_origin(cls) is Union and type(None) in get_args(cls)
|
||||
return is_type_in_union(cls, type(None))
|
||||
|
||||
def get_kwargs(cls: type[Any]) -> Dict[str, Any]:
|
||||
cls_docs = get_attr_docs(cls)
|
||||
@ -255,6 +261,10 @@ class EngineArgs:
|
||||
if is_optional(field.type):
|
||||
kwargs[name]["type"] = nullable_str
|
||||
continue
|
||||
# Handle str in union fields
|
||||
if is_type_in_union(field.type, str):
|
||||
kwargs[name]["type"] = str
|
||||
continue
|
||||
kwargs[name]["type"] = field.type
|
||||
return kwargs
|
||||
|
||||
@ -333,38 +343,23 @@ class EngineArgs:
|
||||
"from directories specified by the server file system. "
|
||||
"This is a security risk. "
|
||||
"Should only be enabled in trusted environments.")
|
||||
parser.add_argument('--download-dir',
|
||||
type=nullable_str,
|
||||
default=EngineArgs.download_dir,
|
||||
help='Directory to download and load the weights.')
|
||||
parser.add_argument(
|
||||
'--load-format',
|
||||
type=str,
|
||||
default=EngineArgs.load_format,
|
||||
choices=[f.value for f in LoadFormat],
|
||||
help='The format of the model weights to load.\n\n'
|
||||
'* "auto" will try to load the weights in the safetensors format '
|
||||
'and fall back to the pytorch bin format if safetensors format '
|
||||
'is not available.\n'
|
||||
'* "pt" will load the weights in the pytorch bin format.\n'
|
||||
'* "safetensors" will load the weights in the safetensors format.\n'
|
||||
'* "npcache" will load the weights in pytorch format and store '
|
||||
'a numpy cache to speed up the loading.\n'
|
||||
'* "dummy" will initialize the weights with random values, '
|
||||
'which is mainly for profiling.\n'
|
||||
'* "tensorizer" will load the weights using tensorizer from '
|
||||
'CoreWeave. See the Tensorize vLLM Model script in the Examples '
|
||||
'section for more information.\n'
|
||||
'* "runai_streamer" will load the Safetensors weights using Run:ai'
|
||||
'Model Streamer.\n'
|
||||
'* "bitsandbytes" will load the weights using bitsandbytes '
|
||||
'quantization.\n'
|
||||
'* "sharded_state" will load weights from pre-sharded checkpoint '
|
||||
'files, supporting efficient loading of tensor-parallel models\n'
|
||||
'* "gguf" will load weights from GGUF format files (details '
|
||||
'specified in https://github.com/ggml-org/ggml/blob/master/docs/gguf.md).\n'
|
||||
'* "mistral" will load weights from consolidated safetensors files '
|
||||
'used by Mistral models.\n')
|
||||
# Model loading arguments
|
||||
load_kwargs = get_kwargs(LoadConfig)
|
||||
load_group = parser.add_argument_group(
|
||||
title="LoadConfig",
|
||||
description=LoadConfig.__doc__,
|
||||
)
|
||||
load_group.add_argument('--load-format',
|
||||
choices=[f.value for f in LoadFormat],
|
||||
**load_kwargs["load_format"])
|
||||
load_group.add_argument('--download-dir',
|
||||
**load_kwargs["download_dir"])
|
||||
load_group.add_argument('--model-loader-extra-config',
|
||||
**load_kwargs["model_loader_extra_config"])
|
||||
load_group.add_argument('--use-tqdm-on-load',
|
||||
action=argparse.BooleanOptionalAction,
|
||||
**load_kwargs["use_tqdm_on_load"])
|
||||
|
||||
parser.add_argument(
|
||||
'--config-format',
|
||||
default=EngineArgs.config_format,
|
||||
@ -770,14 +765,6 @@ class EngineArgs:
|
||||
default=1,
|
||||
help=('Maximum number of forward steps per '
|
||||
'scheduler call.'))
|
||||
parser.add_argument(
|
||||
'--use-tqdm-on-load',
|
||||
dest='use_tqdm_on_load',
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=EngineArgs.use_tqdm_on_load,
|
||||
help='Whether to enable/disable progress bar '
|
||||
'when loading model weights.',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--multi-step-stream-outputs',
|
||||
@ -806,15 +793,6 @@ class EngineArgs:
|
||||
default=None,
|
||||
help='The configurations for speculative decoding.'
|
||||
' Should be a JSON string.')
|
||||
|
||||
parser.add_argument('--model-loader-extra-config',
|
||||
type=nullable_str,
|
||||
default=EngineArgs.model_loader_extra_config,
|
||||
help='Extra config for model loader. '
|
||||
'This will be passed to the model loader '
|
||||
'corresponding to the chosen load_format. '
|
||||
'This should be a JSON string that will be '
|
||||
'parsed into a dictionary.')
|
||||
parser.add_argument(
|
||||
'--ignore-patterns',
|
||||
action="append",
|
||||
|
@ -2,7 +2,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import concurrent
|
||||
import contextlib
|
||||
@ -25,6 +24,7 @@ import socket
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import textwrap
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
@ -32,6 +32,8 @@ import types
|
||||
import uuid
|
||||
import warnings
|
||||
import weakref
|
||||
from argparse import (Action, ArgumentDefaultsHelpFormatter, ArgumentParser,
|
||||
ArgumentTypeError)
|
||||
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task
|
||||
from collections import UserDict, defaultdict
|
||||
from collections.abc import (AsyncGenerator, Awaitable, Generator, Hashable,
|
||||
@ -1209,7 +1211,7 @@ def run_once(f: Callable[P, None]) -> Callable[P, None]:
|
||||
return wrapper
|
||||
|
||||
|
||||
class StoreBoolean(argparse.Action):
|
||||
class StoreBoolean(Action):
|
||||
|
||||
def __call__(self, parser, namespace, values, option_string=None):
|
||||
if values.lower() == "true":
|
||||
@ -1221,15 +1223,28 @@ class StoreBoolean(argparse.Action):
|
||||
"Expected 'true' or 'false'.")
|
||||
|
||||
|
||||
class SortedHelpFormatter(argparse.ArgumentDefaultsHelpFormatter):
|
||||
class SortedHelpFormatter(ArgumentDefaultsHelpFormatter):
|
||||
"""SortedHelpFormatter that sorts arguments by their option strings."""
|
||||
|
||||
def _split_lines(self, text, width):
|
||||
"""
|
||||
1. Sentences split across lines have their single newlines removed.
|
||||
2. Paragraphs and explicit newlines are split into separate lines.
|
||||
3. Each line is wrapped to the specified width (width of terminal).
|
||||
"""
|
||||
# The patterns also include whitespace after the newline
|
||||
single_newline = re.compile("(?<!\n)\n(?!\n)\s*")
|
||||
multiple_newlines = re.compile("\n{2,}\s*")
|
||||
text = single_newline.sub(' ', text)
|
||||
lines = re.split(multiple_newlines, text)
|
||||
return sum([textwrap.wrap(line, width) for line in lines], [])
|
||||
|
||||
def add_arguments(self, actions):
|
||||
actions = sorted(actions, key=lambda x: x.option_strings)
|
||||
super().add_arguments(actions)
|
||||
|
||||
|
||||
class FlexibleArgumentParser(argparse.ArgumentParser):
|
||||
class FlexibleArgumentParser(ArgumentParser):
|
||||
"""ArgumentParser that allows both underscore and dash in names."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
@ -1280,11 +1295,10 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
|
||||
value = int(value)
|
||||
except ValueError:
|
||||
msg = "Port must be an integer"
|
||||
raise argparse.ArgumentTypeError(msg) from None
|
||||
raise ArgumentTypeError(msg) from None
|
||||
|
||||
if not (1024 <= value <= 65535):
|
||||
raise argparse.ArgumentTypeError(
|
||||
"Port must be between 1024 and 65535")
|
||||
raise ArgumentTypeError("Port must be between 1024 and 65535")
|
||||
|
||||
return value
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user