[V1][Spec Decode] Remove deprecated spec decode config params (#15466)

Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
This commit is contained in:
shangmingc 2025-04-01 00:19:35 +08:00 committed by GitHub
parent 09e974d483
commit 239b7befdd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 125 additions and 220 deletions

View File

@ -63,10 +63,12 @@
"model": "meta-llama/Meta-Llama-3.1-70B-Instruct",
"disable_log_requests": "",
"tensor_parallel_size": 4,
"swap_space": 16,
"speculative_model": "turboderp/Qwama-0.5B-Instruct",
"num_speculative_tokens": 4,
"speculative_draft_tensor_parallel_size": 1
"swap_space": 16,
"speculative_config": {
"model": "turboderp/Qwama-0.5B-Instruct",
"num_speculative_tokens": 4,
"draft_tensor_parallel_size": 1
}
},
"client_parameters": {
"model": "meta-llama/Meta-Llama-3.1-70B-Instruct",

View File

@ -52,7 +52,7 @@ python -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port 8000 --model
```
:::{warning}
Note: Please use `--speculative_config` to set all configurations related to speculative decoding. The previous method of specifying the model through `--speculative_model` and adding related parameters (e.g., `--num_speculative_tokens`) separately will be deprecated in the next release.
Note: Please use `--speculative_config` to set all configurations related to speculative decoding. The previous method of specifying the model through `--speculative_model` and adding related parameters (e.g., `--num_speculative_tokens`) separately has been deprecated now.
:::
Then use a client:

View File

@ -69,10 +69,12 @@ llm = LLM(
max_model_len=max_model_len,
max_num_seqs=args.max_num_seqs,
gpu_memory_utilization=0.8,
speculative_model=eagle_dir,
num_speculative_tokens=args.num_spec_tokens,
speculative_draft_tensor_parallel_size=args.draft_tp,
speculative_max_model_len=max_model_len,
speculative_config={
"model": eagle_dir,
"num_speculative_tokens": args.num_spec_tokens,
"draft_tensor_parallel_size": args.draft_tp,
"max_model_len": max_model_len,
},
disable_log_stats=False,
)

View File

@ -248,8 +248,10 @@ def test_metric_spec_decode(
dtype=dtype,
disable_log_stats=False,
gpu_memory_utilization=0.4,
speculative_model=model,
num_speculative_tokens=k,
speculative_config={
"model": model,
"num_speculative_tokens": k,
},
) as vllm_model:
# Force log interval to be 0 to catch all metrics.
@ -300,8 +302,10 @@ def test_metric_spec_decode_interval(
dtype=dtype,
disable_log_stats=False,
gpu_memory_utilization=0.4,
speculative_model=model,
num_speculative_tokens=k,
speculative_config={
"model": model,
"num_speculative_tokens": k,
},
enforce_eager=True,
)

View File

@ -54,8 +54,10 @@ def test_can_initialize(model_arch):
model_info.default,
tokenizer=model_info.tokenizer,
tokenizer_mode=model_info.tokenizer_mode,
speculative_model=model_info.speculative_model,
num_speculative_tokens=1 if model_info.speculative_model else None,
speculative_config={
"model": model_info.speculative_model,
"num_speculative_tokens": 1,
} if model_info.speculative_model else None,
trust_remote_code=model_info.trust_remote_code,
load_format="dummy",
hf_overrides=hf_overrides,

View File

@ -3,6 +3,7 @@
tensor parallelism.
"""
import json
from typing import Optional
import pytest
@ -28,14 +29,14 @@ from .conftest import run_equality_correctness_test_tp
@pytest.mark.parametrize("test_llm_kwargs", [
[
"--speculative_config",
str({
json.dumps({
"model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
}),
],
[
"--speculative_config",
str({
json.dumps({
"model": "ngram",
"num_speculative_tokens": 5,
"prompt_lookup_max": 3,
@ -88,7 +89,7 @@ def test_target_model_tp_gt_1(common_llm_kwargs, per_test_common_llm_kwargs,
"model, test_llm_kwargs",
[("JackFram/llama-68m", [
"--speculative_config",
str({
json.dumps({
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"draft_tensor_parallel_size": 1,
@ -96,7 +97,7 @@ def test_target_model_tp_gt_1(common_llm_kwargs, per_test_common_llm_kwargs,
]),
("ibm-granite/granite-3b-code-instruct", [
"--speculative_config",
str({
json.dumps({
"model": "ibm-granite/granite-3b-code-instruct",
"num_speculative_tokens": 5,
"draft_tensor_parallel_size": 1,
@ -147,20 +148,20 @@ def test_draft_model_tp_lt_target_model_tp2(model, common_llm_kwargs,
@pytest.mark.parametrize("model, test_llm_kwargs",
[("JackFram/llama-68m", [
"--speculative_config",
str({
json.dumps({
"model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
}),
]),
("JackFram/llama-68m", [
"--speculative_config",
str({
json.dumps({
"model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
"draft_tensor_parallel_size": 1,
}),
])])
@pytest.mark.parametrize("logprobs", [None, 2])
@pytest.mark.parametrize("logprobs", [None])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_chunked_prefill_tp2(model, common_llm_kwargs,
@ -171,9 +172,68 @@ def test_spec_decode_chunked_prefill_tp2(model, common_llm_kwargs,
"""Verify spec decode works well with same and different TP size for
the draft model with chunked prefill.
"""
if logprobs:
test_llm_kwargs.extend(
["--disable_logprobs_during_spec_decoding", "False"])
run_equality_correctness_test_tp(model,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=32,
seed=seed,
temperature=0.0,
logprobs=logprobs)
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize(
"common_llm_kwargs",
[[
# Skip cuda graph recording for fast test.
"--enforce-eager",
"--tensor_parallel_size",
"2",
# precision
"--dtype",
"bfloat16",
]])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[["--enable-chunked-prefill", "False"],
[
"--enable-chunked-prefill", "True", "--max-num-batched-tokens", "4",
"--max-num-seqs", "4"
]])
@pytest.mark.parametrize("baseline_llm_kwargs", [[]])
@pytest.mark.parametrize("model, test_llm_kwargs",
[("JackFram/llama-68m", [
"--speculative_config",
json.dumps({
"model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
"disable_logprobs": False,
}),
]),
("JackFram/llama-68m", [
"--speculative_config",
json.dumps({
"model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
"draft_tensor_parallel_size": 1,
"disable_logprobs": False,
}),
])])
@pytest.mark.parametrize("logprobs", [2])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_chunked_prefill_tp2_with_logprobs(
model, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, logprobs: Optional[int],
batch_size: int, seed: int):
"""Verify spec decode works well with same and different TP size for
the draft model with chunked prefill.
"""
run_equality_correctness_test_tp(model,
common_llm_kwargs,
per_test_common_llm_kwargs,

View File

@ -3,6 +3,8 @@
tensor parallelism.
"""
import json
import openai
import pytest
import torch
@ -33,7 +35,7 @@ SPEC_MODEL = "JackFram/llama-68m"
#TODO(wooyeon): add spec_draft_dp=2 case
[
"--speculative_config",
str({
json.dumps({
"model": f"{SPEC_MODEL}",
"num_speculative_tokens": 5,
"draft_tensor_parallel_size": 1,
@ -80,7 +82,7 @@ def test_draft_model_tp_lt_target_model_tp4(common_llm_kwargs,
# Artificially limit the draft model max model len; this forces vLLM
# to skip speculation once the sequences grow beyond 32-k tokens.
"--speculative_config",
str({
json.dumps({
"model": f"{SPEC_MODEL}",
"num_speculative_tokens": 5,
"max_model_len": 32,

View File

@ -49,7 +49,9 @@ def test_unsupported_configs(monkeypatch):
with pytest.raises(NotImplementedError):
AsyncEngineArgs(
model=MODEL,
speculative_model=MODEL,
speculative_config={
"model": MODEL,
},
).create_engine_config()
with pytest.raises(NotImplementedError):

View File

@ -2047,14 +2047,13 @@ class SpeculativeConfig:
def __post_init__(self):
# Note: After next release, the method parameter will be used to
# specify the speculative method, which helps to extend the
# configuration of non-model-based proposers, and the model parameter
# will be used when the draft model or head is needed.
# If users do not specify the method, the speculative method will
# be detected automatically if possible. If the speculative method can
# not be detected, it will be considered as the draft-model-based
# method by default.
# Note: "method" is a new parameter that helps to extend the
# configuration of non-model-based proposers, and the "model" parameter
# will be used to set the draft model, eagle head, or additional weight
# when needed. If users do not specify "method", the speculative method
# will be detected automatically if possible. If the speculative method
# can not be detected, it will be considered as the "draft_model" by
# default.
if self.model is None and self.num_speculative_tokens is not None:
# TODO(Shangming): Refactor mtp configuration logic when supporting
@ -2069,8 +2068,8 @@ class SpeculativeConfig:
raise ValueError("num_speculative_tokens was provided without "
"speculative model.")
# Automatically configure the ngram method during configuration
# refactoring to ensure a smooth transition.
# Automatically configure the method for ngram when "model" is used
# instead of "method"
if self.method is None and (self.model is not None
and self.model in ("ngram", "[ngram]")):
self.method = "ngram"

View File

@ -181,22 +181,7 @@ class EngineArgs:
guided_decoding_backend: str = 'xgrammar'
logits_processor_pattern: Optional[str] = None
speculative_config: Optional[Union[str, Dict[str, Any]]] = None
# TODO(Shangming): Deprecate these out-of-date params after next release
speculative_model: Optional[str] = None
speculative_model_quantization: Optional[str] = None
speculative_draft_tensor_parallel_size: Optional[int] = None
num_speculative_tokens: Optional[int] = None
speculative_disable_mqa_scorer: Optional[bool] = False
speculative_max_model_len: Optional[int] = None
speculative_disable_by_batch_size: Optional[int] = None
ngram_prompt_lookup_max: Optional[int] = None
ngram_prompt_lookup_min: Optional[int] = None
spec_decoding_acceptance_method: str = 'rejection_sampler'
typical_acceptance_sampler_posterior_threshold: Optional[float] = None
typical_acceptance_sampler_posterior_alpha: Optional[float] = None
disable_logprobs_during_spec_decoding: Optional[bool] = None
speculative_config: Optional[Dict[str, Any]] = None
qlora_adapter_name_or_path: Optional[str] = None
show_hidden_metrics_for_version: Optional[str] = None
@ -793,122 +778,10 @@ class EngineArgs:
help='If set, the prefill requests can be chunked based on the '
'max_num_batched_tokens.')
parser.add_argument('--speculative-config',
type=nullable_str,
type=json.loads,
default=None,
help='The configurations for speculative decoding.'
' Should be a JSON string.')
parser.add_argument(
'--speculative-model',
type=nullable_str,
default=EngineArgs.speculative_model,
help=
'The name of the draft model to be used in speculative decoding.')
# Quantization settings for speculative model.
parser.add_argument(
'--speculative-model-quantization',
type=nullable_str,
choices=[*QUANTIZATION_METHODS, None],
default=EngineArgs.speculative_model_quantization,
help='Method used to quantize the weights of speculative model. '
'If None, we first check the `quantization_config` '
'attribute in the model config file. If that is '
'None, we assume the model weights are not '
'quantized and use `dtype` to determine the data '
'type of the weights.')
parser.add_argument(
'--num-speculative-tokens',
type=int,
default=EngineArgs.num_speculative_tokens,
help='The number of speculative tokens to sample from '
'the draft model in speculative decoding.')
parser.add_argument(
'--speculative-disable-mqa-scorer',
action='store_true',
help=
'If set to True, the MQA scorer will be disabled in speculative '
' and fall back to batch expansion')
parser.add_argument(
'--speculative-draft-tensor-parallel-size',
'-spec-draft-tp',
type=int,
default=EngineArgs.speculative_draft_tensor_parallel_size,
help='Number of tensor parallel replicas for '
'the draft model in speculative decoding.')
parser.add_argument(
'--speculative-max-model-len',
type=int,
default=EngineArgs.speculative_max_model_len,
help='The maximum sequence length supported by the '
'draft model. Sequences over this length will skip '
'speculation.')
parser.add_argument(
'--speculative-disable-by-batch-size',
type=int,
default=EngineArgs.speculative_disable_by_batch_size,
help='Disable speculative decoding for new incoming requests '
'if the number of enqueue requests is larger than this value.')
parser.add_argument(
'--ngram-prompt-lookup-max',
type=int,
default=EngineArgs.ngram_prompt_lookup_max,
help='Max size of window for ngram prompt lookup in speculative '
'decoding.')
parser.add_argument(
'--ngram-prompt-lookup-min',
type=int,
default=EngineArgs.ngram_prompt_lookup_min,
help='Min size of window for ngram prompt lookup in speculative '
'decoding.')
parser.add_argument(
'--spec-decoding-acceptance-method',
type=str,
default=EngineArgs.spec_decoding_acceptance_method,
choices=['rejection_sampler', 'typical_acceptance_sampler'],
help='Specify the acceptance method to use during draft token '
'verification in speculative decoding. Two types of acceptance '
'routines are supported: '
'1) RejectionSampler which does not allow changing the '
'acceptance rate of draft tokens, '
'2) TypicalAcceptanceSampler which is configurable, allowing for '
'a higher acceptance rate at the cost of lower quality, '
'and vice versa.')
parser.add_argument(
'--typical-acceptance-sampler-posterior-threshold',
type=float,
default=EngineArgs.typical_acceptance_sampler_posterior_threshold,
help='Set the lower bound threshold for the posterior '
'probability of a token to be accepted. This threshold is '
'used by the TypicalAcceptanceSampler to make sampling decisions '
'during speculative decoding.')
parser.add_argument(
'--typical-acceptance-sampler-posterior-alpha',
type=float,
default=EngineArgs.typical_acceptance_sampler_posterior_alpha,
help='A scaling factor for the entropy-based threshold for token '
'acceptance in the TypicalAcceptanceSampler. Typically defaults '
'to sqrt of --typical-acceptance-sampler-posterior-threshold '
'i.e. 0.3')
parser.add_argument(
'--disable-logprobs-during-spec-decoding',
action=StoreBoolean,
default=EngineArgs.disable_logprobs_during_spec_decoding,
nargs="?",
const="True",
help='If set to True, token log probabilities are not returned '
'during speculative decoding. If set to False, log probabilities '
'are returned according to the settings in SamplingParams. If '
'not specified, it defaults to True. Disabling log probabilities '
'during speculative decoding reduces latency by skipping logprob '
'calculation in proposal sampling, target sampling, and after '
'accepted tokens are determined.')
parser.add_argument('--model-loader-extra-config',
type=nullable_str,
@ -1221,58 +1094,14 @@ class EngineArgs:
This function utilizes `speculative_config` to create a
SpeculativeConfig object. The `speculative_config` can either be
provided as a JSON string input via CLI arguments or directly as a
dictionary from the engine. If `speculative_config` is not set, this
function will attempt to construct a configuration dictionary using
certain parameters, which are scheduled for deprecation in the next
release. Note that in next releases, `speculative_config` must be
provided, and the deprecated standalone speculative-related parameters
will be removed.
dictionary from the engine.
"""
if self.speculative_config is None:
if (self.speculative_model is None
and self.num_speculative_tokens is None):
return None
return None
# TODO(Shangming): Deprecate this way of setting SpeculativeConfig,
# only allow '--speculative-config' after next release
logger.warning_once(
"Please use '--speculative-config' to set all configurations "
"related to speculative decoding. The current method of "
"specifying the model through '--speculative-model' and "
"adding related parameters (e.g., '--num-speculative-tokens') "
"separately will be deprecated in the next release.")
spec_config_dict = {
"model": self.speculative_model,
"quantization": self.speculative_model_quantization,
"max_model_len": self.speculative_max_model_len,
"draft_tensor_parallel_size":
self.speculative_draft_tensor_parallel_size,
"num_speculative_tokens": self.num_speculative_tokens,
"disable_mqa_scorer": self.speculative_disable_mqa_scorer,
"disable_by_batch_size":
self.speculative_disable_by_batch_size,
"prompt_lookup_max": self.ngram_prompt_lookup_max,
"prompt_lookup_min": self.ngram_prompt_lookup_min,
"acceptance_method": self.spec_decoding_acceptance_method,
"posterior_threshold":
self.typical_acceptance_sampler_posterior_threshold,
"posterior_alpha":
self.typical_acceptance_sampler_posterior_alpha,
"disable_logprobs": self.disable_logprobs_during_spec_decoding,
}
self.speculative_config = spec_config_dict
else:
if isinstance(self.speculative_config, str):
import ast
self.speculative_config = ast.literal_eval(
self.speculative_config)
# Note(Shangming): These parameters are not obtained from the cli arg
# '--speculative-config' and must be passed in when creating the engine
# config.
assert isinstance(self.speculative_config, dict)
self.speculative_config.update({
"target_model_config": target_model_config,
"target_parallel_config": target_parallel_config,
@ -1638,11 +1467,15 @@ class EngineArgs:
return False
# Only Ngram speculative decoding so far.
if (self.speculative_model is not None
or self.num_speculative_tokens is not None):
is_ngram_enabled = False
if self.speculative_config is not None:
# This is supported but experimental (handled below).
if self.speculative_model in ("ngram", "[ngram]"):
pass
if (("method" in self.speculative_config
and self.speculative_config["method"] in ("ngram", "[ngram]"))
or
("model" in self.speculative_config and
self.speculative_config["model"] in ("ngram", "[ngram]"))):
is_ngram_enabled = True
else:
_raise_or_fallback(feature_name="Speculative Decoding",
recommend_to_remove=False)
@ -1691,8 +1524,7 @@ class EngineArgs:
return False
# ngram is supported on V1, but off by default for now.
if self.speculative_model in (
"ngram", "[ngram]") and _warn_or_fallback("ngram"):
if is_ngram_enabled and _warn_or_fallback("ngram"):
return False
# Non-CUDA is supported on V1, but off by default for now.
@ -1721,7 +1553,7 @@ class EngineArgs:
is_gpu = current_platform.is_cuda()
use_sliding_window = (model_config.get_sliding_window()
is not None)
use_spec_decode = self.speculative_model is not None
use_spec_decode = self.speculative_config is not None
if (is_gpu and not use_sliding_window and not use_spec_decode
and not self.enable_lora