Improve configs - SchedulerConfig (#16533)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-04-14 10:24:16 +01:00 committed by GitHub
parent dc1b4a6f13
commit e51929ebca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 279 additions and 218 deletions

View File

@ -1522,6 +1522,9 @@ class LoadConfig:
self.ignore_patterns = ["original/**/*"] self.ignore_patterns = ["original/**/*"]
DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"]
@config @config
@dataclass @dataclass
class ParallelConfig: class ParallelConfig:
@ -1563,7 +1566,7 @@ class ParallelConfig:
placement_group: Optional["PlacementGroup"] = None placement_group: Optional["PlacementGroup"] = None
"""ray distributed model workers placement group.""" """ray distributed model workers placement group."""
distributed_executor_backend: Optional[Union[str, distributed_executor_backend: Optional[Union[DistributedExecutorBackend,
type["ExecutorBase"]]] = None type["ExecutorBase"]]] = None
"""Backend to use for distributed model """Backend to use for distributed model
workers, either "ray" or "mp" (multiprocessing). If the product workers, either "ray" or "mp" (multiprocessing). If the product
@ -1687,7 +1690,7 @@ class ParallelConfig:
# current node and we aren't in a ray placement group. # current node and we aren't in a ray placement group.
from vllm.executor import ray_utils from vllm.executor import ray_utils
backend = "mp" backend: DistributedExecutorBackend = "mp"
ray_found = ray_utils.ray_is_available() ray_found = ray_utils.ray_is_available()
if current_platform.is_neuron(): if current_platform.is_neuron():
# neuron uses single process to control multiple devices # neuron uses single process to control multiple devices
@ -1755,92 +1758,124 @@ class ParallelConfig:
"worker_extension_cls must be a string (qualified class name).") "worker_extension_cls must be a string (qualified class name).")
SchedulerPolicy = Literal["fcfs", "priority"]
@config
@dataclass @dataclass
class SchedulerConfig: class SchedulerConfig:
"""Scheduler configuration.""" """Scheduler configuration."""
runner_type: str = "generate" # The runner type to launch for the model. runner_type: RunnerType = "generate"
"""The runner type to launch for the model."""
# Maximum number of tokens to be processed in a single iteration. max_num_batched_tokens: int = None # type: ignore
max_num_batched_tokens: int = field(default=None) # type: ignore """Maximum number of tokens to be processed in a single iteration.
This config has no static default. If left unspecified by the user, it will
be set in `EngineArgs.create_engine_config` based on the usage context."""
# Maximum number of sequences to be processed in a single iteration. max_num_seqs: int = None # type: ignore
max_num_seqs: int = 128 """Maximum number of sequences to be processed in a single iteration.
This config has no static default. If left unspecified by the user, it will
be set in `EngineArgs.create_engine_config` based on the usage context."""
# Maximum length of a sequence (including prompt and generated text). max_model_len: int = None # type: ignore
max_model_len: int = 8192 """Maximum length of a sequence (including prompt and generated text). This
is primarily set in `ModelConfig` and that value should be manually
duplicated here."""
# Maximum number of sequences that can be partially prefilled concurrently
max_num_partial_prefills: int = 1 max_num_partial_prefills: int = 1
"""For chunked prefill, the maximum number of sequences that can be
partially prefilled concurrently."""
# Maximum number of "very long prompt" sequences that can be prefilled
# concurrently (long is defined by long_prefill_threshold)
max_long_partial_prefills: int = 1 max_long_partial_prefills: int = 1
"""For chunked prefill, the maximum number of prompts longer than
long_prefill_token_threshold that will be prefilled concurrently. Setting
this less than max_num_partial_prefills will allow shorter prompts to jump
the queue in front of longer prompts in some cases, improving latency."""
# calculate context length that determines which sequences are
# considered "long"
long_prefill_token_threshold: int = 0 long_prefill_token_threshold: int = 0
"""For chunked prefill, a request is considered long if the prompt is
longer than this number of tokens."""
# The number of slots to allocate per sequence per
# step, beyond the known token ids. This is used in speculative
# decoding to store KV activations of tokens which may or may not be
# accepted.
num_lookahead_slots: int = 0 num_lookahead_slots: int = 0
"""The number of slots to allocate per sequence per
step, beyond the known token ids. This is used in speculative
decoding to store KV activations of tokens which may or may not be
accepted.
NOTE: This will be replaced by speculative config in the future; it is
present to enable correctness tests until then."""
# Apply a delay (of delay factor multiplied by previous
# prompt latency) before scheduling next prompt.
delay_factor: float = 0.0 delay_factor: float = 0.0
"""Apply a delay (of delay factor multiplied by previous
prompt latency) before scheduling next prompt."""
# If True, prefill requests can be chunked based enable_chunked_prefill: bool = None # type: ignore
# on the remaining max_num_batched_tokens. """If True, prefill requests can be chunked based
enable_chunked_prefill: bool = False on the remaining max_num_batched_tokens."""
is_multimodal_model: bool = False is_multimodal_model: bool = False
"""True if the model is multimodal."""
# NOTE: The following multimodal encoder budget will be initialized to # TODO (ywang96): Make this configurable.
# max_num_batched_tokens and overridden in case max multimodal embedding max_num_encoder_input_tokens: int = field(init=False)
# size is larger. """Multimodal encoder compute budget, only used in V1.
# TODO (ywang96): Make these configurable.
# Multimodal encoder compute budget, only used in V1 NOTE: This is not currently configurable. It will be overridden by
max_num_encoder_input_tokens: int = field(default=None) # type: ignore max_num_batched_tokens in case max multimodal embedding size is larger."""
# Multimodal encoder cache size, only used in V1 # TODO (ywang96): Make this configurable.
encoder_cache_size: int = field(default=None) # type: ignore encoder_cache_size: int = field(init=False)
"""Multimodal encoder cache size, only used in V1.
NOTE: This is not currently configurable. It will be overridden by
max_num_batched_tokens in case max multimodal embedding size is larger."""
# Whether to perform preemption by swapping or
# recomputation. If not specified, we determine the mode as follows:
# We use recomputation by default since it incurs lower overhead than
# swapping. However, when the sequence group has multiple sequences
# (e.g., beam search), recomputation is not currently supported. In
# such a case, we use swapping instead.
preemption_mode: Optional[str] = None preemption_mode: Optional[str] = None
"""Whether to perform preemption by swapping or
recomputation. If not specified, we determine the mode as follows:
We use recomputation by default since it incurs lower overhead than
swapping. However, when the sequence group has multiple sequences
(e.g., beam search), recomputation is not currently supported. In
such a case, we use swapping instead."""
num_scheduler_steps: int = 1 num_scheduler_steps: int = 1
"""Maximum number of forward steps per scheduler call."""
multi_step_stream_outputs: bool = False multi_step_stream_outputs: bool = True
"""If False, then multi-step will stream outputs at the end of all steps"""
# Private API. If used, scheduler sends delta data to
# workers instead of an entire data. It should be enabled only
# when SPMD worker architecture is enabled. I.e.,
# VLLM_USE_RAY_SPMD_WORKER=1
send_delta_data: bool = False send_delta_data: bool = False
"""Private API. If used, scheduler sends delta data to
workers instead of an entire data. It should be enabled only
when SPMD worker architecture is enabled. I.e.,
VLLM_USE_RAY_SPMD_WORKER=1"""
# The scheduling policy to use. "fcfs" (default) or "priority". policy: SchedulerPolicy = "fcfs"
policy: str = "fcfs" """The scheduling policy to use:\n
- "fcfs" means first come first served, i.e. requests are handled in order
of arrival.\n
- "priority" means requests are handled based on given priority (lower
value means earlier handling) and time of arrival deciding any ties)."""
chunked_prefill_enabled: bool = field(init=False) chunked_prefill_enabled: bool = field(init=False)
"""True if chunked prefill is enabled."""
# If set to true and chunked prefill is enabled, we do not want to
# partially schedule a multimodal item. Only used in V1
# This ensures that if a request has a mixed prompt
# (like text tokens TTTT followed by image tokens IIIIIIIIII) where only
# some image tokens can be scheduled (like TTTTIIIII, leaving IIIII),
# it will be scheduled as TTTT in one step and IIIIIIIIII in the next.
disable_chunked_mm_input: bool = False disable_chunked_mm_input: bool = False
"""If set to true and chunked prefill is enabled, we do not want to
partially schedule a multimodal item. Only used in V1
This ensures that if a request has a mixed prompt
(like text tokens TTTT followed by image tokens IIIIIIIIII) where only
some image tokens can be scheduled (like TTTTIIIII, leaving IIIII),
it will be scheduled as TTTT in one step and IIIIIIIIII in the next."""
# scheduler class or path. "vllm.core.scheduler.Scheduler" (default)
# or "mod.custom_class".
scheduler_cls: Union[str, type[object]] = "vllm.core.scheduler.Scheduler" scheduler_cls: Union[str, type[object]] = "vllm.core.scheduler.Scheduler"
"""The scheduler class to use. "vllm.core.scheduler.Scheduler" is the
default scheduler. Can be a class directly or the path to a class of form
"mod.custom_class"."""
def compute_hash(self) -> str: def compute_hash(self) -> str:
""" """
@ -1862,6 +1897,18 @@ class SchedulerConfig:
return hash_str return hash_str
def __post_init__(self) -> None: def __post_init__(self) -> None:
if self.max_model_len is None:
self.max_model_len = 8192
logger.warning(
"max_model_len was is not set. Defaulting to arbitrary value "
"of %d.", self.max_model_len)
if self.max_num_seqs is None:
self.max_num_seqs = 128
logger.warning(
"max_num_seqs was is not set. Defaulting to arbitrary value "
"of %d.", self.max_num_seqs)
if self.max_num_batched_tokens is None: if self.max_num_batched_tokens is None:
if self.enable_chunked_prefill: if self.enable_chunked_prefill:
if self.num_scheduler_steps > 1: if self.num_scheduler_steps > 1:

View File

@ -1,25 +1,30 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# yapf: disable
import argparse import argparse
import dataclasses import dataclasses
import json import json
import re import re
import threading import threading
from dataclasses import MISSING, dataclass, fields from dataclasses import MISSING, dataclass, fields
from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional, from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Literal, Mapping,
Tuple, Type, Union, cast, get_args, get_origin) Optional, Tuple, Type, TypeVar, Union, cast, get_args,
get_origin)
import torch import torch
from typing_extensions import TypeIs
import vllm.envs as envs import vllm.envs as envs
from vllm import version from vllm import version
from vllm.config import (CacheConfig, CompilationConfig, ConfigFormat, from vllm.config import (CacheConfig, CompilationConfig, ConfigFormat,
DecodingConfig, DeviceConfig, HfOverrides, DecodingConfig, DeviceConfig,
DistributedExecutorBackend, HfOverrides,
KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig, KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig,
ModelConfig, ModelImpl, ObservabilityConfig, ModelConfig, ModelImpl, ObservabilityConfig,
ParallelConfig, PoolerConfig, PromptAdapterConfig, ParallelConfig, PoolerConfig, PromptAdapterConfig,
SchedulerConfig, SpeculativeConfig, TaskOption, SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
TokenizerPoolConfig, VllmConfig, get_attr_docs) TaskOption, TokenizerPoolConfig, VllmConfig,
get_attr_docs)
from vllm.executor.executor_base import ExecutorBase from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
@ -28,7 +33,9 @@ from vllm.reasoning import ReasoningParserManager
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3 from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
from vllm.transformers_utils.utils import check_gguf_file from vllm.transformers_utils.utils import check_gguf_file
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, StoreBoolean, is_in_ray_actor from vllm.utils import FlexibleArgumentParser, is_in_ray_actor
# yapf: enable
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
@ -47,11 +54,32 @@ DEVICE_OPTIONS = [
"hpu", "hpu",
] ]
# object is used to allow for special typing forms
T = TypeVar("T")
TypeHint = Union[type[Any], object]
TypeHintT = Union[type[T], object]
def nullable_str(val: str):
if not val or val == "None": def optional_arg(val: str, return_type: type[T]) -> Optional[T]:
if val == "" or val == "None":
return None return None
return val try:
return cast(Callable, return_type)(val)
except ValueError as e:
raise argparse.ArgumentTypeError(
f"Value {val} cannot be converted to {return_type}.") from e
def optional_str(val: str) -> Optional[str]:
return optional_arg(val, str)
def optional_int(val: str) -> Optional[int]:
return optional_arg(val, int)
def optional_float(val: str) -> Optional[float]:
return optional_arg(val, float)
def nullable_kvs(val: str) -> Optional[Mapping[str, int]]: def nullable_kvs(val: str) -> Optional[Mapping[str, int]]:
@ -112,7 +140,8 @@ class EngineArgs:
# is intended for expert use only. The API may change without # is intended for expert use only. The API may change without
# notice. # notice.
distributed_executor_backend: Optional[Union[ distributed_executor_backend: Optional[Union[
str, Type[ExecutorBase]]] = ParallelConfig.distributed_executor_backend DistributedExecutorBackend,
Type[ExecutorBase]]] = ParallelConfig.distributed_executor_backend
# number of P/D disaggregation (or other disaggregation) workers # number of P/D disaggregation (or other disaggregation) workers
pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
@ -129,11 +158,13 @@ class EngineArgs:
swap_space: float = 4 # GiB swap_space: float = 4 # GiB
cpu_offload_gb: float = 0 # GiB cpu_offload_gb: float = 0 # GiB
gpu_memory_utilization: float = 0.90 gpu_memory_utilization: float = 0.90
max_num_batched_tokens: Optional[int] = None max_num_batched_tokens: Optional[
max_num_partial_prefills: Optional[int] = 1 int] = SchedulerConfig.max_num_batched_tokens
max_long_partial_prefills: Optional[int] = 1 max_num_partial_prefills: int = SchedulerConfig.max_num_partial_prefills
long_prefill_token_threshold: Optional[int] = 0 max_long_partial_prefills: int = SchedulerConfig.max_long_partial_prefills
max_num_seqs: Optional[int] = None long_prefill_token_threshold: int = \
SchedulerConfig.long_prefill_token_threshold
max_num_seqs: Optional[int] = SchedulerConfig.max_num_seqs
max_logprobs: int = 20 # Default value for OpenAI Chat Completions API max_logprobs: int = 20 # Default value for OpenAI Chat Completions API
disable_log_stats: bool = False disable_log_stats: bool = False
revision: Optional[str] = None revision: Optional[str] = None
@ -169,20 +200,21 @@ class EngineArgs:
lora_dtype: Optional[Union[str, torch.dtype]] = 'auto' lora_dtype: Optional[Union[str, torch.dtype]] = 'auto'
max_cpu_loras: Optional[int] = None max_cpu_loras: Optional[int] = None
device: str = 'auto' device: str = 'auto'
num_scheduler_steps: int = 1 num_scheduler_steps: int = SchedulerConfig.num_scheduler_steps
multi_step_stream_outputs: bool = True multi_step_stream_outputs: bool = SchedulerConfig.multi_step_stream_outputs
ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight
num_gpu_blocks_override: Optional[int] = None num_gpu_blocks_override: Optional[int] = None
num_lookahead_slots: int = 0 num_lookahead_slots: int = SchedulerConfig.num_lookahead_slots
model_loader_extra_config: Optional[ model_loader_extra_config: Optional[
dict] = LoadConfig.model_loader_extra_config dict] = LoadConfig.model_loader_extra_config
ignore_patterns: Optional[Union[str, ignore_patterns: Optional[Union[str,
List[str]]] = LoadConfig.ignore_patterns List[str]]] = LoadConfig.ignore_patterns
preemption_mode: Optional[str] = None preemption_mode: Optional[str] = SchedulerConfig.preemption_mode
scheduler_delay_factor: float = 0.0 scheduler_delay_factor: float = SchedulerConfig.delay_factor
enable_chunked_prefill: Optional[bool] = None enable_chunked_prefill: Optional[
disable_chunked_mm_input: bool = False bool] = SchedulerConfig.enable_chunked_prefill
disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input
guided_decoding_backend: str = DecodingConfig.guided_decoding_backend guided_decoding_backend: str = DecodingConfig.guided_decoding_backend
logits_processor_pattern: Optional[str] = None logits_processor_pattern: Optional[str] = None
@ -194,8 +226,8 @@ class EngineArgs:
otlp_traces_endpoint: Optional[str] = None otlp_traces_endpoint: Optional[str] = None
collect_detailed_traces: Optional[str] = None collect_detailed_traces: Optional[str] = None
disable_async_output_proc: bool = False disable_async_output_proc: bool = False
scheduling_policy: Literal["fcfs", "priority"] = "fcfs" scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
scheduler_cls: Union[str, Type[object]] = "vllm.core.scheduler.Scheduler" scheduler_cls: Union[str, Type[object]] = SchedulerConfig.scheduler_cls
override_neuron_config: Optional[Dict[str, Any]] = None override_neuron_config: Optional[Dict[str, Any]] = None
override_pooler_config: Optional[PoolerConfig] = None override_pooler_config: Optional[PoolerConfig] = None
@ -236,15 +268,33 @@ class EngineArgs:
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"""Shared CLI arguments for vLLM engine.""" """Shared CLI arguments for vLLM engine."""
def is_type_in_union(cls: type[Any], type: type[Any]) -> bool: def is_type_in_union(cls: TypeHint, type: TypeHint) -> bool:
"""Check if the class is a type in a union type.""" """Check if the class is a type in a union type."""
return get_origin(cls) is Union and type in get_args(cls) is_union = get_origin(cls) is Union
type_in_union = type in [get_origin(a) or a for a in get_args(cls)]
return is_union and type_in_union
def is_optional(cls: type[Any]) -> bool: def get_type_from_union(cls: TypeHint, type: TypeHintT) -> TypeHintT:
"""Get the type in a union type."""
for arg in get_args(cls):
if (get_origin(arg) or arg) is type:
return arg
raise ValueError(f"Type {type} not found in union type {cls}.")
def is_optional(cls: TypeHint) -> TypeIs[Union[Any, None]]:
"""Check if the class is an optional type.""" """Check if the class is an optional type."""
return is_type_in_union(cls, type(None)) return is_type_in_union(cls, type(None))
def get_kwargs(cls: type[Any]) -> Dict[str, Any]: def can_be_type(cls: TypeHint, type: TypeHintT) -> TypeIs[TypeHintT]:
"""Check if the class can be of type."""
return cls is type or get_origin(cls) is type or is_type_in_union(
cls, type)
def is_custom_type(cls: TypeHint) -> bool:
"""Check if the class is a custom type."""
return cls.__module__ != "builtins"
def get_kwargs(cls: type[Any]) -> dict[str, Any]:
cls_docs = get_attr_docs(cls) cls_docs = get_attr_docs(cls)
kwargs = {} kwargs = {}
for field in fields(cls): for field in fields(cls):
@ -253,19 +303,41 @@ class EngineArgs:
default = (field.default_factory default = (field.default_factory
if field.default is MISSING else field.default) if field.default is MISSING else field.default)
kwargs[name] = {"default": default, "help": cls_docs[name]} kwargs[name] = {"default": default, "help": cls_docs[name]}
# When using action="store_true"
# add_argument doesn't accept type # Make note of if the field is optional and get the actual
if field.type is bool: # type of the field if it is
continue optional = is_optional(field.type)
# Handle optional fields field_type = get_args(
if is_optional(field.type): field.type)[0] if optional else field.type
kwargs[name]["type"] = nullable_str
continue if can_be_type(field_type, bool):
# Handle str in union fields # Creates --no-<name> and --<name> flags
if is_type_in_union(field.type, str): kwargs[name]["action"] = argparse.BooleanOptionalAction
kwargs[name]["type"] = str kwargs[name]["type"] = bool
continue elif can_be_type(field_type, Literal):
kwargs[name]["type"] = field.type # Creates choices from Literal arguments
if is_type_in_union(field_type, Literal):
field_type = get_type_from_union(field_type, Literal)
choices = get_args(field_type)
kwargs[name]["choices"] = choices
choice_type = type(choices[0])
assert all(type(c) is choice_type for c in choices), (
f"All choices must be of the same type. "
f"Got {choices} with types {[type(c) for c in choices]}"
)
kwargs[name]["type"] = choice_type
elif can_be_type(field_type, int):
kwargs[name]["type"] = optional_int if optional else int
elif can_be_type(field_type, float):
kwargs[name][
"type"] = optional_float if optional else float
elif (can_be_type(field_type, str)
or can_be_type(field_type, dict)
or is_custom_type(field_type)):
kwargs[name]["type"] = optional_str if optional else str
else:
raise ValueError(
f"Unsupported type {field.type} for argument {name}. ")
return kwargs return kwargs
# Model arguments # Model arguments
@ -285,13 +357,13 @@ class EngineArgs:
'which task to use.') 'which task to use.')
parser.add_argument( parser.add_argument(
'--tokenizer', '--tokenizer',
type=nullable_str, type=optional_str,
default=EngineArgs.tokenizer, default=EngineArgs.tokenizer,
help='Name or path of the huggingface tokenizer to use. ' help='Name or path of the huggingface tokenizer to use. '
'If unspecified, model name or path will be used.') 'If unspecified, model name or path will be used.')
parser.add_argument( parser.add_argument(
"--hf-config-path", "--hf-config-path",
type=nullable_str, type=optional_str,
default=EngineArgs.hf_config_path, default=EngineArgs.hf_config_path,
help='Name or path of the huggingface config to use. ' help='Name or path of the huggingface config to use. '
'If unspecified, model name or path will be used.') 'If unspecified, model name or path will be used.')
@ -303,21 +375,21 @@ class EngineArgs:
'the input. The generated output will contain token ids.') 'the input. The generated output will contain token ids.')
parser.add_argument( parser.add_argument(
'--revision', '--revision',
type=nullable_str, type=optional_str,
default=None, default=None,
help='The specific model version to use. It can be a branch ' help='The specific model version to use. It can be a branch '
'name, a tag name, or a commit id. If unspecified, will use ' 'name, a tag name, or a commit id. If unspecified, will use '
'the default version.') 'the default version.')
parser.add_argument( parser.add_argument(
'--code-revision', '--code-revision',
type=nullable_str, type=optional_str,
default=None, default=None,
help='The specific revision to use for the model code on ' help='The specific revision to use for the model code on '
'Hugging Face Hub. It can be a branch name, a tag name, or a ' 'Hugging Face Hub. It can be a branch name, a tag name, or a '
'commit id. If unspecified, will use the default version.') 'commit id. If unspecified, will use the default version.')
parser.add_argument( parser.add_argument(
'--tokenizer-revision', '--tokenizer-revision',
type=nullable_str, type=optional_str,
default=None, default=None,
help='Revision of the huggingface tokenizer to use. ' help='Revision of the huggingface tokenizer to use. '
'It can be a branch name, a tag name, or a commit id. ' 'It can be a branch name, a tag name, or a commit id. '
@ -357,7 +429,6 @@ class EngineArgs:
load_group.add_argument('--model-loader-extra-config', load_group.add_argument('--model-loader-extra-config',
**load_kwargs["model_loader_extra_config"]) **load_kwargs["model_loader_extra_config"])
load_group.add_argument('--use-tqdm-on-load', load_group.add_argument('--use-tqdm-on-load',
action=argparse.BooleanOptionalAction,
**load_kwargs["use_tqdm_on_load"]) **load_kwargs["use_tqdm_on_load"])
parser.add_argument( parser.add_argument(
@ -413,7 +484,7 @@ class EngineArgs:
'the behavior is subject to change in each release.') 'the behavior is subject to change in each release.')
parser.add_argument( parser.add_argument(
'--logits-processor-pattern', '--logits-processor-pattern',
type=nullable_str, type=optional_str,
default=None, default=None,
help='Optional regex pattern specifying valid logits processor ' help='Optional regex pattern specifying valid logits processor '
'qualified names that can be passed with the `logits_processors` ' 'qualified names that can be passed with the `logits_processors` '
@ -439,7 +510,6 @@ class EngineArgs:
) )
parallel_group.add_argument( parallel_group.add_argument(
'--distributed-executor-backend', '--distributed-executor-backend',
choices=['ray', 'mp', 'uni', 'external_launcher'],
**parallel_kwargs["distributed_executor_backend"]) **parallel_kwargs["distributed_executor_backend"])
parallel_group.add_argument( parallel_group.add_argument(
'--pipeline-parallel-size', '-pp', '--pipeline-parallel-size', '-pp',
@ -450,18 +520,15 @@ class EngineArgs:
**parallel_kwargs["data_parallel_size"]) **parallel_kwargs["data_parallel_size"])
parallel_group.add_argument( parallel_group.add_argument(
'--enable-expert-parallel', '--enable-expert-parallel',
action='store_true',
**parallel_kwargs["enable_expert_parallel"]) **parallel_kwargs["enable_expert_parallel"])
parallel_group.add_argument( parallel_group.add_argument(
'--max-parallel-loading-workers', '--max-parallel-loading-workers',
**parallel_kwargs["max_parallel_loading_workers"]) **parallel_kwargs["max_parallel_loading_workers"])
parallel_group.add_argument( parallel_group.add_argument(
'--ray-workers-use-nsight', '--ray-workers-use-nsight',
action='store_true',
**parallel_kwargs["ray_workers_use_nsight"]) **parallel_kwargs["ray_workers_use_nsight"])
parallel_group.add_argument( parallel_group.add_argument(
'--disable-custom-all-reduce', '--disable-custom-all-reduce',
action='store_true',
**parallel_kwargs["disable_custom_all_reduce"]) **parallel_kwargs["disable_custom_all_reduce"])
# KV cache arguments # KV cache arguments
parser.add_argument('--block-size', parser.add_argument('--block-size',
@ -502,14 +569,6 @@ class EngineArgs:
'block manager v2) is now the default. ' 'block manager v2) is now the default. '
'Setting this flag to True or False' 'Setting this flag to True or False'
' has no effect on vLLM behavior.') ' has no effect on vLLM behavior.')
parser.add_argument(
'--num-lookahead-slots',
type=int,
default=EngineArgs.num_lookahead_slots,
help='Experimental scheduling config necessary for '
'speculative decoding. This will be replaced by '
'speculative config in the future; it is present '
'to enable correctness tests until then.')
parser.add_argument('--seed', parser.add_argument('--seed',
type=int, type=int,
@ -552,36 +611,6 @@ class EngineArgs:
default=None, default=None,
help='If specified, ignore GPU profiling result and use this number' help='If specified, ignore GPU profiling result and use this number'
' of GPU blocks. Used for testing preemption.') ' of GPU blocks. Used for testing preemption.')
parser.add_argument('--max-num-batched-tokens',
type=int,
default=EngineArgs.max_num_batched_tokens,
help='Maximum number of batched tokens per '
'iteration.')
parser.add_argument(
"--max-num-partial-prefills",
type=int,
default=EngineArgs.max_num_partial_prefills,
help="For chunked prefill, the max number of concurrent \
partial prefills.")
parser.add_argument(
"--max-long-partial-prefills",
type=int,
default=EngineArgs.max_long_partial_prefills,
help="For chunked prefill, the maximum number of prompts longer "
"than --long-prefill-token-threshold that will be prefilled "
"concurrently. Setting this less than --max-num-partial-prefills "
"will allow shorter prompts to jump the queue in front of longer "
"prompts in some cases, improving latency.")
parser.add_argument(
"--long-prefill-token-threshold",
type=float,
default=EngineArgs.long_prefill_token_threshold,
help="For chunked prefill, a request is considered long if the "
"prompt is longer than this number of tokens.")
parser.add_argument('--max-num-seqs',
type=int,
default=EngineArgs.max_num_seqs,
help='Maximum number of sequences per iteration.')
parser.add_argument( parser.add_argument(
'--max-logprobs', '--max-logprobs',
type=int, type=int,
@ -594,7 +623,7 @@ class EngineArgs:
# Quantization settings. # Quantization settings.
parser.add_argument('--quantization', parser.add_argument('--quantization',
'-q', '-q',
type=nullable_str, type=optional_str,
choices=[*QUANTIZATION_METHODS, None], choices=[*QUANTIZATION_METHODS, None],
default=EngineArgs.quantization, default=EngineArgs.quantization,
help='Method used to quantize the weights. If ' help='Method used to quantize the weights. If '
@ -658,7 +687,7 @@ class EngineArgs:
'asynchronous tokenization. Ignored ' 'asynchronous tokenization. Ignored '
'if tokenizer_pool_size is 0.') 'if tokenizer_pool_size is 0.')
parser.add_argument('--tokenizer-pool-extra-config', parser.add_argument('--tokenizer-pool-extra-config',
type=nullable_str, type=optional_str,
default=EngineArgs.tokenizer_pool_extra_config, default=EngineArgs.tokenizer_pool_extra_config,
help='Extra config for tokenizer pool. ' help='Extra config for tokenizer pool. '
'This should be a JSON string that will be ' 'This should be a JSON string that will be '
@ -721,7 +750,7 @@ class EngineArgs:
'base model dtype.')) 'base model dtype.'))
parser.add_argument( parser.add_argument(
'--long-lora-scaling-factors', '--long-lora-scaling-factors',
type=nullable_str, type=optional_str,
default=EngineArgs.long_lora_scaling_factors, default=EngineArgs.long_lora_scaling_factors,
help=('Specify multiple scaling factors (which can ' help=('Specify multiple scaling factors (which can '
'be different from base model scaling factor ' 'be different from base model scaling factor '
@ -766,28 +795,6 @@ class EngineArgs:
help=('Maximum number of forward steps per ' help=('Maximum number of forward steps per '
'scheduler call.')) 'scheduler call.'))
parser.add_argument(
'--multi-step-stream-outputs',
action=StoreBoolean,
default=EngineArgs.multi_step_stream_outputs,
nargs="?",
const="True",
help='If False, then multi-step will stream outputs at the end '
'of all steps')
parser.add_argument(
'--scheduler-delay-factor',
type=float,
default=EngineArgs.scheduler_delay_factor,
help='Apply a delay (of delay factor multiplied by previous '
'prompt latency) before scheduling next prompt.')
parser.add_argument(
'--enable-chunked-prefill',
action=StoreBoolean,
default=EngineArgs.enable_chunked_prefill,
nargs="?",
const="True",
help='If set, the prefill requests can be chunked based on the '
'max_num_batched_tokens.')
parser.add_argument('--speculative-config', parser.add_argument('--speculative-config',
type=json.loads, type=json.loads,
default=None, default=None,
@ -863,22 +870,43 @@ class EngineArgs:
help="Disable async output processing. This may result in " help="Disable async output processing. This may result in "
"lower performance.") "lower performance.")
parser.add_argument( # Scheduler arguments
'--scheduling-policy', scheduler_kwargs = get_kwargs(SchedulerConfig)
choices=['fcfs', 'priority'], scheduler_group = parser.add_argument_group(
default="fcfs", title="SchedulerConfig",
help='The scheduling policy to use. "fcfs" (first come first served' description=SchedulerConfig.__doc__,
', i.e. requests are handled in order of arrival; default) ' )
'or "priority" (requests are handled based on given ' scheduler_group.add_argument(
'priority (lower value means earlier handling) and time of ' '--max-num-batched-tokens',
'arrival deciding any ties).') **scheduler_kwargs["max_num_batched_tokens"])
scheduler_group.add_argument('--max-num-seqs',
parser.add_argument( **scheduler_kwargs["max_num_seqs"])
'--scheduler-cls', scheduler_group.add_argument(
default=EngineArgs.scheduler_cls, "--max-num-partial-prefills",
help='The scheduler class to use. "vllm.core.scheduler.Scheduler" ' **scheduler_kwargs["max_num_partial_prefills"])
'is the default scheduler. Can be a class directly or the path to ' scheduler_group.add_argument(
'a class of form "mod.custom_class".') "--max-long-partial-prefills",
**scheduler_kwargs["max_long_partial_prefills"])
scheduler_group.add_argument(
"--long-prefill-token-threshold",
**scheduler_kwargs["long_prefill_token_threshold"])
scheduler_group.add_argument('--num-lookahead-slots',
**scheduler_kwargs["num_lookahead_slots"])
scheduler_group.add_argument('--scheduler-delay-factor',
**scheduler_kwargs["delay_factor"])
scheduler_group.add_argument(
'--enable-chunked-prefill',
**scheduler_kwargs["enable_chunked_prefill"])
scheduler_group.add_argument(
'--multi-step-stream-outputs',
**scheduler_kwargs["multi_step_stream_outputs"])
scheduler_group.add_argument('--scheduling-policy',
**scheduler_kwargs["policy"])
scheduler_group.add_argument(
"--disable-chunked-mm-input",
**scheduler_kwargs["disable_chunked_mm_input"])
parser.add_argument('--scheduler-cls',
**scheduler_kwargs["scheduler_cls"])
parser.add_argument( parser.add_argument(
'--override-neuron-config', '--override-neuron-config',
@ -930,7 +958,7 @@ class EngineArgs:
'class without changing the existing functions.') 'class without changing the existing functions.')
parser.add_argument( parser.add_argument(
"--generation-config", "--generation-config",
type=nullable_str, type=optional_str,
default="auto", default="auto",
help="The folder path to the generation config. " help="The folder path to the generation config. "
"Defaults to 'auto', the generation config will be loaded from " "Defaults to 'auto', the generation config will be loaded from "
@ -1003,20 +1031,6 @@ class EngineArgs:
"Note that even if this is set to False, cascade attention will be " "Note that even if this is set to False, cascade attention will be "
"only used when the heuristic tells that it's beneficial.") "only used when the heuristic tells that it's beneficial.")
parser.add_argument(
"--disable-chunked-mm-input",
action=StoreBoolean,
default=EngineArgs.disable_chunked_mm_input,
nargs="?",
const="True",
help="Disable multimodal input chunking attention for V1. "
"If set to true and chunked prefill is enabled, we do not want to"
" partially schedule a multimodal item. This ensures that if a "
"request has a mixed prompt (like text tokens TTTT followed by "
"image tokens IIIIIIIIII) where only some image tokens can be "
"scheduled (like TTTTIIIII, leaving IIIII), it will be scheduled "
"as TTTT in one step and IIIIIIIIII in the next.")
return parser return parser
@classmethod @classmethod
@ -1370,7 +1384,7 @@ class EngineArgs:
recommend_to_remove=False) recommend_to_remove=False)
return False return False
if self.preemption_mode != EngineArgs.preemption_mode: if self.preemption_mode != SchedulerConfig.preemption_mode:
_raise_or_fallback(feature_name="--preemption-mode", _raise_or_fallback(feature_name="--preemption-mode",
recommend_to_remove=True) recommend_to_remove=True)
return False return False
@ -1381,17 +1395,17 @@ class EngineArgs:
recommend_to_remove=True) recommend_to_remove=True)
return False return False
if self.scheduling_policy != EngineArgs.scheduling_policy: if self.scheduling_policy != SchedulerConfig.policy:
_raise_or_fallback(feature_name="--scheduling-policy", _raise_or_fallback(feature_name="--scheduling-policy",
recommend_to_remove=False) recommend_to_remove=False)
return False return False
if self.num_scheduler_steps != EngineArgs.num_scheduler_steps: if self.num_scheduler_steps != SchedulerConfig.num_scheduler_steps:
_raise_or_fallback(feature_name="--num-scheduler-steps", _raise_or_fallback(feature_name="--num-scheduler-steps",
recommend_to_remove=True) recommend_to_remove=True)
return False return False
if self.scheduler_delay_factor != EngineArgs.scheduler_delay_factor: if self.scheduler_delay_factor != SchedulerConfig.delay_factor:
_raise_or_fallback(feature_name="--scheduler-delay-factor", _raise_or_fallback(feature_name="--scheduler-delay-factor",
recommend_to_remove=True) recommend_to_remove=True)
return False return False
@ -1475,9 +1489,9 @@ class EngineArgs:
# No Concurrent Partial Prefills so far. # No Concurrent Partial Prefills so far.
if (self.max_num_partial_prefills if (self.max_num_partial_prefills
!= EngineArgs.max_num_partial_prefills != SchedulerConfig.max_num_partial_prefills
or self.max_long_partial_prefills or self.max_long_partial_prefills
!= EngineArgs.max_long_partial_prefills): != SchedulerConfig.max_long_partial_prefills):
_raise_or_fallback(feature_name="Concurrent Partial Prefill", _raise_or_fallback(feature_name="Concurrent Partial Prefill",
recommend_to_remove=False) recommend_to_remove=False)
return False return False

View File

@ -11,7 +11,7 @@ import ssl
from collections.abc import Sequence from collections.abc import Sequence
from typing import Optional, Union, get_args from typing import Optional, Union, get_args
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str from vllm.engine.arg_utils import AsyncEngineArgs, optional_str
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption, from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
validate_chat_template) validate_chat_template)
from vllm.entrypoints.openai.serving_models import (LoRAModulePath, from vllm.entrypoints.openai.serving_models import (LoRAModulePath,
@ -79,7 +79,7 @@ class PromptAdapterParserAction(argparse.Action):
def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser.add_argument("--host", parser.add_argument("--host",
type=nullable_str, type=optional_str,
default=None, default=None,
help="Host name.") help="Host name.")
parser.add_argument("--port", type=int, default=8000, help="Port number.") parser.add_argument("--port", type=int, default=8000, help="Port number.")
@ -108,13 +108,13 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
default=["*"], default=["*"],
help="Allowed headers.") help="Allowed headers.")
parser.add_argument("--api-key", parser.add_argument("--api-key",
type=nullable_str, type=optional_str,
default=None, default=None,
help="If provided, the server will require this key " help="If provided, the server will require this key "
"to be presented in the header.") "to be presented in the header.")
parser.add_argument( parser.add_argument(
"--lora-modules", "--lora-modules",
type=nullable_str, type=optional_str,
default=None, default=None,
nargs='+', nargs='+',
action=LoRAParserAction, action=LoRAParserAction,
@ -126,14 +126,14 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"\"base_model_name\": \"id\"}``") "\"base_model_name\": \"id\"}``")
parser.add_argument( parser.add_argument(
"--prompt-adapters", "--prompt-adapters",
type=nullable_str, type=optional_str,
default=None, default=None,
nargs='+', nargs='+',
action=PromptAdapterParserAction, action=PromptAdapterParserAction,
help="Prompt adapter configurations in the format name=path. " help="Prompt adapter configurations in the format name=path. "
"Multiple adapters can be specified.") "Multiple adapters can be specified.")
parser.add_argument("--chat-template", parser.add_argument("--chat-template",
type=nullable_str, type=optional_str,
default=None, default=None,
help="The file path to the chat template, " help="The file path to the chat template, "
"or the template in single-line form " "or the template in single-line form "
@ -151,20 +151,20 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
'similar to OpenAI schema. ' 'similar to OpenAI schema. '
'Example: ``[{"type": "text", "text": "Hello world!"}]``') 'Example: ``[{"type": "text", "text": "Hello world!"}]``')
parser.add_argument("--response-role", parser.add_argument("--response-role",
type=nullable_str, type=optional_str,
default="assistant", default="assistant",
help="The role name to return if " help="The role name to return if "
"``request.add_generation_prompt=true``.") "``request.add_generation_prompt=true``.")
parser.add_argument("--ssl-keyfile", parser.add_argument("--ssl-keyfile",
type=nullable_str, type=optional_str,
default=None, default=None,
help="The file path to the SSL key file.") help="The file path to the SSL key file.")
parser.add_argument("--ssl-certfile", parser.add_argument("--ssl-certfile",
type=nullable_str, type=optional_str,
default=None, default=None,
help="The file path to the SSL cert file.") help="The file path to the SSL cert file.")
parser.add_argument("--ssl-ca-certs", parser.add_argument("--ssl-ca-certs",
type=nullable_str, type=optional_str,
default=None, default=None,
help="The CA certificates file.") help="The CA certificates file.")
parser.add_argument( parser.add_argument(
@ -180,13 +180,13 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
) )
parser.add_argument( parser.add_argument(
"--root-path", "--root-path",
type=nullable_str, type=optional_str,
default=None, default=None,
help="FastAPI root_path when app is behind a path based routing proxy." help="FastAPI root_path when app is behind a path based routing proxy."
) )
parser.add_argument( parser.add_argument(
"--middleware", "--middleware",
type=nullable_str, type=optional_str,
action="append", action="append",
default=[], default=[],
help="Additional ASGI middleware to apply to the app. " help="Additional ASGI middleware to apply to the app. "

View File

@ -12,7 +12,7 @@ import torch
from prometheus_client import start_http_server from prometheus_client import start_http_server
from tqdm import tqdm from tqdm import tqdm
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str from vllm.engine.arg_utils import AsyncEngineArgs, optional_str
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.logger import RequestLogger, logger from vllm.entrypoints.logger import RequestLogger, logger
# yapf: disable # yapf: disable
@ -61,7 +61,7 @@ def parse_args():
"to the output URL.", "to the output URL.",
) )
parser.add_argument("--response-role", parser.add_argument("--response-role",
type=nullable_str, type=optional_str,
default="assistant", default="assistant",
help="The role name to return if " help="The role name to return if "
"`request.add_generation_prompt=True`.") "`request.add_generation_prompt=True`.")