Improve configs - SchedulerConfig
(#16533)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
dc1b4a6f13
commit
e51929ebca
151
vllm/config.py
151
vllm/config.py
@ -1522,6 +1522,9 @@ class LoadConfig:
|
|||||||
self.ignore_patterns = ["original/**/*"]
|
self.ignore_patterns = ["original/**/*"]
|
||||||
|
|
||||||
|
|
||||||
|
DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"]
|
||||||
|
|
||||||
|
|
||||||
@config
|
@config
|
||||||
@dataclass
|
@dataclass
|
||||||
class ParallelConfig:
|
class ParallelConfig:
|
||||||
@ -1563,7 +1566,7 @@ class ParallelConfig:
|
|||||||
placement_group: Optional["PlacementGroup"] = None
|
placement_group: Optional["PlacementGroup"] = None
|
||||||
"""ray distributed model workers placement group."""
|
"""ray distributed model workers placement group."""
|
||||||
|
|
||||||
distributed_executor_backend: Optional[Union[str,
|
distributed_executor_backend: Optional[Union[DistributedExecutorBackend,
|
||||||
type["ExecutorBase"]]] = None
|
type["ExecutorBase"]]] = None
|
||||||
"""Backend to use for distributed model
|
"""Backend to use for distributed model
|
||||||
workers, either "ray" or "mp" (multiprocessing). If the product
|
workers, either "ray" or "mp" (multiprocessing). If the product
|
||||||
@ -1687,7 +1690,7 @@ class ParallelConfig:
|
|||||||
# current node and we aren't in a ray placement group.
|
# current node and we aren't in a ray placement group.
|
||||||
|
|
||||||
from vllm.executor import ray_utils
|
from vllm.executor import ray_utils
|
||||||
backend = "mp"
|
backend: DistributedExecutorBackend = "mp"
|
||||||
ray_found = ray_utils.ray_is_available()
|
ray_found = ray_utils.ray_is_available()
|
||||||
if current_platform.is_neuron():
|
if current_platform.is_neuron():
|
||||||
# neuron uses single process to control multiple devices
|
# neuron uses single process to control multiple devices
|
||||||
@ -1755,92 +1758,124 @@ class ParallelConfig:
|
|||||||
"worker_extension_cls must be a string (qualified class name).")
|
"worker_extension_cls must be a string (qualified class name).")
|
||||||
|
|
||||||
|
|
||||||
|
SchedulerPolicy = Literal["fcfs", "priority"]
|
||||||
|
|
||||||
|
|
||||||
|
@config
|
||||||
@dataclass
|
@dataclass
|
||||||
class SchedulerConfig:
|
class SchedulerConfig:
|
||||||
"""Scheduler configuration."""
|
"""Scheduler configuration."""
|
||||||
|
|
||||||
runner_type: str = "generate" # The runner type to launch for the model.
|
runner_type: RunnerType = "generate"
|
||||||
|
"""The runner type to launch for the model."""
|
||||||
|
|
||||||
# Maximum number of tokens to be processed in a single iteration.
|
max_num_batched_tokens: int = None # type: ignore
|
||||||
max_num_batched_tokens: int = field(default=None) # type: ignore
|
"""Maximum number of tokens to be processed in a single iteration.
|
||||||
|
|
||||||
# Maximum number of sequences to be processed in a single iteration.
|
This config has no static default. If left unspecified by the user, it will
|
||||||
max_num_seqs: int = 128
|
be set in `EngineArgs.create_engine_config` based on the usage context."""
|
||||||
|
|
||||||
# Maximum length of a sequence (including prompt and generated text).
|
max_num_seqs: int = None # type: ignore
|
||||||
max_model_len: int = 8192
|
"""Maximum number of sequences to be processed in a single iteration.
|
||||||
|
|
||||||
|
This config has no static default. If left unspecified by the user, it will
|
||||||
|
be set in `EngineArgs.create_engine_config` based on the usage context."""
|
||||||
|
|
||||||
|
max_model_len: int = None # type: ignore
|
||||||
|
"""Maximum length of a sequence (including prompt and generated text). This
|
||||||
|
is primarily set in `ModelConfig` and that value should be manually
|
||||||
|
duplicated here."""
|
||||||
|
|
||||||
# Maximum number of sequences that can be partially prefilled concurrently
|
|
||||||
max_num_partial_prefills: int = 1
|
max_num_partial_prefills: int = 1
|
||||||
|
"""For chunked prefill, the maximum number of sequences that can be
|
||||||
|
partially prefilled concurrently."""
|
||||||
|
|
||||||
# Maximum number of "very long prompt" sequences that can be prefilled
|
|
||||||
# concurrently (long is defined by long_prefill_threshold)
|
|
||||||
max_long_partial_prefills: int = 1
|
max_long_partial_prefills: int = 1
|
||||||
|
"""For chunked prefill, the maximum number of prompts longer than
|
||||||
|
long_prefill_token_threshold that will be prefilled concurrently. Setting
|
||||||
|
this less than max_num_partial_prefills will allow shorter prompts to jump
|
||||||
|
the queue in front of longer prompts in some cases, improving latency."""
|
||||||
|
|
||||||
# calculate context length that determines which sequences are
|
|
||||||
# considered "long"
|
|
||||||
long_prefill_token_threshold: int = 0
|
long_prefill_token_threshold: int = 0
|
||||||
|
"""For chunked prefill, a request is considered long if the prompt is
|
||||||
|
longer than this number of tokens."""
|
||||||
|
|
||||||
# The number of slots to allocate per sequence per
|
|
||||||
# step, beyond the known token ids. This is used in speculative
|
|
||||||
# decoding to store KV activations of tokens which may or may not be
|
|
||||||
# accepted.
|
|
||||||
num_lookahead_slots: int = 0
|
num_lookahead_slots: int = 0
|
||||||
|
"""The number of slots to allocate per sequence per
|
||||||
|
step, beyond the known token ids. This is used in speculative
|
||||||
|
decoding to store KV activations of tokens which may or may not be
|
||||||
|
accepted.
|
||||||
|
|
||||||
|
NOTE: This will be replaced by speculative config in the future; it is
|
||||||
|
present to enable correctness tests until then."""
|
||||||
|
|
||||||
# Apply a delay (of delay factor multiplied by previous
|
|
||||||
# prompt latency) before scheduling next prompt.
|
|
||||||
delay_factor: float = 0.0
|
delay_factor: float = 0.0
|
||||||
|
"""Apply a delay (of delay factor multiplied by previous
|
||||||
|
prompt latency) before scheduling next prompt."""
|
||||||
|
|
||||||
# If True, prefill requests can be chunked based
|
enable_chunked_prefill: bool = None # type: ignore
|
||||||
# on the remaining max_num_batched_tokens.
|
"""If True, prefill requests can be chunked based
|
||||||
enable_chunked_prefill: bool = False
|
on the remaining max_num_batched_tokens."""
|
||||||
|
|
||||||
is_multimodal_model: bool = False
|
is_multimodal_model: bool = False
|
||||||
|
"""True if the model is multimodal."""
|
||||||
|
|
||||||
# NOTE: The following multimodal encoder budget will be initialized to
|
# TODO (ywang96): Make this configurable.
|
||||||
# max_num_batched_tokens and overridden in case max multimodal embedding
|
max_num_encoder_input_tokens: int = field(init=False)
|
||||||
# size is larger.
|
"""Multimodal encoder compute budget, only used in V1.
|
||||||
# TODO (ywang96): Make these configurable.
|
|
||||||
# Multimodal encoder compute budget, only used in V1
|
|
||||||
max_num_encoder_input_tokens: int = field(default=None) # type: ignore
|
|
||||||
|
|
||||||
# Multimodal encoder cache size, only used in V1
|
NOTE: This is not currently configurable. It will be overridden by
|
||||||
encoder_cache_size: int = field(default=None) # type: ignore
|
max_num_batched_tokens in case max multimodal embedding size is larger."""
|
||||||
|
|
||||||
|
# TODO (ywang96): Make this configurable.
|
||||||
|
encoder_cache_size: int = field(init=False)
|
||||||
|
"""Multimodal encoder cache size, only used in V1.
|
||||||
|
|
||||||
|
NOTE: This is not currently configurable. It will be overridden by
|
||||||
|
max_num_batched_tokens in case max multimodal embedding size is larger."""
|
||||||
|
|
||||||
# Whether to perform preemption by swapping or
|
|
||||||
# recomputation. If not specified, we determine the mode as follows:
|
|
||||||
# We use recomputation by default since it incurs lower overhead than
|
|
||||||
# swapping. However, when the sequence group has multiple sequences
|
|
||||||
# (e.g., beam search), recomputation is not currently supported. In
|
|
||||||
# such a case, we use swapping instead.
|
|
||||||
preemption_mode: Optional[str] = None
|
preemption_mode: Optional[str] = None
|
||||||
|
"""Whether to perform preemption by swapping or
|
||||||
|
recomputation. If not specified, we determine the mode as follows:
|
||||||
|
We use recomputation by default since it incurs lower overhead than
|
||||||
|
swapping. However, when the sequence group has multiple sequences
|
||||||
|
(e.g., beam search), recomputation is not currently supported. In
|
||||||
|
such a case, we use swapping instead."""
|
||||||
|
|
||||||
num_scheduler_steps: int = 1
|
num_scheduler_steps: int = 1
|
||||||
|
"""Maximum number of forward steps per scheduler call."""
|
||||||
|
|
||||||
multi_step_stream_outputs: bool = False
|
multi_step_stream_outputs: bool = True
|
||||||
|
"""If False, then multi-step will stream outputs at the end of all steps"""
|
||||||
|
|
||||||
# Private API. If used, scheduler sends delta data to
|
|
||||||
# workers instead of an entire data. It should be enabled only
|
|
||||||
# when SPMD worker architecture is enabled. I.e.,
|
|
||||||
# VLLM_USE_RAY_SPMD_WORKER=1
|
|
||||||
send_delta_data: bool = False
|
send_delta_data: bool = False
|
||||||
|
"""Private API. If used, scheduler sends delta data to
|
||||||
|
workers instead of an entire data. It should be enabled only
|
||||||
|
when SPMD worker architecture is enabled. I.e.,
|
||||||
|
VLLM_USE_RAY_SPMD_WORKER=1"""
|
||||||
|
|
||||||
# The scheduling policy to use. "fcfs" (default) or "priority".
|
policy: SchedulerPolicy = "fcfs"
|
||||||
policy: str = "fcfs"
|
"""The scheduling policy to use:\n
|
||||||
|
- "fcfs" means first come first served, i.e. requests are handled in order
|
||||||
|
of arrival.\n
|
||||||
|
- "priority" means requests are handled based on given priority (lower
|
||||||
|
value means earlier handling) and time of arrival deciding any ties)."""
|
||||||
|
|
||||||
chunked_prefill_enabled: bool = field(init=False)
|
chunked_prefill_enabled: bool = field(init=False)
|
||||||
|
"""True if chunked prefill is enabled."""
|
||||||
|
|
||||||
# If set to true and chunked prefill is enabled, we do not want to
|
|
||||||
# partially schedule a multimodal item. Only used in V1
|
|
||||||
# This ensures that if a request has a mixed prompt
|
|
||||||
# (like text tokens TTTT followed by image tokens IIIIIIIIII) where only
|
|
||||||
# some image tokens can be scheduled (like TTTTIIIII, leaving IIIII),
|
|
||||||
# it will be scheduled as TTTT in one step and IIIIIIIIII in the next.
|
|
||||||
disable_chunked_mm_input: bool = False
|
disable_chunked_mm_input: bool = False
|
||||||
|
"""If set to true and chunked prefill is enabled, we do not want to
|
||||||
|
partially schedule a multimodal item. Only used in V1
|
||||||
|
This ensures that if a request has a mixed prompt
|
||||||
|
(like text tokens TTTT followed by image tokens IIIIIIIIII) where only
|
||||||
|
some image tokens can be scheduled (like TTTTIIIII, leaving IIIII),
|
||||||
|
it will be scheduled as TTTT in one step and IIIIIIIIII in the next."""
|
||||||
|
|
||||||
# scheduler class or path. "vllm.core.scheduler.Scheduler" (default)
|
|
||||||
# or "mod.custom_class".
|
|
||||||
scheduler_cls: Union[str, type[object]] = "vllm.core.scheduler.Scheduler"
|
scheduler_cls: Union[str, type[object]] = "vllm.core.scheduler.Scheduler"
|
||||||
|
"""The scheduler class to use. "vllm.core.scheduler.Scheduler" is the
|
||||||
|
default scheduler. Can be a class directly or the path to a class of form
|
||||||
|
"mod.custom_class"."""
|
||||||
|
|
||||||
def compute_hash(self) -> str:
|
def compute_hash(self) -> str:
|
||||||
"""
|
"""
|
||||||
@ -1862,6 +1897,18 @@ class SchedulerConfig:
|
|||||||
return hash_str
|
return hash_str
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
|
if self.max_model_len is None:
|
||||||
|
self.max_model_len = 8192
|
||||||
|
logger.warning(
|
||||||
|
"max_model_len was is not set. Defaulting to arbitrary value "
|
||||||
|
"of %d.", self.max_model_len)
|
||||||
|
|
||||||
|
if self.max_num_seqs is None:
|
||||||
|
self.max_num_seqs = 128
|
||||||
|
logger.warning(
|
||||||
|
"max_num_seqs was is not set. Defaulting to arbitrary value "
|
||||||
|
"of %d.", self.max_num_seqs)
|
||||||
|
|
||||||
if self.max_num_batched_tokens is None:
|
if self.max_num_batched_tokens is None:
|
||||||
if self.enable_chunked_prefill:
|
if self.enable_chunked_prefill:
|
||||||
if self.num_scheduler_steps > 1:
|
if self.num_scheduler_steps > 1:
|
||||||
|
@ -1,25 +1,30 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
# yapf: disable
|
||||||
import argparse
|
import argparse
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
import threading
|
import threading
|
||||||
from dataclasses import MISSING, dataclass, fields
|
from dataclasses import MISSING, dataclass, fields
|
||||||
from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional,
|
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Literal, Mapping,
|
||||||
Tuple, Type, Union, cast, get_args, get_origin)
|
Optional, Tuple, Type, TypeVar, Union, cast, get_args,
|
||||||
|
get_origin)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from typing_extensions import TypeIs
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm import version
|
from vllm import version
|
||||||
from vllm.config import (CacheConfig, CompilationConfig, ConfigFormat,
|
from vllm.config import (CacheConfig, CompilationConfig, ConfigFormat,
|
||||||
DecodingConfig, DeviceConfig, HfOverrides,
|
DecodingConfig, DeviceConfig,
|
||||||
|
DistributedExecutorBackend, HfOverrides,
|
||||||
KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig,
|
KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig,
|
||||||
ModelConfig, ModelImpl, ObservabilityConfig,
|
ModelConfig, ModelImpl, ObservabilityConfig,
|
||||||
ParallelConfig, PoolerConfig, PromptAdapterConfig,
|
ParallelConfig, PoolerConfig, PromptAdapterConfig,
|
||||||
SchedulerConfig, SpeculativeConfig, TaskOption,
|
SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
|
||||||
TokenizerPoolConfig, VllmConfig, get_attr_docs)
|
TaskOption, TokenizerPoolConfig, VllmConfig,
|
||||||
|
get_attr_docs)
|
||||||
from vllm.executor.executor_base import ExecutorBase
|
from vllm.executor.executor_base import ExecutorBase
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||||
@ -28,7 +33,9 @@ from vllm.reasoning import ReasoningParserManager
|
|||||||
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
|
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
|
||||||
from vllm.transformers_utils.utils import check_gguf_file
|
from vllm.transformers_utils.utils import check_gguf_file
|
||||||
from vllm.usage.usage_lib import UsageContext
|
from vllm.usage.usage_lib import UsageContext
|
||||||
from vllm.utils import FlexibleArgumentParser, StoreBoolean, is_in_ray_actor
|
from vllm.utils import FlexibleArgumentParser, is_in_ray_actor
|
||||||
|
|
||||||
|
# yapf: enable
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
|
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
|
||||||
@ -47,11 +54,32 @@ DEVICE_OPTIONS = [
|
|||||||
"hpu",
|
"hpu",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# object is used to allow for special typing forms
|
||||||
|
T = TypeVar("T")
|
||||||
|
TypeHint = Union[type[Any], object]
|
||||||
|
TypeHintT = Union[type[T], object]
|
||||||
|
|
||||||
def nullable_str(val: str):
|
|
||||||
if not val or val == "None":
|
def optional_arg(val: str, return_type: type[T]) -> Optional[T]:
|
||||||
|
if val == "" or val == "None":
|
||||||
return None
|
return None
|
||||||
return val
|
try:
|
||||||
|
return cast(Callable, return_type)(val)
|
||||||
|
except ValueError as e:
|
||||||
|
raise argparse.ArgumentTypeError(
|
||||||
|
f"Value {val} cannot be converted to {return_type}.") from e
|
||||||
|
|
||||||
|
|
||||||
|
def optional_str(val: str) -> Optional[str]:
|
||||||
|
return optional_arg(val, str)
|
||||||
|
|
||||||
|
|
||||||
|
def optional_int(val: str) -> Optional[int]:
|
||||||
|
return optional_arg(val, int)
|
||||||
|
|
||||||
|
|
||||||
|
def optional_float(val: str) -> Optional[float]:
|
||||||
|
return optional_arg(val, float)
|
||||||
|
|
||||||
|
|
||||||
def nullable_kvs(val: str) -> Optional[Mapping[str, int]]:
|
def nullable_kvs(val: str) -> Optional[Mapping[str, int]]:
|
||||||
@ -112,7 +140,8 @@ class EngineArgs:
|
|||||||
# is intended for expert use only. The API may change without
|
# is intended for expert use only. The API may change without
|
||||||
# notice.
|
# notice.
|
||||||
distributed_executor_backend: Optional[Union[
|
distributed_executor_backend: Optional[Union[
|
||||||
str, Type[ExecutorBase]]] = ParallelConfig.distributed_executor_backend
|
DistributedExecutorBackend,
|
||||||
|
Type[ExecutorBase]]] = ParallelConfig.distributed_executor_backend
|
||||||
# number of P/D disaggregation (or other disaggregation) workers
|
# number of P/D disaggregation (or other disaggregation) workers
|
||||||
pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
|
pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
|
||||||
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
|
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
|
||||||
@ -129,11 +158,13 @@ class EngineArgs:
|
|||||||
swap_space: float = 4 # GiB
|
swap_space: float = 4 # GiB
|
||||||
cpu_offload_gb: float = 0 # GiB
|
cpu_offload_gb: float = 0 # GiB
|
||||||
gpu_memory_utilization: float = 0.90
|
gpu_memory_utilization: float = 0.90
|
||||||
max_num_batched_tokens: Optional[int] = None
|
max_num_batched_tokens: Optional[
|
||||||
max_num_partial_prefills: Optional[int] = 1
|
int] = SchedulerConfig.max_num_batched_tokens
|
||||||
max_long_partial_prefills: Optional[int] = 1
|
max_num_partial_prefills: int = SchedulerConfig.max_num_partial_prefills
|
||||||
long_prefill_token_threshold: Optional[int] = 0
|
max_long_partial_prefills: int = SchedulerConfig.max_long_partial_prefills
|
||||||
max_num_seqs: Optional[int] = None
|
long_prefill_token_threshold: int = \
|
||||||
|
SchedulerConfig.long_prefill_token_threshold
|
||||||
|
max_num_seqs: Optional[int] = SchedulerConfig.max_num_seqs
|
||||||
max_logprobs: int = 20 # Default value for OpenAI Chat Completions API
|
max_logprobs: int = 20 # Default value for OpenAI Chat Completions API
|
||||||
disable_log_stats: bool = False
|
disable_log_stats: bool = False
|
||||||
revision: Optional[str] = None
|
revision: Optional[str] = None
|
||||||
@ -169,20 +200,21 @@ class EngineArgs:
|
|||||||
lora_dtype: Optional[Union[str, torch.dtype]] = 'auto'
|
lora_dtype: Optional[Union[str, torch.dtype]] = 'auto'
|
||||||
max_cpu_loras: Optional[int] = None
|
max_cpu_loras: Optional[int] = None
|
||||||
device: str = 'auto'
|
device: str = 'auto'
|
||||||
num_scheduler_steps: int = 1
|
num_scheduler_steps: int = SchedulerConfig.num_scheduler_steps
|
||||||
multi_step_stream_outputs: bool = True
|
multi_step_stream_outputs: bool = SchedulerConfig.multi_step_stream_outputs
|
||||||
ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight
|
ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight
|
||||||
num_gpu_blocks_override: Optional[int] = None
|
num_gpu_blocks_override: Optional[int] = None
|
||||||
num_lookahead_slots: int = 0
|
num_lookahead_slots: int = SchedulerConfig.num_lookahead_slots
|
||||||
model_loader_extra_config: Optional[
|
model_loader_extra_config: Optional[
|
||||||
dict] = LoadConfig.model_loader_extra_config
|
dict] = LoadConfig.model_loader_extra_config
|
||||||
ignore_patterns: Optional[Union[str,
|
ignore_patterns: Optional[Union[str,
|
||||||
List[str]]] = LoadConfig.ignore_patterns
|
List[str]]] = LoadConfig.ignore_patterns
|
||||||
preemption_mode: Optional[str] = None
|
preemption_mode: Optional[str] = SchedulerConfig.preemption_mode
|
||||||
|
|
||||||
scheduler_delay_factor: float = 0.0
|
scheduler_delay_factor: float = SchedulerConfig.delay_factor
|
||||||
enable_chunked_prefill: Optional[bool] = None
|
enable_chunked_prefill: Optional[
|
||||||
disable_chunked_mm_input: bool = False
|
bool] = SchedulerConfig.enable_chunked_prefill
|
||||||
|
disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input
|
||||||
|
|
||||||
guided_decoding_backend: str = DecodingConfig.guided_decoding_backend
|
guided_decoding_backend: str = DecodingConfig.guided_decoding_backend
|
||||||
logits_processor_pattern: Optional[str] = None
|
logits_processor_pattern: Optional[str] = None
|
||||||
@ -194,8 +226,8 @@ class EngineArgs:
|
|||||||
otlp_traces_endpoint: Optional[str] = None
|
otlp_traces_endpoint: Optional[str] = None
|
||||||
collect_detailed_traces: Optional[str] = None
|
collect_detailed_traces: Optional[str] = None
|
||||||
disable_async_output_proc: bool = False
|
disable_async_output_proc: bool = False
|
||||||
scheduling_policy: Literal["fcfs", "priority"] = "fcfs"
|
scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
|
||||||
scheduler_cls: Union[str, Type[object]] = "vllm.core.scheduler.Scheduler"
|
scheduler_cls: Union[str, Type[object]] = SchedulerConfig.scheduler_cls
|
||||||
|
|
||||||
override_neuron_config: Optional[Dict[str, Any]] = None
|
override_neuron_config: Optional[Dict[str, Any]] = None
|
||||||
override_pooler_config: Optional[PoolerConfig] = None
|
override_pooler_config: Optional[PoolerConfig] = None
|
||||||
@ -236,15 +268,33 @@ class EngineArgs:
|
|||||||
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||||
"""Shared CLI arguments for vLLM engine."""
|
"""Shared CLI arguments for vLLM engine."""
|
||||||
|
|
||||||
def is_type_in_union(cls: type[Any], type: type[Any]) -> bool:
|
def is_type_in_union(cls: TypeHint, type: TypeHint) -> bool:
|
||||||
"""Check if the class is a type in a union type."""
|
"""Check if the class is a type in a union type."""
|
||||||
return get_origin(cls) is Union and type in get_args(cls)
|
is_union = get_origin(cls) is Union
|
||||||
|
type_in_union = type in [get_origin(a) or a for a in get_args(cls)]
|
||||||
|
return is_union and type_in_union
|
||||||
|
|
||||||
def is_optional(cls: type[Any]) -> bool:
|
def get_type_from_union(cls: TypeHint, type: TypeHintT) -> TypeHintT:
|
||||||
|
"""Get the type in a union type."""
|
||||||
|
for arg in get_args(cls):
|
||||||
|
if (get_origin(arg) or arg) is type:
|
||||||
|
return arg
|
||||||
|
raise ValueError(f"Type {type} not found in union type {cls}.")
|
||||||
|
|
||||||
|
def is_optional(cls: TypeHint) -> TypeIs[Union[Any, None]]:
|
||||||
"""Check if the class is an optional type."""
|
"""Check if the class is an optional type."""
|
||||||
return is_type_in_union(cls, type(None))
|
return is_type_in_union(cls, type(None))
|
||||||
|
|
||||||
def get_kwargs(cls: type[Any]) -> Dict[str, Any]:
|
def can_be_type(cls: TypeHint, type: TypeHintT) -> TypeIs[TypeHintT]:
|
||||||
|
"""Check if the class can be of type."""
|
||||||
|
return cls is type or get_origin(cls) is type or is_type_in_union(
|
||||||
|
cls, type)
|
||||||
|
|
||||||
|
def is_custom_type(cls: TypeHint) -> bool:
|
||||||
|
"""Check if the class is a custom type."""
|
||||||
|
return cls.__module__ != "builtins"
|
||||||
|
|
||||||
|
def get_kwargs(cls: type[Any]) -> dict[str, Any]:
|
||||||
cls_docs = get_attr_docs(cls)
|
cls_docs = get_attr_docs(cls)
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
for field in fields(cls):
|
for field in fields(cls):
|
||||||
@ -253,19 +303,41 @@ class EngineArgs:
|
|||||||
default = (field.default_factory
|
default = (field.default_factory
|
||||||
if field.default is MISSING else field.default)
|
if field.default is MISSING else field.default)
|
||||||
kwargs[name] = {"default": default, "help": cls_docs[name]}
|
kwargs[name] = {"default": default, "help": cls_docs[name]}
|
||||||
# When using action="store_true"
|
|
||||||
# add_argument doesn't accept type
|
# Make note of if the field is optional and get the actual
|
||||||
if field.type is bool:
|
# type of the field if it is
|
||||||
continue
|
optional = is_optional(field.type)
|
||||||
# Handle optional fields
|
field_type = get_args(
|
||||||
if is_optional(field.type):
|
field.type)[0] if optional else field.type
|
||||||
kwargs[name]["type"] = nullable_str
|
|
||||||
continue
|
if can_be_type(field_type, bool):
|
||||||
# Handle str in union fields
|
# Creates --no-<name> and --<name> flags
|
||||||
if is_type_in_union(field.type, str):
|
kwargs[name]["action"] = argparse.BooleanOptionalAction
|
||||||
kwargs[name]["type"] = str
|
kwargs[name]["type"] = bool
|
||||||
continue
|
elif can_be_type(field_type, Literal):
|
||||||
kwargs[name]["type"] = field.type
|
# Creates choices from Literal arguments
|
||||||
|
if is_type_in_union(field_type, Literal):
|
||||||
|
field_type = get_type_from_union(field_type, Literal)
|
||||||
|
choices = get_args(field_type)
|
||||||
|
kwargs[name]["choices"] = choices
|
||||||
|
choice_type = type(choices[0])
|
||||||
|
assert all(type(c) is choice_type for c in choices), (
|
||||||
|
f"All choices must be of the same type. "
|
||||||
|
f"Got {choices} with types {[type(c) for c in choices]}"
|
||||||
|
)
|
||||||
|
kwargs[name]["type"] = choice_type
|
||||||
|
elif can_be_type(field_type, int):
|
||||||
|
kwargs[name]["type"] = optional_int if optional else int
|
||||||
|
elif can_be_type(field_type, float):
|
||||||
|
kwargs[name][
|
||||||
|
"type"] = optional_float if optional else float
|
||||||
|
elif (can_be_type(field_type, str)
|
||||||
|
or can_be_type(field_type, dict)
|
||||||
|
or is_custom_type(field_type)):
|
||||||
|
kwargs[name]["type"] = optional_str if optional else str
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported type {field.type} for argument {name}. ")
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
# Model arguments
|
# Model arguments
|
||||||
@ -285,13 +357,13 @@ class EngineArgs:
|
|||||||
'which task to use.')
|
'which task to use.')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--tokenizer',
|
'--tokenizer',
|
||||||
type=nullable_str,
|
type=optional_str,
|
||||||
default=EngineArgs.tokenizer,
|
default=EngineArgs.tokenizer,
|
||||||
help='Name or path of the huggingface tokenizer to use. '
|
help='Name or path of the huggingface tokenizer to use. '
|
||||||
'If unspecified, model name or path will be used.')
|
'If unspecified, model name or path will be used.')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--hf-config-path",
|
"--hf-config-path",
|
||||||
type=nullable_str,
|
type=optional_str,
|
||||||
default=EngineArgs.hf_config_path,
|
default=EngineArgs.hf_config_path,
|
||||||
help='Name or path of the huggingface config to use. '
|
help='Name or path of the huggingface config to use. '
|
||||||
'If unspecified, model name or path will be used.')
|
'If unspecified, model name or path will be used.')
|
||||||
@ -303,21 +375,21 @@ class EngineArgs:
|
|||||||
'the input. The generated output will contain token ids.')
|
'the input. The generated output will contain token ids.')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--revision',
|
'--revision',
|
||||||
type=nullable_str,
|
type=optional_str,
|
||||||
default=None,
|
default=None,
|
||||||
help='The specific model version to use. It can be a branch '
|
help='The specific model version to use. It can be a branch '
|
||||||
'name, a tag name, or a commit id. If unspecified, will use '
|
'name, a tag name, or a commit id. If unspecified, will use '
|
||||||
'the default version.')
|
'the default version.')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--code-revision',
|
'--code-revision',
|
||||||
type=nullable_str,
|
type=optional_str,
|
||||||
default=None,
|
default=None,
|
||||||
help='The specific revision to use for the model code on '
|
help='The specific revision to use for the model code on '
|
||||||
'Hugging Face Hub. It can be a branch name, a tag name, or a '
|
'Hugging Face Hub. It can be a branch name, a tag name, or a '
|
||||||
'commit id. If unspecified, will use the default version.')
|
'commit id. If unspecified, will use the default version.')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--tokenizer-revision',
|
'--tokenizer-revision',
|
||||||
type=nullable_str,
|
type=optional_str,
|
||||||
default=None,
|
default=None,
|
||||||
help='Revision of the huggingface tokenizer to use. '
|
help='Revision of the huggingface tokenizer to use. '
|
||||||
'It can be a branch name, a tag name, or a commit id. '
|
'It can be a branch name, a tag name, or a commit id. '
|
||||||
@ -357,7 +429,6 @@ class EngineArgs:
|
|||||||
load_group.add_argument('--model-loader-extra-config',
|
load_group.add_argument('--model-loader-extra-config',
|
||||||
**load_kwargs["model_loader_extra_config"])
|
**load_kwargs["model_loader_extra_config"])
|
||||||
load_group.add_argument('--use-tqdm-on-load',
|
load_group.add_argument('--use-tqdm-on-load',
|
||||||
action=argparse.BooleanOptionalAction,
|
|
||||||
**load_kwargs["use_tqdm_on_load"])
|
**load_kwargs["use_tqdm_on_load"])
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -413,7 +484,7 @@ class EngineArgs:
|
|||||||
'the behavior is subject to change in each release.')
|
'the behavior is subject to change in each release.')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--logits-processor-pattern',
|
'--logits-processor-pattern',
|
||||||
type=nullable_str,
|
type=optional_str,
|
||||||
default=None,
|
default=None,
|
||||||
help='Optional regex pattern specifying valid logits processor '
|
help='Optional regex pattern specifying valid logits processor '
|
||||||
'qualified names that can be passed with the `logits_processors` '
|
'qualified names that can be passed with the `logits_processors` '
|
||||||
@ -439,7 +510,6 @@ class EngineArgs:
|
|||||||
)
|
)
|
||||||
parallel_group.add_argument(
|
parallel_group.add_argument(
|
||||||
'--distributed-executor-backend',
|
'--distributed-executor-backend',
|
||||||
choices=['ray', 'mp', 'uni', 'external_launcher'],
|
|
||||||
**parallel_kwargs["distributed_executor_backend"])
|
**parallel_kwargs["distributed_executor_backend"])
|
||||||
parallel_group.add_argument(
|
parallel_group.add_argument(
|
||||||
'--pipeline-parallel-size', '-pp',
|
'--pipeline-parallel-size', '-pp',
|
||||||
@ -450,18 +520,15 @@ class EngineArgs:
|
|||||||
**parallel_kwargs["data_parallel_size"])
|
**parallel_kwargs["data_parallel_size"])
|
||||||
parallel_group.add_argument(
|
parallel_group.add_argument(
|
||||||
'--enable-expert-parallel',
|
'--enable-expert-parallel',
|
||||||
action='store_true',
|
|
||||||
**parallel_kwargs["enable_expert_parallel"])
|
**parallel_kwargs["enable_expert_parallel"])
|
||||||
parallel_group.add_argument(
|
parallel_group.add_argument(
|
||||||
'--max-parallel-loading-workers',
|
'--max-parallel-loading-workers',
|
||||||
**parallel_kwargs["max_parallel_loading_workers"])
|
**parallel_kwargs["max_parallel_loading_workers"])
|
||||||
parallel_group.add_argument(
|
parallel_group.add_argument(
|
||||||
'--ray-workers-use-nsight',
|
'--ray-workers-use-nsight',
|
||||||
action='store_true',
|
|
||||||
**parallel_kwargs["ray_workers_use_nsight"])
|
**parallel_kwargs["ray_workers_use_nsight"])
|
||||||
parallel_group.add_argument(
|
parallel_group.add_argument(
|
||||||
'--disable-custom-all-reduce',
|
'--disable-custom-all-reduce',
|
||||||
action='store_true',
|
|
||||||
**parallel_kwargs["disable_custom_all_reduce"])
|
**parallel_kwargs["disable_custom_all_reduce"])
|
||||||
# KV cache arguments
|
# KV cache arguments
|
||||||
parser.add_argument('--block-size',
|
parser.add_argument('--block-size',
|
||||||
@ -502,14 +569,6 @@ class EngineArgs:
|
|||||||
'block manager v2) is now the default. '
|
'block manager v2) is now the default. '
|
||||||
'Setting this flag to True or False'
|
'Setting this flag to True or False'
|
||||||
' has no effect on vLLM behavior.')
|
' has no effect on vLLM behavior.')
|
||||||
parser.add_argument(
|
|
||||||
'--num-lookahead-slots',
|
|
||||||
type=int,
|
|
||||||
default=EngineArgs.num_lookahead_slots,
|
|
||||||
help='Experimental scheduling config necessary for '
|
|
||||||
'speculative decoding. This will be replaced by '
|
|
||||||
'speculative config in the future; it is present '
|
|
||||||
'to enable correctness tests until then.')
|
|
||||||
|
|
||||||
parser.add_argument('--seed',
|
parser.add_argument('--seed',
|
||||||
type=int,
|
type=int,
|
||||||
@ -552,36 +611,6 @@ class EngineArgs:
|
|||||||
default=None,
|
default=None,
|
||||||
help='If specified, ignore GPU profiling result and use this number'
|
help='If specified, ignore GPU profiling result and use this number'
|
||||||
' of GPU blocks. Used for testing preemption.')
|
' of GPU blocks. Used for testing preemption.')
|
||||||
parser.add_argument('--max-num-batched-tokens',
|
|
||||||
type=int,
|
|
||||||
default=EngineArgs.max_num_batched_tokens,
|
|
||||||
help='Maximum number of batched tokens per '
|
|
||||||
'iteration.')
|
|
||||||
parser.add_argument(
|
|
||||||
"--max-num-partial-prefills",
|
|
||||||
type=int,
|
|
||||||
default=EngineArgs.max_num_partial_prefills,
|
|
||||||
help="For chunked prefill, the max number of concurrent \
|
|
||||||
partial prefills.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--max-long-partial-prefills",
|
|
||||||
type=int,
|
|
||||||
default=EngineArgs.max_long_partial_prefills,
|
|
||||||
help="For chunked prefill, the maximum number of prompts longer "
|
|
||||||
"than --long-prefill-token-threshold that will be prefilled "
|
|
||||||
"concurrently. Setting this less than --max-num-partial-prefills "
|
|
||||||
"will allow shorter prompts to jump the queue in front of longer "
|
|
||||||
"prompts in some cases, improving latency.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--long-prefill-token-threshold",
|
|
||||||
type=float,
|
|
||||||
default=EngineArgs.long_prefill_token_threshold,
|
|
||||||
help="For chunked prefill, a request is considered long if the "
|
|
||||||
"prompt is longer than this number of tokens.")
|
|
||||||
parser.add_argument('--max-num-seqs',
|
|
||||||
type=int,
|
|
||||||
default=EngineArgs.max_num_seqs,
|
|
||||||
help='Maximum number of sequences per iteration.')
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--max-logprobs',
|
'--max-logprobs',
|
||||||
type=int,
|
type=int,
|
||||||
@ -594,7 +623,7 @@ class EngineArgs:
|
|||||||
# Quantization settings.
|
# Quantization settings.
|
||||||
parser.add_argument('--quantization',
|
parser.add_argument('--quantization',
|
||||||
'-q',
|
'-q',
|
||||||
type=nullable_str,
|
type=optional_str,
|
||||||
choices=[*QUANTIZATION_METHODS, None],
|
choices=[*QUANTIZATION_METHODS, None],
|
||||||
default=EngineArgs.quantization,
|
default=EngineArgs.quantization,
|
||||||
help='Method used to quantize the weights. If '
|
help='Method used to quantize the weights. If '
|
||||||
@ -658,7 +687,7 @@ class EngineArgs:
|
|||||||
'asynchronous tokenization. Ignored '
|
'asynchronous tokenization. Ignored '
|
||||||
'if tokenizer_pool_size is 0.')
|
'if tokenizer_pool_size is 0.')
|
||||||
parser.add_argument('--tokenizer-pool-extra-config',
|
parser.add_argument('--tokenizer-pool-extra-config',
|
||||||
type=nullable_str,
|
type=optional_str,
|
||||||
default=EngineArgs.tokenizer_pool_extra_config,
|
default=EngineArgs.tokenizer_pool_extra_config,
|
||||||
help='Extra config for tokenizer pool. '
|
help='Extra config for tokenizer pool. '
|
||||||
'This should be a JSON string that will be '
|
'This should be a JSON string that will be '
|
||||||
@ -721,7 +750,7 @@ class EngineArgs:
|
|||||||
'base model dtype.'))
|
'base model dtype.'))
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--long-lora-scaling-factors',
|
'--long-lora-scaling-factors',
|
||||||
type=nullable_str,
|
type=optional_str,
|
||||||
default=EngineArgs.long_lora_scaling_factors,
|
default=EngineArgs.long_lora_scaling_factors,
|
||||||
help=('Specify multiple scaling factors (which can '
|
help=('Specify multiple scaling factors (which can '
|
||||||
'be different from base model scaling factor '
|
'be different from base model scaling factor '
|
||||||
@ -766,28 +795,6 @@ class EngineArgs:
|
|||||||
help=('Maximum number of forward steps per '
|
help=('Maximum number of forward steps per '
|
||||||
'scheduler call.'))
|
'scheduler call.'))
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
'--multi-step-stream-outputs',
|
|
||||||
action=StoreBoolean,
|
|
||||||
default=EngineArgs.multi_step_stream_outputs,
|
|
||||||
nargs="?",
|
|
||||||
const="True",
|
|
||||||
help='If False, then multi-step will stream outputs at the end '
|
|
||||||
'of all steps')
|
|
||||||
parser.add_argument(
|
|
||||||
'--scheduler-delay-factor',
|
|
||||||
type=float,
|
|
||||||
default=EngineArgs.scheduler_delay_factor,
|
|
||||||
help='Apply a delay (of delay factor multiplied by previous '
|
|
||||||
'prompt latency) before scheduling next prompt.')
|
|
||||||
parser.add_argument(
|
|
||||||
'--enable-chunked-prefill',
|
|
||||||
action=StoreBoolean,
|
|
||||||
default=EngineArgs.enable_chunked_prefill,
|
|
||||||
nargs="?",
|
|
||||||
const="True",
|
|
||||||
help='If set, the prefill requests can be chunked based on the '
|
|
||||||
'max_num_batched_tokens.')
|
|
||||||
parser.add_argument('--speculative-config',
|
parser.add_argument('--speculative-config',
|
||||||
type=json.loads,
|
type=json.loads,
|
||||||
default=None,
|
default=None,
|
||||||
@ -863,22 +870,43 @@ class EngineArgs:
|
|||||||
help="Disable async output processing. This may result in "
|
help="Disable async output processing. This may result in "
|
||||||
"lower performance.")
|
"lower performance.")
|
||||||
|
|
||||||
parser.add_argument(
|
# Scheduler arguments
|
||||||
'--scheduling-policy',
|
scheduler_kwargs = get_kwargs(SchedulerConfig)
|
||||||
choices=['fcfs', 'priority'],
|
scheduler_group = parser.add_argument_group(
|
||||||
default="fcfs",
|
title="SchedulerConfig",
|
||||||
help='The scheduling policy to use. "fcfs" (first come first served'
|
description=SchedulerConfig.__doc__,
|
||||||
', i.e. requests are handled in order of arrival; default) '
|
)
|
||||||
'or "priority" (requests are handled based on given '
|
scheduler_group.add_argument(
|
||||||
'priority (lower value means earlier handling) and time of '
|
'--max-num-batched-tokens',
|
||||||
'arrival deciding any ties).')
|
**scheduler_kwargs["max_num_batched_tokens"])
|
||||||
|
scheduler_group.add_argument('--max-num-seqs',
|
||||||
parser.add_argument(
|
**scheduler_kwargs["max_num_seqs"])
|
||||||
'--scheduler-cls',
|
scheduler_group.add_argument(
|
||||||
default=EngineArgs.scheduler_cls,
|
"--max-num-partial-prefills",
|
||||||
help='The scheduler class to use. "vllm.core.scheduler.Scheduler" '
|
**scheduler_kwargs["max_num_partial_prefills"])
|
||||||
'is the default scheduler. Can be a class directly or the path to '
|
scheduler_group.add_argument(
|
||||||
'a class of form "mod.custom_class".')
|
"--max-long-partial-prefills",
|
||||||
|
**scheduler_kwargs["max_long_partial_prefills"])
|
||||||
|
scheduler_group.add_argument(
|
||||||
|
"--long-prefill-token-threshold",
|
||||||
|
**scheduler_kwargs["long_prefill_token_threshold"])
|
||||||
|
scheduler_group.add_argument('--num-lookahead-slots',
|
||||||
|
**scheduler_kwargs["num_lookahead_slots"])
|
||||||
|
scheduler_group.add_argument('--scheduler-delay-factor',
|
||||||
|
**scheduler_kwargs["delay_factor"])
|
||||||
|
scheduler_group.add_argument(
|
||||||
|
'--enable-chunked-prefill',
|
||||||
|
**scheduler_kwargs["enable_chunked_prefill"])
|
||||||
|
scheduler_group.add_argument(
|
||||||
|
'--multi-step-stream-outputs',
|
||||||
|
**scheduler_kwargs["multi_step_stream_outputs"])
|
||||||
|
scheduler_group.add_argument('--scheduling-policy',
|
||||||
|
**scheduler_kwargs["policy"])
|
||||||
|
scheduler_group.add_argument(
|
||||||
|
"--disable-chunked-mm-input",
|
||||||
|
**scheduler_kwargs["disable_chunked_mm_input"])
|
||||||
|
parser.add_argument('--scheduler-cls',
|
||||||
|
**scheduler_kwargs["scheduler_cls"])
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--override-neuron-config',
|
'--override-neuron-config',
|
||||||
@ -930,7 +958,7 @@ class EngineArgs:
|
|||||||
'class without changing the existing functions.')
|
'class without changing the existing functions.')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--generation-config",
|
"--generation-config",
|
||||||
type=nullable_str,
|
type=optional_str,
|
||||||
default="auto",
|
default="auto",
|
||||||
help="The folder path to the generation config. "
|
help="The folder path to the generation config. "
|
||||||
"Defaults to 'auto', the generation config will be loaded from "
|
"Defaults to 'auto', the generation config will be loaded from "
|
||||||
@ -1003,20 +1031,6 @@ class EngineArgs:
|
|||||||
"Note that even if this is set to False, cascade attention will be "
|
"Note that even if this is set to False, cascade attention will be "
|
||||||
"only used when the heuristic tells that it's beneficial.")
|
"only used when the heuristic tells that it's beneficial.")
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--disable-chunked-mm-input",
|
|
||||||
action=StoreBoolean,
|
|
||||||
default=EngineArgs.disable_chunked_mm_input,
|
|
||||||
nargs="?",
|
|
||||||
const="True",
|
|
||||||
help="Disable multimodal input chunking attention for V1. "
|
|
||||||
"If set to true and chunked prefill is enabled, we do not want to"
|
|
||||||
" partially schedule a multimodal item. This ensures that if a "
|
|
||||||
"request has a mixed prompt (like text tokens TTTT followed by "
|
|
||||||
"image tokens IIIIIIIIII) where only some image tokens can be "
|
|
||||||
"scheduled (like TTTTIIIII, leaving IIIII), it will be scheduled "
|
|
||||||
"as TTTT in one step and IIIIIIIIII in the next.")
|
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -1370,7 +1384,7 @@ class EngineArgs:
|
|||||||
recommend_to_remove=False)
|
recommend_to_remove=False)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if self.preemption_mode != EngineArgs.preemption_mode:
|
if self.preemption_mode != SchedulerConfig.preemption_mode:
|
||||||
_raise_or_fallback(feature_name="--preemption-mode",
|
_raise_or_fallback(feature_name="--preemption-mode",
|
||||||
recommend_to_remove=True)
|
recommend_to_remove=True)
|
||||||
return False
|
return False
|
||||||
@ -1381,17 +1395,17 @@ class EngineArgs:
|
|||||||
recommend_to_remove=True)
|
recommend_to_remove=True)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if self.scheduling_policy != EngineArgs.scheduling_policy:
|
if self.scheduling_policy != SchedulerConfig.policy:
|
||||||
_raise_or_fallback(feature_name="--scheduling-policy",
|
_raise_or_fallback(feature_name="--scheduling-policy",
|
||||||
recommend_to_remove=False)
|
recommend_to_remove=False)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if self.num_scheduler_steps != EngineArgs.num_scheduler_steps:
|
if self.num_scheduler_steps != SchedulerConfig.num_scheduler_steps:
|
||||||
_raise_or_fallback(feature_name="--num-scheduler-steps",
|
_raise_or_fallback(feature_name="--num-scheduler-steps",
|
||||||
recommend_to_remove=True)
|
recommend_to_remove=True)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if self.scheduler_delay_factor != EngineArgs.scheduler_delay_factor:
|
if self.scheduler_delay_factor != SchedulerConfig.delay_factor:
|
||||||
_raise_or_fallback(feature_name="--scheduler-delay-factor",
|
_raise_or_fallback(feature_name="--scheduler-delay-factor",
|
||||||
recommend_to_remove=True)
|
recommend_to_remove=True)
|
||||||
return False
|
return False
|
||||||
@ -1475,9 +1489,9 @@ class EngineArgs:
|
|||||||
|
|
||||||
# No Concurrent Partial Prefills so far.
|
# No Concurrent Partial Prefills so far.
|
||||||
if (self.max_num_partial_prefills
|
if (self.max_num_partial_prefills
|
||||||
!= EngineArgs.max_num_partial_prefills
|
!= SchedulerConfig.max_num_partial_prefills
|
||||||
or self.max_long_partial_prefills
|
or self.max_long_partial_prefills
|
||||||
!= EngineArgs.max_long_partial_prefills):
|
!= SchedulerConfig.max_long_partial_prefills):
|
||||||
_raise_or_fallback(feature_name="Concurrent Partial Prefill",
|
_raise_or_fallback(feature_name="Concurrent Partial Prefill",
|
||||||
recommend_to_remove=False)
|
recommend_to_remove=False)
|
||||||
return False
|
return False
|
||||||
|
@ -11,7 +11,7 @@ import ssl
|
|||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import Optional, Union, get_args
|
from typing import Optional, Union, get_args
|
||||||
|
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
|
from vllm.engine.arg_utils import AsyncEngineArgs, optional_str
|
||||||
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
|
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
|
||||||
validate_chat_template)
|
validate_chat_template)
|
||||||
from vllm.entrypoints.openai.serving_models import (LoRAModulePath,
|
from vllm.entrypoints.openai.serving_models import (LoRAModulePath,
|
||||||
@ -79,7 +79,7 @@ class PromptAdapterParserAction(argparse.Action):
|
|||||||
|
|
||||||
def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||||
parser.add_argument("--host",
|
parser.add_argument("--host",
|
||||||
type=nullable_str,
|
type=optional_str,
|
||||||
default=None,
|
default=None,
|
||||||
help="Host name.")
|
help="Host name.")
|
||||||
parser.add_argument("--port", type=int, default=8000, help="Port number.")
|
parser.add_argument("--port", type=int, default=8000, help="Port number.")
|
||||||
@ -108,13 +108,13 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
|||||||
default=["*"],
|
default=["*"],
|
||||||
help="Allowed headers.")
|
help="Allowed headers.")
|
||||||
parser.add_argument("--api-key",
|
parser.add_argument("--api-key",
|
||||||
type=nullable_str,
|
type=optional_str,
|
||||||
default=None,
|
default=None,
|
||||||
help="If provided, the server will require this key "
|
help="If provided, the server will require this key "
|
||||||
"to be presented in the header.")
|
"to be presented in the header.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lora-modules",
|
"--lora-modules",
|
||||||
type=nullable_str,
|
type=optional_str,
|
||||||
default=None,
|
default=None,
|
||||||
nargs='+',
|
nargs='+',
|
||||||
action=LoRAParserAction,
|
action=LoRAParserAction,
|
||||||
@ -126,14 +126,14 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
|||||||
"\"base_model_name\": \"id\"}``")
|
"\"base_model_name\": \"id\"}``")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--prompt-adapters",
|
"--prompt-adapters",
|
||||||
type=nullable_str,
|
type=optional_str,
|
||||||
default=None,
|
default=None,
|
||||||
nargs='+',
|
nargs='+',
|
||||||
action=PromptAdapterParserAction,
|
action=PromptAdapterParserAction,
|
||||||
help="Prompt adapter configurations in the format name=path. "
|
help="Prompt adapter configurations in the format name=path. "
|
||||||
"Multiple adapters can be specified.")
|
"Multiple adapters can be specified.")
|
||||||
parser.add_argument("--chat-template",
|
parser.add_argument("--chat-template",
|
||||||
type=nullable_str,
|
type=optional_str,
|
||||||
default=None,
|
default=None,
|
||||||
help="The file path to the chat template, "
|
help="The file path to the chat template, "
|
||||||
"or the template in single-line form "
|
"or the template in single-line form "
|
||||||
@ -151,20 +151,20 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
|||||||
'similar to OpenAI schema. '
|
'similar to OpenAI schema. '
|
||||||
'Example: ``[{"type": "text", "text": "Hello world!"}]``')
|
'Example: ``[{"type": "text", "text": "Hello world!"}]``')
|
||||||
parser.add_argument("--response-role",
|
parser.add_argument("--response-role",
|
||||||
type=nullable_str,
|
type=optional_str,
|
||||||
default="assistant",
|
default="assistant",
|
||||||
help="The role name to return if "
|
help="The role name to return if "
|
||||||
"``request.add_generation_prompt=true``.")
|
"``request.add_generation_prompt=true``.")
|
||||||
parser.add_argument("--ssl-keyfile",
|
parser.add_argument("--ssl-keyfile",
|
||||||
type=nullable_str,
|
type=optional_str,
|
||||||
default=None,
|
default=None,
|
||||||
help="The file path to the SSL key file.")
|
help="The file path to the SSL key file.")
|
||||||
parser.add_argument("--ssl-certfile",
|
parser.add_argument("--ssl-certfile",
|
||||||
type=nullable_str,
|
type=optional_str,
|
||||||
default=None,
|
default=None,
|
||||||
help="The file path to the SSL cert file.")
|
help="The file path to the SSL cert file.")
|
||||||
parser.add_argument("--ssl-ca-certs",
|
parser.add_argument("--ssl-ca-certs",
|
||||||
type=nullable_str,
|
type=optional_str,
|
||||||
default=None,
|
default=None,
|
||||||
help="The CA certificates file.")
|
help="The CA certificates file.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -180,13 +180,13 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
|||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--root-path",
|
"--root-path",
|
||||||
type=nullable_str,
|
type=optional_str,
|
||||||
default=None,
|
default=None,
|
||||||
help="FastAPI root_path when app is behind a path based routing proxy."
|
help="FastAPI root_path when app is behind a path based routing proxy."
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--middleware",
|
"--middleware",
|
||||||
type=nullable_str,
|
type=optional_str,
|
||||||
action="append",
|
action="append",
|
||||||
default=[],
|
default=[],
|
||||||
help="Additional ASGI middleware to apply to the app. "
|
help="Additional ASGI middleware to apply to the app. "
|
||||||
|
@ -12,7 +12,7 @@ import torch
|
|||||||
from prometheus_client import start_http_server
|
from prometheus_client import start_http_server
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
|
from vllm.engine.arg_utils import AsyncEngineArgs, optional_str
|
||||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
from vllm.entrypoints.logger import RequestLogger, logger
|
from vllm.entrypoints.logger import RequestLogger, logger
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
@ -61,7 +61,7 @@ def parse_args():
|
|||||||
"to the output URL.",
|
"to the output URL.",
|
||||||
)
|
)
|
||||||
parser.add_argument("--response-role",
|
parser.add_argument("--response-role",
|
||||||
type=nullable_str,
|
type=optional_str,
|
||||||
default="assistant",
|
default="assistant",
|
||||||
help="The role name to return if "
|
help="The role name to return if "
|
||||||
"`request.add_generation_prompt=True`.")
|
"`request.add_generation_prompt=True`.")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user