[V1][Spec Decode] Remove deprecated spec decode config params (#15466)

Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
This commit is contained in:
shangmingc 2025-04-01 00:19:35 +08:00 committed by GitHub
parent 09e974d483
commit 239b7befdd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 125 additions and 220 deletions

View File

@ -63,10 +63,12 @@
"model": "meta-llama/Meta-Llama-3.1-70B-Instruct", "model": "meta-llama/Meta-Llama-3.1-70B-Instruct",
"disable_log_requests": "", "disable_log_requests": "",
"tensor_parallel_size": 4, "tensor_parallel_size": 4,
"swap_space": 16, "swap_space": 16,
"speculative_model": "turboderp/Qwama-0.5B-Instruct", "speculative_config": {
"num_speculative_tokens": 4, "model": "turboderp/Qwama-0.5B-Instruct",
"speculative_draft_tensor_parallel_size": 1 "num_speculative_tokens": 4,
"draft_tensor_parallel_size": 1
}
}, },
"client_parameters": { "client_parameters": {
"model": "meta-llama/Meta-Llama-3.1-70B-Instruct", "model": "meta-llama/Meta-Llama-3.1-70B-Instruct",

View File

@ -52,7 +52,7 @@ python -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port 8000 --model
``` ```
:::{warning} :::{warning}
Note: Please use `--speculative_config` to set all configurations related to speculative decoding. The previous method of specifying the model through `--speculative_model` and adding related parameters (e.g., `--num_speculative_tokens`) separately will be deprecated in the next release. Note: Please use `--speculative_config` to set all configurations related to speculative decoding. The previous method of specifying the model through `--speculative_model` and adding related parameters (e.g., `--num_speculative_tokens`) separately has been deprecated now.
::: :::
Then use a client: Then use a client:

View File

@ -69,10 +69,12 @@ llm = LLM(
max_model_len=max_model_len, max_model_len=max_model_len,
max_num_seqs=args.max_num_seqs, max_num_seqs=args.max_num_seqs,
gpu_memory_utilization=0.8, gpu_memory_utilization=0.8,
speculative_model=eagle_dir, speculative_config={
num_speculative_tokens=args.num_spec_tokens, "model": eagle_dir,
speculative_draft_tensor_parallel_size=args.draft_tp, "num_speculative_tokens": args.num_spec_tokens,
speculative_max_model_len=max_model_len, "draft_tensor_parallel_size": args.draft_tp,
"max_model_len": max_model_len,
},
disable_log_stats=False, disable_log_stats=False,
) )

View File

@ -248,8 +248,10 @@ def test_metric_spec_decode(
dtype=dtype, dtype=dtype,
disable_log_stats=False, disable_log_stats=False,
gpu_memory_utilization=0.4, gpu_memory_utilization=0.4,
speculative_model=model, speculative_config={
num_speculative_tokens=k, "model": model,
"num_speculative_tokens": k,
},
) as vllm_model: ) as vllm_model:
# Force log interval to be 0 to catch all metrics. # Force log interval to be 0 to catch all metrics.
@ -300,8 +302,10 @@ def test_metric_spec_decode_interval(
dtype=dtype, dtype=dtype,
disable_log_stats=False, disable_log_stats=False,
gpu_memory_utilization=0.4, gpu_memory_utilization=0.4,
speculative_model=model, speculative_config={
num_speculative_tokens=k, "model": model,
"num_speculative_tokens": k,
},
enforce_eager=True, enforce_eager=True,
) )

View File

@ -54,8 +54,10 @@ def test_can_initialize(model_arch):
model_info.default, model_info.default,
tokenizer=model_info.tokenizer, tokenizer=model_info.tokenizer,
tokenizer_mode=model_info.tokenizer_mode, tokenizer_mode=model_info.tokenizer_mode,
speculative_model=model_info.speculative_model, speculative_config={
num_speculative_tokens=1 if model_info.speculative_model else None, "model": model_info.speculative_model,
"num_speculative_tokens": 1,
} if model_info.speculative_model else None,
trust_remote_code=model_info.trust_remote_code, trust_remote_code=model_info.trust_remote_code,
load_format="dummy", load_format="dummy",
hf_overrides=hf_overrides, hf_overrides=hf_overrides,

View File

@ -3,6 +3,7 @@
tensor parallelism. tensor parallelism.
""" """
import json
from typing import Optional from typing import Optional
import pytest import pytest
@ -28,14 +29,14 @@ from .conftest import run_equality_correctness_test_tp
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs", [
[ [
"--speculative_config", "--speculative_config",
str({ json.dumps({
"model": "JackFram/llama-68m", "model": "JackFram/llama-68m",
"num_speculative_tokens": 3, "num_speculative_tokens": 3,
}), }),
], ],
[ [
"--speculative_config", "--speculative_config",
str({ json.dumps({
"model": "ngram", "model": "ngram",
"num_speculative_tokens": 5, "num_speculative_tokens": 5,
"prompt_lookup_max": 3, "prompt_lookup_max": 3,
@ -88,7 +89,7 @@ def test_target_model_tp_gt_1(common_llm_kwargs, per_test_common_llm_kwargs,
"model, test_llm_kwargs", "model, test_llm_kwargs",
[("JackFram/llama-68m", [ [("JackFram/llama-68m", [
"--speculative_config", "--speculative_config",
str({ json.dumps({
"model": "JackFram/llama-68m", "model": "JackFram/llama-68m",
"num_speculative_tokens": 5, "num_speculative_tokens": 5,
"draft_tensor_parallel_size": 1, "draft_tensor_parallel_size": 1,
@ -96,7 +97,7 @@ def test_target_model_tp_gt_1(common_llm_kwargs, per_test_common_llm_kwargs,
]), ]),
("ibm-granite/granite-3b-code-instruct", [ ("ibm-granite/granite-3b-code-instruct", [
"--speculative_config", "--speculative_config",
str({ json.dumps({
"model": "ibm-granite/granite-3b-code-instruct", "model": "ibm-granite/granite-3b-code-instruct",
"num_speculative_tokens": 5, "num_speculative_tokens": 5,
"draft_tensor_parallel_size": 1, "draft_tensor_parallel_size": 1,
@ -147,20 +148,20 @@ def test_draft_model_tp_lt_target_model_tp2(model, common_llm_kwargs,
@pytest.mark.parametrize("model, test_llm_kwargs", @pytest.mark.parametrize("model, test_llm_kwargs",
[("JackFram/llama-68m", [ [("JackFram/llama-68m", [
"--speculative_config", "--speculative_config",
str({ json.dumps({
"model": "JackFram/llama-68m", "model": "JackFram/llama-68m",
"num_speculative_tokens": 3, "num_speculative_tokens": 3,
}), }),
]), ]),
("JackFram/llama-68m", [ ("JackFram/llama-68m", [
"--speculative_config", "--speculative_config",
str({ json.dumps({
"model": "JackFram/llama-68m", "model": "JackFram/llama-68m",
"num_speculative_tokens": 3, "num_speculative_tokens": 3,
"draft_tensor_parallel_size": 1, "draft_tensor_parallel_size": 1,
}), }),
])]) ])])
@pytest.mark.parametrize("logprobs", [None, 2]) @pytest.mark.parametrize("logprobs", [None])
@pytest.mark.parametrize("batch_size", [2]) @pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seed", [1])
def test_spec_decode_chunked_prefill_tp2(model, common_llm_kwargs, def test_spec_decode_chunked_prefill_tp2(model, common_llm_kwargs,
@ -171,9 +172,68 @@ def test_spec_decode_chunked_prefill_tp2(model, common_llm_kwargs,
"""Verify spec decode works well with same and different TP size for """Verify spec decode works well with same and different TP size for
the draft model with chunked prefill. the draft model with chunked prefill.
""" """
if logprobs: run_equality_correctness_test_tp(model,
test_llm_kwargs.extend( common_llm_kwargs,
["--disable_logprobs_during_spec_decoding", "False"]) per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=32,
seed=seed,
temperature=0.0,
logprobs=logprobs)
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize(
"common_llm_kwargs",
[[
# Skip cuda graph recording for fast test.
"--enforce-eager",
"--tensor_parallel_size",
"2",
# precision
"--dtype",
"bfloat16",
]])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[["--enable-chunked-prefill", "False"],
[
"--enable-chunked-prefill", "True", "--max-num-batched-tokens", "4",
"--max-num-seqs", "4"
]])
@pytest.mark.parametrize("baseline_llm_kwargs", [[]])
@pytest.mark.parametrize("model, test_llm_kwargs",
[("JackFram/llama-68m", [
"--speculative_config",
json.dumps({
"model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
"disable_logprobs": False,
}),
]),
("JackFram/llama-68m", [
"--speculative_config",
json.dumps({
"model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
"draft_tensor_parallel_size": 1,
"disable_logprobs": False,
}),
])])
@pytest.mark.parametrize("logprobs", [2])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_chunked_prefill_tp2_with_logprobs(
model, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, logprobs: Optional[int],
batch_size: int, seed: int):
"""Verify spec decode works well with same and different TP size for
the draft model with chunked prefill.
"""
run_equality_correctness_test_tp(model, run_equality_correctness_test_tp(model,
common_llm_kwargs, common_llm_kwargs,
per_test_common_llm_kwargs, per_test_common_llm_kwargs,

View File

@ -3,6 +3,8 @@
tensor parallelism. tensor parallelism.
""" """
import json
import openai import openai
import pytest import pytest
import torch import torch
@ -33,7 +35,7 @@ SPEC_MODEL = "JackFram/llama-68m"
#TODO(wooyeon): add spec_draft_dp=2 case #TODO(wooyeon): add spec_draft_dp=2 case
[ [
"--speculative_config", "--speculative_config",
str({ json.dumps({
"model": f"{SPEC_MODEL}", "model": f"{SPEC_MODEL}",
"num_speculative_tokens": 5, "num_speculative_tokens": 5,
"draft_tensor_parallel_size": 1, "draft_tensor_parallel_size": 1,
@ -80,7 +82,7 @@ def test_draft_model_tp_lt_target_model_tp4(common_llm_kwargs,
# Artificially limit the draft model max model len; this forces vLLM # Artificially limit the draft model max model len; this forces vLLM
# to skip speculation once the sequences grow beyond 32-k tokens. # to skip speculation once the sequences grow beyond 32-k tokens.
"--speculative_config", "--speculative_config",
str({ json.dumps({
"model": f"{SPEC_MODEL}", "model": f"{SPEC_MODEL}",
"num_speculative_tokens": 5, "num_speculative_tokens": 5,
"max_model_len": 32, "max_model_len": 32,

View File

@ -49,7 +49,9 @@ def test_unsupported_configs(monkeypatch):
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
AsyncEngineArgs( AsyncEngineArgs(
model=MODEL, model=MODEL,
speculative_model=MODEL, speculative_config={
"model": MODEL,
},
).create_engine_config() ).create_engine_config()
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):

View File

@ -2047,14 +2047,13 @@ class SpeculativeConfig:
def __post_init__(self): def __post_init__(self):
# Note: After next release, the method parameter will be used to # Note: "method" is a new parameter that helps to extend the
# specify the speculative method, which helps to extend the # configuration of non-model-based proposers, and the "model" parameter
# configuration of non-model-based proposers, and the model parameter # will be used to set the draft model, eagle head, or additional weight
# will be used when the draft model or head is needed. # when needed. If users do not specify "method", the speculative method
# If users do not specify the method, the speculative method will # will be detected automatically if possible. If the speculative method
# be detected automatically if possible. If the speculative method can # can not be detected, it will be considered as the "draft_model" by
# not be detected, it will be considered as the draft-model-based # default.
# method by default.
if self.model is None and self.num_speculative_tokens is not None: if self.model is None and self.num_speculative_tokens is not None:
# TODO(Shangming): Refactor mtp configuration logic when supporting # TODO(Shangming): Refactor mtp configuration logic when supporting
@ -2069,8 +2068,8 @@ class SpeculativeConfig:
raise ValueError("num_speculative_tokens was provided without " raise ValueError("num_speculative_tokens was provided without "
"speculative model.") "speculative model.")
# Automatically configure the ngram method during configuration # Automatically configure the method for ngram when "model" is used
# refactoring to ensure a smooth transition. # instead of "method"
if self.method is None and (self.model is not None if self.method is None and (self.model is not None
and self.model in ("ngram", "[ngram]")): and self.model in ("ngram", "[ngram]")):
self.method = "ngram" self.method = "ngram"

View File

@ -181,22 +181,7 @@ class EngineArgs:
guided_decoding_backend: str = 'xgrammar' guided_decoding_backend: str = 'xgrammar'
logits_processor_pattern: Optional[str] = None logits_processor_pattern: Optional[str] = None
speculative_config: Optional[Union[str, Dict[str, Any]]] = None speculative_config: Optional[Dict[str, Any]] = None
# TODO(Shangming): Deprecate these out-of-date params after next release
speculative_model: Optional[str] = None
speculative_model_quantization: Optional[str] = None
speculative_draft_tensor_parallel_size: Optional[int] = None
num_speculative_tokens: Optional[int] = None
speculative_disable_mqa_scorer: Optional[bool] = False
speculative_max_model_len: Optional[int] = None
speculative_disable_by_batch_size: Optional[int] = None
ngram_prompt_lookup_max: Optional[int] = None
ngram_prompt_lookup_min: Optional[int] = None
spec_decoding_acceptance_method: str = 'rejection_sampler'
typical_acceptance_sampler_posterior_threshold: Optional[float] = None
typical_acceptance_sampler_posterior_alpha: Optional[float] = None
disable_logprobs_during_spec_decoding: Optional[bool] = None
qlora_adapter_name_or_path: Optional[str] = None qlora_adapter_name_or_path: Optional[str] = None
show_hidden_metrics_for_version: Optional[str] = None show_hidden_metrics_for_version: Optional[str] = None
@ -793,122 +778,10 @@ class EngineArgs:
help='If set, the prefill requests can be chunked based on the ' help='If set, the prefill requests can be chunked based on the '
'max_num_batched_tokens.') 'max_num_batched_tokens.')
parser.add_argument('--speculative-config', parser.add_argument('--speculative-config',
type=nullable_str, type=json.loads,
default=None, default=None,
help='The configurations for speculative decoding.' help='The configurations for speculative decoding.'
' Should be a JSON string.') ' Should be a JSON string.')
parser.add_argument(
'--speculative-model',
type=nullable_str,
default=EngineArgs.speculative_model,
help=
'The name of the draft model to be used in speculative decoding.')
# Quantization settings for speculative model.
parser.add_argument(
'--speculative-model-quantization',
type=nullable_str,
choices=[*QUANTIZATION_METHODS, None],
default=EngineArgs.speculative_model_quantization,
help='Method used to quantize the weights of speculative model. '
'If None, we first check the `quantization_config` '
'attribute in the model config file. If that is '
'None, we assume the model weights are not '
'quantized and use `dtype` to determine the data '
'type of the weights.')
parser.add_argument(
'--num-speculative-tokens',
type=int,
default=EngineArgs.num_speculative_tokens,
help='The number of speculative tokens to sample from '
'the draft model in speculative decoding.')
parser.add_argument(
'--speculative-disable-mqa-scorer',
action='store_true',
help=
'If set to True, the MQA scorer will be disabled in speculative '
' and fall back to batch expansion')
parser.add_argument(
'--speculative-draft-tensor-parallel-size',
'-spec-draft-tp',
type=int,
default=EngineArgs.speculative_draft_tensor_parallel_size,
help='Number of tensor parallel replicas for '
'the draft model in speculative decoding.')
parser.add_argument(
'--speculative-max-model-len',
type=int,
default=EngineArgs.speculative_max_model_len,
help='The maximum sequence length supported by the '
'draft model. Sequences over this length will skip '
'speculation.')
parser.add_argument(
'--speculative-disable-by-batch-size',
type=int,
default=EngineArgs.speculative_disable_by_batch_size,
help='Disable speculative decoding for new incoming requests '
'if the number of enqueue requests is larger than this value.')
parser.add_argument(
'--ngram-prompt-lookup-max',
type=int,
default=EngineArgs.ngram_prompt_lookup_max,
help='Max size of window for ngram prompt lookup in speculative '
'decoding.')
parser.add_argument(
'--ngram-prompt-lookup-min',
type=int,
default=EngineArgs.ngram_prompt_lookup_min,
help='Min size of window for ngram prompt lookup in speculative '
'decoding.')
parser.add_argument(
'--spec-decoding-acceptance-method',
type=str,
default=EngineArgs.spec_decoding_acceptance_method,
choices=['rejection_sampler', 'typical_acceptance_sampler'],
help='Specify the acceptance method to use during draft token '
'verification in speculative decoding. Two types of acceptance '
'routines are supported: '
'1) RejectionSampler which does not allow changing the '
'acceptance rate of draft tokens, '
'2) TypicalAcceptanceSampler which is configurable, allowing for '
'a higher acceptance rate at the cost of lower quality, '
'and vice versa.')
parser.add_argument(
'--typical-acceptance-sampler-posterior-threshold',
type=float,
default=EngineArgs.typical_acceptance_sampler_posterior_threshold,
help='Set the lower bound threshold for the posterior '
'probability of a token to be accepted. This threshold is '
'used by the TypicalAcceptanceSampler to make sampling decisions '
'during speculative decoding.')
parser.add_argument(
'--typical-acceptance-sampler-posterior-alpha',
type=float,
default=EngineArgs.typical_acceptance_sampler_posterior_alpha,
help='A scaling factor for the entropy-based threshold for token '
'acceptance in the TypicalAcceptanceSampler. Typically defaults '
'to sqrt of --typical-acceptance-sampler-posterior-threshold '
'i.e. 0.3')
parser.add_argument(
'--disable-logprobs-during-spec-decoding',
action=StoreBoolean,
default=EngineArgs.disable_logprobs_during_spec_decoding,
nargs="?",
const="True",
help='If set to True, token log probabilities are not returned '
'during speculative decoding. If set to False, log probabilities '
'are returned according to the settings in SamplingParams. If '
'not specified, it defaults to True. Disabling log probabilities '
'during speculative decoding reduces latency by skipping logprob '
'calculation in proposal sampling, target sampling, and after '
'accepted tokens are determined.')
parser.add_argument('--model-loader-extra-config', parser.add_argument('--model-loader-extra-config',
type=nullable_str, type=nullable_str,
@ -1221,58 +1094,14 @@ class EngineArgs:
This function utilizes `speculative_config` to create a This function utilizes `speculative_config` to create a
SpeculativeConfig object. The `speculative_config` can either be SpeculativeConfig object. The `speculative_config` can either be
provided as a JSON string input via CLI arguments or directly as a provided as a JSON string input via CLI arguments or directly as a
dictionary from the engine. If `speculative_config` is not set, this dictionary from the engine.
function will attempt to construct a configuration dictionary using
certain parameters, which are scheduled for deprecation in the next
release. Note that in next releases, `speculative_config` must be
provided, and the deprecated standalone speculative-related parameters
will be removed.
""" """
if self.speculative_config is None: if self.speculative_config is None:
if (self.speculative_model is None return None
and self.num_speculative_tokens is None):
return None
# TODO(Shangming): Deprecate this way of setting SpeculativeConfig,
# only allow '--speculative-config' after next release
logger.warning_once(
"Please use '--speculative-config' to set all configurations "
"related to speculative decoding. The current method of "
"specifying the model through '--speculative-model' and "
"adding related parameters (e.g., '--num-speculative-tokens') "
"separately will be deprecated in the next release.")
spec_config_dict = {
"model": self.speculative_model,
"quantization": self.speculative_model_quantization,
"max_model_len": self.speculative_max_model_len,
"draft_tensor_parallel_size":
self.speculative_draft_tensor_parallel_size,
"num_speculative_tokens": self.num_speculative_tokens,
"disable_mqa_scorer": self.speculative_disable_mqa_scorer,
"disable_by_batch_size":
self.speculative_disable_by_batch_size,
"prompt_lookup_max": self.ngram_prompt_lookup_max,
"prompt_lookup_min": self.ngram_prompt_lookup_min,
"acceptance_method": self.spec_decoding_acceptance_method,
"posterior_threshold":
self.typical_acceptance_sampler_posterior_threshold,
"posterior_alpha":
self.typical_acceptance_sampler_posterior_alpha,
"disable_logprobs": self.disable_logprobs_during_spec_decoding,
}
self.speculative_config = spec_config_dict
else:
if isinstance(self.speculative_config, str):
import ast
self.speculative_config = ast.literal_eval(
self.speculative_config)
# Note(Shangming): These parameters are not obtained from the cli arg # Note(Shangming): These parameters are not obtained from the cli arg
# '--speculative-config' and must be passed in when creating the engine # '--speculative-config' and must be passed in when creating the engine
# config. # config.
assert isinstance(self.speculative_config, dict)
self.speculative_config.update({ self.speculative_config.update({
"target_model_config": target_model_config, "target_model_config": target_model_config,
"target_parallel_config": target_parallel_config, "target_parallel_config": target_parallel_config,
@ -1638,11 +1467,15 @@ class EngineArgs:
return False return False
# Only Ngram speculative decoding so far. # Only Ngram speculative decoding so far.
if (self.speculative_model is not None is_ngram_enabled = False
or self.num_speculative_tokens is not None): if self.speculative_config is not None:
# This is supported but experimental (handled below). # This is supported but experimental (handled below).
if self.speculative_model in ("ngram", "[ngram]"): if (("method" in self.speculative_config
pass and self.speculative_config["method"] in ("ngram", "[ngram]"))
or
("model" in self.speculative_config and
self.speculative_config["model"] in ("ngram", "[ngram]"))):
is_ngram_enabled = True
else: else:
_raise_or_fallback(feature_name="Speculative Decoding", _raise_or_fallback(feature_name="Speculative Decoding",
recommend_to_remove=False) recommend_to_remove=False)
@ -1691,8 +1524,7 @@ class EngineArgs:
return False return False
# ngram is supported on V1, but off by default for now. # ngram is supported on V1, but off by default for now.
if self.speculative_model in ( if is_ngram_enabled and _warn_or_fallback("ngram"):
"ngram", "[ngram]") and _warn_or_fallback("ngram"):
return False return False
# Non-CUDA is supported on V1, but off by default for now. # Non-CUDA is supported on V1, but off by default for now.
@ -1721,7 +1553,7 @@ class EngineArgs:
is_gpu = current_platform.is_cuda() is_gpu = current_platform.is_cuda()
use_sliding_window = (model_config.get_sliding_window() use_sliding_window = (model_config.get_sliding_window()
is not None) is not None)
use_spec_decode = self.speculative_model is not None use_spec_decode = self.speculative_config is not None
if (is_gpu and not use_sliding_window and not use_spec_decode if (is_gpu and not use_sliding_window and not use_spec_decode
and not self.enable_lora and not self.enable_lora