Fix missing docs and out of sync EngineArgs
(#4219)
Co-authored-by: Harry Mellor <hmellor@oxts.com>
This commit is contained in:
parent
138485a82d
commit
682789d402
@ -11,12 +11,14 @@
|
||||
# documentation root, use os.path.abspath to make it absolute, like shown here.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from typing import List
|
||||
|
||||
from sphinx.ext import autodoc
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
sys.path.append(os.path.abspath("../.."))
|
||||
|
||||
# -- Project information -----------------------------------------------------
|
||||
|
||||
|
@ -5,133 +5,17 @@ Engine Arguments
|
||||
|
||||
Below, you can find an explanation of every engine argument for vLLM:
|
||||
|
||||
.. option:: --model <model_name_or_path>
|
||||
|
||||
Name or path of the huggingface model to use.
|
||||
|
||||
.. option:: --tokenizer <tokenizer_name_or_path>
|
||||
|
||||
Name or path of the huggingface tokenizer to use.
|
||||
|
||||
.. option:: --revision <revision>
|
||||
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version.
|
||||
|
||||
.. option:: --tokenizer-revision <revision>
|
||||
|
||||
The specific tokenizer version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version.
|
||||
|
||||
.. option:: --tokenizer-mode {auto,slow}
|
||||
|
||||
The tokenizer mode.
|
||||
|
||||
* "auto" will use the fast tokenizer if available.
|
||||
* "slow" will always use the slow tokenizer.
|
||||
|
||||
.. option:: --trust-remote-code
|
||||
|
||||
Trust remote code from huggingface.
|
||||
|
||||
.. option:: --download-dir <directory>
|
||||
|
||||
Directory to download and load the weights, default to the default cache dir of huggingface.
|
||||
|
||||
.. option:: --load-format {auto,pt,safetensors,npcache,dummy,tensorizer}
|
||||
|
||||
The format of the model weights to load.
|
||||
|
||||
* "auto" will try to load the weights in the safetensors format and fall back to the pytorch bin format if safetensors format is not available.
|
||||
* "pt" will load the weights in the pytorch bin format.
|
||||
* "safetensors" will load the weights in the safetensors format.
|
||||
* "npcache" will load the weights in pytorch format and store a numpy cache to speed up the loading.
|
||||
* "dummy" will initialize the weights with random values, mainly for profiling.
|
||||
* "tensorizer" will load serialized weights using `CoreWeave's Tensorizer model deserializer. <https://github.com/coreweave/tensorizer>`_ See `examples/tensorize_vllm_model.py <https://github.com/vllm-project/vllm/blob/main/examples/tensorize_vllm_model.py>`_ to serialize a vLLM model, and for more information.
|
||||
|
||||
.. option:: --dtype {auto,half,float16,bfloat16,float,float32}
|
||||
|
||||
Data type for model weights and activations.
|
||||
|
||||
* "auto" will use FP16 precision for FP32 and FP16 models, and BF16 precision for BF16 models.
|
||||
* "half" for FP16. Recommended for AWQ quantization.
|
||||
* "float16" is the same as "half".
|
||||
* "bfloat16" for a balance between precision and range.
|
||||
* "float" is shorthand for FP32 precision.
|
||||
* "float32" for FP32 precision.
|
||||
|
||||
.. option:: --max-model-len <length>
|
||||
|
||||
Model context length. If unspecified, will be automatically derived from the model config.
|
||||
|
||||
.. option:: --worker-use-ray
|
||||
|
||||
Use Ray for distributed serving, will be automatically set when using more than 1 GPU.
|
||||
|
||||
.. option:: --pipeline-parallel-size (-pp) <size>
|
||||
|
||||
Number of pipeline stages.
|
||||
|
||||
.. option:: --tensor-parallel-size (-tp) <size>
|
||||
|
||||
Number of tensor parallel replicas.
|
||||
|
||||
.. option:: --max-parallel-loading-workers <workers>
|
||||
|
||||
Load model sequentially in multiple batches, to avoid RAM OOM when using tensor parallel and large models.
|
||||
|
||||
.. option:: --block-size {8,16,32}
|
||||
|
||||
Token block size for contiguous chunks of tokens.
|
||||
|
||||
.. option:: --enable-prefix-caching
|
||||
|
||||
Enables automatic prefix caching
|
||||
|
||||
.. option:: --seed <seed>
|
||||
|
||||
Random seed for operations.
|
||||
|
||||
.. option:: --swap-space <size>
|
||||
|
||||
CPU swap space size (GiB) per GPU.
|
||||
|
||||
.. option:: --gpu-memory-utilization <fraction>
|
||||
|
||||
The fraction of GPU memory to be used for the model executor, which can range from 0 to 1.
|
||||
For example, a value of 0.5 would imply 50% GPU memory utilization.
|
||||
If unspecified, will use the default value of 0.9.
|
||||
|
||||
.. option:: --max-num-batched-tokens <tokens>
|
||||
|
||||
Maximum number of batched tokens per iteration.
|
||||
|
||||
.. option:: --max-num-seqs <sequences>
|
||||
|
||||
Maximum number of sequences per iteration.
|
||||
|
||||
.. option:: --max-paddings <paddings>
|
||||
|
||||
Maximum number of paddings in a batch.
|
||||
|
||||
.. option:: --disable-log-stats
|
||||
|
||||
Disable logging statistics.
|
||||
|
||||
.. option:: --quantization (-q) {awq,squeezellm,None}
|
||||
|
||||
Method used to quantize the weights.
|
||||
.. argparse::
|
||||
:module: vllm.engine.arg_utils
|
||||
:func: _engine_args_parser
|
||||
:prog: -m vllm.entrypoints.openai.api_server
|
||||
|
||||
Async Engine Arguments
|
||||
----------------------
|
||||
|
||||
Below are the additional arguments related to the asynchronous engine:
|
||||
|
||||
.. option:: --engine-use-ray
|
||||
|
||||
Use Ray to start the LLM engine in a separate process as the server process.
|
||||
|
||||
.. option:: --disable-log-requests
|
||||
|
||||
Disable logging requests.
|
||||
|
||||
.. option:: --max-log-len
|
||||
|
||||
Max number of prompt characters or prompt ID numbers being printed in log. Defaults to unlimited.
|
||||
.. argparse::
|
||||
:module: vllm.engine.arg_utils
|
||||
:func: _async_engine_args_parser
|
||||
:prog: -m vllm.entrypoints.openai.api_server
|
@ -82,57 +82,55 @@ class EngineArgs:
|
||||
parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
||||
"""Shared CLI arguments for vLLM engine."""
|
||||
|
||||
# NOTE: If you update any of the arguments below, please also
|
||||
# make sure to update docs/source/models/engine_args.rst
|
||||
|
||||
# Model arguments
|
||||
parser.add_argument(
|
||||
'--model',
|
||||
type=str,
|
||||
default='facebook/opt-125m',
|
||||
help='name or path of the huggingface model to use')
|
||||
help='Name or path of the huggingface model to use.')
|
||||
parser.add_argument(
|
||||
'--tokenizer',
|
||||
type=str,
|
||||
default=EngineArgs.tokenizer,
|
||||
help='name or path of the huggingface tokenizer to use')
|
||||
help='Name or path of the huggingface tokenizer to use.')
|
||||
parser.add_argument(
|
||||
'--revision',
|
||||
type=str,
|
||||
default=None,
|
||||
help='the specific model version to use. It can be a branch '
|
||||
help='The specific model version to use. It can be a branch '
|
||||
'name, a tag name, or a commit id. If unspecified, will use '
|
||||
'the default version.')
|
||||
parser.add_argument(
|
||||
'--code-revision',
|
||||
type=str,
|
||||
default=None,
|
||||
help='the specific revision to use for the model code on '
|
||||
help='The specific revision to use for the model code on '
|
||||
'Hugging Face Hub. It can be a branch name, a tag name, or a '
|
||||
'commit id. If unspecified, will use the default version.')
|
||||
parser.add_argument(
|
||||
'--tokenizer-revision',
|
||||
type=str,
|
||||
default=None,
|
||||
help='the specific tokenizer version to use. It can be a branch '
|
||||
help='The specific tokenizer version to use. It can be a branch '
|
||||
'name, a tag name, or a commit id. If unspecified, will use '
|
||||
'the default version.')
|
||||
parser.add_argument('--tokenizer-mode',
|
||||
type=str,
|
||||
default=EngineArgs.tokenizer_mode,
|
||||
choices=['auto', 'slow'],
|
||||
help='tokenizer mode. "auto" will use the fast '
|
||||
'tokenizer if available, and "slow" will '
|
||||
'always use the slow tokenizer.')
|
||||
parser.add_argument(
|
||||
'--tokenizer-mode',
|
||||
type=str,
|
||||
default=EngineArgs.tokenizer_mode,
|
||||
choices=['auto', 'slow'],
|
||||
help='The tokenizer mode.\n\n* "auto" will use the '
|
||||
'fast tokenizer if available.\n* "slow" will '
|
||||
'always use the slow tokenizer.')
|
||||
parser.add_argument('--trust-remote-code',
|
||||
action='store_true',
|
||||
help='trust remote code from huggingface')
|
||||
help='Trust remote code from huggingface.')
|
||||
parser.add_argument('--download-dir',
|
||||
type=str,
|
||||
default=EngineArgs.download_dir,
|
||||
help='directory to download and load the weights, '
|
||||
help='Directory to download and load the weights, '
|
||||
'default to the default cache dir of '
|
||||
'huggingface')
|
||||
'huggingface.')
|
||||
parser.add_argument(
|
||||
'--load-format',
|
||||
type=str,
|
||||
@ -140,19 +138,19 @@ class EngineArgs:
|
||||
choices=[
|
||||
'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer'
|
||||
],
|
||||
help='The format of the model weights to load. '
|
||||
'"auto" will try to load the weights in the safetensors format '
|
||||
help='The format of the model weights to load.\n\n'
|
||||
'* "auto" will try to load the weights in the safetensors format '
|
||||
'and fall back to the pytorch bin format if safetensors format '
|
||||
'is not available. '
|
||||
'"pt" will load the weights in the pytorch bin format. '
|
||||
'"safetensors" will load the weights in the safetensors format. '
|
||||
'"npcache" will load the weights in pytorch format and store '
|
||||
'a numpy cache to speed up the loading. '
|
||||
'"dummy" will initialize the weights with random values, '
|
||||
'which is mainly for profiling.'
|
||||
'"tensorizer" will load the weights using tensorizer from CoreWeave'
|
||||
'which assumes tensorizer_uri is set to the location of the '
|
||||
'serialized weights.')
|
||||
'is not available.\n'
|
||||
'* "pt" will load the weights in the pytorch bin format.\n'
|
||||
'* "safetensors" will load the weights in the safetensors format.\n'
|
||||
'* "npcache" will load the weights in pytorch format and store '
|
||||
'a numpy cache to speed up the loading.\n'
|
||||
'* "dummy" will initialize the weights with random values, '
|
||||
'which is mainly for profiling.\n'
|
||||
'* "tensorizer" will load the weights using tensorizer from '
|
||||
'CoreWeave which assumes tensorizer_uri is set to the location of '
|
||||
'the serialized weights.')
|
||||
parser.add_argument(
|
||||
'--dtype',
|
||||
type=str,
|
||||
@ -160,10 +158,14 @@ class EngineArgs:
|
||||
choices=[
|
||||
'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
|
||||
],
|
||||
help='data type for model weights and activations. '
|
||||
'The "auto" option will use FP16 precision '
|
||||
'for FP32 and FP16 models, and BF16 precision '
|
||||
'for BF16 models.')
|
||||
help='Data type for model weights and activations.\n\n'
|
||||
'* "auto" will use FP16 precision for FP32 and FP16 models, and '
|
||||
'BF16 precision for BF16 models.\n'
|
||||
'* "half" for FP16. Recommended for AWQ quantization.\n'
|
||||
'* "float16" is the same as "half".\n'
|
||||
'* "bfloat16" for a balance between precision and range.\n'
|
||||
'* "float" is shorthand for FP32 precision.\n'
|
||||
'* "float32" for FP32 precision.')
|
||||
parser.add_argument(
|
||||
'--kv-cache-dtype',
|
||||
type=str,
|
||||
@ -172,7 +174,7 @@ class EngineArgs:
|
||||
help='Data type for kv cache storage. If "auto", will use model '
|
||||
'data type. FP8_E5M2 (without scaling) is only supported on cuda '
|
||||
'version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
|
||||
'supported for common inference criteria. ')
|
||||
'supported for common inference criteria.')
|
||||
parser.add_argument(
|
||||
'--quantization-param-path',
|
||||
type=str,
|
||||
@ -183,58 +185,59 @@ class EngineArgs:
|
||||
'default to 1.0, which may cause accuracy issues. '
|
||||
'FP8_E5M2 (without scaling) is only supported on cuda version'
|
||||
'greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
|
||||
'supported for common inference criteria. ')
|
||||
'supported for common inference criteria.')
|
||||
parser.add_argument('--max-model-len',
|
||||
type=int,
|
||||
default=EngineArgs.max_model_len,
|
||||
help='model context length. If unspecified, '
|
||||
'will be automatically derived from the model.')
|
||||
help='Model context length. If unspecified, will '
|
||||
'be automatically derived from the model config.')
|
||||
parser.add_argument(
|
||||
'--guided-decoding-backend',
|
||||
type=str,
|
||||
default='outlines',
|
||||
choices=['outlines', 'lm-format-enforcer'],
|
||||
help='Which engine will be used for guided decoding'
|
||||
' (JSON schema / regex etc)')
|
||||
' (JSON schema / regex etc).')
|
||||
# Parallel arguments
|
||||
parser.add_argument('--worker-use-ray',
|
||||
action='store_true',
|
||||
help='use Ray for distributed serving, will be '
|
||||
'automatically set when using more than 1 GPU')
|
||||
help='Use Ray for distributed serving, will be '
|
||||
'automatically set when using more than 1 GPU.')
|
||||
parser.add_argument('--pipeline-parallel-size',
|
||||
'-pp',
|
||||
type=int,
|
||||
default=EngineArgs.pipeline_parallel_size,
|
||||
help='number of pipeline stages')
|
||||
help='Number of pipeline stages.')
|
||||
parser.add_argument('--tensor-parallel-size',
|
||||
'-tp',
|
||||
type=int,
|
||||
default=EngineArgs.tensor_parallel_size,
|
||||
help='number of tensor parallel replicas')
|
||||
help='Number of tensor parallel replicas.')
|
||||
parser.add_argument(
|
||||
'--max-parallel-loading-workers',
|
||||
type=int,
|
||||
default=EngineArgs.max_parallel_loading_workers,
|
||||
help='load model sequentially in multiple batches, '
|
||||
help='Load model sequentially in multiple batches, '
|
||||
'to avoid RAM OOM when using tensor '
|
||||
'parallel and large models')
|
||||
'parallel and large models.')
|
||||
parser.add_argument(
|
||||
'--ray-workers-use-nsight',
|
||||
action='store_true',
|
||||
help='If specified, use nsight to profile ray workers')
|
||||
help='If specified, use nsight to profile Ray workers.')
|
||||
# KV cache arguments
|
||||
parser.add_argument('--block-size',
|
||||
type=int,
|
||||
default=EngineArgs.block_size,
|
||||
choices=[8, 16, 32, 128],
|
||||
help='token block size')
|
||||
help='Token block size for contiguous chunks of '
|
||||
'tokens.')
|
||||
|
||||
parser.add_argument('--enable-prefix-caching',
|
||||
action='store_true',
|
||||
help='Enables automatic prefix caching')
|
||||
help='Enables automatic prefix caching.')
|
||||
parser.add_argument('--use-v2-block-manager',
|
||||
action='store_true',
|
||||
help='Use BlockSpaceMangerV2')
|
||||
help='Use BlockSpaceMangerV2.')
|
||||
parser.add_argument(
|
||||
'--num-lookahead-slots',
|
||||
type=int,
|
||||
@ -247,18 +250,19 @@ class EngineArgs:
|
||||
parser.add_argument('--seed',
|
||||
type=int,
|
||||
default=EngineArgs.seed,
|
||||
help='random seed')
|
||||
help='Random seed for operations.')
|
||||
parser.add_argument('--swap-space',
|
||||
type=int,
|
||||
default=EngineArgs.swap_space,
|
||||
help='CPU swap space size (GiB) per GPU')
|
||||
help='CPU swap space size (GiB) per GPU.')
|
||||
parser.add_argument(
|
||||
'--gpu-memory-utilization',
|
||||
type=float,
|
||||
default=EngineArgs.gpu_memory_utilization,
|
||||
help='the fraction of GPU memory to be used for '
|
||||
'the model executor, which can range from 0 to 1.'
|
||||
'If unspecified, will use the default value of 0.9.')
|
||||
help='The fraction of GPU memory to be used for the model '
|
||||
'executor, which can range from 0 to 1. For example, a value of '
|
||||
'0.5 would imply 50%% GPU memory utilization. If unspecified, '
|
||||
'will use the default value of 0.9.')
|
||||
parser.add_argument(
|
||||
'--num-gpu-blocks-override',
|
||||
type=int,
|
||||
@ -268,21 +272,21 @@ class EngineArgs:
|
||||
parser.add_argument('--max-num-batched-tokens',
|
||||
type=int,
|
||||
default=EngineArgs.max_num_batched_tokens,
|
||||
help='maximum number of batched tokens per '
|
||||
'iteration')
|
||||
help='Maximum number of batched tokens per '
|
||||
'iteration.')
|
||||
parser.add_argument('--max-num-seqs',
|
||||
type=int,
|
||||
default=EngineArgs.max_num_seqs,
|
||||
help='maximum number of sequences per iteration')
|
||||
help='Maximum number of sequences per iteration.')
|
||||
parser.add_argument(
|
||||
'--max-logprobs',
|
||||
type=int,
|
||||
default=EngineArgs.max_logprobs,
|
||||
help=('max number of log probs to return logprobs is specified in'
|
||||
' SamplingParams'))
|
||||
help=('Max number of log probs to return logprobs is specified in'
|
||||
' SamplingParams.'))
|
||||
parser.add_argument('--disable-log-stats',
|
||||
action='store_true',
|
||||
help='disable logging statistics')
|
||||
help='Disable logging statistics.')
|
||||
# Quantization settings.
|
||||
parser.add_argument('--quantization',
|
||||
'-q',
|
||||
@ -303,13 +307,13 @@ class EngineArgs:
|
||||
parser.add_argument('--max-context-len-to-capture',
|
||||
type=int,
|
||||
default=EngineArgs.max_context_len_to_capture,
|
||||
help='maximum context length covered by CUDA '
|
||||
help='Maximum context length covered by CUDA '
|
||||
'graphs. When a sequence has context length '
|
||||
'larger than this, we fall back to eager mode.')
|
||||
parser.add_argument('--disable-custom-all-reduce',
|
||||
action='store_true',
|
||||
default=EngineArgs.disable_custom_all_reduce,
|
||||
help='See ParallelConfig')
|
||||
help='See ParallelConfig.')
|
||||
parser.add_argument('--tokenizer-pool-size',
|
||||
type=int,
|
||||
default=EngineArgs.tokenizer_pool_size,
|
||||
@ -402,7 +406,7 @@ class EngineArgs:
|
||||
'--enable-chunked-prefill',
|
||||
action='store_true',
|
||||
help='If set, the prefill requests can be chunked based on the '
|
||||
'max_num_batched_tokens')
|
||||
'max_num_batched_tokens.')
|
||||
|
||||
parser.add_argument(
|
||||
'--speculative-model',
|
||||
@ -416,7 +420,7 @@ class EngineArgs:
|
||||
type=int,
|
||||
default=None,
|
||||
help='The number of speculative tokens to sample from '
|
||||
'the draft model in speculative decoding')
|
||||
'the draft model in speculative decoding.')
|
||||
|
||||
parser.add_argument('--model-loader-extra-config',
|
||||
type=str,
|
||||
@ -534,20 +538,31 @@ class AsyncEngineArgs(EngineArgs):
|
||||
max_log_len: Optional[int] = None
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(
|
||||
parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
def add_cli_args(parser: argparse.ArgumentParser,
|
||||
async_args_only: bool = False) -> argparse.ArgumentParser:
|
||||
if not async_args_only:
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
parser.add_argument('--engine-use-ray',
|
||||
action='store_true',
|
||||
help='use Ray to start the LLM engine in a '
|
||||
help='Use Ray to start the LLM engine in a '
|
||||
'separate process as the server process.')
|
||||
parser.add_argument('--disable-log-requests',
|
||||
action='store_true',
|
||||
help='disable logging requests')
|
||||
help='Disable logging requests.')
|
||||
parser.add_argument('--max-log-len',
|
||||
type=int,
|
||||
default=None,
|
||||
help='max number of prompt characters or prompt '
|
||||
'ID numbers being printed in log. '
|
||||
'Default: unlimited.')
|
||||
help='Max number of prompt characters or prompt '
|
||||
'ID numbers being printed in log.'
|
||||
'\n\nDefault: Unlimited')
|
||||
return parser
|
||||
|
||||
|
||||
# These functions are used by sphinx to build the documentation
|
||||
def _engine_args_parser():
|
||||
return EngineArgs.add_cli_args(argparse.ArgumentParser())
|
||||
|
||||
|
||||
def _async_engine_args_parser():
|
||||
return AsyncEngineArgs.add_cli_args(argparse.ArgumentParser(),
|
||||
async_args_only=True)
|
||||
|
Loading…
x
Reference in New Issue
Block a user