[V1][Usage] Refactor speculative decoding configuration and tests (#14434)

Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
This commit is contained in:
shangmingc 2025-03-23 13:28:10 +08:00 committed by GitHub
parent 0661cfef7a
commit 50c9636d87
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 1055 additions and 802 deletions

View File

@ -30,8 +30,10 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM(
model="facebook/opt-6.7b",
tensor_parallel_size=1,
speculative_model="facebook/opt-125m",
num_speculative_tokens=5,
speculative_config={
"model": "facebook/opt-125m",
"num_speculative_tokens": 5,
},
)
outputs = llm.generate(prompts, sampling_params)
@ -45,10 +47,14 @@ To perform the same with an online mode launch the server:
```bash
python -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port 8000 --model facebook/opt-6.7b \
--seed 42 -tp 1 --speculative_model facebook/opt-125m \
--num_speculative_tokens 5 --gpu_memory_utilization 0.8
--seed 42 -tp 1 --gpu_memory_utilization 0.8 \
--speculative_config '{"model": "facebook/opt-125m", "num_speculative_tokens": 5}'
```
:::{warning}
Note: Please use `--speculative_config` to set all configurations related to speculative decoding. The previous method of specifying the model through `--speculative_model` and adding related parameters (e.g., `--num_speculative_tokens`) separately will be deprecated in the next release.
:::
Then use a client:
```python
@ -101,9 +107,11 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM(
model="facebook/opt-6.7b",
tensor_parallel_size=1,
speculative_model="[ngram]",
num_speculative_tokens=5,
ngram_prompt_lookup_max=4,
speculative_config={
"method": "ngram",
"num_speculative_tokens": 5,
"prompt_lookup_max": 4,
},
)
outputs = llm.generate(prompts, sampling_params)
@ -131,8 +139,10 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM(
model="meta-llama/Meta-Llama-3.1-70B-Instruct",
tensor_parallel_size=4,
speculative_model="ibm-ai-platform/llama3-70b-accelerator",
speculative_draft_tensor_parallel_size=1,
speculative_config={
"model": "ibm-ai-platform/llama3-70b-accelerator",
"draft_tensor_parallel_size": 1,
},
)
outputs = llm.generate(prompts, sampling_params)
@ -175,8 +185,10 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM(
model="meta-llama/Meta-Llama-3-8B-Instruct",
tensor_parallel_size=4,
speculative_model="yuhuili/EAGLE-LLaMA3-Instruct-8B",
speculative_draft_tensor_parallel_size=1,
speculative_config={
"model": "yuhuili/EAGLE-LLaMA3-Instruct-8B",
"draft_tensor_parallel_size": 1,
},
)
outputs = llm.generate(prompts, sampling_params)
@ -194,11 +206,10 @@ A few important things to consider when using the EAGLE based draft models:
be able to be loaded and used directly by vLLM after [PR 12304](https://github.com/vllm-project/vllm/pull/12304).
If you are using vllm version before [PR 12304](https://github.com/vllm-project/vllm/pull/12304), please use the
[script](https://gist.github.com/abhigoyal1997/1e7a4109ccb7704fbc67f625e86b2d6d) to convert the speculative model,
and specify `speculative_model="path/to/modified/eagle/model"`. If weight-loading problems still occur when using
the latest version of vLLM, please leave a comment or raise an issue.
and specify `"model": "path/to/modified/eagle/model"` in `speculative_config`. If weight-loading problems still occur when using the latest version of vLLM, please leave a comment or raise an issue.
2. The EAGLE based draft models need to be run without tensor parallelism
(i.e. speculative_draft_tensor_parallel_size is set to 1), although
(i.e. draft_tensor_parallel_size is set to 1 in `speculative_config`), although
it is possible to run the main model using tensor parallelism (see example above).
3. When using EAGLE-based speculators with vLLM, the observed speedup is lower than what is

View File

@ -50,7 +50,9 @@ if __name__ == "__main__":
# Create an LLM with spec decoding
llm = LLM(
model="meta-llama/Llama-2-13b-chat-hf",
speculative_model="ibm-ai-platform/llama-13b-accelerator",
speculative_config={
"model": "ibm-ai-platform/llama-13b-accelerator",
},
)
print("With speculation")

View File

@ -56,7 +56,7 @@ def test_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
def maybe_assert_ngram_worker(llm):
# Verify the proposer worker is ngram if ngram is specified.
if (llm.llm_engine.speculative_config is not None
and llm.llm_engine.speculative_config.ngram_prompt_lookup_max > 0):
and llm.llm_engine.speculative_config.method == "ngram"):
from vllm.spec_decode.ngram_worker import NGramWorker
assert isinstance(
llm.llm_engine.model_executor.driver_worker.proposer_worker,

View File

@ -7,28 +7,39 @@ from vllm import SamplingParams
from .conftest import get_output_from_llm_generator
@pytest.mark.parametrize("common_llm_kwargs", [{
"model": "meta-llama/Llama-3.2-1B-Instruct",
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
}])
@pytest.mark.parametrize("common_llm_kwargs",
[{
"model": "meta-llama/Llama-3.2-1B-Instruct",
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[
{
# Speculative max model len > overridden max model len should raise.
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"max_model_len": 129,
},
"max_model_len": 128,
"speculative_max_model_len": 129,
},
{
# Speculative max model len > draft max model len should raise.
# https://huggingface.co/JackFram/llama-68m/blob/3b606af5198a0b26762d589a3ee3d26ee6fa6c85/config.json#L12
"speculative_max_model_len": 2048 + 1,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"max_model_len": 2048 + 1,
},
},
{
# Speculative max model len > target max model len should raise.
# https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct/blob/9213176726f574b556790deb65791e0c5aa438b6/config.json#L18
"speculative_max_model_len": 131072 + 1,
# https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct/blob/9213176726f574b556790deb65791e0c5aa438b6/config.json#L18
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"max_model_len": 131072 + 1,
},
},
])
@pytest.mark.parametrize("test_llm_kwargs", [{}])

View File

@ -57,8 +57,10 @@ PRECISION = "float32"
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize("output_len", [
@ -95,18 +97,19 @@ def test_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs_during_spec_decoding": False,
"disable_logprobs": False,
},
{
"speculative_model": SPEC_MODEL,
}, {
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs_during_spec_decoding": True,
"disable_logprobs": True,
},
])
}])
@pytest.mark.parametrize("output_len", [
128,
])
@ -119,18 +122,19 @@ def test_eagle_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
batch_size: int, output_len: int, seed: int,
logprobs: int):
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
logprobs=logprobs,
prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs[
'disable_logprobs_during_spec_decoding'])
run_equality_correctness_test(
vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
logprobs=logprobs,
prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs["speculative_config"]
["disable_logprobs"])
@pytest.mark.parametrize(
@ -151,8 +155,10 @@ def test_eagle_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize("output_len", [
@ -193,8 +199,10 @@ def test_eagle_e2e_greedy_correctness_cuda_graph(
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize(
@ -236,8 +244,10 @@ def test_eagle_e2e_greedy_correctness_with_preemption(
"test_llm_kwargs",
[
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": k,
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": k,
},
}
# Try a range of num. speculative tokens
for k in range(1, 1 + MAX_SPEC_TOKENS)
@ -277,12 +287,13 @@ def test_eagle_different_k(vllm_runner, common_llm_kwargs,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_disable_by_batch_size": 4
}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_by_batch_size": 4,
},
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
@ -324,8 +335,10 @@ def test_eagle_disable_queue(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "yuhuili/EAGLE-llama2-chat-7B",
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_config": {
"model": "yuhuili/EAGLE-llama2-chat-7B",
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize(
@ -372,8 +385,10 @@ def test_llama2_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "yuhuili/EAGLE-LLaMA3-Instruct-8B",
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_config": {
"model": "yuhuili/EAGLE-LLaMA3-Instruct-8B",
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize(
@ -420,8 +435,10 @@ def test_llama3_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "yuhuili/EAGLE-Qwen2-7B-Instruct",
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_config": {
"model": "yuhuili/EAGLE-Qwen2-7B-Instruct",
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize(

View File

@ -23,8 +23,10 @@ MAIN_MODEL = "JackFram/llama-68m"
[
{
# Identical models.
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@ -57,26 +59,33 @@ def test_spec_decode_cuda_graph(vllm_runner, common_llm_kwargs,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
{
"speculative_model": "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit",
"num_speculative_tokens": 5,
},
])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [])
@pytest.mark.parametrize(
"test_llm_kwargs",
[
# Explicitly specify draft model quantization
{
"speculative_model_quantization": "gptq",
"speculative_config": {
"model": "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit",
"num_speculative_tokens": 5,
"quantization": "gptq",
},
},
# Explicitly specify GPTQ-based draft model to use marlin quantization
{
"speculative_model_quantization": "marlin",
"speculative_config": {
"model": "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit",
"num_speculative_tokens": 5,
"quantization": "marlin",
},
},
# Not explicitly specify draft model quantization
{
"speculative_model_quantization": None,
"speculative_config": {
"model": "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit",
"num_speculative_tokens": 5,
"quantization": None,
},
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@ -107,15 +116,16 @@ def test_speculative_model_quantization_config(vllm_runner, common_llm_kwargs,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_disable_mqa_scorer": True,
}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
"disable_mqa_scorer": True,
},
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
@ -127,7 +137,7 @@ def test_speculative_model_quantization_config(vllm_runner, common_llm_kwargs,
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
output_len: int, seed: int):
"""Verify that ngram speculative decoding generates the same output
"""Verify that speculative decoding generates the same output
with batch expansion scorer and mqa scorer.
"""
run_equality_correctness_test(vllm_runner,

View File

@ -27,18 +27,19 @@ from .conftest import run_equality_correctness_test_tp
@pytest.mark.parametrize("baseline_llm_kwargs", [[]])
@pytest.mark.parametrize("test_llm_kwargs", [
[
"--speculative-model",
"JackFram/llama-68m",
"--num-speculative-tokens",
"3",
"--speculative_config",
str({
"model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
}),
],
[
"--speculative-model",
"[ngram]",
"--num-speculative-tokens",
"5",
"--ngram-prompt-lookup-max",
"3",
"--speculative_config",
str({
"model": "ngram",
"num_speculative_tokens": 5,
"prompt_lookup_max": 3,
}),
],
])
@pytest.mark.parametrize("batch_size", [2])
@ -83,23 +84,24 @@ def test_target_model_tp_gt_1(common_llm_kwargs, per_test_common_llm_kwargs,
]])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [[]])
@pytest.mark.parametrize("baseline_llm_kwargs", [[]])
@pytest.mark.parametrize("model, test_llm_kwargs",
[("JackFram/llama-68m", [
"--speculative-model",
"JackFram/llama-68m",
"--num_speculative-tokens",
"5",
"--speculative-draft-tensor-parallel-size",
"1",
]),
("ibm-granite/granite-3b-code-instruct", [
"--speculative-model",
"ibm-granite/granite-3b-code-instruct",
"--num_speculative-tokens",
"5",
"--speculative-draft-tensor-parallel-size",
"1",
])])
@pytest.mark.parametrize(
"model, test_llm_kwargs",
[("JackFram/llama-68m", [
"--speculative_config",
str({
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"draft_tensor_parallel_size": 1,
}),
]),
("ibm-granite/granite-3b-code-instruct", [
"--speculative_config",
str({
"model": "ibm-granite/granite-3b-code-instruct",
"num_speculative_tokens": 5,
"draft_tensor_parallel_size": 1,
}),
])])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("seed", [1])
def test_draft_model_tp_lt_target_model_tp2(model, common_llm_kwargs,
@ -144,18 +146,19 @@ def test_draft_model_tp_lt_target_model_tp2(model, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [[]])
@pytest.mark.parametrize("model, test_llm_kwargs",
[("JackFram/llama-68m", [
"--speculative-model",
"JackFram/llama-68m",
"--num_speculative-tokens",
"3",
"--speculative_config",
str({
"model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
}),
]),
("JackFram/llama-68m", [
"--speculative-model",
"JackFram/llama-68m",
"--num_speculative-tokens",
"3",
"--speculative-draft-tensor-parallel-size",
"1",
"--speculative_config",
str({
"model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
"draft_tensor_parallel_size": 1,
}),
])])
@pytest.mark.parametrize("logprobs", [None, 2])
@pytest.mark.parametrize("batch_size", [2])

View File

@ -24,12 +24,7 @@ SPEC_MODEL = "JackFram/llama-68m"
"4",
]])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
[
"--speculative-model",
f"{SPEC_MODEL}",
"--num-speculative-tokens",
"5",
],
[],
])
@pytest.mark.parametrize("baseline_llm_kwargs", [[]])
@pytest.mark.parametrize(
@ -37,8 +32,12 @@ SPEC_MODEL = "JackFram/llama-68m"
[
#TODO(wooyeon): add spec_draft_dp=2 case
[
"--speculative-draft-tensor-parallel-size",
"1",
"--speculative_config",
str({
"model": f"{SPEC_MODEL}",
"num_speculative_tokens": 5,
"draft_tensor_parallel_size": 1,
}),
],
])
@pytest.mark.parametrize("batch_size", [2])
@ -78,15 +77,14 @@ def test_draft_model_tp_lt_target_model_tp4(common_llm_kwargs,
"test_llm_kwargs",
[
[
"--speculative-model",
f"{SPEC_MODEL}",
"--num-speculative-tokens",
"5",
# Artificially limit the draft model max model len; this forces vLLM
# to skip speculation once the sequences grow beyond 32-k tokens.
"--speculative-max-model-len",
"32",
"--speculative_config",
str({
"model": f"{SPEC_MODEL}",
"num_speculative_tokens": 5,
"max_model_len": 32,
}),
],
])
@pytest.mark.parametrize("batch_size", [8])

View File

@ -20,16 +20,19 @@ from .conftest import run_equality_correctness_test
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
"disable_logprobs_during_spec_decoding": False,
}, {
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
"disable_logprobs_during_spec_decoding": True,
}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
"disable_logprobs": False,
},
}, {
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
"disable_logprobs": True,
},
}])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize(
"output_len",
@ -48,19 +51,20 @@ def test_logprobs_equality(vllm_runner, common_llm_kwargs,
as well as with and without chunked prefill.
"""
maybe_enable_chunked_prefill(prefill_chunk_size, common_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0,
logprobs=logprobs,
prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs[
'disable_logprobs_during_spec_decoding'])
run_equality_correctness_test(
vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0,
logprobs=logprobs,
prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs["speculative_config"]
["disable_logprobs"])
@pytest.mark.parametrize(
@ -73,16 +77,19 @@ def test_logprobs_equality(vllm_runner, common_llm_kwargs,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": "JackFram/llama-160m",
"num_speculative_tokens": 3,
"disable_logprobs_during_spec_decoding": False,
}, {
"speculative_model": "JackFram/llama-160m",
"num_speculative_tokens": 6,
"disable_logprobs_during_spec_decoding": False,
}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"model": "JackFram/llama-160m",
"num_speculative_tokens": 3,
"disable_logprobs": False,
},
}, {
"speculative_config": {
"model": "JackFram/llama-160m",
"num_speculative_tokens": 6,
"disable_logprobs": False,
},
}])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize(
"output_len",
@ -98,18 +105,19 @@ def test_logprobs_different_k(vllm_runner, common_llm_kwargs,
output_len: int, seed: int, logprobs: int):
"""Veriy logprob greedy equality with different speculation lens.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0,
logprobs=logprobs,
disable_logprobs=test_llm_kwargs[
'disable_logprobs_during_spec_decoding'])
run_equality_correctness_test(
vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0,
logprobs=logprobs,
disable_logprobs=test_llm_kwargs["speculative_config"]
["disable_logprobs"])
@pytest.mark.parametrize(
@ -125,13 +133,15 @@ def test_logprobs_different_k(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize(
"test_llm_kwargs",
[{
"speculative_model": "JackFram/llama-160m",
"num_speculative_tokens": 3,
"disable_logprobs_during_spec_decoding": False,
# Artificially limit the draft model max model len; this forces vLLM
# to skip speculation once the sequences grow beyond 32-k tokens.
"speculative_max_model_len": 32,
"speculative_config": {
"model": "JackFram/llama-160m",
"num_speculative_tokens": 3,
"disable_logprobs": False,
# Artificially limit the draft model max model len; this forces
# vLLM to skip speculation once the sequences grow beyond 32-k
# tokens.
"max_model_len": 32,
},
}])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize(
@ -149,18 +159,19 @@ def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs,
seed: int, logprobs: int):
"""Verify logprobs greedy equality when some sequences skip speculation.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0,
logprobs=logprobs,
disable_logprobs=test_llm_kwargs[
'disable_logprobs_during_spec_decoding'])
run_equality_correctness_test(
vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0,
logprobs=logprobs,
disable_logprobs=test_llm_kwargs["speculative_config"]
["disable_logprobs"])
@pytest.mark.parametrize(
@ -173,12 +184,13 @@ def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": "JackFram/llama-160m",
"num_speculative_tokens": 3,
"disable_logprobs_during_spec_decoding": False,
}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"model": "JackFram/llama-160m",
"num_speculative_tokens": 3,
"disable_logprobs": False,
},
}])
@pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize(
"output_len",
@ -248,12 +260,13 @@ def test_logprobs_temp_1(vllm_runner, common_llm_kwargs,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
"disable_logprobs_during_spec_decoding": True,
}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
"disable_logprobs": True,
},
}])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize(
@ -270,15 +283,16 @@ def test_logprobs_disabled(vllm_runner, common_llm_kwargs,
"""Check the behavior when logprobs are disabled.
Token choices should match with the base model.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0,
logprobs=logprobs,
disable_logprobs=test_llm_kwargs[
'disable_logprobs_during_spec_decoding'])
run_equality_correctness_test(
vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0,
logprobs=logprobs,
disable_logprobs=test_llm_kwargs["speculative_config"]
["disable_logprobs"])

View File

@ -60,8 +60,10 @@ PRECISION = "float32"
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize("output_len", [
@ -107,14 +109,18 @@ def test_medusa_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs_during_spec_decoding": False,
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs": False,
},
},
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs_during_spec_decoding": True,
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs": True,
},
},
])
@pytest.mark.parametrize("output_len", [
@ -132,19 +138,20 @@ def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
prefill_chunk_size: int):
"""Verify greedy equality with different batch size."""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0,
logprobs=logprobs,
prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs[
'disable_logprobs_during_spec_decoding'])
run_equality_correctness_test(
vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0,
logprobs=logprobs,
prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs["speculative_config"]
["disable_logprobs"])
@pytest.mark.parametrize(
@ -165,8 +172,10 @@ def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize("output_len", [
@ -214,8 +223,10 @@ def test_medusa_e2e_greedy_correctness_cuda_graph(
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize(
@ -264,8 +275,10 @@ def test_medusa_e2e_greedy_correctness_with_preemption(
"test_llm_kwargs",
[
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": k,
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": k,
},
}
# Try a range of num. speculative tokens
for k in range(1, 1 + MAX_SPEC_TOKENS)
@ -312,12 +325,13 @@ def test_medusa_different_k(vllm_runner, common_llm_kwargs,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_disable_by_batch_size": 4
}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_by_batch_size": 4,
},
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
@ -359,16 +373,17 @@ def test_medusa_disable_queue(vllm_runner, common_llm_kwargs,
# Main model
"model_name": MAIN_MODEL,
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_disable_by_batch_size": 4
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_disable_mqa_scorer": True,
}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_by_batch_size": 4,
"disable_mqa_scorer": True,
},
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",

View File

@ -62,7 +62,9 @@ PRECISION = "float32"
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
"speculative_config": {
"model": SPEC_MODEL,
},
},
])
@pytest.mark.parametrize("output_len", [
@ -108,12 +110,16 @@ def test_mlp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
"disable_logprobs_during_spec_decoding": False,
"speculative_config": {
"model": SPEC_MODEL,
"disable_logprobs": False,
},
},
{
"speculative_model": SPEC_MODEL,
"disable_logprobs_during_spec_decoding": True,
"speculative_config": {
"model": SPEC_MODEL,
"disable_logprobs": True,
},
},
])
@pytest.mark.parametrize("output_len", [8])
@ -133,19 +139,20 @@ def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
# up sampling different tokens at the tail (ie top tokens don't change).
# TL;DR: sd+cp == org+cp but sd+cp != org..is this expected?
maybe_enable_chunked_prefill(prefill_chunk_size, baseline_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0,
logprobs=logprobs,
prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs[
'disable_logprobs_during_spec_decoding'])
run_equality_correctness_test(
vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0,
logprobs=logprobs,
prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs["speculative_config"]
["disable_logprobs"])
@pytest.mark.parametrize(
@ -167,7 +174,9 @@ def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
"speculative_config": {
"model": SPEC_MODEL,
},
},
])
@pytest.mark.parametrize("output_len", [2048])
@ -209,8 +218,10 @@ def test_mlp_e2e_acceptance_rate(vllm_runner, common_llm_kwargs,
# Main model
"model_name": MAIN_MODEL,
# Speculative model
"speculative_model": SPEC_MODEL,
# Speculative config
"speculative_config": {
"model": SPEC_MODEL,
},
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{"seed": 1}])
@ -274,7 +285,9 @@ def test_mlp_e2e_seeded_correctness(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
"speculative_config": {
"model": SPEC_MODEL,
},
},
])
@pytest.mark.parametrize(
@ -326,7 +339,9 @@ def test_mlp_e2e_greedy_correctness_with_preemption(
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
"speculative_config": {
"model": SPEC_MODEL,
},
},
])
@pytest.mark.parametrize(
@ -382,8 +397,10 @@ def test_mlp_e2e_greedy_correctness_with_padding(
"test_llm_kwargs",
[
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": k,
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": k,
},
}
# Try a range of num. speculative tokens
for k in range(1, 1 + MAX_SPEC_TOKENS)
@ -430,11 +447,12 @@ def test_mlp_different_k(vllm_runner, common_llm_kwargs,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": SPEC_MODEL,
"speculative_disable_by_batch_size": 4
}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"model": SPEC_MODEL,
"disable_by_batch_size": 4,
},
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
@ -475,14 +493,15 @@ def test_mlp_disable_queue(vllm_runner, common_llm_kwargs,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
"speculative_model": SPEC_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_disable_mqa_scorer": True,
}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"model": SPEC_MODEL,
"disable_mqa_scorer": True,
},
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",

View File

@ -57,7 +57,9 @@ PRECISION = "bfloat16"
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize("output_len", [
@ -99,12 +101,16 @@ def test_mtp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs_during_spec_decoding": False,
"speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs": False,
},
},
{
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs_during_spec_decoding": True,
"speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs": True,
},
},
])
@pytest.mark.parametrize("output_len", [
@ -119,18 +125,19 @@ def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
batch_size: int, output_len: int, seed: int,
logprobs: int):
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
logprobs=logprobs,
prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs[
'disable_logprobs_during_spec_decoding'])
run_equality_correctness_test(
vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
logprobs=logprobs,
prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs["speculative_config"]
["disable_logprobs"])
@pytest.mark.parametrize(
@ -152,7 +159,9 @@ def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize("output_len", [
@ -198,7 +207,9 @@ def test_mtp_e2e_greedy_correctness_cuda_graph(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize(
@ -243,7 +254,9 @@ def test_mtp_e2e_greedy_correctness_with_preemption(
"test_llm_kwargs",
[
{
"num_speculative_tokens": k,
"speculative_config": {
"num_speculative_tokens": k,
},
}
# Try a range of num. speculative tokens
for k in range(1, 1 + MAX_SPEC_TOKENS)
@ -286,11 +299,12 @@ def test_mtp_different_k(vllm_runner, common_llm_kwargs,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_disable_by_batch_size": 4
}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_by_batch_size": 4
},
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",

View File

@ -61,15 +61,19 @@ from .conftest import (get_output_from_llm_generator,
"per_test_common_llm_kwargs",
[
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": False,
},
{
# Chunked prefill enabled with small value
# to make sure we get mixed batches.
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
@ -148,20 +152,23 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator,
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": False,
"disable_logprobs_during_spec_decoding": False
}, {
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4,
"disable_logprobs_during_spec_decoding": False
}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"disable_logprobs": False,
},
"enable_chunked_prefill": False,
}, {
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
"disable_logprobs": False,
},
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4,
}])
@pytest.mark.parametrize(
"output_len",
[
@ -184,7 +191,7 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
whether all speculative tokens are accepted.
"""
ensure_all_accepted = per_test_common_llm_kwargs.get(
"model_name") == test_llm_kwargs.get("speculative_model")
"model_name") == test_llm_kwargs.get("speculative_config")["model"]
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
@ -224,13 +231,17 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
@ -283,13 +294,17 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs(
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
@ -336,13 +351,17 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len(
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
@ -391,13 +410,17 @@ def test_spec_decode_e2e_greedy_correctness_real_model_bs1(
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
@ -449,13 +472,17 @@ def test_spec_decode_e2e_greedy_correctness_real_model_large_bs(
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
@ -514,13 +541,17 @@ def test_spec_decode_e2e_greedy_correctness_with_preemption(
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
@ -567,21 +598,25 @@ def test_spec_decode_different_block_size(vllm_runner, common_llm_kwargs,
"test_llm_kwargs",
[
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
# Artificially limit the draft model max model len; this forces vLLM
# to skip speculation once the sequences grow beyond 32-k tokens.
"speculative_max_model_len": 32,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"max_model_len": 32,
},
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"max_model_len": 32,
},
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4,
"speculative_max_model_len": 32,
},
])
@pytest.mark.parametrize("batch_size", [8])
@ -627,15 +662,19 @@ def test_skip_speculation(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_disable_by_batch_size": 2,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"disable_by_batch_size": 2,
},
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_disable_by_batch_size": 2,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"disable_by_batch_size": 2,
},
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4,
@ -676,15 +715,19 @@ def test_disable_speculation(vllm_runner, common_llm_kwargs,
"test_llm_kwargs",
[
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": k,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": k,
},
"enable_chunked_prefill": False,
}
# Try a range of common k, as well as large speculation.
for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63]
] + [{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": k,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": k,
},
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4,
@ -729,17 +772,21 @@ def test_many_k(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
"test_llm_kwargs",
[
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": k,
"spec_decoding_acceptance_method": "typical_acceptance_sampler",
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": k,
"acceptance_method": "typical_acceptance_sampler",
},
"enable_chunked_prefill": False
}
# Try a range of common k.
for k in [1, 2, 3]
] + [{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": k,
"spec_decoding_acceptance_method": "typical_acceptance_sampler",
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": k,
"acceptance_method": "typical_acceptance_sampler",
},
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4

View File

@ -48,16 +48,20 @@ from .conftest import run_equality_correctness_test
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
"speculative_disable_mqa_scorer": False,
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": 5,
"prompt_lookup_max": 3,
"disable_mqa_scorer": False,
},
},
{
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
"speculative_disable_mqa_scorer": True,
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": 5,
"prompt_lookup_max": 3,
"disable_mqa_scorer": True,
},
},
])
@pytest.mark.parametrize("output_len", [
@ -101,16 +105,20 @@ def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
"disable_logprobs_during_spec_decoding": False,
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": 5,
"prompt_lookup_max": 3,
"disable_logprobs": False,
},
},
{
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
"disable_logprobs_during_spec_decoding": True,
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": 5,
"prompt_lookup_max": 3,
"disable_logprobs": True,
},
},
])
@pytest.mark.parametrize("output_len", [
@ -125,19 +133,20 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
batch_size: int, output_len: int, seed: int,
logprobs: int):
"""Verify greedy equality on a tiny model with different batch size."""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0,
logprobs=logprobs,
prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs[
'disable_logprobs_during_spec_decoding'])
run_equality_correctness_test(
vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0,
logprobs=logprobs,
prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs["speculative_config"]
["disable_logprobs"])
@pytest.mark.parametrize(
@ -159,17 +168,21 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": 5,
"prompt_lookup_max": 3,
},
"enable_chunked_prefill": False,
},
{
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": 5,
"prompt_lookup_max": 3,
"disable_mqa_scorer": True,
},
"enable_chunked_prefill": True,
"speculative_disable_mqa_scorer": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
},
@ -214,17 +227,21 @@ def test_ngram_e2e_greedy_correctness_with_preemption(
"test_llm_kwargs",
[
{
"speculative_model": "[ngram]",
"num_speculative_tokens": k,
"ngram_prompt_lookup_max": 3,
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": k,
"prompt_lookup_max": 3,
},
}
# Try a range of common k, as well as large speculation.
for k in [1, 3, 5]
] + [
{
"speculative_model": "[ngram]",
"num_speculative_tokens": k,
"ngram_prompt_lookup_max": 1,
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": k,
"prompt_lookup_max": 1,
},
}
# Try a range of common k, as well as large speculation.
for k in [1, 3, 5]
@ -243,7 +260,7 @@ def test_ngram_different_k(vllm_runner, common_llm_kwargs,
seed: int):
"""Verify that ngram speculative decoding produces exact equality
to without spec decode with many different values of k and
different ngram_prompt_lookup_max.
different ngram prompt_lookup_max.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
@ -266,22 +283,25 @@ def test_ngram_different_k(vllm_runner, common_llm_kwargs,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
"speculative_disable_by_batch_size": 4
}, {
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
"speculative_disable_by_batch_size": 4,
"enable_chunked_prefill": True,
"speculative_disable_mqa_scorer": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": 5,
"prompt_lookup_max": 3,
"disable_by_batch_size": 4
},
}, {
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": 5,
"prompt_lookup_max": 3,
"disable_by_batch_size": 4,
"disable_mqa_scorer": True,
},
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
@ -296,7 +316,7 @@ def test_ngram_disable_queue(vllm_runner, common_llm_kwargs,
seed: int):
"""Verify that ngram speculative decoding produces exact equality
to without spec decode with many different values of k and
different ngram_prompt_lookup_max.
different ngram prompt_lookup_max.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
@ -316,18 +336,17 @@ def test_ngram_disable_queue(vllm_runner, common_llm_kwargs,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_disable_mqa_scorer": True,
}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": 5,
"prompt_lookup_max": 3,
"disable_mqa_scorer": True,
},
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",

View File

@ -19,11 +19,11 @@ SPEC_MODEL = "JackFram/llama-160m"
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# speculative model
"speculative_model": "JackFram/llama-160m",
# num speculative tokens
"num_speculative_tokens": 3,
# speculative config
"speculative_config": {
"model": "JackFram/llama-160m",
"num_speculative_tokens": 3,
},
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{"seed": 1}])

View File

@ -70,12 +70,16 @@ def test_ngram_correctness(
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm
spec_llm = LLM(model=model_name,
speculative_model='[ngram]',
ngram_prompt_lookup_max=5,
ngram_prompt_lookup_min=3,
num_speculative_tokens=3,
max_model_len=1024)
spec_llm = LLM(
model=model_name,
speculative_config={
"method": "ngram",
"prompt_lookup_max": 5,
"prompt_lookup_min": 3,
"num_speculative_tokens": 3,
},
max_model_len=1024,
)
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
matches = 0
misses = 0

View File

@ -1810,12 +1810,139 @@ class DeviceConfig:
self.device = torch.device(self.device_type)
@dataclass
class SpeculativeConfig:
"""Configuration for speculative decoding.
The configuration is currently specialized to draft-model speculative
decoding with top-1 proposals.
"""
Configuration for speculative decoding.
Configurable parameters include:
- General Speculative Decoding Control:
- num_speculative_tokens (int): The number of speculative
tokens, if provided. It will default to the number in the draft
model config if present, otherwise, it is required.
- model (Optional[str]): The name of the draft model, eagle head,
or additional weights, if provided.
- method (Optional[str]): The name of the speculative method to use.
If users provide and set the `model` param, the speculative method
type will be detected automatically if possible, if `model` param
is not provided, the method name must be provided.
- Possible values:
- ngram
Related additional configuration:
- prompt_lookup_max (Optional[int]):
Maximum size of ngram token window when using Ngram
proposer, required when method is set to ngram.
- prompt_lookup_min (Optional[int]):
Minimum size of ngram token window when using Ngram
proposer, if provided. Defaults to 1.
- eagle
- medusa
- mlp_speculator
- draft_model
- acceptance_method (str): The method to use for accepting draft
tokens. This can take two possible values: 'rejection_sampler' and
'typical_acceptance_sampler' for RejectionSampler and
TypicalAcceptanceSampler respectively. If not specified, it
defaults to 'rejection_sampler'.
- Possible values:
- rejection_sampler
- typical_acceptance_sampler
Related additional configuration:
- posterior_threshold (Optional[float]):
A threshold value that sets a lower bound on the
posterior probability of a token in the target model
for it to be accepted. This threshold is used only
when we use the TypicalAcceptanceSampler for token
acceptance.
- posterior_alpha (Optional[float]):
Scaling factor for entropy-based threshold, applied
when using TypicalAcceptanceSampler.
- draft_tensor_parallel_size (Optional[int]): The degree of the tensor
parallelism for the draft model. Can only be 1 or the same as the
target model's tensor parallel size.
- disable_logprobs (bool): If set to True, token log probabilities are
not returned during speculative decoding. If set to False, token
log probabilities are returned according to the log probability
settings in SamplingParams. If not specified, it defaults to True.
- Draft Model Configuration:
- quantization (Optional[str]): Quantization method that was used to
quantize the draft model weights. If None, we assume the
model weights are not quantized. Note that it only takes effect
when using the draft model-based speculative method.
- max_model_len (Optional[int]): The maximum model length of the
draft model. Used when testing the ability to skip
speculation for some sequences.
- revision: The specific model version to use for the draft model. It
can be a branch name, a tag name, or a commit id. If unspecified,
will use the default version.
- code_revision: The specific revision to use for the draft model code
on Hugging Face Hub. It can be a branch name, a tag name, or a
commit id. If unspecified, will use the default version.
- Advanced Control:
- disable_mqa_scorer (bool): Disable the MQA scorer and fall back to
batch expansion for scoring proposals. If not specified, it
defaults to False.
- disable_by_batch_size (Optional[int]): Disable speculative decoding
for new incoming requests when the number of enqueued requests is
larger than this value, if provided.
Although the parameters above are structured hierarchically, there is no
need to nest them during configuration.
Non-configurable internal parameters include:
- Model Configuration:
- target_model_config (ModelConfig): The configuration of the target
model.
- draft_model_config (ModelConfig): The configuration of the draft
model initialized internal.
- Parallelism Configuration:
- target_parallel_config (ParallelConfig): The parallel configuration
for the target model.
- draft_parallel_config (ParallelConfig): The parallel configuration
for the draft model initialized internal.
- Execution Control:
- enable_chunked_prefill (bool): Whether vLLM is configured to use
chunked prefill or not. Used for raising an error since it's not
yet compatible with speculative decode.
- disable_log_stats (bool): Whether to disable the periodic printing of
stage times in speculative decoding.
"""
# speculative configs from cli args
num_speculative_tokens: int = field(default=None,
init=True) # type: ignore
method: Optional[str] = None
acceptance_method: str = "rejection_sampler"
draft_tensor_parallel_size: Optional[int] = None
disable_logprobs: bool = True
model: Optional[str] = None
quantization: Optional[str] = None
max_model_len: Optional[int] = None
revision: Optional[str] = None
code_revision: Optional[str] = None
disable_mqa_scorer: bool = False
disable_by_batch_size: Optional[int] = None
prompt_lookup_max: Optional[int] = None
prompt_lookup_min: Optional[int] = None
posterior_threshold: Optional[float] = None
posterior_alpha: Optional[float] = None
# required configuration params passed from engine
target_model_config: ModelConfig = field(default=None,
init=True) # type: ignore
target_parallel_config: ParallelConfig = field(default=None,
init=True) # type: ignore
enable_chunked_prefill: bool = field(default=None,
init=True) # type: ignore
disable_log_stats: bool = field(default=None, init=True) # type: ignore
# params generated in the post-init stage
draft_model_config: ModelConfig = field(default=None,
init=True) # type: ignore
draft_parallel_config: ParallelConfig = field(default=None,
init=True) # type: ignore
def compute_hash(self) -> str:
"""
@ -1835,6 +1962,11 @@ class SpeculativeConfig:
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
return hash_str
@classmethod
def from_dict(cls, dict_value: dict) -> "SpeculativeConfig":
"""Parse the CLI value for the speculative config."""
return cls(**dict_value)
@staticmethod
def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
if hf_config.model_type == "deepseek_v3":
@ -1847,230 +1979,160 @@ class SpeculativeConfig:
})
return hf_config
@staticmethod
def maybe_create_spec_config(
target_model_config: ModelConfig,
target_parallel_config: ParallelConfig,
target_dtype: str,
speculative_model: Optional[str],
speculative_model_quantization: Optional[str],
speculative_draft_tensor_parallel_size: Optional[int],
num_speculative_tokens: Optional[int],
speculative_disable_mqa_scorer: Optional[bool],
speculative_max_model_len: Optional[int],
enable_chunked_prefill: bool,
disable_log_stats: bool,
speculative_disable_by_batch_size: Optional[int],
ngram_prompt_lookup_max: Optional[int],
ngram_prompt_lookup_min: Optional[int],
draft_token_acceptance_method: str,
typical_acceptance_sampler_posterior_threshold: Optional[float],
typical_acceptance_sampler_posterior_alpha: Optional[float],
disable_logprobs: Optional[bool],
) -> Optional["SpeculativeConfig"]:
"""Create a SpeculativeConfig if possible, else return None.
def __post_init__(self):
This function attempts to create a SpeculativeConfig object based on the
provided parameters. If the necessary conditions are met, it returns an
instance of SpeculativeConfig. Otherwise, it returns None.
# Note: After next release, the method parameter will be used to
# specify the speculative method, which helps to extend the
# configuration of non-model-based proposers, and the model parameter
# will be used when the draft model or head is needed.
# If users do not specify the method, the speculative method will
# be detected automatically if possible. If the speculative method can
# not be detected, it will be considered as the draft-model-based
# method by default.
Args:
target_model_config (ModelConfig): The configuration of the target
model.
target_parallel_config (ParallelConfig): The parallel configuration
for the target model.
target_dtype (str): The data type used for the target model.
speculative_model (Optional[str]): The name of the speculative
model, if provided.
speculative_model_quantization (Optional[str]): Quantization method
that was used to quantize the speculative model weights. If
None, we assume the model weights are not quantized.
speculative_draft_tensor_parallel_size (Optional[int]): The degree
of the tensor parallelism for the draft model.
num_speculative_tokens (Optional[int]): The number of speculative
tokens, if provided. Will default to the number in the draft
model config if present, otherwise is required.
speculative_disable_mqa_scorer (Optional[bool]): Disable the MQA
scorer for the speculative model and fall back to batch
expansion for scoring.
speculative_max_model_len (Optional[int]): The maximum model len of
the speculative model. Used when testing the ability to skip
speculation for some sequences.
enable_chunked_prefill (bool): Whether vLLM is configured to use
chunked prefill or not. Used for raising an error since its not
yet compatible with spec decode.
speculative_disable_by_batch_size (Optional[int]): Disable
speculative decoding for new incoming requests when the number
of enqueue requests is larger than this value, if provided.
ngram_prompt_lookup_max (Optional[int]): Max size of ngram token
window, if provided.
ngram_prompt_lookup_min (Optional[int]): Min size of ngram token
window, if provided.
draft_token_acceptance_method (str): The method to use for
accepting draft tokens. This can take two possible
values 'rejection_sampler' and 'typical_acceptance_sampler'
for RejectionSampler and TypicalAcceptanceSampler
respectively.
typical_acceptance_sampler_posterior_threshold (Optional[float]):
A threshold value that sets a lower bound on the posterior
probability of a token in the target model for it to be
accepted. This threshold is used only when we use the
TypicalAcceptanceSampler for token acceptance.
typical_acceptance_sampler_posterior_alpha (Optional[float]):
A scaling factor for the entropy-based threshold in the
TypicalAcceptanceSampler.
disable_logprobs (Optional[bool]): If set to True, token log
probabilities are not returned during speculative decoding.
If set to False, token log probabilities are returned
according to the log probability settings in SamplingParams.
If not specified, it defaults to True.
Returns:
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
the necessary conditions are met, else None.
"""
if speculative_model is None:
if num_speculative_tokens is not None:
if target_model_config.hf_text_config.model_type \
if self.model is None and self.num_speculative_tokens is not None:
# TODO(Shangming): Refactor mtp configuration logic when supporting
# mtp acceleration for more models besides deepseek_v3
if self.target_model_config.hf_text_config.model_type \
== "deepseek_v3":
# use the draft model from the same model:
speculative_model = target_model_config.model
else:
raise ValueError(
"num_speculative_tokens was provided without "
"speculative_model.")
# use the draft model from the same model:
self.model = self.target_model_config.model
elif self.method in ("ngram", "[ngram]"):
self.model = "ngram"
else:
return None
raise ValueError("num_speculative_tokens was provided without "
"speculative model.")
if (speculative_disable_by_batch_size is not None
and speculative_disable_by_batch_size < 2):
raise ValueError("Expect the batch size threshold of disabling "
"speculative decoding is > 1, but got "
f"{speculative_disable_by_batch_size=}")
if (enable_chunked_prefill and speculative_model == "eagle"):
raise ValueError("Chunked prefill and EAGLE are not compatible.")
# TODO: The user should be able to specify revision/max model len
# for the draft model. It is not currently supported.
draft_revision = None
draft_code_revision = None
draft_quantization = speculative_model_quantization
# Automatically configure the ngram method during configuration
# refactoring to ensure a smooth transition.
if self.method is None and (self.model is not None
and self.model in ("ngram", "[ngram]")):
self.method = "ngram"
if speculative_model == "[ngram]":
if ngram_prompt_lookup_min is None:
ngram_prompt_lookup_min = 1
if ngram_prompt_lookup_max is None or ngram_prompt_lookup_max < 1:
raise ValueError(f"{ngram_prompt_lookup_max=} must be > 0")
if ngram_prompt_lookup_min < 1:
raise ValueError(f"{ngram_prompt_lookup_min=} must be > 0")
if ngram_prompt_lookup_min > ngram_prompt_lookup_max:
raise ValueError(f"{ngram_prompt_lookup_min=} cannot be "
f"larger than {ngram_prompt_lookup_max=}")
if self.method in ("ngram", "[ngram]"):
# Unified to "ngram" internally
self.method = "ngram"
if self.prompt_lookup_min is None:
self.prompt_lookup_min = 1
if self.prompt_lookup_max is None or self.prompt_lookup_max < 1:
raise ValueError("prompt_lookup_max="
f"{self.prompt_lookup_max} must be > 0")
if self.prompt_lookup_min < 1:
raise ValueError("prompt_lookup_min="
f"{self.prompt_lookup_min} must be > 0")
if self.prompt_lookup_min > self.prompt_lookup_max:
raise ValueError(f"prompt_lookup_min={self.prompt_lookup_min} "
"cannot be larger than prompt_lookup_max="
f"{self.prompt_lookup_max}")
# TODO: current we still need extract vocab_size from target model
# config, in future, we may try refactor it out, and set
# draft related config as None here.
draft_model_config = target_model_config
draft_parallel_config = target_parallel_config
self.draft_model_config = self.target_model_config
self.draft_parallel_config = self.target_parallel_config
else:
ngram_prompt_lookup_max = 0
ngram_prompt_lookup_min = 0
draft_model_config = ModelConfig(
model=speculative_model,
task="draft",
tokenizer=target_model_config.tokenizer,
tokenizer_mode=target_model_config.tokenizer_mode,
trust_remote_code=target_model_config.trust_remote_code,
allowed_local_media_path=target_model_config.
allowed_local_media_path,
dtype=target_model_config.dtype,
seed=target_model_config.seed,
revision=draft_revision,
code_revision=draft_code_revision,
tokenizer_revision=target_model_config.tokenizer_revision,
max_model_len=None,
spec_target_max_model_len=target_model_config.max_model_len,
quantization=draft_quantization,
enforce_eager=target_model_config.enforce_eager,
max_seq_len_to_capture=target_model_config.
max_seq_len_to_capture,
max_logprobs=target_model_config.max_logprobs,
hf_overrides=SpeculativeConfig.hf_config_override,
)
self.prompt_lookup_max = 0
self.prompt_lookup_min = 0
draft_hf_config = draft_model_config.hf_config
if self.model is not None:
self.draft_model_config = ModelConfig(
model=self.model,
task="draft",
tokenizer=self.target_model_config.tokenizer,
tokenizer_mode=self.target_model_config.tokenizer_mode,
trust_remote_code=self.target_model_config.
trust_remote_code,
allowed_local_media_path=self.target_model_config.
allowed_local_media_path,
dtype=self.target_model_config.dtype,
seed=self.target_model_config.seed,
revision=self.revision,
code_revision=self.code_revision,
tokenizer_revision=self.target_model_config.
tokenizer_revision,
max_model_len=None,
spec_target_max_model_len=self.target_model_config.
max_model_len,
quantization=self.quantization,
enforce_eager=self.target_model_config.enforce_eager,
max_seq_len_to_capture=self.target_model_config.
max_seq_len_to_capture,
max_logprobs=self.target_model_config.max_logprobs,
hf_overrides=SpeculativeConfig.hf_config_override,
)
# Detect EAGLE prefix to replace hf_config for EAGLE draft_model
if "eagle-" in draft_model_config.model.lower():
from vllm.transformers_utils.configs.eagle import EAGLEConfig
if isinstance(draft_model_config.hf_config, EAGLEConfig):
pass
# Automatically detect the method
if "eagle-" in self.draft_model_config.model.lower():
self.method = "eagle"
elif self.draft_model_config.hf_config.model_type == "medusa":
self.method = "medusa"
elif (self.draft_model_config.hf_config.model_type ==
"mlp_speculator"):
self.method = "mlp_speculator"
else:
eagle_config = EAGLEConfig(draft_model_config.hf_config)
draft_model_config.hf_config = eagle_config
self.method = "draft_model"
if (num_speculative_tokens is not None
and hasattr(draft_hf_config, "num_lookahead_tokens")):
draft_hf_config.num_lookahead_tokens = num_speculative_tokens
n_predict = getattr(draft_hf_config, "n_predict", None)
if n_predict is not None:
if num_speculative_tokens is None:
# Default to max value defined in draft model config.
num_speculative_tokens = n_predict
elif num_speculative_tokens > n_predict and \
num_speculative_tokens % n_predict != 0:
# Ensure divisibility for MTP module reuse.
raise ValueError(
f"{num_speculative_tokens=} must be divisible by "
f"{n_predict=}")
# Replace hf_config for EAGLE draft_model
if self.method == "eagle":
if self.enable_chunked_prefill:
raise ValueError(
"Chunked prefill and EAGLE are not compatible.")
speculative_draft_tensor_parallel_size = \
SpeculativeConfig._verify_and_get_draft_model_tensor_parallel_size(
target_parallel_config,
speculative_draft_tensor_parallel_size,
draft_hf_config
)
from vllm.transformers_utils.configs.eagle import (
EAGLEConfig)
if isinstance(self.draft_model_config.hf_config,
EAGLEConfig):
pass
else:
eagle_config = EAGLEConfig(
self.draft_model_config.hf_config)
self.draft_model_config.hf_config = eagle_config
draft_model_config.max_model_len = (
SpeculativeConfig._maybe_override_draft_max_model_len(
speculative_max_model_len,
draft_model_config.max_model_len,
target_model_config.max_model_len,
))
if (self.num_speculative_tokens is not None
and hasattr(self.draft_model_config.hf_config,
"num_lookahead_tokens")):
self.draft_model_config.hf_config.num_lookahead_tokens = \
self.num_speculative_tokens
draft_parallel_config = (
SpeculativeConfig.create_draft_parallel_config(
target_parallel_config,
speculative_draft_tensor_parallel_size, draft_hf_config))
n_predict = getattr(self.draft_model_config.hf_config,
"n_predict", None)
if n_predict is not None:
if self.num_speculative_tokens is None:
# Default to max value defined in draft model config.
self.num_speculative_tokens = n_predict
elif self.num_speculative_tokens > n_predict and \
self.num_speculative_tokens % n_predict != 0:
# Ensure divisibility for MTP module reuse.
raise ValueError(
f"num_speculative_tokens:{self.num_speculative_tokens}"
f" must be divisible by {n_predict=}")
if num_speculative_tokens is None:
raise ValueError(
"num_speculative_tokens must be provided with "
"speculative_model unless the draft model config contains an "
"n_predict parameter.")
self.draft_tensor_parallel_size = \
SpeculativeConfig._verify_and_get_draft_tp(
self.target_parallel_config,
self.draft_tensor_parallel_size,
self.draft_model_config.hf_config
)
if typical_acceptance_sampler_posterior_threshold is None:
typical_acceptance_sampler_posterior_threshold = 0.09
if typical_acceptance_sampler_posterior_alpha is None:
typical_acceptance_sampler_posterior_alpha = 0.3
if disable_logprobs is None:
disable_logprobs = True
self.draft_model_config.max_model_len = (
SpeculativeConfig._maybe_override_draft_max_model_len(
self.max_model_len,
self.draft_model_config.max_model_len,
self.target_model_config.max_model_len,
))
return SpeculativeConfig(
draft_model_config,
draft_parallel_config,
num_speculative_tokens,
speculative_disable_mqa_scorer,
speculative_disable_by_batch_size,
ngram_prompt_lookup_max,
ngram_prompt_lookup_min,
draft_token_acceptance_method=draft_token_acceptance_method,
typical_acceptance_sampler_posterior_threshold=\
typical_acceptance_sampler_posterior_threshold,
typical_acceptance_sampler_posterior_alpha=\
typical_acceptance_sampler_posterior_alpha,
disable_logprobs=disable_logprobs,
disable_log_stats=disable_log_stats,
)
self.draft_parallel_config = (
SpeculativeConfig.create_draft_parallel_config(
self.target_parallel_config,
self.draft_tensor_parallel_size))
if self.acceptance_method == "typical_acceptance_sampler":
if self.posterior_threshold is None:
self.posterior_threshold = 0.09
if self.posterior_alpha is None:
self.posterior_alpha = 0.3
self._verify_args()
@staticmethod
def _maybe_override_draft_max_model_len(
@ -2108,7 +2170,7 @@ class SpeculativeConfig:
)
@staticmethod
def _verify_and_get_draft_model_tensor_parallel_size(
def _verify_and_get_draft_tp(
target_parallel_config: ParallelConfig,
speculative_draft_tensor_parallel_size: Optional[int],
draft_hf_config: PretrainedConfig) -> int:
@ -2140,7 +2202,6 @@ class SpeculativeConfig:
def create_draft_parallel_config(
target_parallel_config: ParallelConfig,
speculative_draft_tensor_parallel_size: int,
draft_hf_config: PretrainedConfig,
) -> ParallelConfig:
"""Create a parallel config for use by the draft worker.
@ -2164,74 +2225,13 @@ class SpeculativeConfig:
return draft_parallel_config
def __init__(
self,
draft_model_config: ModelConfig,
draft_parallel_config: ParallelConfig,
num_speculative_tokens: int,
speculative_disable_mqa_scorer: Optional[bool],
speculative_disable_by_batch_size: Optional[int],
ngram_prompt_lookup_max: Optional[int],
ngram_prompt_lookup_min: Optional[int],
draft_token_acceptance_method: str,
typical_acceptance_sampler_posterior_threshold: float,
typical_acceptance_sampler_posterior_alpha: float,
disable_logprobs: bool,
disable_log_stats: bool,
):
"""Create a SpeculativeConfig object.
Args:
draft_model_config: ModelConfig for the draft model.
draft_parallel_config: ParallelConfig for the draft model.
num_speculative_tokens: The number of tokens to sample from the
draft model before scoring with the target model.
speculative_disable_by_batch_size: Disable speculative
decoding for new incoming requests when the number of
enqueue requests is larger than this value.
ngram_prompt_lookup_max: Max size of ngram token window.
ngram_prompt_lookup_min: Min size of ngram token window.
draft_token_acceptance_method (str): The method to use for
accepting draft tokens. This can take two possible
values 'rejection_sampler' and 'typical_acceptance_sampler'
for RejectionSampler and TypicalAcceptanceSampler
respectively.
typical_acceptance_sampler_posterior_threshold (Optional[float]):
A threshold value that sets a lower bound on the posterior
probability of a token in the target model for it to be
accepted. This threshold is used only when we use the
TypicalAcceptanceSampler for token acceptance.
typical_acceptance_sampler_posterior_alpha (Optional[float]):
A scaling factor for the entropy-based threshold in the
TypicalAcceptanceSampler.
disable_logprobs: If set to True, token log probabilities will not
be returned even if requested by sampling parameters. This
reduces latency by skipping logprob calculation in proposal
sampling, target sampling, and after accepted tokens are
determined. If set to False, log probabilities will be
returned.
disable_log_stats: Whether to disable periodic printing of stage
times in speculative decoding.
"""
self.draft_model_config = draft_model_config
self.draft_parallel_config = draft_parallel_config
self.num_speculative_tokens = num_speculative_tokens
self.speculative_disable_mqa_scorer = speculative_disable_mqa_scorer
self.speculative_disable_by_batch_size = \
speculative_disable_by_batch_size
self.ngram_prompt_lookup_max = ngram_prompt_lookup_max or 0
self.ngram_prompt_lookup_min = ngram_prompt_lookup_min or 0
self.draft_token_acceptance_method = draft_token_acceptance_method
self.typical_acceptance_sampler_posterior_threshold = \
typical_acceptance_sampler_posterior_threshold
self.typical_acceptance_sampler_posterior_alpha = \
typical_acceptance_sampler_posterior_alpha
self.disable_logprobs = disable_logprobs
self.disable_log_stats = disable_log_stats
self._verify_args()
def _verify_args(self) -> None:
if self.num_speculative_tokens is None:
raise ValueError(
"num_speculative_tokens must be provided with "
"speculative model unless the draft model config contains an "
"n_predict parameter.")
if self.num_speculative_tokens <= 0:
raise ValueError("Expected num_speculative_tokens to be greater "
f"than zero ({self.num_speculative_tokens}).")
@ -2241,29 +2241,34 @@ class SpeculativeConfig:
self.draft_parallel_config)
# Validate and set draft token acceptance related settings.
if (self.draft_token_acceptance_method is None):
raise ValueError("draft_token_acceptance_method is not set. "
if self.acceptance_method is None:
raise ValueError("acceptance_method is not set. "
"Expected values are rejection_sampler or "
"typical_acceptance_sampler.")
if (self.draft_token_acceptance_method != 'rejection_sampler'
and self.draft_token_acceptance_method
!= 'typical_acceptance_sampler'):
if (self.acceptance_method != 'rejection_sampler'
and self.acceptance_method != 'typical_acceptance_sampler'):
raise ValueError(
"Expected draft_token_acceptance_method to be either "
"Expected acceptance_method to be either "
"rejection_sampler or typical_acceptance_sampler. Instead it "
f"is {self.draft_token_acceptance_method}")
f"is {self.acceptance_method}")
if (self.typical_acceptance_sampler_posterior_threshold < 0
or self.typical_acceptance_sampler_posterior_alpha < 0):
if self.acceptance_method == "typical_acceptance_sampler" and (
(self.posterior_threshold is not None
and self.posterior_threshold < 0) or
(self.posterior_alpha is not None and self.posterior_alpha < 0)):
raise ValueError(
"Expected typical_acceptance_sampler_posterior_threshold "
"and typical_acceptance_sampler_posterior_alpha to be > 0. "
"Instead found "
f"typical_acceptance_sampler_posterior_threshold = "
f"{self.typical_acceptance_sampler_posterior_threshold} and "
f"typical_acceptance_sampler_posterior_alpha = "
f"{self.typical_acceptance_sampler_posterior_alpha}")
"Expected the posterior_threshold and posterior_alpha of "
"typical_acceptance_sampler to be > 0. "
"Instead found posterior_threshold = "
f"{self.posterior_threshold} and posterior_alpha = "
f"{self.posterior_alpha}")
if (self.disable_by_batch_size is not None
and self.disable_by_batch_size < 2):
raise ValueError("Expect the batch size threshold of disabling "
"speculative decoding is > 1, but got "
f"{self.disable_by_batch_size=}")
@property
def num_lookahead_slots(self) -> int:
@ -2276,8 +2281,8 @@ class SpeculativeConfig:
return self.num_speculative_tokens
def __repr__(self) -> str:
if self.ngram_prompt_lookup_max > 0:
draft_model = "[ngram]"
if self.prompt_lookup_max is not None and self.prompt_lookup_max > 0:
draft_model = "ngram"
else:
draft_model = self.draft_model_config.model
num_spec_tokens = self.num_speculative_tokens
@ -3285,7 +3290,8 @@ class VllmConfig:
init=True) # type: ignore
load_config: LoadConfig = field(default=None, init=True) # type: ignore
lora_config: Optional[LoRAConfig] = None
speculative_config: Optional[SpeculativeConfig] = None
speculative_config: SpeculativeConfig = field(default=None,
init=True) # type: ignore
decoding_config: Optional[DecodingConfig] = None
observability_config: Optional[ObservabilityConfig] = None
prompt_adapter_config: Optional[PromptAdapterConfig] = None

View File

@ -177,7 +177,10 @@ class EngineArgs:
guided_decoding_backend: str = 'xgrammar'
logits_processor_pattern: Optional[str] = None
# Speculative decoding configuration.
speculative_config: Optional[Union[str, Dict[str, Any]]] = None
# TODO(Shangming): Deprecate these out-of-date params after next release
speculative_model: Optional[str] = None
speculative_model_quantization: Optional[str] = None
speculative_draft_tensor_parallel_size: Optional[int] = None
@ -190,9 +193,9 @@ class EngineArgs:
spec_decoding_acceptance_method: str = 'rejection_sampler'
typical_acceptance_sampler_posterior_threshold: Optional[float] = None
typical_acceptance_sampler_posterior_alpha: Optional[float] = None
qlora_adapter_name_or_path: Optional[str] = None
disable_logprobs_during_spec_decoding: Optional[bool] = None
qlora_adapter_name_or_path: Optional[str] = None
show_hidden_metrics_for_version: Optional[str] = None
otlp_traces_endpoint: Optional[str] = None
collect_detailed_traces: Optional[str] = None
@ -780,7 +783,11 @@ class EngineArgs:
const="True",
help='If set, the prefill requests can be chunked based on the '
'max_num_batched_tokens.')
parser.add_argument('--speculative-config',
type=nullable_str,
default=None,
help='The configurations for speculative decoding.'
' Should be a JSON string.')
parser.add_argument(
'--speculative-model',
type=nullable_str,
@ -1192,6 +1199,82 @@ class EngineArgs:
use_tqdm_on_load=self.use_tqdm_on_load,
)
def create_speculative_config(
self,
target_model_config: ModelConfig,
target_parallel_config: ParallelConfig,
enable_chunked_prefill: bool,
disable_log_stats: bool,
) -> Optional["SpeculativeConfig"]:
"""Initializes and returns a SpeculativeConfig object based on
`speculative_config`.
This function utilizes `speculative_config` to create a
SpeculativeConfig object. The `speculative_config` can either be
provided as a JSON string input via CLI arguments or directly as a
dictionary from the engine. If `speculative_config` is not set, this
function will attempt to construct a configuration dictionary using
certain parameters, which are scheduled for deprecation in the next
release. Note that in next releases, `speculative_config` must be
provided, and the deprecated standalone speculative-related parameters
will be removed.
"""
if self.speculative_config is None:
if (self.speculative_model is None
and self.num_speculative_tokens is None):
return None
# TODO(Shangming): Deprecate this way of setting SpeculativeConfig,
# only allow '--speculative-config' after next release
logger.warning_once(
"Please use '--speculative-config' to set all configurations "
"related to speculative decoding. The current method of "
"specifying the model through '--speculative-model' and "
"adding related parameters (e.g., '--num-speculative-tokens') "
"separately will be deprecated in the next release.")
spec_config_dict = {
"model": self.speculative_model,
"quantization": self.speculative_model_quantization,
"max_model_len": self.speculative_max_model_len,
"draft_tensor_parallel_size":
self.speculative_draft_tensor_parallel_size,
"num_speculative_tokens": self.num_speculative_tokens,
"disable_mqa_scorer": self.speculative_disable_mqa_scorer,
"disable_by_batch_size":
self.speculative_disable_by_batch_size,
"prompt_lookup_max": self.ngram_prompt_lookup_max,
"prompt_lookup_min": self.ngram_prompt_lookup_min,
"acceptance_method": self.spec_decoding_acceptance_method,
"posterior_threshold":
self.typical_acceptance_sampler_posterior_threshold,
"posterior_alpha":
self.typical_acceptance_sampler_posterior_alpha,
"disable_logprobs": self.disable_logprobs_during_spec_decoding,
}
self.speculative_config = spec_config_dict
else:
if isinstance(self.speculative_config, str):
import ast
self.speculative_config = ast.literal_eval(
self.speculative_config)
# Note(Shangming): These parameters are not obtained from the cli arg
# '--speculative-config' and must be passed in when creating the engine
# config.
assert isinstance(self.speculative_config, dict)
self.speculative_config.update({
"target_model_config": target_model_config,
"target_parallel_config": target_parallel_config,
"enable_chunked_prefill": enable_chunked_prefill,
"disable_log_stats": disable_log_stats,
})
speculative_config = SpeculativeConfig.from_dict(
self.speculative_config)
return speculative_config
def create_engine_config(
self,
usage_context: Optional[UsageContext] = None,
@ -1238,6 +1321,8 @@ class EngineArgs:
else:
self._set_default_args_v0(model_config)
assert self.enable_chunked_prefill is not None
cache_config = CacheConfig(
block_size=self.block_size,
gpu_memory_utilization=self.gpu_memory_utilization,
@ -1280,31 +1365,11 @@ class EngineArgs:
worker_extension_cls=self.worker_extension_cls,
)
speculative_config = SpeculativeConfig.maybe_create_spec_config(
speculative_config = self.create_speculative_config(
target_model_config=model_config,
target_parallel_config=parallel_config,
target_dtype=self.dtype,
speculative_model=self.speculative_model,
speculative_model_quantization = \
self.speculative_model_quantization,
speculative_draft_tensor_parallel_size = \
self.speculative_draft_tensor_parallel_size,
num_speculative_tokens=self.num_speculative_tokens,
speculative_disable_mqa_scorer=self.speculative_disable_mqa_scorer,
speculative_disable_by_batch_size=self.
speculative_disable_by_batch_size,
speculative_max_model_len=self.speculative_max_model_len,
enable_chunked_prefill=self.enable_chunked_prefill,
disable_log_stats=self.disable_log_stats,
ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
draft_token_acceptance_method=\
self.spec_decoding_acceptance_method,
typical_acceptance_sampler_posterior_threshold=self.
typical_acceptance_sampler_posterior_threshold,
typical_acceptance_sampler_posterior_alpha=self.
typical_acceptance_sampler_posterior_alpha,
disable_logprobs=self.disable_logprobs_during_spec_decoding,
)
# Reminder: Please update docs/source/features/compatibility_matrix.md
@ -1569,7 +1634,7 @@ class EngineArgs:
if (self.speculative_model is not None
or self.num_speculative_tokens is not None):
# This is supported but experimental (handled below).
if self.speculative_model == "[ngram]":
if self.speculative_model in ("ngram", "[ngram]"):
pass
else:
_raise_or_fallback(feature_name="Speculative Decoding",
@ -1617,7 +1682,8 @@ class EngineArgs:
return False
# ngram is supported on V1, but off by default for now.
if self.speculative_model == "[ngram]" and _warn_or_fallback("ngram"):
if self.speculative_model in (
"ngram", "[ngram]") and _warn_or_fallback("ngram"):
return False
# Non-CUDA is supported on V1, but off by default for now.

View File

@ -92,22 +92,20 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
# Override draft-model specific worker args.
draft_worker_kwargs.update(
vllm_config=draft_worker_config,
ngram_prompt_lookup_max=speculative_config.ngram_prompt_lookup_max,
ngram_prompt_lookup_min=speculative_config.ngram_prompt_lookup_min,
ngram_prompt_lookup_max=speculative_config.prompt_lookup_max,
ngram_prompt_lookup_min=speculative_config.prompt_lookup_min,
)
spec_decode_worker = SpecDecodeWorker.create_worker(
scorer_worker=target_worker,
draft_worker_kwargs=draft_worker_kwargs,
disable_mqa_scorer=speculative_config.speculative_disable_mqa_scorer,
disable_by_batch_size=speculative_config.
speculative_disable_by_batch_size,
draft_token_acceptance_method=speculative_config.
draft_token_acceptance_method,
disable_mqa_scorer=speculative_config.disable_mqa_scorer,
disable_by_batch_size=speculative_config.disable_by_batch_size,
draft_token_acceptance_method=speculative_config.acceptance_method,
typical_acceptance_sampler_posterior_threshold=speculative_config.
typical_acceptance_sampler_posterior_threshold,
posterior_threshold,
typical_acceptance_sampler_posterior_alpha=speculative_config.
typical_acceptance_sampler_posterior_alpha,
posterior_alpha,
disable_logprobs=speculative_config.disable_logprobs,
disable_log_stats=speculative_config.disable_log_stats,
num_speculative_tokens=speculative_config.num_speculative_tokens,

View File

@ -151,8 +151,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.use_spec_decode = False
if self.speculative_config:
self.use_spec_decode = True
# TODO: find a better way to check if we are using ngram.
assert self.speculative_config.ngram_prompt_lookup_min, \
assert self.speculative_config.method == "ngram", \
"Currently, only ngram spec decode is supported in V1."
if get_pp_group().is_last_rank:
self.drafter = NgramProposer()
@ -160,7 +159,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# This usually takes less than 1 second.
self.drafter.propose(
np.zeros(1024, dtype=np.int32),
self.speculative_config.ngram_prompt_lookup_min,
self.speculative_config.prompt_lookup_min,
self.speculative_config.num_speculative_tokens,
)
self.rejection_sampler = RejectionSampler()
@ -1155,7 +1154,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids
drafter_output = self.drafter.propose(
self.input_batch.token_ids_cpu[i, :end_idx],
self.speculative_config.ngram_prompt_lookup_min,
self.speculative_config.prompt_lookup_min,
self.speculative_config.num_speculative_tokens,
)
if drafter_output is None or len(drafter_output) == 0: