[V1][Usage] Refactor speculative decoding configuration and tests (#14434)
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
This commit is contained in:
parent
0661cfef7a
commit
50c9636d87
@ -30,8 +30,10 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
|||||||
llm = LLM(
|
llm = LLM(
|
||||||
model="facebook/opt-6.7b",
|
model="facebook/opt-6.7b",
|
||||||
tensor_parallel_size=1,
|
tensor_parallel_size=1,
|
||||||
speculative_model="facebook/opt-125m",
|
speculative_config={
|
||||||
num_speculative_tokens=5,
|
"model": "facebook/opt-125m",
|
||||||
|
"num_speculative_tokens": 5,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
outputs = llm.generate(prompts, sampling_params)
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
|
||||||
@ -45,10 +47,14 @@ To perform the same with an online mode launch the server:
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
python -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port 8000 --model facebook/opt-6.7b \
|
python -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port 8000 --model facebook/opt-6.7b \
|
||||||
--seed 42 -tp 1 --speculative_model facebook/opt-125m \
|
--seed 42 -tp 1 --gpu_memory_utilization 0.8 \
|
||||||
--num_speculative_tokens 5 --gpu_memory_utilization 0.8
|
--speculative_config '{"model": "facebook/opt-125m", "num_speculative_tokens": 5}'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
:::{warning}
|
||||||
|
Note: Please use `--speculative_config` to set all configurations related to speculative decoding. The previous method of specifying the model through `--speculative_model` and adding related parameters (e.g., `--num_speculative_tokens`) separately will be deprecated in the next release.
|
||||||
|
:::
|
||||||
|
|
||||||
Then use a client:
|
Then use a client:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
@ -101,9 +107,11 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
|||||||
llm = LLM(
|
llm = LLM(
|
||||||
model="facebook/opt-6.7b",
|
model="facebook/opt-6.7b",
|
||||||
tensor_parallel_size=1,
|
tensor_parallel_size=1,
|
||||||
speculative_model="[ngram]",
|
speculative_config={
|
||||||
num_speculative_tokens=5,
|
"method": "ngram",
|
||||||
ngram_prompt_lookup_max=4,
|
"num_speculative_tokens": 5,
|
||||||
|
"prompt_lookup_max": 4,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
outputs = llm.generate(prompts, sampling_params)
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
|
||||||
@ -131,8 +139,10 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
|||||||
llm = LLM(
|
llm = LLM(
|
||||||
model="meta-llama/Meta-Llama-3.1-70B-Instruct",
|
model="meta-llama/Meta-Llama-3.1-70B-Instruct",
|
||||||
tensor_parallel_size=4,
|
tensor_parallel_size=4,
|
||||||
speculative_model="ibm-ai-platform/llama3-70b-accelerator",
|
speculative_config={
|
||||||
speculative_draft_tensor_parallel_size=1,
|
"model": "ibm-ai-platform/llama3-70b-accelerator",
|
||||||
|
"draft_tensor_parallel_size": 1,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
outputs = llm.generate(prompts, sampling_params)
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
|
||||||
@ -175,8 +185,10 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
|||||||
llm = LLM(
|
llm = LLM(
|
||||||
model="meta-llama/Meta-Llama-3-8B-Instruct",
|
model="meta-llama/Meta-Llama-3-8B-Instruct",
|
||||||
tensor_parallel_size=4,
|
tensor_parallel_size=4,
|
||||||
speculative_model="yuhuili/EAGLE-LLaMA3-Instruct-8B",
|
speculative_config={
|
||||||
speculative_draft_tensor_parallel_size=1,
|
"model": "yuhuili/EAGLE-LLaMA3-Instruct-8B",
|
||||||
|
"draft_tensor_parallel_size": 1,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
outputs = llm.generate(prompts, sampling_params)
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
@ -194,11 +206,10 @@ A few important things to consider when using the EAGLE based draft models:
|
|||||||
be able to be loaded and used directly by vLLM after [PR 12304](https://github.com/vllm-project/vllm/pull/12304).
|
be able to be loaded and used directly by vLLM after [PR 12304](https://github.com/vllm-project/vllm/pull/12304).
|
||||||
If you are using vllm version before [PR 12304](https://github.com/vllm-project/vllm/pull/12304), please use the
|
If you are using vllm version before [PR 12304](https://github.com/vllm-project/vllm/pull/12304), please use the
|
||||||
[script](https://gist.github.com/abhigoyal1997/1e7a4109ccb7704fbc67f625e86b2d6d) to convert the speculative model,
|
[script](https://gist.github.com/abhigoyal1997/1e7a4109ccb7704fbc67f625e86b2d6d) to convert the speculative model,
|
||||||
and specify `speculative_model="path/to/modified/eagle/model"`. If weight-loading problems still occur when using
|
and specify `"model": "path/to/modified/eagle/model"` in `speculative_config`. If weight-loading problems still occur when using the latest version of vLLM, please leave a comment or raise an issue.
|
||||||
the latest version of vLLM, please leave a comment or raise an issue.
|
|
||||||
|
|
||||||
2. The EAGLE based draft models need to be run without tensor parallelism
|
2. The EAGLE based draft models need to be run without tensor parallelism
|
||||||
(i.e. speculative_draft_tensor_parallel_size is set to 1), although
|
(i.e. draft_tensor_parallel_size is set to 1 in `speculative_config`), although
|
||||||
it is possible to run the main model using tensor parallelism (see example above).
|
it is possible to run the main model using tensor parallelism (see example above).
|
||||||
|
|
||||||
3. When using EAGLE-based speculators with vLLM, the observed speedup is lower than what is
|
3. When using EAGLE-based speculators with vLLM, the observed speedup is lower than what is
|
||||||
|
@ -50,7 +50,9 @@ if __name__ == "__main__":
|
|||||||
# Create an LLM with spec decoding
|
# Create an LLM with spec decoding
|
||||||
llm = LLM(
|
llm = LLM(
|
||||||
model="meta-llama/Llama-2-13b-chat-hf",
|
model="meta-llama/Llama-2-13b-chat-hf",
|
||||||
speculative_model="ibm-ai-platform/llama-13b-accelerator",
|
speculative_config={
|
||||||
|
"model": "ibm-ai-platform/llama-13b-accelerator",
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
print("With speculation")
|
print("With speculation")
|
||||||
|
@ -56,7 +56,7 @@ def test_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
|
|||||||
def maybe_assert_ngram_worker(llm):
|
def maybe_assert_ngram_worker(llm):
|
||||||
# Verify the proposer worker is ngram if ngram is specified.
|
# Verify the proposer worker is ngram if ngram is specified.
|
||||||
if (llm.llm_engine.speculative_config is not None
|
if (llm.llm_engine.speculative_config is not None
|
||||||
and llm.llm_engine.speculative_config.ngram_prompt_lookup_max > 0):
|
and llm.llm_engine.speculative_config.method == "ngram"):
|
||||||
from vllm.spec_decode.ngram_worker import NGramWorker
|
from vllm.spec_decode.ngram_worker import NGramWorker
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
llm.llm_engine.model_executor.driver_worker.proposer_worker,
|
llm.llm_engine.model_executor.driver_worker.proposer_worker,
|
||||||
|
@ -7,28 +7,39 @@ from vllm import SamplingParams
|
|||||||
from .conftest import get_output_from_llm_generator
|
from .conftest import get_output_from_llm_generator
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("common_llm_kwargs", [{
|
@pytest.mark.parametrize("common_llm_kwargs",
|
||||||
|
[{
|
||||||
"model": "meta-llama/Llama-3.2-1B-Instruct",
|
"model": "meta-llama/Llama-3.2-1B-Instruct",
|
||||||
"speculative_model": "JackFram/llama-68m",
|
|
||||||
"num_speculative_tokens": 5,
|
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"per_test_common_llm_kwargs",
|
"per_test_common_llm_kwargs",
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
# Speculative max model len > overridden max model len should raise.
|
# Speculative max model len > overridden max model len should raise.
|
||||||
|
"speculative_config": {
|
||||||
|
"model": "JackFram/llama-68m",
|
||||||
|
"num_speculative_tokens": 5,
|
||||||
|
"max_model_len": 129,
|
||||||
|
},
|
||||||
"max_model_len": 128,
|
"max_model_len": 128,
|
||||||
"speculative_max_model_len": 129,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
# Speculative max model len > draft max model len should raise.
|
# Speculative max model len > draft max model len should raise.
|
||||||
# https://huggingface.co/JackFram/llama-68m/blob/3b606af5198a0b26762d589a3ee3d26ee6fa6c85/config.json#L12
|
# https://huggingface.co/JackFram/llama-68m/blob/3b606af5198a0b26762d589a3ee3d26ee6fa6c85/config.json#L12
|
||||||
"speculative_max_model_len": 2048 + 1,
|
"speculative_config": {
|
||||||
|
"model": "JackFram/llama-68m",
|
||||||
|
"num_speculative_tokens": 5,
|
||||||
|
"max_model_len": 2048 + 1,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
# Speculative max model len > target max model len should raise.
|
# Speculative max model len > target max model len should raise.
|
||||||
# https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct/blob/9213176726f574b556790deb65791e0c5aa438b6/config.json#L18
|
# https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct/blob/9213176726f574b556790deb65791e0c5aa438b6/config.json#L18
|
||||||
"speculative_max_model_len": 131072 + 1,
|
"speculative_config": {
|
||||||
|
"model": "JackFram/llama-68m",
|
||||||
|
"num_speculative_tokens": 5,
|
||||||
|
"max_model_len": 131072 + 1,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("test_llm_kwargs", [{}])
|
@pytest.mark.parametrize("test_llm_kwargs", [{}])
|
||||||
|
@ -57,9 +57,11 @@ PRECISION = "float32"
|
|||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||||
{
|
{
|
||||||
"speculative_model": SPEC_MODEL,
|
"speculative_config": {
|
||||||
|
"model": SPEC_MODEL,
|
||||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||||
},
|
},
|
||||||
|
},
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("output_len", [
|
@pytest.mark.parametrize("output_len", [
|
||||||
128,
|
128,
|
||||||
@ -95,18 +97,19 @@ def test_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
|||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||||
{
|
"speculative_config": {
|
||||||
"speculative_model": SPEC_MODEL,
|
"model": SPEC_MODEL,
|
||||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||||
"disable_logprobs_during_spec_decoding": False,
|
"disable_logprobs": False,
|
||||||
},
|
},
|
||||||
{
|
}, {
|
||||||
"speculative_model": SPEC_MODEL,
|
"speculative_config": {
|
||||||
|
"model": SPEC_MODEL,
|
||||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||||
"disable_logprobs_during_spec_decoding": True,
|
"disable_logprobs": True,
|
||||||
},
|
},
|
||||||
])
|
}])
|
||||||
@pytest.mark.parametrize("output_len", [
|
@pytest.mark.parametrize("output_len", [
|
||||||
128,
|
128,
|
||||||
])
|
])
|
||||||
@ -119,7 +122,8 @@ def test_eagle_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
|||||||
batch_size: int, output_len: int, seed: int,
|
batch_size: int, output_len: int, seed: int,
|
||||||
logprobs: int):
|
logprobs: int):
|
||||||
|
|
||||||
run_equality_correctness_test(vllm_runner,
|
run_equality_correctness_test(
|
||||||
|
vllm_runner,
|
||||||
common_llm_kwargs,
|
common_llm_kwargs,
|
||||||
per_test_common_llm_kwargs,
|
per_test_common_llm_kwargs,
|
||||||
baseline_llm_kwargs,
|
baseline_llm_kwargs,
|
||||||
@ -129,8 +133,8 @@ def test_eagle_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
|||||||
seed,
|
seed,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
prompt_logprobs=logprobs,
|
prompt_logprobs=logprobs,
|
||||||
disable_logprobs=test_llm_kwargs[
|
disable_logprobs=test_llm_kwargs["speculative_config"]
|
||||||
'disable_logprobs_during_spec_decoding'])
|
["disable_logprobs"])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -151,9 +155,11 @@ def test_eagle_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
|||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||||
{
|
{
|
||||||
"speculative_model": SPEC_MODEL,
|
"speculative_config": {
|
||||||
|
"model": SPEC_MODEL,
|
||||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||||
},
|
},
|
||||||
|
},
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("output_len", [
|
@pytest.mark.parametrize("output_len", [
|
||||||
128,
|
128,
|
||||||
@ -193,9 +199,11 @@ def test_eagle_e2e_greedy_correctness_cuda_graph(
|
|||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||||
{
|
{
|
||||||
"speculative_model": SPEC_MODEL,
|
"speculative_config": {
|
||||||
|
"model": SPEC_MODEL,
|
||||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||||
},
|
},
|
||||||
|
},
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"output_len",
|
"output_len",
|
||||||
@ -236,8 +244,10 @@ def test_eagle_e2e_greedy_correctness_with_preemption(
|
|||||||
"test_llm_kwargs",
|
"test_llm_kwargs",
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"speculative_model": SPEC_MODEL,
|
"speculative_config": {
|
||||||
|
"model": SPEC_MODEL,
|
||||||
"num_speculative_tokens": k,
|
"num_speculative_tokens": k,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
# Try a range of num. speculative tokens
|
# Try a range of num. speculative tokens
|
||||||
for k in range(1, 1 + MAX_SPEC_TOKENS)
|
for k in range(1, 1 + MAX_SPEC_TOKENS)
|
||||||
@ -277,11 +287,12 @@ def test_eagle_different_k(vllm_runner, common_llm_kwargs,
|
|||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("test_llm_kwargs",
|
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||||
[{
|
"speculative_config": {
|
||||||
"speculative_model": SPEC_MODEL,
|
"model": SPEC_MODEL,
|
||||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||||
"speculative_disable_by_batch_size": 4
|
"disable_by_batch_size": 4,
|
||||||
|
},
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -324,9 +335,11 @@ def test_eagle_disable_queue(vllm_runner, common_llm_kwargs,
|
|||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||||
{
|
{
|
||||||
"speculative_model": "yuhuili/EAGLE-llama2-chat-7B",
|
"speculative_config": {
|
||||||
|
"model": "yuhuili/EAGLE-llama2-chat-7B",
|
||||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||||
},
|
},
|
||||||
|
},
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"output_len",
|
"output_len",
|
||||||
@ -372,9 +385,11 @@ def test_llama2_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
|||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||||
{
|
{
|
||||||
"speculative_model": "yuhuili/EAGLE-LLaMA3-Instruct-8B",
|
"speculative_config": {
|
||||||
|
"model": "yuhuili/EAGLE-LLaMA3-Instruct-8B",
|
||||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||||
},
|
},
|
||||||
|
},
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"output_len",
|
"output_len",
|
||||||
@ -420,9 +435,11 @@ def test_llama3_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
|||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||||
{
|
{
|
||||||
"speculative_model": "yuhuili/EAGLE-Qwen2-7B-Instruct",
|
"speculative_config": {
|
||||||
|
"model": "yuhuili/EAGLE-Qwen2-7B-Instruct",
|
||||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||||
},
|
},
|
||||||
|
},
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"output_len",
|
"output_len",
|
||||||
|
@ -23,9 +23,11 @@ MAIN_MODEL = "JackFram/llama-68m"
|
|||||||
[
|
[
|
||||||
{
|
{
|
||||||
# Identical models.
|
# Identical models.
|
||||||
"speculative_model": "JackFram/llama-68m",
|
"speculative_config": {
|
||||||
|
"model": "JackFram/llama-68m",
|
||||||
"num_speculative_tokens": 5,
|
"num_speculative_tokens": 5,
|
||||||
},
|
},
|
||||||
|
},
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("test_llm_kwargs", [{}])
|
@pytest.mark.parametrize("test_llm_kwargs", [{}])
|
||||||
@ -57,26 +59,33 @@ def test_spec_decode_cuda_graph(vllm_runner, common_llm_kwargs,
|
|||||||
# Skip cuda graph recording for fast test.
|
# Skip cuda graph recording for fast test.
|
||||||
"enforce_eager": True,
|
"enforce_eager": True,
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [])
|
||||||
{
|
|
||||||
"speculative_model": "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit",
|
|
||||||
"num_speculative_tokens": 5,
|
|
||||||
},
|
|
||||||
])
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"test_llm_kwargs",
|
"test_llm_kwargs",
|
||||||
[
|
[
|
||||||
# Explicitly specify draft model quantization
|
# Explicitly specify draft model quantization
|
||||||
{
|
{
|
||||||
"speculative_model_quantization": "gptq",
|
"speculative_config": {
|
||||||
|
"model": "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit",
|
||||||
|
"num_speculative_tokens": 5,
|
||||||
|
"quantization": "gptq",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
# Explicitly specify GPTQ-based draft model to use marlin quantization
|
# Explicitly specify GPTQ-based draft model to use marlin quantization
|
||||||
{
|
{
|
||||||
"speculative_model_quantization": "marlin",
|
"speculative_config": {
|
||||||
|
"model": "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit",
|
||||||
|
"num_speculative_tokens": 5,
|
||||||
|
"quantization": "marlin",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
# Not explicitly specify draft model quantization
|
# Not explicitly specify draft model quantization
|
||||||
{
|
{
|
||||||
"speculative_model_quantization": None,
|
"speculative_config": {
|
||||||
|
"model": "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit",
|
||||||
|
"num_speculative_tokens": 5,
|
||||||
|
"quantization": None,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@ -107,14 +116,15 @@ def test_speculative_model_quantization_config(vllm_runner, common_llm_kwargs,
|
|||||||
|
|
||||||
# Skip cuda graph recording for fast test.
|
# Skip cuda graph recording for fast test.
|
||||||
"enforce_eager": True,
|
"enforce_eager": True,
|
||||||
"speculative_model": "JackFram/llama-68m",
|
|
||||||
"num_speculative_tokens": 3,
|
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("test_llm_kwargs",
|
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||||
[{
|
"speculative_config": {
|
||||||
"speculative_disable_mqa_scorer": True,
|
"model": "JackFram/llama-68m",
|
||||||
|
"num_speculative_tokens": 3,
|
||||||
|
"disable_mqa_scorer": True,
|
||||||
|
},
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -127,7 +137,7 @@ def test_speculative_model_quantization_config(vllm_runner, common_llm_kwargs,
|
|||||||
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
|
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
|
||||||
output_len: int, seed: int):
|
output_len: int, seed: int):
|
||||||
"""Verify that ngram speculative decoding generates the same output
|
"""Verify that speculative decoding generates the same output
|
||||||
with batch expansion scorer and mqa scorer.
|
with batch expansion scorer and mqa scorer.
|
||||||
"""
|
"""
|
||||||
run_equality_correctness_test(vllm_runner,
|
run_equality_correctness_test(vllm_runner,
|
||||||
|
@ -27,18 +27,19 @@ from .conftest import run_equality_correctness_test_tp
|
|||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [[]])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [[]])
|
||||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||||
[
|
[
|
||||||
"--speculative-model",
|
"--speculative_config",
|
||||||
"JackFram/llama-68m",
|
str({
|
||||||
"--num-speculative-tokens",
|
"model": "JackFram/llama-68m",
|
||||||
"3",
|
"num_speculative_tokens": 3,
|
||||||
|
}),
|
||||||
],
|
],
|
||||||
[
|
[
|
||||||
"--speculative-model",
|
"--speculative_config",
|
||||||
"[ngram]",
|
str({
|
||||||
"--num-speculative-tokens",
|
"model": "ngram",
|
||||||
"5",
|
"num_speculative_tokens": 5,
|
||||||
"--ngram-prompt-lookup-max",
|
"prompt_lookup_max": 3,
|
||||||
"3",
|
}),
|
||||||
],
|
],
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("batch_size", [2])
|
@pytest.mark.parametrize("batch_size", [2])
|
||||||
@ -83,22 +84,23 @@ def test_target_model_tp_gt_1(common_llm_kwargs, per_test_common_llm_kwargs,
|
|||||||
]])
|
]])
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [[]])
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [[]])
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [[]])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [[]])
|
||||||
@pytest.mark.parametrize("model, test_llm_kwargs",
|
@pytest.mark.parametrize(
|
||||||
|
"model, test_llm_kwargs",
|
||||||
[("JackFram/llama-68m", [
|
[("JackFram/llama-68m", [
|
||||||
"--speculative-model",
|
"--speculative_config",
|
||||||
"JackFram/llama-68m",
|
str({
|
||||||
"--num_speculative-tokens",
|
"model": "JackFram/llama-68m",
|
||||||
"5",
|
"num_speculative_tokens": 5,
|
||||||
"--speculative-draft-tensor-parallel-size",
|
"draft_tensor_parallel_size": 1,
|
||||||
"1",
|
}),
|
||||||
]),
|
]),
|
||||||
("ibm-granite/granite-3b-code-instruct", [
|
("ibm-granite/granite-3b-code-instruct", [
|
||||||
"--speculative-model",
|
"--speculative_config",
|
||||||
"ibm-granite/granite-3b-code-instruct",
|
str({
|
||||||
"--num_speculative-tokens",
|
"model": "ibm-granite/granite-3b-code-instruct",
|
||||||
"5",
|
"num_speculative_tokens": 5,
|
||||||
"--speculative-draft-tensor-parallel-size",
|
"draft_tensor_parallel_size": 1,
|
||||||
"1",
|
}),
|
||||||
])])
|
])])
|
||||||
@pytest.mark.parametrize("batch_size", [2])
|
@pytest.mark.parametrize("batch_size", [2])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
@ -144,18 +146,19 @@ def test_draft_model_tp_lt_target_model_tp2(model, common_llm_kwargs,
|
|||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [[]])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [[]])
|
||||||
@pytest.mark.parametrize("model, test_llm_kwargs",
|
@pytest.mark.parametrize("model, test_llm_kwargs",
|
||||||
[("JackFram/llama-68m", [
|
[("JackFram/llama-68m", [
|
||||||
"--speculative-model",
|
"--speculative_config",
|
||||||
"JackFram/llama-68m",
|
str({
|
||||||
"--num_speculative-tokens",
|
"model": "JackFram/llama-68m",
|
||||||
"3",
|
"num_speculative_tokens": 3,
|
||||||
|
}),
|
||||||
]),
|
]),
|
||||||
("JackFram/llama-68m", [
|
("JackFram/llama-68m", [
|
||||||
"--speculative-model",
|
"--speculative_config",
|
||||||
"JackFram/llama-68m",
|
str({
|
||||||
"--num_speculative-tokens",
|
"model": "JackFram/llama-68m",
|
||||||
"3",
|
"num_speculative_tokens": 3,
|
||||||
"--speculative-draft-tensor-parallel-size",
|
"draft_tensor_parallel_size": 1,
|
||||||
"1",
|
}),
|
||||||
])])
|
])])
|
||||||
@pytest.mark.parametrize("logprobs", [None, 2])
|
@pytest.mark.parametrize("logprobs", [None, 2])
|
||||||
@pytest.mark.parametrize("batch_size", [2])
|
@pytest.mark.parametrize("batch_size", [2])
|
||||||
|
@ -24,12 +24,7 @@ SPEC_MODEL = "JackFram/llama-68m"
|
|||||||
"4",
|
"4",
|
||||||
]])
|
]])
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
||||||
[
|
[],
|
||||||
"--speculative-model",
|
|
||||||
f"{SPEC_MODEL}",
|
|
||||||
"--num-speculative-tokens",
|
|
||||||
"5",
|
|
||||||
],
|
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [[]])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [[]])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -37,8 +32,12 @@ SPEC_MODEL = "JackFram/llama-68m"
|
|||||||
[
|
[
|
||||||
#TODO(wooyeon): add spec_draft_dp=2 case
|
#TODO(wooyeon): add spec_draft_dp=2 case
|
||||||
[
|
[
|
||||||
"--speculative-draft-tensor-parallel-size",
|
"--speculative_config",
|
||||||
"1",
|
str({
|
||||||
|
"model": f"{SPEC_MODEL}",
|
||||||
|
"num_speculative_tokens": 5,
|
||||||
|
"draft_tensor_parallel_size": 1,
|
||||||
|
}),
|
||||||
],
|
],
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("batch_size", [2])
|
@pytest.mark.parametrize("batch_size", [2])
|
||||||
@ -78,15 +77,14 @@ def test_draft_model_tp_lt_target_model_tp4(common_llm_kwargs,
|
|||||||
"test_llm_kwargs",
|
"test_llm_kwargs",
|
||||||
[
|
[
|
||||||
[
|
[
|
||||||
"--speculative-model",
|
|
||||||
f"{SPEC_MODEL}",
|
|
||||||
"--num-speculative-tokens",
|
|
||||||
"5",
|
|
||||||
|
|
||||||
# Artificially limit the draft model max model len; this forces vLLM
|
# Artificially limit the draft model max model len; this forces vLLM
|
||||||
# to skip speculation once the sequences grow beyond 32-k tokens.
|
# to skip speculation once the sequences grow beyond 32-k tokens.
|
||||||
"--speculative-max-model-len",
|
"--speculative_config",
|
||||||
"32",
|
str({
|
||||||
|
"model": f"{SPEC_MODEL}",
|
||||||
|
"num_speculative_tokens": 5,
|
||||||
|
"max_model_len": 32,
|
||||||
|
}),
|
||||||
],
|
],
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("batch_size", [8])
|
@pytest.mark.parametrize("batch_size", [8])
|
||||||
|
@ -20,15 +20,18 @@ from .conftest import run_equality_correctness_test
|
|||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("test_llm_kwargs",
|
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||||
[{
|
"speculative_config": {
|
||||||
"speculative_model": "JackFram/llama-68m",
|
"model": "JackFram/llama-68m",
|
||||||
"num_speculative_tokens": 3,
|
"num_speculative_tokens": 3,
|
||||||
"disable_logprobs_during_spec_decoding": False,
|
"disable_logprobs": False,
|
||||||
|
},
|
||||||
}, {
|
}, {
|
||||||
"speculative_model": "JackFram/llama-68m",
|
"speculative_config": {
|
||||||
|
"model": "JackFram/llama-68m",
|
||||||
"num_speculative_tokens": 3,
|
"num_speculative_tokens": 3,
|
||||||
"disable_logprobs_during_spec_decoding": True,
|
"disable_logprobs": True,
|
||||||
|
},
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("batch_size", [8])
|
@pytest.mark.parametrize("batch_size", [8])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -48,7 +51,8 @@ def test_logprobs_equality(vllm_runner, common_llm_kwargs,
|
|||||||
as well as with and without chunked prefill.
|
as well as with and without chunked prefill.
|
||||||
"""
|
"""
|
||||||
maybe_enable_chunked_prefill(prefill_chunk_size, common_llm_kwargs)
|
maybe_enable_chunked_prefill(prefill_chunk_size, common_llm_kwargs)
|
||||||
run_equality_correctness_test(vllm_runner,
|
run_equality_correctness_test(
|
||||||
|
vllm_runner,
|
||||||
common_llm_kwargs,
|
common_llm_kwargs,
|
||||||
per_test_common_llm_kwargs,
|
per_test_common_llm_kwargs,
|
||||||
baseline_llm_kwargs,
|
baseline_llm_kwargs,
|
||||||
@ -59,8 +63,8 @@ def test_logprobs_equality(vllm_runner, common_llm_kwargs,
|
|||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
prompt_logprobs=logprobs,
|
prompt_logprobs=logprobs,
|
||||||
disable_logprobs=test_llm_kwargs[
|
disable_logprobs=test_llm_kwargs["speculative_config"]
|
||||||
'disable_logprobs_during_spec_decoding'])
|
["disable_logprobs"])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -73,15 +77,18 @@ def test_logprobs_equality(vllm_runner, common_llm_kwargs,
|
|||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("test_llm_kwargs",
|
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||||
[{
|
"speculative_config": {
|
||||||
"speculative_model": "JackFram/llama-160m",
|
"model": "JackFram/llama-160m",
|
||||||
"num_speculative_tokens": 3,
|
"num_speculative_tokens": 3,
|
||||||
"disable_logprobs_during_spec_decoding": False,
|
"disable_logprobs": False,
|
||||||
|
},
|
||||||
}, {
|
}, {
|
||||||
"speculative_model": "JackFram/llama-160m",
|
"speculative_config": {
|
||||||
|
"model": "JackFram/llama-160m",
|
||||||
"num_speculative_tokens": 6,
|
"num_speculative_tokens": 6,
|
||||||
"disable_logprobs_during_spec_decoding": False,
|
"disable_logprobs": False,
|
||||||
|
},
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("batch_size", [8])
|
@pytest.mark.parametrize("batch_size", [8])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -98,7 +105,8 @@ def test_logprobs_different_k(vllm_runner, common_llm_kwargs,
|
|||||||
output_len: int, seed: int, logprobs: int):
|
output_len: int, seed: int, logprobs: int):
|
||||||
"""Veriy logprob greedy equality with different speculation lens.
|
"""Veriy logprob greedy equality with different speculation lens.
|
||||||
"""
|
"""
|
||||||
run_equality_correctness_test(vllm_runner,
|
run_equality_correctness_test(
|
||||||
|
vllm_runner,
|
||||||
common_llm_kwargs,
|
common_llm_kwargs,
|
||||||
per_test_common_llm_kwargs,
|
per_test_common_llm_kwargs,
|
||||||
baseline_llm_kwargs,
|
baseline_llm_kwargs,
|
||||||
@ -108,8 +116,8 @@ def test_logprobs_different_k(vllm_runner, common_llm_kwargs,
|
|||||||
seed,
|
seed,
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
disable_logprobs=test_llm_kwargs[
|
disable_logprobs=test_llm_kwargs["speculative_config"]
|
||||||
'disable_logprobs_during_spec_decoding'])
|
["disable_logprobs"])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -125,13 +133,15 @@ def test_logprobs_different_k(vllm_runner, common_llm_kwargs,
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"test_llm_kwargs",
|
"test_llm_kwargs",
|
||||||
[{
|
[{
|
||||||
"speculative_model": "JackFram/llama-160m",
|
"speculative_config": {
|
||||||
|
"model": "JackFram/llama-160m",
|
||||||
"num_speculative_tokens": 3,
|
"num_speculative_tokens": 3,
|
||||||
"disable_logprobs_during_spec_decoding": False,
|
"disable_logprobs": False,
|
||||||
|
# Artificially limit the draft model max model len; this forces
|
||||||
# Artificially limit the draft model max model len; this forces vLLM
|
# vLLM to skip speculation once the sequences grow beyond 32-k
|
||||||
# to skip speculation once the sequences grow beyond 32-k tokens.
|
# tokens.
|
||||||
"speculative_max_model_len": 32,
|
"max_model_len": 32,
|
||||||
|
},
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("batch_size", [8])
|
@pytest.mark.parametrize("batch_size", [8])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -149,7 +159,8 @@ def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs,
|
|||||||
seed: int, logprobs: int):
|
seed: int, logprobs: int):
|
||||||
"""Verify logprobs greedy equality when some sequences skip speculation.
|
"""Verify logprobs greedy equality when some sequences skip speculation.
|
||||||
"""
|
"""
|
||||||
run_equality_correctness_test(vllm_runner,
|
run_equality_correctness_test(
|
||||||
|
vllm_runner,
|
||||||
common_llm_kwargs,
|
common_llm_kwargs,
|
||||||
per_test_common_llm_kwargs,
|
per_test_common_llm_kwargs,
|
||||||
baseline_llm_kwargs,
|
baseline_llm_kwargs,
|
||||||
@ -159,8 +170,8 @@ def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs,
|
|||||||
seed,
|
seed,
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
disable_logprobs=test_llm_kwargs[
|
disable_logprobs=test_llm_kwargs["speculative_config"]
|
||||||
'disable_logprobs_during_spec_decoding'])
|
["disable_logprobs"])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -173,11 +184,12 @@ def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs,
|
|||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("test_llm_kwargs",
|
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||||
[{
|
"speculative_config": {
|
||||||
"speculative_model": "JackFram/llama-160m",
|
"model": "JackFram/llama-160m",
|
||||||
"num_speculative_tokens": 3,
|
"num_speculative_tokens": 3,
|
||||||
"disable_logprobs_during_spec_decoding": False,
|
"disable_logprobs": False,
|
||||||
|
},
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("batch_size", [1])
|
@pytest.mark.parametrize("batch_size", [1])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -248,11 +260,12 @@ def test_logprobs_temp_1(vllm_runner, common_llm_kwargs,
|
|||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("test_llm_kwargs",
|
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||||
[{
|
"speculative_config": {
|
||||||
"speculative_model": "JackFram/llama-68m",
|
"model": "JackFram/llama-68m",
|
||||||
"num_speculative_tokens": 3,
|
"num_speculative_tokens": 3,
|
||||||
"disable_logprobs_during_spec_decoding": True,
|
"disable_logprobs": True,
|
||||||
|
},
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
@pytest.mark.parametrize("batch_size", [4])
|
@pytest.mark.parametrize("batch_size", [4])
|
||||||
@ -270,7 +283,8 @@ def test_logprobs_disabled(vllm_runner, common_llm_kwargs,
|
|||||||
"""Check the behavior when logprobs are disabled.
|
"""Check the behavior when logprobs are disabled.
|
||||||
Token choices should match with the base model.
|
Token choices should match with the base model.
|
||||||
"""
|
"""
|
||||||
run_equality_correctness_test(vllm_runner,
|
run_equality_correctness_test(
|
||||||
|
vllm_runner,
|
||||||
common_llm_kwargs,
|
common_llm_kwargs,
|
||||||
per_test_common_llm_kwargs,
|
per_test_common_llm_kwargs,
|
||||||
baseline_llm_kwargs,
|
baseline_llm_kwargs,
|
||||||
@ -280,5 +294,5 @@ def test_logprobs_disabled(vllm_runner, common_llm_kwargs,
|
|||||||
seed,
|
seed,
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
disable_logprobs=test_llm_kwargs[
|
disable_logprobs=test_llm_kwargs["speculative_config"]
|
||||||
'disable_logprobs_during_spec_decoding'])
|
["disable_logprobs"])
|
||||||
|
@ -60,9 +60,11 @@ PRECISION = "float32"
|
|||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||||
{
|
{
|
||||||
"speculative_model": SPEC_MODEL,
|
"speculative_config": {
|
||||||
|
"model": SPEC_MODEL,
|
||||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||||
},
|
},
|
||||||
|
},
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("output_len", [
|
@pytest.mark.parametrize("output_len", [
|
||||||
128,
|
128,
|
||||||
@ -107,14 +109,18 @@ def test_medusa_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
|||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||||
{
|
{
|
||||||
"speculative_model": SPEC_MODEL,
|
"speculative_config": {
|
||||||
|
"model": SPEC_MODEL,
|
||||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||||
"disable_logprobs_during_spec_decoding": False,
|
"disable_logprobs": False,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"speculative_model": SPEC_MODEL,
|
"speculative_config": {
|
||||||
|
"model": SPEC_MODEL,
|
||||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||||
"disable_logprobs_during_spec_decoding": True,
|
"disable_logprobs": True,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("output_len", [
|
@pytest.mark.parametrize("output_len", [
|
||||||
@ -132,7 +138,8 @@ def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
|||||||
prefill_chunk_size: int):
|
prefill_chunk_size: int):
|
||||||
"""Verify greedy equality with different batch size."""
|
"""Verify greedy equality with different batch size."""
|
||||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||||
run_equality_correctness_test(vllm_runner,
|
run_equality_correctness_test(
|
||||||
|
vllm_runner,
|
||||||
common_llm_kwargs,
|
common_llm_kwargs,
|
||||||
per_test_common_llm_kwargs,
|
per_test_common_llm_kwargs,
|
||||||
baseline_llm_kwargs,
|
baseline_llm_kwargs,
|
||||||
@ -143,8 +150,8 @@ def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
|||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
prompt_logprobs=logprobs,
|
prompt_logprobs=logprobs,
|
||||||
disable_logprobs=test_llm_kwargs[
|
disable_logprobs=test_llm_kwargs["speculative_config"]
|
||||||
'disable_logprobs_during_spec_decoding'])
|
["disable_logprobs"])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -165,9 +172,11 @@ def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
|||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||||
{
|
{
|
||||||
"speculative_model": SPEC_MODEL,
|
"speculative_config": {
|
||||||
|
"model": SPEC_MODEL,
|
||||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||||
},
|
},
|
||||||
|
},
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("output_len", [
|
@pytest.mark.parametrize("output_len", [
|
||||||
128,
|
128,
|
||||||
@ -214,9 +223,11 @@ def test_medusa_e2e_greedy_correctness_cuda_graph(
|
|||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||||
{
|
{
|
||||||
"speculative_model": SPEC_MODEL,
|
"speculative_config": {
|
||||||
|
"model": SPEC_MODEL,
|
||||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||||
},
|
},
|
||||||
|
},
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"output_len",
|
"output_len",
|
||||||
@ -264,8 +275,10 @@ def test_medusa_e2e_greedy_correctness_with_preemption(
|
|||||||
"test_llm_kwargs",
|
"test_llm_kwargs",
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"speculative_model": SPEC_MODEL,
|
"speculative_config": {
|
||||||
|
"model": SPEC_MODEL,
|
||||||
"num_speculative_tokens": k,
|
"num_speculative_tokens": k,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
# Try a range of num. speculative tokens
|
# Try a range of num. speculative tokens
|
||||||
for k in range(1, 1 + MAX_SPEC_TOKENS)
|
for k in range(1, 1 + MAX_SPEC_TOKENS)
|
||||||
@ -312,11 +325,12 @@ def test_medusa_different_k(vllm_runner, common_llm_kwargs,
|
|||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("test_llm_kwargs",
|
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||||
[{
|
"speculative_config": {
|
||||||
"speculative_model": SPEC_MODEL,
|
"model": SPEC_MODEL,
|
||||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||||
"speculative_disable_by_batch_size": 4
|
"disable_by_batch_size": 4,
|
||||||
|
},
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -359,15 +373,16 @@ def test_medusa_disable_queue(vllm_runner, common_llm_kwargs,
|
|||||||
|
|
||||||
# Main model
|
# Main model
|
||||||
"model_name": MAIN_MODEL,
|
"model_name": MAIN_MODEL,
|
||||||
"speculative_model": SPEC_MODEL,
|
|
||||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
|
||||||
"speculative_disable_by_batch_size": 4
|
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("test_llm_kwargs",
|
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||||
[{
|
"speculative_config": {
|
||||||
"speculative_disable_mqa_scorer": True,
|
"model": SPEC_MODEL,
|
||||||
|
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||||
|
"disable_by_batch_size": 4,
|
||||||
|
"disable_mqa_scorer": True,
|
||||||
|
},
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -62,7 +62,9 @@ PRECISION = "float32"
|
|||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||||
{
|
{
|
||||||
"speculative_model": SPEC_MODEL,
|
"speculative_config": {
|
||||||
|
"model": SPEC_MODEL,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("output_len", [
|
@pytest.mark.parametrize("output_len", [
|
||||||
@ -108,12 +110,16 @@ def test_mlp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
|||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||||
{
|
{
|
||||||
"speculative_model": SPEC_MODEL,
|
"speculative_config": {
|
||||||
"disable_logprobs_during_spec_decoding": False,
|
"model": SPEC_MODEL,
|
||||||
|
"disable_logprobs": False,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"speculative_model": SPEC_MODEL,
|
"speculative_config": {
|
||||||
"disable_logprobs_during_spec_decoding": True,
|
"model": SPEC_MODEL,
|
||||||
|
"disable_logprobs": True,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("output_len", [8])
|
@pytest.mark.parametrize("output_len", [8])
|
||||||
@ -133,7 +139,8 @@ def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
|||||||
# up sampling different tokens at the tail (ie top tokens don't change).
|
# up sampling different tokens at the tail (ie top tokens don't change).
|
||||||
# TL;DR: sd+cp == org+cp but sd+cp != org..is this expected?
|
# TL;DR: sd+cp == org+cp but sd+cp != org..is this expected?
|
||||||
maybe_enable_chunked_prefill(prefill_chunk_size, baseline_llm_kwargs)
|
maybe_enable_chunked_prefill(prefill_chunk_size, baseline_llm_kwargs)
|
||||||
run_equality_correctness_test(vllm_runner,
|
run_equality_correctness_test(
|
||||||
|
vllm_runner,
|
||||||
common_llm_kwargs,
|
common_llm_kwargs,
|
||||||
per_test_common_llm_kwargs,
|
per_test_common_llm_kwargs,
|
||||||
baseline_llm_kwargs,
|
baseline_llm_kwargs,
|
||||||
@ -144,8 +151,8 @@ def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
|||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
prompt_logprobs=logprobs,
|
prompt_logprobs=logprobs,
|
||||||
disable_logprobs=test_llm_kwargs[
|
disable_logprobs=test_llm_kwargs["speculative_config"]
|
||||||
'disable_logprobs_during_spec_decoding'])
|
["disable_logprobs"])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -167,7 +174,9 @@ def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
|||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||||
{
|
{
|
||||||
"speculative_model": SPEC_MODEL,
|
"speculative_config": {
|
||||||
|
"model": SPEC_MODEL,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("output_len", [2048])
|
@pytest.mark.parametrize("output_len", [2048])
|
||||||
@ -209,8 +218,10 @@ def test_mlp_e2e_acceptance_rate(vllm_runner, common_llm_kwargs,
|
|||||||
# Main model
|
# Main model
|
||||||
"model_name": MAIN_MODEL,
|
"model_name": MAIN_MODEL,
|
||||||
|
|
||||||
# Speculative model
|
# Speculative config
|
||||||
"speculative_model": SPEC_MODEL,
|
"speculative_config": {
|
||||||
|
"model": SPEC_MODEL,
|
||||||
|
},
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{"seed": 1}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{"seed": 1}])
|
||||||
@ -274,7 +285,9 @@ def test_mlp_e2e_seeded_correctness(vllm_runner, common_llm_kwargs,
|
|||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||||
{
|
{
|
||||||
"speculative_model": SPEC_MODEL,
|
"speculative_config": {
|
||||||
|
"model": SPEC_MODEL,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -326,7 +339,9 @@ def test_mlp_e2e_greedy_correctness_with_preemption(
|
|||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||||
{
|
{
|
||||||
"speculative_model": SPEC_MODEL,
|
"speculative_config": {
|
||||||
|
"model": SPEC_MODEL,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -382,8 +397,10 @@ def test_mlp_e2e_greedy_correctness_with_padding(
|
|||||||
"test_llm_kwargs",
|
"test_llm_kwargs",
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"speculative_model": SPEC_MODEL,
|
"speculative_config": {
|
||||||
|
"model": SPEC_MODEL,
|
||||||
"num_speculative_tokens": k,
|
"num_speculative_tokens": k,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
# Try a range of num. speculative tokens
|
# Try a range of num. speculative tokens
|
||||||
for k in range(1, 1 + MAX_SPEC_TOKENS)
|
for k in range(1, 1 + MAX_SPEC_TOKENS)
|
||||||
@ -430,10 +447,11 @@ def test_mlp_different_k(vllm_runner, common_llm_kwargs,
|
|||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("test_llm_kwargs",
|
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||||
[{
|
"speculative_config": {
|
||||||
"speculative_model": SPEC_MODEL,
|
"model": SPEC_MODEL,
|
||||||
"speculative_disable_by_batch_size": 4
|
"disable_by_batch_size": 4,
|
||||||
|
},
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -475,13 +493,14 @@ def test_mlp_disable_queue(vllm_runner, common_llm_kwargs,
|
|||||||
|
|
||||||
# Skip cuda graph recording for fast test.
|
# Skip cuda graph recording for fast test.
|
||||||
"enforce_eager": True,
|
"enforce_eager": True,
|
||||||
"speculative_model": SPEC_MODEL,
|
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("test_llm_kwargs",
|
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||||
[{
|
"speculative_config": {
|
||||||
"speculative_disable_mqa_scorer": True,
|
"model": SPEC_MODEL,
|
||||||
|
"disable_mqa_scorer": True,
|
||||||
|
},
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -57,8 +57,10 @@ PRECISION = "bfloat16"
|
|||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||||
{
|
{
|
||||||
|
"speculative_config": {
|
||||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||||
},
|
},
|
||||||
|
},
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("output_len", [
|
@pytest.mark.parametrize("output_len", [
|
||||||
128,
|
128,
|
||||||
@ -99,12 +101,16 @@ def test_mtp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
|||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||||
{
|
{
|
||||||
|
"speculative_config": {
|
||||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||||
"disable_logprobs_during_spec_decoding": False,
|
"disable_logprobs": False,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
"speculative_config": {
|
||||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||||
"disable_logprobs_during_spec_decoding": True,
|
"disable_logprobs": True,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("output_len", [
|
@pytest.mark.parametrize("output_len", [
|
||||||
@ -119,7 +125,8 @@ def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
|||||||
batch_size: int, output_len: int, seed: int,
|
batch_size: int, output_len: int, seed: int,
|
||||||
logprobs: int):
|
logprobs: int):
|
||||||
|
|
||||||
run_equality_correctness_test(vllm_runner,
|
run_equality_correctness_test(
|
||||||
|
vllm_runner,
|
||||||
common_llm_kwargs,
|
common_llm_kwargs,
|
||||||
per_test_common_llm_kwargs,
|
per_test_common_llm_kwargs,
|
||||||
baseline_llm_kwargs,
|
baseline_llm_kwargs,
|
||||||
@ -129,8 +136,8 @@ def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
|||||||
seed,
|
seed,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
prompt_logprobs=logprobs,
|
prompt_logprobs=logprobs,
|
||||||
disable_logprobs=test_llm_kwargs[
|
disable_logprobs=test_llm_kwargs["speculative_config"]
|
||||||
'disable_logprobs_during_spec_decoding'])
|
["disable_logprobs"])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -152,8 +159,10 @@ def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
|||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||||
{
|
{
|
||||||
|
"speculative_config": {
|
||||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||||
},
|
},
|
||||||
|
},
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("output_len", [
|
@pytest.mark.parametrize("output_len", [
|
||||||
128,
|
128,
|
||||||
@ -198,8 +207,10 @@ def test_mtp_e2e_greedy_correctness_cuda_graph(vllm_runner, common_llm_kwargs,
|
|||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||||
{
|
{
|
||||||
|
"speculative_config": {
|
||||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||||
},
|
},
|
||||||
|
},
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"output_len",
|
"output_len",
|
||||||
@ -243,7 +254,9 @@ def test_mtp_e2e_greedy_correctness_with_preemption(
|
|||||||
"test_llm_kwargs",
|
"test_llm_kwargs",
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
|
"speculative_config": {
|
||||||
"num_speculative_tokens": k,
|
"num_speculative_tokens": k,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
# Try a range of num. speculative tokens
|
# Try a range of num. speculative tokens
|
||||||
for k in range(1, 1 + MAX_SPEC_TOKENS)
|
for k in range(1, 1 + MAX_SPEC_TOKENS)
|
||||||
@ -286,10 +299,11 @@ def test_mtp_different_k(vllm_runner, common_llm_kwargs,
|
|||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("test_llm_kwargs",
|
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||||
[{
|
"speculative_config": {
|
||||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||||
"speculative_disable_by_batch_size": 4
|
"disable_by_batch_size": 4
|
||||||
|
},
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -61,15 +61,19 @@ from .conftest import (get_output_from_llm_generator,
|
|||||||
"per_test_common_llm_kwargs",
|
"per_test_common_llm_kwargs",
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"speculative_model": "JackFram/llama-68m",
|
"speculative_config": {
|
||||||
|
"model": "JackFram/llama-68m",
|
||||||
"num_speculative_tokens": 5,
|
"num_speculative_tokens": 5,
|
||||||
|
},
|
||||||
"enable_chunked_prefill": False,
|
"enable_chunked_prefill": False,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
# Chunked prefill enabled with small value
|
# Chunked prefill enabled with small value
|
||||||
# to make sure we get mixed batches.
|
# to make sure we get mixed batches.
|
||||||
"speculative_model": "JackFram/llama-68m",
|
"speculative_config": {
|
||||||
|
"model": "JackFram/llama-68m",
|
||||||
"num_speculative_tokens": 5,
|
"num_speculative_tokens": 5,
|
||||||
|
},
|
||||||
"enable_chunked_prefill": True,
|
"enable_chunked_prefill": True,
|
||||||
"max_num_batched_tokens": 4,
|
"max_num_batched_tokens": 4,
|
||||||
"max_num_seqs": 4
|
"max_num_seqs": 4
|
||||||
@ -148,19 +152,22 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator,
|
|||||||
},
|
},
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("test_llm_kwargs",
|
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||||
[{
|
"speculative_config": {
|
||||||
"speculative_model": "JackFram/llama-68m",
|
"model": "JackFram/llama-68m",
|
||||||
"num_speculative_tokens": 5,
|
"num_speculative_tokens": 5,
|
||||||
|
"disable_logprobs": False,
|
||||||
|
},
|
||||||
"enable_chunked_prefill": False,
|
"enable_chunked_prefill": False,
|
||||||
"disable_logprobs_during_spec_decoding": False
|
|
||||||
}, {
|
}, {
|
||||||
"speculative_model": "JackFram/llama-68m",
|
"speculative_config": {
|
||||||
|
"model": "JackFram/llama-68m",
|
||||||
"num_speculative_tokens": 3,
|
"num_speculative_tokens": 3,
|
||||||
|
"disable_logprobs": False,
|
||||||
|
},
|
||||||
"enable_chunked_prefill": True,
|
"enable_chunked_prefill": True,
|
||||||
"max_num_batched_tokens": 4,
|
"max_num_batched_tokens": 4,
|
||||||
"max_num_seqs": 4,
|
"max_num_seqs": 4,
|
||||||
"disable_logprobs_during_spec_decoding": False
|
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"output_len",
|
"output_len",
|
||||||
@ -184,7 +191,7 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
|
|||||||
whether all speculative tokens are accepted.
|
whether all speculative tokens are accepted.
|
||||||
"""
|
"""
|
||||||
ensure_all_accepted = per_test_common_llm_kwargs.get(
|
ensure_all_accepted = per_test_common_llm_kwargs.get(
|
||||||
"model_name") == test_llm_kwargs.get("speculative_model")
|
"model_name") == test_llm_kwargs.get("speculative_config")["model"]
|
||||||
run_equality_correctness_test(vllm_runner,
|
run_equality_correctness_test(vllm_runner,
|
||||||
common_llm_kwargs,
|
common_llm_kwargs,
|
||||||
per_test_common_llm_kwargs,
|
per_test_common_llm_kwargs,
|
||||||
@ -224,13 +231,17 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
|
|||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||||
{
|
{
|
||||||
"speculative_model": "JackFram/llama-68m",
|
"speculative_config": {
|
||||||
|
"model": "JackFram/llama-68m",
|
||||||
"num_speculative_tokens": 5,
|
"num_speculative_tokens": 5,
|
||||||
|
},
|
||||||
"enable_chunked_prefill": False,
|
"enable_chunked_prefill": False,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"speculative_model": "JackFram/llama-68m",
|
"speculative_config": {
|
||||||
|
"model": "JackFram/llama-68m",
|
||||||
"num_speculative_tokens": 5,
|
"num_speculative_tokens": 5,
|
||||||
|
},
|
||||||
"enable_chunked_prefill": True,
|
"enable_chunked_prefill": True,
|
||||||
"max_num_batched_tokens": 4,
|
"max_num_batched_tokens": 4,
|
||||||
"max_num_seqs": 4
|
"max_num_seqs": 4
|
||||||
@ -283,13 +294,17 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs(
|
|||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||||
{
|
{
|
||||||
"speculative_model": "JackFram/llama-68m",
|
"speculative_config": {
|
||||||
|
"model": "JackFram/llama-68m",
|
||||||
"num_speculative_tokens": 5,
|
"num_speculative_tokens": 5,
|
||||||
|
},
|
||||||
"enable_chunked_prefill": False,
|
"enable_chunked_prefill": False,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"speculative_model": "JackFram/llama-68m",
|
"speculative_config": {
|
||||||
|
"model": "JackFram/llama-68m",
|
||||||
"num_speculative_tokens": 5,
|
"num_speculative_tokens": 5,
|
||||||
|
},
|
||||||
"enable_chunked_prefill": True,
|
"enable_chunked_prefill": True,
|
||||||
"max_num_batched_tokens": 4,
|
"max_num_batched_tokens": 4,
|
||||||
"max_num_seqs": 4
|
"max_num_seqs": 4
|
||||||
@ -336,13 +351,17 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len(
|
|||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||||
{
|
{
|
||||||
"speculative_model": "JackFram/llama-68m",
|
"speculative_config": {
|
||||||
|
"model": "JackFram/llama-68m",
|
||||||
"num_speculative_tokens": 5,
|
"num_speculative_tokens": 5,
|
||||||
|
},
|
||||||
"enable_chunked_prefill": False,
|
"enable_chunked_prefill": False,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"speculative_model": "JackFram/llama-68m",
|
"speculative_config": {
|
||||||
|
"model": "JackFram/llama-68m",
|
||||||
"num_speculative_tokens": 5,
|
"num_speculative_tokens": 5,
|
||||||
|
},
|
||||||
"enable_chunked_prefill": True,
|
"enable_chunked_prefill": True,
|
||||||
"max_num_batched_tokens": 4,
|
"max_num_batched_tokens": 4,
|
||||||
"max_num_seqs": 4
|
"max_num_seqs": 4
|
||||||
@ -391,13 +410,17 @@ def test_spec_decode_e2e_greedy_correctness_real_model_bs1(
|
|||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||||
{
|
{
|
||||||
"speculative_model": "JackFram/llama-68m",
|
"speculative_config": {
|
||||||
|
"model": "JackFram/llama-68m",
|
||||||
"num_speculative_tokens": 5,
|
"num_speculative_tokens": 5,
|
||||||
|
},
|
||||||
"enable_chunked_prefill": False,
|
"enable_chunked_prefill": False,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"speculative_model": "JackFram/llama-68m",
|
"speculative_config": {
|
||||||
|
"model": "JackFram/llama-68m",
|
||||||
"num_speculative_tokens": 5,
|
"num_speculative_tokens": 5,
|
||||||
|
},
|
||||||
"enable_chunked_prefill": True,
|
"enable_chunked_prefill": True,
|
||||||
"max_num_batched_tokens": 4,
|
"max_num_batched_tokens": 4,
|
||||||
"max_num_seqs": 4
|
"max_num_seqs": 4
|
||||||
@ -449,13 +472,17 @@ def test_spec_decode_e2e_greedy_correctness_real_model_large_bs(
|
|||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||||
{
|
{
|
||||||
"speculative_model": "JackFram/llama-68m",
|
"speculative_config": {
|
||||||
|
"model": "JackFram/llama-68m",
|
||||||
"num_speculative_tokens": 5,
|
"num_speculative_tokens": 5,
|
||||||
|
},
|
||||||
"enable_chunked_prefill": False,
|
"enable_chunked_prefill": False,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"speculative_model": "JackFram/llama-68m",
|
"speculative_config": {
|
||||||
|
"model": "JackFram/llama-68m",
|
||||||
"num_speculative_tokens": 5,
|
"num_speculative_tokens": 5,
|
||||||
|
},
|
||||||
"enable_chunked_prefill": True,
|
"enable_chunked_prefill": True,
|
||||||
"max_num_batched_tokens": 4,
|
"max_num_batched_tokens": 4,
|
||||||
"max_num_seqs": 4
|
"max_num_seqs": 4
|
||||||
@ -514,13 +541,17 @@ def test_spec_decode_e2e_greedy_correctness_with_preemption(
|
|||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||||
{
|
{
|
||||||
"speculative_model": "JackFram/llama-68m",
|
"speculative_config": {
|
||||||
|
"model": "JackFram/llama-68m",
|
||||||
"num_speculative_tokens": 5,
|
"num_speculative_tokens": 5,
|
||||||
|
},
|
||||||
"enable_chunked_prefill": False,
|
"enable_chunked_prefill": False,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"speculative_model": "JackFram/llama-68m",
|
"speculative_config": {
|
||||||
|
"model": "JackFram/llama-68m",
|
||||||
"num_speculative_tokens": 5,
|
"num_speculative_tokens": 5,
|
||||||
|
},
|
||||||
"enable_chunked_prefill": True,
|
"enable_chunked_prefill": True,
|
||||||
"max_num_batched_tokens": 4,
|
"max_num_batched_tokens": 4,
|
||||||
"max_num_seqs": 4
|
"max_num_seqs": 4
|
||||||
@ -567,21 +598,25 @@ def test_spec_decode_different_block_size(vllm_runner, common_llm_kwargs,
|
|||||||
"test_llm_kwargs",
|
"test_llm_kwargs",
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"speculative_model": "JackFram/llama-68m",
|
|
||||||
"num_speculative_tokens": 5,
|
|
||||||
|
|
||||||
# Artificially limit the draft model max model len; this forces vLLM
|
# Artificially limit the draft model max model len; this forces vLLM
|
||||||
# to skip speculation once the sequences grow beyond 32-k tokens.
|
# to skip speculation once the sequences grow beyond 32-k tokens.
|
||||||
"speculative_max_model_len": 32,
|
"speculative_config": {
|
||||||
|
"model": "JackFram/llama-68m",
|
||||||
|
"num_speculative_tokens": 5,
|
||||||
|
"max_model_len": 32,
|
||||||
|
},
|
||||||
"enable_chunked_prefill": False,
|
"enable_chunked_prefill": False,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"speculative_model": "JackFram/llama-68m",
|
"speculative_config": {
|
||||||
|
"model": "JackFram/llama-68m",
|
||||||
"num_speculative_tokens": 5,
|
"num_speculative_tokens": 5,
|
||||||
|
"max_model_len": 32,
|
||||||
|
},
|
||||||
"enable_chunked_prefill": True,
|
"enable_chunked_prefill": True,
|
||||||
"max_num_batched_tokens": 4,
|
"max_num_batched_tokens": 4,
|
||||||
"max_num_seqs": 4,
|
"max_num_seqs": 4,
|
||||||
"speculative_max_model_len": 32,
|
|
||||||
},
|
},
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("batch_size", [8])
|
@pytest.mark.parametrize("batch_size", [8])
|
||||||
@ -627,15 +662,19 @@ def test_skip_speculation(vllm_runner, common_llm_kwargs,
|
|||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||||
{
|
{
|
||||||
"speculative_model": "JackFram/llama-68m",
|
"speculative_config": {
|
||||||
|
"model": "JackFram/llama-68m",
|
||||||
"num_speculative_tokens": 5,
|
"num_speculative_tokens": 5,
|
||||||
"speculative_disable_by_batch_size": 2,
|
"disable_by_batch_size": 2,
|
||||||
|
},
|
||||||
"enable_chunked_prefill": False,
|
"enable_chunked_prefill": False,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"speculative_model": "JackFram/llama-68m",
|
"speculative_config": {
|
||||||
|
"model": "JackFram/llama-68m",
|
||||||
"num_speculative_tokens": 5,
|
"num_speculative_tokens": 5,
|
||||||
"speculative_disable_by_batch_size": 2,
|
"disable_by_batch_size": 2,
|
||||||
|
},
|
||||||
"enable_chunked_prefill": True,
|
"enable_chunked_prefill": True,
|
||||||
"max_num_batched_tokens": 4,
|
"max_num_batched_tokens": 4,
|
||||||
"max_num_seqs": 4,
|
"max_num_seqs": 4,
|
||||||
@ -676,15 +715,19 @@ def test_disable_speculation(vllm_runner, common_llm_kwargs,
|
|||||||
"test_llm_kwargs",
|
"test_llm_kwargs",
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"speculative_model": "JackFram/llama-68m",
|
"speculative_config": {
|
||||||
|
"model": "JackFram/llama-68m",
|
||||||
"num_speculative_tokens": k,
|
"num_speculative_tokens": k,
|
||||||
|
},
|
||||||
"enable_chunked_prefill": False,
|
"enable_chunked_prefill": False,
|
||||||
}
|
}
|
||||||
# Try a range of common k, as well as large speculation.
|
# Try a range of common k, as well as large speculation.
|
||||||
for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63]
|
for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63]
|
||||||
] + [{
|
] + [{
|
||||||
"speculative_model": "JackFram/llama-68m",
|
"speculative_config": {
|
||||||
|
"model": "JackFram/llama-68m",
|
||||||
"num_speculative_tokens": k,
|
"num_speculative_tokens": k,
|
||||||
|
},
|
||||||
"enable_chunked_prefill": True,
|
"enable_chunked_prefill": True,
|
||||||
"max_num_batched_tokens": 4,
|
"max_num_batched_tokens": 4,
|
||||||
"max_num_seqs": 4,
|
"max_num_seqs": 4,
|
||||||
@ -729,17 +772,21 @@ def test_many_k(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
|||||||
"test_llm_kwargs",
|
"test_llm_kwargs",
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"speculative_model": "JackFram/llama-68m",
|
"speculative_config": {
|
||||||
|
"model": "JackFram/llama-68m",
|
||||||
"num_speculative_tokens": k,
|
"num_speculative_tokens": k,
|
||||||
"spec_decoding_acceptance_method": "typical_acceptance_sampler",
|
"acceptance_method": "typical_acceptance_sampler",
|
||||||
|
},
|
||||||
"enable_chunked_prefill": False
|
"enable_chunked_prefill": False
|
||||||
}
|
}
|
||||||
# Try a range of common k.
|
# Try a range of common k.
|
||||||
for k in [1, 2, 3]
|
for k in [1, 2, 3]
|
||||||
] + [{
|
] + [{
|
||||||
"speculative_model": "JackFram/llama-68m",
|
"speculative_config": {
|
||||||
|
"model": "JackFram/llama-68m",
|
||||||
"num_speculative_tokens": k,
|
"num_speculative_tokens": k,
|
||||||
"spec_decoding_acceptance_method": "typical_acceptance_sampler",
|
"acceptance_method": "typical_acceptance_sampler",
|
||||||
|
},
|
||||||
"enable_chunked_prefill": True,
|
"enable_chunked_prefill": True,
|
||||||
"max_num_batched_tokens": 4,
|
"max_num_batched_tokens": 4,
|
||||||
"max_num_seqs": 4
|
"max_num_seqs": 4
|
||||||
|
@ -48,16 +48,20 @@ from .conftest import run_equality_correctness_test
|
|||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||||
{
|
{
|
||||||
"speculative_model": "[ngram]",
|
"speculative_config": {
|
||||||
|
"method": "ngram",
|
||||||
"num_speculative_tokens": 5,
|
"num_speculative_tokens": 5,
|
||||||
"ngram_prompt_lookup_max": 3,
|
"prompt_lookup_max": 3,
|
||||||
"speculative_disable_mqa_scorer": False,
|
"disable_mqa_scorer": False,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"speculative_model": "[ngram]",
|
"speculative_config": {
|
||||||
|
"method": "ngram",
|
||||||
"num_speculative_tokens": 5,
|
"num_speculative_tokens": 5,
|
||||||
"ngram_prompt_lookup_max": 3,
|
"prompt_lookup_max": 3,
|
||||||
"speculative_disable_mqa_scorer": True,
|
"disable_mqa_scorer": True,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("output_len", [
|
@pytest.mark.parametrize("output_len", [
|
||||||
@ -101,16 +105,20 @@ def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
|||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||||
{
|
{
|
||||||
"speculative_model": "[ngram]",
|
"speculative_config": {
|
||||||
|
"method": "ngram",
|
||||||
"num_speculative_tokens": 5,
|
"num_speculative_tokens": 5,
|
||||||
"ngram_prompt_lookup_max": 3,
|
"prompt_lookup_max": 3,
|
||||||
"disable_logprobs_during_spec_decoding": False,
|
"disable_logprobs": False,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"speculative_model": "[ngram]",
|
"speculative_config": {
|
||||||
|
"method": "ngram",
|
||||||
"num_speculative_tokens": 5,
|
"num_speculative_tokens": 5,
|
||||||
"ngram_prompt_lookup_max": 3,
|
"prompt_lookup_max": 3,
|
||||||
"disable_logprobs_during_spec_decoding": True,
|
"disable_logprobs": True,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("output_len", [
|
@pytest.mark.parametrize("output_len", [
|
||||||
@ -125,7 +133,8 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
|||||||
batch_size: int, output_len: int, seed: int,
|
batch_size: int, output_len: int, seed: int,
|
||||||
logprobs: int):
|
logprobs: int):
|
||||||
"""Verify greedy equality on a tiny model with different batch size."""
|
"""Verify greedy equality on a tiny model with different batch size."""
|
||||||
run_equality_correctness_test(vllm_runner,
|
run_equality_correctness_test(
|
||||||
|
vllm_runner,
|
||||||
common_llm_kwargs,
|
common_llm_kwargs,
|
||||||
per_test_common_llm_kwargs,
|
per_test_common_llm_kwargs,
|
||||||
baseline_llm_kwargs,
|
baseline_llm_kwargs,
|
||||||
@ -136,8 +145,8 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
|||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
prompt_logprobs=logprobs,
|
prompt_logprobs=logprobs,
|
||||||
disable_logprobs=test_llm_kwargs[
|
disable_logprobs=test_llm_kwargs["speculative_config"]
|
||||||
'disable_logprobs_during_spec_decoding'])
|
["disable_logprobs"])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -159,17 +168,21 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
|||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||||
{
|
{
|
||||||
"speculative_model": "[ngram]",
|
"speculative_config": {
|
||||||
|
"method": "ngram",
|
||||||
"num_speculative_tokens": 5,
|
"num_speculative_tokens": 5,
|
||||||
"ngram_prompt_lookup_max": 3,
|
"prompt_lookup_max": 3,
|
||||||
|
},
|
||||||
"enable_chunked_prefill": False,
|
"enable_chunked_prefill": False,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"speculative_model": "[ngram]",
|
"speculative_config": {
|
||||||
|
"method": "ngram",
|
||||||
"num_speculative_tokens": 5,
|
"num_speculative_tokens": 5,
|
||||||
"ngram_prompt_lookup_max": 3,
|
"prompt_lookup_max": 3,
|
||||||
|
"disable_mqa_scorer": True,
|
||||||
|
},
|
||||||
"enable_chunked_prefill": True,
|
"enable_chunked_prefill": True,
|
||||||
"speculative_disable_mqa_scorer": True,
|
|
||||||
"max_num_batched_tokens": 4,
|
"max_num_batched_tokens": 4,
|
||||||
"max_num_seqs": 4
|
"max_num_seqs": 4
|
||||||
},
|
},
|
||||||
@ -214,17 +227,21 @@ def test_ngram_e2e_greedy_correctness_with_preemption(
|
|||||||
"test_llm_kwargs",
|
"test_llm_kwargs",
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"speculative_model": "[ngram]",
|
"speculative_config": {
|
||||||
|
"method": "ngram",
|
||||||
"num_speculative_tokens": k,
|
"num_speculative_tokens": k,
|
||||||
"ngram_prompt_lookup_max": 3,
|
"prompt_lookup_max": 3,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
# Try a range of common k, as well as large speculation.
|
# Try a range of common k, as well as large speculation.
|
||||||
for k in [1, 3, 5]
|
for k in [1, 3, 5]
|
||||||
] + [
|
] + [
|
||||||
{
|
{
|
||||||
"speculative_model": "[ngram]",
|
"speculative_config": {
|
||||||
|
"method": "ngram",
|
||||||
"num_speculative_tokens": k,
|
"num_speculative_tokens": k,
|
||||||
"ngram_prompt_lookup_max": 1,
|
"prompt_lookup_max": 1,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
# Try a range of common k, as well as large speculation.
|
# Try a range of common k, as well as large speculation.
|
||||||
for k in [1, 3, 5]
|
for k in [1, 3, 5]
|
||||||
@ -243,7 +260,7 @@ def test_ngram_different_k(vllm_runner, common_llm_kwargs,
|
|||||||
seed: int):
|
seed: int):
|
||||||
"""Verify that ngram speculative decoding produces exact equality
|
"""Verify that ngram speculative decoding produces exact equality
|
||||||
to without spec decode with many different values of k and
|
to without spec decode with many different values of k and
|
||||||
different ngram_prompt_lookup_max.
|
different ngram prompt_lookup_max.
|
||||||
"""
|
"""
|
||||||
run_equality_correctness_test(vllm_runner,
|
run_equality_correctness_test(vllm_runner,
|
||||||
common_llm_kwargs,
|
common_llm_kwargs,
|
||||||
@ -266,19 +283,22 @@ def test_ngram_different_k(vllm_runner, common_llm_kwargs,
|
|||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("test_llm_kwargs",
|
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||||
[{
|
"speculative_config": {
|
||||||
"speculative_model": "[ngram]",
|
"method": "ngram",
|
||||||
"num_speculative_tokens": 5,
|
"num_speculative_tokens": 5,
|
||||||
"ngram_prompt_lookup_max": 3,
|
"prompt_lookup_max": 3,
|
||||||
"speculative_disable_by_batch_size": 4
|
"disable_by_batch_size": 4
|
||||||
|
},
|
||||||
}, {
|
}, {
|
||||||
"speculative_model": "[ngram]",
|
"speculative_config": {
|
||||||
|
"method": "ngram",
|
||||||
"num_speculative_tokens": 5,
|
"num_speculative_tokens": 5,
|
||||||
"ngram_prompt_lookup_max": 3,
|
"prompt_lookup_max": 3,
|
||||||
"speculative_disable_by_batch_size": 4,
|
"disable_by_batch_size": 4,
|
||||||
|
"disable_mqa_scorer": True,
|
||||||
|
},
|
||||||
"enable_chunked_prefill": True,
|
"enable_chunked_prefill": True,
|
||||||
"speculative_disable_mqa_scorer": True,
|
|
||||||
"max_num_batched_tokens": 4,
|
"max_num_batched_tokens": 4,
|
||||||
"max_num_seqs": 4
|
"max_num_seqs": 4
|
||||||
}])
|
}])
|
||||||
@ -296,7 +316,7 @@ def test_ngram_disable_queue(vllm_runner, common_llm_kwargs,
|
|||||||
seed: int):
|
seed: int):
|
||||||
"""Verify that ngram speculative decoding produces exact equality
|
"""Verify that ngram speculative decoding produces exact equality
|
||||||
to without spec decode with many different values of k and
|
to without spec decode with many different values of k and
|
||||||
different ngram_prompt_lookup_max.
|
different ngram prompt_lookup_max.
|
||||||
"""
|
"""
|
||||||
run_equality_correctness_test(vllm_runner,
|
run_equality_correctness_test(vllm_runner,
|
||||||
common_llm_kwargs,
|
common_llm_kwargs,
|
||||||
@ -316,17 +336,16 @@ def test_ngram_disable_queue(vllm_runner, common_llm_kwargs,
|
|||||||
|
|
||||||
# Skip cuda graph recording for fast test.
|
# Skip cuda graph recording for fast test.
|
||||||
"enforce_eager": True,
|
"enforce_eager": True,
|
||||||
|
|
||||||
# Required for spec decode.
|
|
||||||
"speculative_model": "[ngram]",
|
|
||||||
"num_speculative_tokens": 5,
|
|
||||||
"ngram_prompt_lookup_max": 3,
|
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("test_llm_kwargs",
|
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||||
[{
|
"speculative_config": {
|
||||||
"speculative_disable_mqa_scorer": True,
|
"method": "ngram",
|
||||||
|
"num_speculative_tokens": 5,
|
||||||
|
"prompt_lookup_max": 3,
|
||||||
|
"disable_mqa_scorer": True,
|
||||||
|
},
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -19,11 +19,11 @@ SPEC_MODEL = "JackFram/llama-160m"
|
|||||||
# Skip cuda graph recording for fast test.
|
# Skip cuda graph recording for fast test.
|
||||||
"enforce_eager": True,
|
"enforce_eager": True,
|
||||||
|
|
||||||
# speculative model
|
# speculative config
|
||||||
"speculative_model": "JackFram/llama-160m",
|
"speculative_config": {
|
||||||
|
"model": "JackFram/llama-160m",
|
||||||
# num speculative tokens
|
|
||||||
"num_speculative_tokens": 3,
|
"num_speculative_tokens": 3,
|
||||||
|
},
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{"seed": 1}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{"seed": 1}])
|
||||||
|
@ -70,12 +70,16 @@ def test_ngram_correctness(
|
|||||||
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
||||||
del ref_llm
|
del ref_llm
|
||||||
|
|
||||||
spec_llm = LLM(model=model_name,
|
spec_llm = LLM(
|
||||||
speculative_model='[ngram]',
|
model=model_name,
|
||||||
ngram_prompt_lookup_max=5,
|
speculative_config={
|
||||||
ngram_prompt_lookup_min=3,
|
"method": "ngram",
|
||||||
num_speculative_tokens=3,
|
"prompt_lookup_max": 5,
|
||||||
max_model_len=1024)
|
"prompt_lookup_min": 3,
|
||||||
|
"num_speculative_tokens": 3,
|
||||||
|
},
|
||||||
|
max_model_len=1024,
|
||||||
|
)
|
||||||
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
|
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
|
||||||
matches = 0
|
matches = 0
|
||||||
misses = 0
|
misses = 0
|
||||||
|
566
vllm/config.py
566
vllm/config.py
@ -1810,12 +1810,139 @@ class DeviceConfig:
|
|||||||
self.device = torch.device(self.device_type)
|
self.device = torch.device(self.device_type)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class SpeculativeConfig:
|
class SpeculativeConfig:
|
||||||
"""Configuration for speculative decoding.
|
|
||||||
|
|
||||||
The configuration is currently specialized to draft-model speculative
|
|
||||||
decoding with top-1 proposals.
|
|
||||||
"""
|
"""
|
||||||
|
Configuration for speculative decoding.
|
||||||
|
Configurable parameters include:
|
||||||
|
- General Speculative Decoding Control:
|
||||||
|
- num_speculative_tokens (int): The number of speculative
|
||||||
|
tokens, if provided. It will default to the number in the draft
|
||||||
|
model config if present, otherwise, it is required.
|
||||||
|
- model (Optional[str]): The name of the draft model, eagle head,
|
||||||
|
or additional weights, if provided.
|
||||||
|
- method (Optional[str]): The name of the speculative method to use.
|
||||||
|
If users provide and set the `model` param, the speculative method
|
||||||
|
type will be detected automatically if possible, if `model` param
|
||||||
|
is not provided, the method name must be provided.
|
||||||
|
- Possible values:
|
||||||
|
- ngram
|
||||||
|
Related additional configuration:
|
||||||
|
- prompt_lookup_max (Optional[int]):
|
||||||
|
Maximum size of ngram token window when using Ngram
|
||||||
|
proposer, required when method is set to ngram.
|
||||||
|
- prompt_lookup_min (Optional[int]):
|
||||||
|
Minimum size of ngram token window when using Ngram
|
||||||
|
proposer, if provided. Defaults to 1.
|
||||||
|
- eagle
|
||||||
|
- medusa
|
||||||
|
- mlp_speculator
|
||||||
|
- draft_model
|
||||||
|
- acceptance_method (str): The method to use for accepting draft
|
||||||
|
tokens. This can take two possible values: 'rejection_sampler' and
|
||||||
|
'typical_acceptance_sampler' for RejectionSampler and
|
||||||
|
TypicalAcceptanceSampler respectively. If not specified, it
|
||||||
|
defaults to 'rejection_sampler'.
|
||||||
|
- Possible values:
|
||||||
|
- rejection_sampler
|
||||||
|
- typical_acceptance_sampler
|
||||||
|
Related additional configuration:
|
||||||
|
- posterior_threshold (Optional[float]):
|
||||||
|
A threshold value that sets a lower bound on the
|
||||||
|
posterior probability of a token in the target model
|
||||||
|
for it to be accepted. This threshold is used only
|
||||||
|
when we use the TypicalAcceptanceSampler for token
|
||||||
|
acceptance.
|
||||||
|
- posterior_alpha (Optional[float]):
|
||||||
|
Scaling factor for entropy-based threshold, applied
|
||||||
|
when using TypicalAcceptanceSampler.
|
||||||
|
- draft_tensor_parallel_size (Optional[int]): The degree of the tensor
|
||||||
|
parallelism for the draft model. Can only be 1 or the same as the
|
||||||
|
target model's tensor parallel size.
|
||||||
|
- disable_logprobs (bool): If set to True, token log probabilities are
|
||||||
|
not returned during speculative decoding. If set to False, token
|
||||||
|
log probabilities are returned according to the log probability
|
||||||
|
settings in SamplingParams. If not specified, it defaults to True.
|
||||||
|
|
||||||
|
- Draft Model Configuration:
|
||||||
|
- quantization (Optional[str]): Quantization method that was used to
|
||||||
|
quantize the draft model weights. If None, we assume the
|
||||||
|
model weights are not quantized. Note that it only takes effect
|
||||||
|
when using the draft model-based speculative method.
|
||||||
|
- max_model_len (Optional[int]): The maximum model length of the
|
||||||
|
draft model. Used when testing the ability to skip
|
||||||
|
speculation for some sequences.
|
||||||
|
- revision: The specific model version to use for the draft model. It
|
||||||
|
can be a branch name, a tag name, or a commit id. If unspecified,
|
||||||
|
will use the default version.
|
||||||
|
- code_revision: The specific revision to use for the draft model code
|
||||||
|
on Hugging Face Hub. It can be a branch name, a tag name, or a
|
||||||
|
commit id. If unspecified, will use the default version.
|
||||||
|
|
||||||
|
- Advanced Control:
|
||||||
|
- disable_mqa_scorer (bool): Disable the MQA scorer and fall back to
|
||||||
|
batch expansion for scoring proposals. If not specified, it
|
||||||
|
defaults to False.
|
||||||
|
- disable_by_batch_size (Optional[int]): Disable speculative decoding
|
||||||
|
for new incoming requests when the number of enqueued requests is
|
||||||
|
larger than this value, if provided.
|
||||||
|
|
||||||
|
Although the parameters above are structured hierarchically, there is no
|
||||||
|
need to nest them during configuration.
|
||||||
|
|
||||||
|
Non-configurable internal parameters include:
|
||||||
|
- Model Configuration:
|
||||||
|
- target_model_config (ModelConfig): The configuration of the target
|
||||||
|
model.
|
||||||
|
- draft_model_config (ModelConfig): The configuration of the draft
|
||||||
|
model initialized internal.
|
||||||
|
- Parallelism Configuration:
|
||||||
|
- target_parallel_config (ParallelConfig): The parallel configuration
|
||||||
|
for the target model.
|
||||||
|
- draft_parallel_config (ParallelConfig): The parallel configuration
|
||||||
|
for the draft model initialized internal.
|
||||||
|
- Execution Control:
|
||||||
|
- enable_chunked_prefill (bool): Whether vLLM is configured to use
|
||||||
|
chunked prefill or not. Used for raising an error since it's not
|
||||||
|
yet compatible with speculative decode.
|
||||||
|
- disable_log_stats (bool): Whether to disable the periodic printing of
|
||||||
|
stage times in speculative decoding.
|
||||||
|
"""
|
||||||
|
# speculative configs from cli args
|
||||||
|
num_speculative_tokens: int = field(default=None,
|
||||||
|
init=True) # type: ignore
|
||||||
|
method: Optional[str] = None
|
||||||
|
acceptance_method: str = "rejection_sampler"
|
||||||
|
draft_tensor_parallel_size: Optional[int] = None
|
||||||
|
disable_logprobs: bool = True
|
||||||
|
|
||||||
|
model: Optional[str] = None
|
||||||
|
quantization: Optional[str] = None
|
||||||
|
max_model_len: Optional[int] = None
|
||||||
|
revision: Optional[str] = None
|
||||||
|
code_revision: Optional[str] = None
|
||||||
|
|
||||||
|
disable_mqa_scorer: bool = False
|
||||||
|
disable_by_batch_size: Optional[int] = None
|
||||||
|
prompt_lookup_max: Optional[int] = None
|
||||||
|
prompt_lookup_min: Optional[int] = None
|
||||||
|
posterior_threshold: Optional[float] = None
|
||||||
|
posterior_alpha: Optional[float] = None
|
||||||
|
|
||||||
|
# required configuration params passed from engine
|
||||||
|
target_model_config: ModelConfig = field(default=None,
|
||||||
|
init=True) # type: ignore
|
||||||
|
target_parallel_config: ParallelConfig = field(default=None,
|
||||||
|
init=True) # type: ignore
|
||||||
|
enable_chunked_prefill: bool = field(default=None,
|
||||||
|
init=True) # type: ignore
|
||||||
|
disable_log_stats: bool = field(default=None, init=True) # type: ignore
|
||||||
|
|
||||||
|
# params generated in the post-init stage
|
||||||
|
draft_model_config: ModelConfig = field(default=None,
|
||||||
|
init=True) # type: ignore
|
||||||
|
draft_parallel_config: ParallelConfig = field(default=None,
|
||||||
|
init=True) # type: ignore
|
||||||
|
|
||||||
def compute_hash(self) -> str:
|
def compute_hash(self) -> str:
|
||||||
"""
|
"""
|
||||||
@ -1835,6 +1962,11 @@ class SpeculativeConfig:
|
|||||||
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
|
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
|
||||||
return hash_str
|
return hash_str
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, dict_value: dict) -> "SpeculativeConfig":
|
||||||
|
"""Parse the CLI value for the speculative config."""
|
||||||
|
return cls(**dict_value)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
|
def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
|
||||||
if hf_config.model_type == "deepseek_v3":
|
if hf_config.model_type == "deepseek_v3":
|
||||||
@ -1847,230 +1979,160 @@ class SpeculativeConfig:
|
|||||||
})
|
})
|
||||||
return hf_config
|
return hf_config
|
||||||
|
|
||||||
@staticmethod
|
def __post_init__(self):
|
||||||
def maybe_create_spec_config(
|
|
||||||
target_model_config: ModelConfig,
|
|
||||||
target_parallel_config: ParallelConfig,
|
|
||||||
target_dtype: str,
|
|
||||||
speculative_model: Optional[str],
|
|
||||||
speculative_model_quantization: Optional[str],
|
|
||||||
speculative_draft_tensor_parallel_size: Optional[int],
|
|
||||||
num_speculative_tokens: Optional[int],
|
|
||||||
speculative_disable_mqa_scorer: Optional[bool],
|
|
||||||
speculative_max_model_len: Optional[int],
|
|
||||||
enable_chunked_prefill: bool,
|
|
||||||
disable_log_stats: bool,
|
|
||||||
speculative_disable_by_batch_size: Optional[int],
|
|
||||||
ngram_prompt_lookup_max: Optional[int],
|
|
||||||
ngram_prompt_lookup_min: Optional[int],
|
|
||||||
draft_token_acceptance_method: str,
|
|
||||||
typical_acceptance_sampler_posterior_threshold: Optional[float],
|
|
||||||
typical_acceptance_sampler_posterior_alpha: Optional[float],
|
|
||||||
disable_logprobs: Optional[bool],
|
|
||||||
) -> Optional["SpeculativeConfig"]:
|
|
||||||
"""Create a SpeculativeConfig if possible, else return None.
|
|
||||||
|
|
||||||
This function attempts to create a SpeculativeConfig object based on the
|
# Note: After next release, the method parameter will be used to
|
||||||
provided parameters. If the necessary conditions are met, it returns an
|
# specify the speculative method, which helps to extend the
|
||||||
instance of SpeculativeConfig. Otherwise, it returns None.
|
# configuration of non-model-based proposers, and the model parameter
|
||||||
|
# will be used when the draft model or head is needed.
|
||||||
|
# If users do not specify the method, the speculative method will
|
||||||
|
# be detected automatically if possible. If the speculative method can
|
||||||
|
# not be detected, it will be considered as the draft-model-based
|
||||||
|
# method by default.
|
||||||
|
|
||||||
Args:
|
if self.model is None and self.num_speculative_tokens is not None:
|
||||||
target_model_config (ModelConfig): The configuration of the target
|
# TODO(Shangming): Refactor mtp configuration logic when supporting
|
||||||
model.
|
# mtp acceleration for more models besides deepseek_v3
|
||||||
target_parallel_config (ParallelConfig): The parallel configuration
|
if self.target_model_config.hf_text_config.model_type \
|
||||||
for the target model.
|
|
||||||
target_dtype (str): The data type used for the target model.
|
|
||||||
speculative_model (Optional[str]): The name of the speculative
|
|
||||||
model, if provided.
|
|
||||||
speculative_model_quantization (Optional[str]): Quantization method
|
|
||||||
that was used to quantize the speculative model weights. If
|
|
||||||
None, we assume the model weights are not quantized.
|
|
||||||
speculative_draft_tensor_parallel_size (Optional[int]): The degree
|
|
||||||
of the tensor parallelism for the draft model.
|
|
||||||
num_speculative_tokens (Optional[int]): The number of speculative
|
|
||||||
tokens, if provided. Will default to the number in the draft
|
|
||||||
model config if present, otherwise is required.
|
|
||||||
speculative_disable_mqa_scorer (Optional[bool]): Disable the MQA
|
|
||||||
scorer for the speculative model and fall back to batch
|
|
||||||
expansion for scoring.
|
|
||||||
speculative_max_model_len (Optional[int]): The maximum model len of
|
|
||||||
the speculative model. Used when testing the ability to skip
|
|
||||||
speculation for some sequences.
|
|
||||||
enable_chunked_prefill (bool): Whether vLLM is configured to use
|
|
||||||
chunked prefill or not. Used for raising an error since its not
|
|
||||||
yet compatible with spec decode.
|
|
||||||
speculative_disable_by_batch_size (Optional[int]): Disable
|
|
||||||
speculative decoding for new incoming requests when the number
|
|
||||||
of enqueue requests is larger than this value, if provided.
|
|
||||||
ngram_prompt_lookup_max (Optional[int]): Max size of ngram token
|
|
||||||
window, if provided.
|
|
||||||
ngram_prompt_lookup_min (Optional[int]): Min size of ngram token
|
|
||||||
window, if provided.
|
|
||||||
draft_token_acceptance_method (str): The method to use for
|
|
||||||
accepting draft tokens. This can take two possible
|
|
||||||
values 'rejection_sampler' and 'typical_acceptance_sampler'
|
|
||||||
for RejectionSampler and TypicalAcceptanceSampler
|
|
||||||
respectively.
|
|
||||||
typical_acceptance_sampler_posterior_threshold (Optional[float]):
|
|
||||||
A threshold value that sets a lower bound on the posterior
|
|
||||||
probability of a token in the target model for it to be
|
|
||||||
accepted. This threshold is used only when we use the
|
|
||||||
TypicalAcceptanceSampler for token acceptance.
|
|
||||||
typical_acceptance_sampler_posterior_alpha (Optional[float]):
|
|
||||||
A scaling factor for the entropy-based threshold in the
|
|
||||||
TypicalAcceptanceSampler.
|
|
||||||
disable_logprobs (Optional[bool]): If set to True, token log
|
|
||||||
probabilities are not returned during speculative decoding.
|
|
||||||
If set to False, token log probabilities are returned
|
|
||||||
according to the log probability settings in SamplingParams.
|
|
||||||
If not specified, it defaults to True.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
|
|
||||||
the necessary conditions are met, else None.
|
|
||||||
"""
|
|
||||||
if speculative_model is None:
|
|
||||||
if num_speculative_tokens is not None:
|
|
||||||
if target_model_config.hf_text_config.model_type \
|
|
||||||
== "deepseek_v3":
|
== "deepseek_v3":
|
||||||
# use the draft model from the same model:
|
# use the draft model from the same model:
|
||||||
speculative_model = target_model_config.model
|
self.model = self.target_model_config.model
|
||||||
|
elif self.method in ("ngram", "[ngram]"):
|
||||||
|
self.model = "ngram"
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError("num_speculative_tokens was provided without "
|
||||||
"num_speculative_tokens was provided without "
|
"speculative model.")
|
||||||
"speculative_model.")
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if (speculative_disable_by_batch_size is not None
|
# Automatically configure the ngram method during configuration
|
||||||
and speculative_disable_by_batch_size < 2):
|
# refactoring to ensure a smooth transition.
|
||||||
raise ValueError("Expect the batch size threshold of disabling "
|
if self.method is None and (self.model is not None
|
||||||
"speculative decoding is > 1, but got "
|
and self.model in ("ngram", "[ngram]")):
|
||||||
f"{speculative_disable_by_batch_size=}")
|
self.method = "ngram"
|
||||||
if (enable_chunked_prefill and speculative_model == "eagle"):
|
|
||||||
raise ValueError("Chunked prefill and EAGLE are not compatible.")
|
|
||||||
# TODO: The user should be able to specify revision/max model len
|
|
||||||
# for the draft model. It is not currently supported.
|
|
||||||
draft_revision = None
|
|
||||||
draft_code_revision = None
|
|
||||||
draft_quantization = speculative_model_quantization
|
|
||||||
|
|
||||||
if speculative_model == "[ngram]":
|
if self.method in ("ngram", "[ngram]"):
|
||||||
if ngram_prompt_lookup_min is None:
|
# Unified to "ngram" internally
|
||||||
ngram_prompt_lookup_min = 1
|
self.method = "ngram"
|
||||||
if ngram_prompt_lookup_max is None or ngram_prompt_lookup_max < 1:
|
if self.prompt_lookup_min is None:
|
||||||
raise ValueError(f"{ngram_prompt_lookup_max=} must be > 0")
|
self.prompt_lookup_min = 1
|
||||||
if ngram_prompt_lookup_min < 1:
|
if self.prompt_lookup_max is None or self.prompt_lookup_max < 1:
|
||||||
raise ValueError(f"{ngram_prompt_lookup_min=} must be > 0")
|
raise ValueError("prompt_lookup_max="
|
||||||
if ngram_prompt_lookup_min > ngram_prompt_lookup_max:
|
f"{self.prompt_lookup_max} must be > 0")
|
||||||
raise ValueError(f"{ngram_prompt_lookup_min=} cannot be "
|
if self.prompt_lookup_min < 1:
|
||||||
f"larger than {ngram_prompt_lookup_max=}")
|
raise ValueError("prompt_lookup_min="
|
||||||
|
f"{self.prompt_lookup_min} must be > 0")
|
||||||
|
if self.prompt_lookup_min > self.prompt_lookup_max:
|
||||||
|
raise ValueError(f"prompt_lookup_min={self.prompt_lookup_min} "
|
||||||
|
"cannot be larger than prompt_lookup_max="
|
||||||
|
f"{self.prompt_lookup_max}")
|
||||||
|
|
||||||
# TODO: current we still need extract vocab_size from target model
|
# TODO: current we still need extract vocab_size from target model
|
||||||
# config, in future, we may try refactor it out, and set
|
# config, in future, we may try refactor it out, and set
|
||||||
# draft related config as None here.
|
# draft related config as None here.
|
||||||
draft_model_config = target_model_config
|
self.draft_model_config = self.target_model_config
|
||||||
draft_parallel_config = target_parallel_config
|
self.draft_parallel_config = self.target_parallel_config
|
||||||
else:
|
else:
|
||||||
ngram_prompt_lookup_max = 0
|
self.prompt_lookup_max = 0
|
||||||
ngram_prompt_lookup_min = 0
|
self.prompt_lookup_min = 0
|
||||||
draft_model_config = ModelConfig(
|
|
||||||
model=speculative_model,
|
if self.model is not None:
|
||||||
|
self.draft_model_config = ModelConfig(
|
||||||
|
model=self.model,
|
||||||
task="draft",
|
task="draft",
|
||||||
tokenizer=target_model_config.tokenizer,
|
tokenizer=self.target_model_config.tokenizer,
|
||||||
tokenizer_mode=target_model_config.tokenizer_mode,
|
tokenizer_mode=self.target_model_config.tokenizer_mode,
|
||||||
trust_remote_code=target_model_config.trust_remote_code,
|
trust_remote_code=self.target_model_config.
|
||||||
allowed_local_media_path=target_model_config.
|
trust_remote_code,
|
||||||
|
allowed_local_media_path=self.target_model_config.
|
||||||
allowed_local_media_path,
|
allowed_local_media_path,
|
||||||
dtype=target_model_config.dtype,
|
dtype=self.target_model_config.dtype,
|
||||||
seed=target_model_config.seed,
|
seed=self.target_model_config.seed,
|
||||||
revision=draft_revision,
|
revision=self.revision,
|
||||||
code_revision=draft_code_revision,
|
code_revision=self.code_revision,
|
||||||
tokenizer_revision=target_model_config.tokenizer_revision,
|
tokenizer_revision=self.target_model_config.
|
||||||
|
tokenizer_revision,
|
||||||
max_model_len=None,
|
max_model_len=None,
|
||||||
spec_target_max_model_len=target_model_config.max_model_len,
|
spec_target_max_model_len=self.target_model_config.
|
||||||
quantization=draft_quantization,
|
max_model_len,
|
||||||
enforce_eager=target_model_config.enforce_eager,
|
quantization=self.quantization,
|
||||||
max_seq_len_to_capture=target_model_config.
|
enforce_eager=self.target_model_config.enforce_eager,
|
||||||
|
max_seq_len_to_capture=self.target_model_config.
|
||||||
max_seq_len_to_capture,
|
max_seq_len_to_capture,
|
||||||
max_logprobs=target_model_config.max_logprobs,
|
max_logprobs=self.target_model_config.max_logprobs,
|
||||||
hf_overrides=SpeculativeConfig.hf_config_override,
|
hf_overrides=SpeculativeConfig.hf_config_override,
|
||||||
)
|
)
|
||||||
|
|
||||||
draft_hf_config = draft_model_config.hf_config
|
# Automatically detect the method
|
||||||
|
if "eagle-" in self.draft_model_config.model.lower():
|
||||||
|
self.method = "eagle"
|
||||||
|
elif self.draft_model_config.hf_config.model_type == "medusa":
|
||||||
|
self.method = "medusa"
|
||||||
|
elif (self.draft_model_config.hf_config.model_type ==
|
||||||
|
"mlp_speculator"):
|
||||||
|
self.method = "mlp_speculator"
|
||||||
|
else:
|
||||||
|
self.method = "draft_model"
|
||||||
|
|
||||||
# Detect EAGLE prefix to replace hf_config for EAGLE draft_model
|
# Replace hf_config for EAGLE draft_model
|
||||||
if "eagle-" in draft_model_config.model.lower():
|
if self.method == "eagle":
|
||||||
from vllm.transformers_utils.configs.eagle import EAGLEConfig
|
if self.enable_chunked_prefill:
|
||||||
if isinstance(draft_model_config.hf_config, EAGLEConfig):
|
raise ValueError(
|
||||||
|
"Chunked prefill and EAGLE are not compatible.")
|
||||||
|
|
||||||
|
from vllm.transformers_utils.configs.eagle import (
|
||||||
|
EAGLEConfig)
|
||||||
|
if isinstance(self.draft_model_config.hf_config,
|
||||||
|
EAGLEConfig):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
eagle_config = EAGLEConfig(draft_model_config.hf_config)
|
eagle_config = EAGLEConfig(
|
||||||
draft_model_config.hf_config = eagle_config
|
self.draft_model_config.hf_config)
|
||||||
|
self.draft_model_config.hf_config = eagle_config
|
||||||
|
|
||||||
if (num_speculative_tokens is not None
|
if (self.num_speculative_tokens is not None
|
||||||
and hasattr(draft_hf_config, "num_lookahead_tokens")):
|
and hasattr(self.draft_model_config.hf_config,
|
||||||
draft_hf_config.num_lookahead_tokens = num_speculative_tokens
|
"num_lookahead_tokens")):
|
||||||
n_predict = getattr(draft_hf_config, "n_predict", None)
|
self.draft_model_config.hf_config.num_lookahead_tokens = \
|
||||||
|
self.num_speculative_tokens
|
||||||
|
|
||||||
|
n_predict = getattr(self.draft_model_config.hf_config,
|
||||||
|
"n_predict", None)
|
||||||
if n_predict is not None:
|
if n_predict is not None:
|
||||||
if num_speculative_tokens is None:
|
if self.num_speculative_tokens is None:
|
||||||
# Default to max value defined in draft model config.
|
# Default to max value defined in draft model config.
|
||||||
num_speculative_tokens = n_predict
|
self.num_speculative_tokens = n_predict
|
||||||
elif num_speculative_tokens > n_predict and \
|
elif self.num_speculative_tokens > n_predict and \
|
||||||
num_speculative_tokens % n_predict != 0:
|
self.num_speculative_tokens % n_predict != 0:
|
||||||
# Ensure divisibility for MTP module reuse.
|
# Ensure divisibility for MTP module reuse.
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"{num_speculative_tokens=} must be divisible by "
|
f"num_speculative_tokens:{self.num_speculative_tokens}"
|
||||||
f"{n_predict=}")
|
f" must be divisible by {n_predict=}")
|
||||||
|
|
||||||
speculative_draft_tensor_parallel_size = \
|
self.draft_tensor_parallel_size = \
|
||||||
SpeculativeConfig._verify_and_get_draft_model_tensor_parallel_size(
|
SpeculativeConfig._verify_and_get_draft_tp(
|
||||||
target_parallel_config,
|
self.target_parallel_config,
|
||||||
speculative_draft_tensor_parallel_size,
|
self.draft_tensor_parallel_size,
|
||||||
draft_hf_config
|
self.draft_model_config.hf_config
|
||||||
)
|
)
|
||||||
|
|
||||||
draft_model_config.max_model_len = (
|
self.draft_model_config.max_model_len = (
|
||||||
SpeculativeConfig._maybe_override_draft_max_model_len(
|
SpeculativeConfig._maybe_override_draft_max_model_len(
|
||||||
speculative_max_model_len,
|
self.max_model_len,
|
||||||
draft_model_config.max_model_len,
|
self.draft_model_config.max_model_len,
|
||||||
target_model_config.max_model_len,
|
self.target_model_config.max_model_len,
|
||||||
))
|
))
|
||||||
|
|
||||||
draft_parallel_config = (
|
self.draft_parallel_config = (
|
||||||
SpeculativeConfig.create_draft_parallel_config(
|
SpeculativeConfig.create_draft_parallel_config(
|
||||||
target_parallel_config,
|
self.target_parallel_config,
|
||||||
speculative_draft_tensor_parallel_size, draft_hf_config))
|
self.draft_tensor_parallel_size))
|
||||||
|
|
||||||
if num_speculative_tokens is None:
|
if self.acceptance_method == "typical_acceptance_sampler":
|
||||||
raise ValueError(
|
if self.posterior_threshold is None:
|
||||||
"num_speculative_tokens must be provided with "
|
self.posterior_threshold = 0.09
|
||||||
"speculative_model unless the draft model config contains an "
|
if self.posterior_alpha is None:
|
||||||
"n_predict parameter.")
|
self.posterior_alpha = 0.3
|
||||||
|
|
||||||
if typical_acceptance_sampler_posterior_threshold is None:
|
self._verify_args()
|
||||||
typical_acceptance_sampler_posterior_threshold = 0.09
|
|
||||||
if typical_acceptance_sampler_posterior_alpha is None:
|
|
||||||
typical_acceptance_sampler_posterior_alpha = 0.3
|
|
||||||
if disable_logprobs is None:
|
|
||||||
disable_logprobs = True
|
|
||||||
|
|
||||||
return SpeculativeConfig(
|
|
||||||
draft_model_config,
|
|
||||||
draft_parallel_config,
|
|
||||||
num_speculative_tokens,
|
|
||||||
speculative_disable_mqa_scorer,
|
|
||||||
speculative_disable_by_batch_size,
|
|
||||||
ngram_prompt_lookup_max,
|
|
||||||
ngram_prompt_lookup_min,
|
|
||||||
draft_token_acceptance_method=draft_token_acceptance_method,
|
|
||||||
typical_acceptance_sampler_posterior_threshold=\
|
|
||||||
typical_acceptance_sampler_posterior_threshold,
|
|
||||||
typical_acceptance_sampler_posterior_alpha=\
|
|
||||||
typical_acceptance_sampler_posterior_alpha,
|
|
||||||
disable_logprobs=disable_logprobs,
|
|
||||||
disable_log_stats=disable_log_stats,
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _maybe_override_draft_max_model_len(
|
def _maybe_override_draft_max_model_len(
|
||||||
@ -2108,7 +2170,7 @@ class SpeculativeConfig:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _verify_and_get_draft_model_tensor_parallel_size(
|
def _verify_and_get_draft_tp(
|
||||||
target_parallel_config: ParallelConfig,
|
target_parallel_config: ParallelConfig,
|
||||||
speculative_draft_tensor_parallel_size: Optional[int],
|
speculative_draft_tensor_parallel_size: Optional[int],
|
||||||
draft_hf_config: PretrainedConfig) -> int:
|
draft_hf_config: PretrainedConfig) -> int:
|
||||||
@ -2140,7 +2202,6 @@ class SpeculativeConfig:
|
|||||||
def create_draft_parallel_config(
|
def create_draft_parallel_config(
|
||||||
target_parallel_config: ParallelConfig,
|
target_parallel_config: ParallelConfig,
|
||||||
speculative_draft_tensor_parallel_size: int,
|
speculative_draft_tensor_parallel_size: int,
|
||||||
draft_hf_config: PretrainedConfig,
|
|
||||||
) -> ParallelConfig:
|
) -> ParallelConfig:
|
||||||
"""Create a parallel config for use by the draft worker.
|
"""Create a parallel config for use by the draft worker.
|
||||||
|
|
||||||
@ -2164,74 +2225,13 @@ class SpeculativeConfig:
|
|||||||
|
|
||||||
return draft_parallel_config
|
return draft_parallel_config
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
draft_model_config: ModelConfig,
|
|
||||||
draft_parallel_config: ParallelConfig,
|
|
||||||
num_speculative_tokens: int,
|
|
||||||
speculative_disable_mqa_scorer: Optional[bool],
|
|
||||||
speculative_disable_by_batch_size: Optional[int],
|
|
||||||
ngram_prompt_lookup_max: Optional[int],
|
|
||||||
ngram_prompt_lookup_min: Optional[int],
|
|
||||||
draft_token_acceptance_method: str,
|
|
||||||
typical_acceptance_sampler_posterior_threshold: float,
|
|
||||||
typical_acceptance_sampler_posterior_alpha: float,
|
|
||||||
disable_logprobs: bool,
|
|
||||||
disable_log_stats: bool,
|
|
||||||
):
|
|
||||||
"""Create a SpeculativeConfig object.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
draft_model_config: ModelConfig for the draft model.
|
|
||||||
draft_parallel_config: ParallelConfig for the draft model.
|
|
||||||
num_speculative_tokens: The number of tokens to sample from the
|
|
||||||
draft model before scoring with the target model.
|
|
||||||
speculative_disable_by_batch_size: Disable speculative
|
|
||||||
decoding for new incoming requests when the number of
|
|
||||||
enqueue requests is larger than this value.
|
|
||||||
ngram_prompt_lookup_max: Max size of ngram token window.
|
|
||||||
ngram_prompt_lookup_min: Min size of ngram token window.
|
|
||||||
draft_token_acceptance_method (str): The method to use for
|
|
||||||
accepting draft tokens. This can take two possible
|
|
||||||
values 'rejection_sampler' and 'typical_acceptance_sampler'
|
|
||||||
for RejectionSampler and TypicalAcceptanceSampler
|
|
||||||
respectively.
|
|
||||||
typical_acceptance_sampler_posterior_threshold (Optional[float]):
|
|
||||||
A threshold value that sets a lower bound on the posterior
|
|
||||||
probability of a token in the target model for it to be
|
|
||||||
accepted. This threshold is used only when we use the
|
|
||||||
TypicalAcceptanceSampler for token acceptance.
|
|
||||||
typical_acceptance_sampler_posterior_alpha (Optional[float]):
|
|
||||||
A scaling factor for the entropy-based threshold in the
|
|
||||||
TypicalAcceptanceSampler.
|
|
||||||
disable_logprobs: If set to True, token log probabilities will not
|
|
||||||
be returned even if requested by sampling parameters. This
|
|
||||||
reduces latency by skipping logprob calculation in proposal
|
|
||||||
sampling, target sampling, and after accepted tokens are
|
|
||||||
determined. If set to False, log probabilities will be
|
|
||||||
returned.
|
|
||||||
disable_log_stats: Whether to disable periodic printing of stage
|
|
||||||
times in speculative decoding.
|
|
||||||
"""
|
|
||||||
self.draft_model_config = draft_model_config
|
|
||||||
self.draft_parallel_config = draft_parallel_config
|
|
||||||
self.num_speculative_tokens = num_speculative_tokens
|
|
||||||
self.speculative_disable_mqa_scorer = speculative_disable_mqa_scorer
|
|
||||||
self.speculative_disable_by_batch_size = \
|
|
||||||
speculative_disable_by_batch_size
|
|
||||||
self.ngram_prompt_lookup_max = ngram_prompt_lookup_max or 0
|
|
||||||
self.ngram_prompt_lookup_min = ngram_prompt_lookup_min or 0
|
|
||||||
self.draft_token_acceptance_method = draft_token_acceptance_method
|
|
||||||
self.typical_acceptance_sampler_posterior_threshold = \
|
|
||||||
typical_acceptance_sampler_posterior_threshold
|
|
||||||
self.typical_acceptance_sampler_posterior_alpha = \
|
|
||||||
typical_acceptance_sampler_posterior_alpha
|
|
||||||
self.disable_logprobs = disable_logprobs
|
|
||||||
self.disable_log_stats = disable_log_stats
|
|
||||||
|
|
||||||
self._verify_args()
|
|
||||||
|
|
||||||
def _verify_args(self) -> None:
|
def _verify_args(self) -> None:
|
||||||
|
if self.num_speculative_tokens is None:
|
||||||
|
raise ValueError(
|
||||||
|
"num_speculative_tokens must be provided with "
|
||||||
|
"speculative model unless the draft model config contains an "
|
||||||
|
"n_predict parameter.")
|
||||||
|
|
||||||
if self.num_speculative_tokens <= 0:
|
if self.num_speculative_tokens <= 0:
|
||||||
raise ValueError("Expected num_speculative_tokens to be greater "
|
raise ValueError("Expected num_speculative_tokens to be greater "
|
||||||
f"than zero ({self.num_speculative_tokens}).")
|
f"than zero ({self.num_speculative_tokens}).")
|
||||||
@ -2241,29 +2241,34 @@ class SpeculativeConfig:
|
|||||||
self.draft_parallel_config)
|
self.draft_parallel_config)
|
||||||
# Validate and set draft token acceptance related settings.
|
# Validate and set draft token acceptance related settings.
|
||||||
|
|
||||||
if (self.draft_token_acceptance_method is None):
|
if self.acceptance_method is None:
|
||||||
raise ValueError("draft_token_acceptance_method is not set. "
|
raise ValueError("acceptance_method is not set. "
|
||||||
"Expected values are rejection_sampler or "
|
"Expected values are rejection_sampler or "
|
||||||
"typical_acceptance_sampler.")
|
"typical_acceptance_sampler.")
|
||||||
|
|
||||||
if (self.draft_token_acceptance_method != 'rejection_sampler'
|
if (self.acceptance_method != 'rejection_sampler'
|
||||||
and self.draft_token_acceptance_method
|
and self.acceptance_method != 'typical_acceptance_sampler'):
|
||||||
!= 'typical_acceptance_sampler'):
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Expected draft_token_acceptance_method to be either "
|
"Expected acceptance_method to be either "
|
||||||
"rejection_sampler or typical_acceptance_sampler. Instead it "
|
"rejection_sampler or typical_acceptance_sampler. Instead it "
|
||||||
f"is {self.draft_token_acceptance_method}")
|
f"is {self.acceptance_method}")
|
||||||
|
|
||||||
if (self.typical_acceptance_sampler_posterior_threshold < 0
|
if self.acceptance_method == "typical_acceptance_sampler" and (
|
||||||
or self.typical_acceptance_sampler_posterior_alpha < 0):
|
(self.posterior_threshold is not None
|
||||||
|
and self.posterior_threshold < 0) or
|
||||||
|
(self.posterior_alpha is not None and self.posterior_alpha < 0)):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Expected typical_acceptance_sampler_posterior_threshold "
|
"Expected the posterior_threshold and posterior_alpha of "
|
||||||
"and typical_acceptance_sampler_posterior_alpha to be > 0. "
|
"typical_acceptance_sampler to be > 0. "
|
||||||
"Instead found "
|
"Instead found posterior_threshold = "
|
||||||
f"typical_acceptance_sampler_posterior_threshold = "
|
f"{self.posterior_threshold} and posterior_alpha = "
|
||||||
f"{self.typical_acceptance_sampler_posterior_threshold} and "
|
f"{self.posterior_alpha}")
|
||||||
f"typical_acceptance_sampler_posterior_alpha = "
|
|
||||||
f"{self.typical_acceptance_sampler_posterior_alpha}")
|
if (self.disable_by_batch_size is not None
|
||||||
|
and self.disable_by_batch_size < 2):
|
||||||
|
raise ValueError("Expect the batch size threshold of disabling "
|
||||||
|
"speculative decoding is > 1, but got "
|
||||||
|
f"{self.disable_by_batch_size=}")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_lookahead_slots(self) -> int:
|
def num_lookahead_slots(self) -> int:
|
||||||
@ -2276,8 +2281,8 @@ class SpeculativeConfig:
|
|||||||
return self.num_speculative_tokens
|
return self.num_speculative_tokens
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
if self.ngram_prompt_lookup_max > 0:
|
if self.prompt_lookup_max is not None and self.prompt_lookup_max > 0:
|
||||||
draft_model = "[ngram]"
|
draft_model = "ngram"
|
||||||
else:
|
else:
|
||||||
draft_model = self.draft_model_config.model
|
draft_model = self.draft_model_config.model
|
||||||
num_spec_tokens = self.num_speculative_tokens
|
num_spec_tokens = self.num_speculative_tokens
|
||||||
@ -3285,7 +3290,8 @@ class VllmConfig:
|
|||||||
init=True) # type: ignore
|
init=True) # type: ignore
|
||||||
load_config: LoadConfig = field(default=None, init=True) # type: ignore
|
load_config: LoadConfig = field(default=None, init=True) # type: ignore
|
||||||
lora_config: Optional[LoRAConfig] = None
|
lora_config: Optional[LoRAConfig] = None
|
||||||
speculative_config: Optional[SpeculativeConfig] = None
|
speculative_config: SpeculativeConfig = field(default=None,
|
||||||
|
init=True) # type: ignore
|
||||||
decoding_config: Optional[DecodingConfig] = None
|
decoding_config: Optional[DecodingConfig] = None
|
||||||
observability_config: Optional[ObservabilityConfig] = None
|
observability_config: Optional[ObservabilityConfig] = None
|
||||||
prompt_adapter_config: Optional[PromptAdapterConfig] = None
|
prompt_adapter_config: Optional[PromptAdapterConfig] = None
|
||||||
|
@ -177,7 +177,10 @@ class EngineArgs:
|
|||||||
|
|
||||||
guided_decoding_backend: str = 'xgrammar'
|
guided_decoding_backend: str = 'xgrammar'
|
||||||
logits_processor_pattern: Optional[str] = None
|
logits_processor_pattern: Optional[str] = None
|
||||||
# Speculative decoding configuration.
|
|
||||||
|
speculative_config: Optional[Union[str, Dict[str, Any]]] = None
|
||||||
|
|
||||||
|
# TODO(Shangming): Deprecate these out-of-date params after next release
|
||||||
speculative_model: Optional[str] = None
|
speculative_model: Optional[str] = None
|
||||||
speculative_model_quantization: Optional[str] = None
|
speculative_model_quantization: Optional[str] = None
|
||||||
speculative_draft_tensor_parallel_size: Optional[int] = None
|
speculative_draft_tensor_parallel_size: Optional[int] = None
|
||||||
@ -190,9 +193,9 @@ class EngineArgs:
|
|||||||
spec_decoding_acceptance_method: str = 'rejection_sampler'
|
spec_decoding_acceptance_method: str = 'rejection_sampler'
|
||||||
typical_acceptance_sampler_posterior_threshold: Optional[float] = None
|
typical_acceptance_sampler_posterior_threshold: Optional[float] = None
|
||||||
typical_acceptance_sampler_posterior_alpha: Optional[float] = None
|
typical_acceptance_sampler_posterior_alpha: Optional[float] = None
|
||||||
qlora_adapter_name_or_path: Optional[str] = None
|
|
||||||
disable_logprobs_during_spec_decoding: Optional[bool] = None
|
disable_logprobs_during_spec_decoding: Optional[bool] = None
|
||||||
|
|
||||||
|
qlora_adapter_name_or_path: Optional[str] = None
|
||||||
show_hidden_metrics_for_version: Optional[str] = None
|
show_hidden_metrics_for_version: Optional[str] = None
|
||||||
otlp_traces_endpoint: Optional[str] = None
|
otlp_traces_endpoint: Optional[str] = None
|
||||||
collect_detailed_traces: Optional[str] = None
|
collect_detailed_traces: Optional[str] = None
|
||||||
@ -780,7 +783,11 @@ class EngineArgs:
|
|||||||
const="True",
|
const="True",
|
||||||
help='If set, the prefill requests can be chunked based on the '
|
help='If set, the prefill requests can be chunked based on the '
|
||||||
'max_num_batched_tokens.')
|
'max_num_batched_tokens.')
|
||||||
|
parser.add_argument('--speculative-config',
|
||||||
|
type=nullable_str,
|
||||||
|
default=None,
|
||||||
|
help='The configurations for speculative decoding.'
|
||||||
|
' Should be a JSON string.')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--speculative-model',
|
'--speculative-model',
|
||||||
type=nullable_str,
|
type=nullable_str,
|
||||||
@ -1192,6 +1199,82 @@ class EngineArgs:
|
|||||||
use_tqdm_on_load=self.use_tqdm_on_load,
|
use_tqdm_on_load=self.use_tqdm_on_load,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def create_speculative_config(
|
||||||
|
self,
|
||||||
|
target_model_config: ModelConfig,
|
||||||
|
target_parallel_config: ParallelConfig,
|
||||||
|
enable_chunked_prefill: bool,
|
||||||
|
disable_log_stats: bool,
|
||||||
|
) -> Optional["SpeculativeConfig"]:
|
||||||
|
"""Initializes and returns a SpeculativeConfig object based on
|
||||||
|
`speculative_config`.
|
||||||
|
|
||||||
|
This function utilizes `speculative_config` to create a
|
||||||
|
SpeculativeConfig object. The `speculative_config` can either be
|
||||||
|
provided as a JSON string input via CLI arguments or directly as a
|
||||||
|
dictionary from the engine. If `speculative_config` is not set, this
|
||||||
|
function will attempt to construct a configuration dictionary using
|
||||||
|
certain parameters, which are scheduled for deprecation in the next
|
||||||
|
release. Note that in next releases, `speculative_config` must be
|
||||||
|
provided, and the deprecated standalone speculative-related parameters
|
||||||
|
will be removed.
|
||||||
|
"""
|
||||||
|
if self.speculative_config is None:
|
||||||
|
if (self.speculative_model is None
|
||||||
|
and self.num_speculative_tokens is None):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# TODO(Shangming): Deprecate this way of setting SpeculativeConfig,
|
||||||
|
# only allow '--speculative-config' after next release
|
||||||
|
logger.warning_once(
|
||||||
|
"Please use '--speculative-config' to set all configurations "
|
||||||
|
"related to speculative decoding. The current method of "
|
||||||
|
"specifying the model through '--speculative-model' and "
|
||||||
|
"adding related parameters (e.g., '--num-speculative-tokens') "
|
||||||
|
"separately will be deprecated in the next release.")
|
||||||
|
|
||||||
|
spec_config_dict = {
|
||||||
|
"model": self.speculative_model,
|
||||||
|
"quantization": self.speculative_model_quantization,
|
||||||
|
"max_model_len": self.speculative_max_model_len,
|
||||||
|
"draft_tensor_parallel_size":
|
||||||
|
self.speculative_draft_tensor_parallel_size,
|
||||||
|
"num_speculative_tokens": self.num_speculative_tokens,
|
||||||
|
"disable_mqa_scorer": self.speculative_disable_mqa_scorer,
|
||||||
|
"disable_by_batch_size":
|
||||||
|
self.speculative_disable_by_batch_size,
|
||||||
|
"prompt_lookup_max": self.ngram_prompt_lookup_max,
|
||||||
|
"prompt_lookup_min": self.ngram_prompt_lookup_min,
|
||||||
|
"acceptance_method": self.spec_decoding_acceptance_method,
|
||||||
|
"posterior_threshold":
|
||||||
|
self.typical_acceptance_sampler_posterior_threshold,
|
||||||
|
"posterior_alpha":
|
||||||
|
self.typical_acceptance_sampler_posterior_alpha,
|
||||||
|
"disable_logprobs": self.disable_logprobs_during_spec_decoding,
|
||||||
|
}
|
||||||
|
|
||||||
|
self.speculative_config = spec_config_dict
|
||||||
|
else:
|
||||||
|
if isinstance(self.speculative_config, str):
|
||||||
|
import ast
|
||||||
|
self.speculative_config = ast.literal_eval(
|
||||||
|
self.speculative_config)
|
||||||
|
# Note(Shangming): These parameters are not obtained from the cli arg
|
||||||
|
# '--speculative-config' and must be passed in when creating the engine
|
||||||
|
# config.
|
||||||
|
|
||||||
|
assert isinstance(self.speculative_config, dict)
|
||||||
|
self.speculative_config.update({
|
||||||
|
"target_model_config": target_model_config,
|
||||||
|
"target_parallel_config": target_parallel_config,
|
||||||
|
"enable_chunked_prefill": enable_chunked_prefill,
|
||||||
|
"disable_log_stats": disable_log_stats,
|
||||||
|
})
|
||||||
|
speculative_config = SpeculativeConfig.from_dict(
|
||||||
|
self.speculative_config)
|
||||||
|
|
||||||
|
return speculative_config
|
||||||
|
|
||||||
def create_engine_config(
|
def create_engine_config(
|
||||||
self,
|
self,
|
||||||
usage_context: Optional[UsageContext] = None,
|
usage_context: Optional[UsageContext] = None,
|
||||||
@ -1238,6 +1321,8 @@ class EngineArgs:
|
|||||||
else:
|
else:
|
||||||
self._set_default_args_v0(model_config)
|
self._set_default_args_v0(model_config)
|
||||||
|
|
||||||
|
assert self.enable_chunked_prefill is not None
|
||||||
|
|
||||||
cache_config = CacheConfig(
|
cache_config = CacheConfig(
|
||||||
block_size=self.block_size,
|
block_size=self.block_size,
|
||||||
gpu_memory_utilization=self.gpu_memory_utilization,
|
gpu_memory_utilization=self.gpu_memory_utilization,
|
||||||
@ -1280,31 +1365,11 @@ class EngineArgs:
|
|||||||
worker_extension_cls=self.worker_extension_cls,
|
worker_extension_cls=self.worker_extension_cls,
|
||||||
)
|
)
|
||||||
|
|
||||||
speculative_config = SpeculativeConfig.maybe_create_spec_config(
|
speculative_config = self.create_speculative_config(
|
||||||
target_model_config=model_config,
|
target_model_config=model_config,
|
||||||
target_parallel_config=parallel_config,
|
target_parallel_config=parallel_config,
|
||||||
target_dtype=self.dtype,
|
|
||||||
speculative_model=self.speculative_model,
|
|
||||||
speculative_model_quantization = \
|
|
||||||
self.speculative_model_quantization,
|
|
||||||
speculative_draft_tensor_parallel_size = \
|
|
||||||
self.speculative_draft_tensor_parallel_size,
|
|
||||||
num_speculative_tokens=self.num_speculative_tokens,
|
|
||||||
speculative_disable_mqa_scorer=self.speculative_disable_mqa_scorer,
|
|
||||||
speculative_disable_by_batch_size=self.
|
|
||||||
speculative_disable_by_batch_size,
|
|
||||||
speculative_max_model_len=self.speculative_max_model_len,
|
|
||||||
enable_chunked_prefill=self.enable_chunked_prefill,
|
enable_chunked_prefill=self.enable_chunked_prefill,
|
||||||
disable_log_stats=self.disable_log_stats,
|
disable_log_stats=self.disable_log_stats,
|
||||||
ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
|
|
||||||
ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
|
|
||||||
draft_token_acceptance_method=\
|
|
||||||
self.spec_decoding_acceptance_method,
|
|
||||||
typical_acceptance_sampler_posterior_threshold=self.
|
|
||||||
typical_acceptance_sampler_posterior_threshold,
|
|
||||||
typical_acceptance_sampler_posterior_alpha=self.
|
|
||||||
typical_acceptance_sampler_posterior_alpha,
|
|
||||||
disable_logprobs=self.disable_logprobs_during_spec_decoding,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Reminder: Please update docs/source/features/compatibility_matrix.md
|
# Reminder: Please update docs/source/features/compatibility_matrix.md
|
||||||
@ -1569,7 +1634,7 @@ class EngineArgs:
|
|||||||
if (self.speculative_model is not None
|
if (self.speculative_model is not None
|
||||||
or self.num_speculative_tokens is not None):
|
or self.num_speculative_tokens is not None):
|
||||||
# This is supported but experimental (handled below).
|
# This is supported but experimental (handled below).
|
||||||
if self.speculative_model == "[ngram]":
|
if self.speculative_model in ("ngram", "[ngram]"):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
_raise_or_fallback(feature_name="Speculative Decoding",
|
_raise_or_fallback(feature_name="Speculative Decoding",
|
||||||
@ -1617,7 +1682,8 @@ class EngineArgs:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
# ngram is supported on V1, but off by default for now.
|
# ngram is supported on V1, but off by default for now.
|
||||||
if self.speculative_model == "[ngram]" and _warn_or_fallback("ngram"):
|
if self.speculative_model in (
|
||||||
|
"ngram", "[ngram]") and _warn_or_fallback("ngram"):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Non-CUDA is supported on V1, but off by default for now.
|
# Non-CUDA is supported on V1, but off by default for now.
|
||||||
|
@ -92,22 +92,20 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
|
|||||||
# Override draft-model specific worker args.
|
# Override draft-model specific worker args.
|
||||||
draft_worker_kwargs.update(
|
draft_worker_kwargs.update(
|
||||||
vllm_config=draft_worker_config,
|
vllm_config=draft_worker_config,
|
||||||
ngram_prompt_lookup_max=speculative_config.ngram_prompt_lookup_max,
|
ngram_prompt_lookup_max=speculative_config.prompt_lookup_max,
|
||||||
ngram_prompt_lookup_min=speculative_config.ngram_prompt_lookup_min,
|
ngram_prompt_lookup_min=speculative_config.prompt_lookup_min,
|
||||||
)
|
)
|
||||||
|
|
||||||
spec_decode_worker = SpecDecodeWorker.create_worker(
|
spec_decode_worker = SpecDecodeWorker.create_worker(
|
||||||
scorer_worker=target_worker,
|
scorer_worker=target_worker,
|
||||||
draft_worker_kwargs=draft_worker_kwargs,
|
draft_worker_kwargs=draft_worker_kwargs,
|
||||||
disable_mqa_scorer=speculative_config.speculative_disable_mqa_scorer,
|
disable_mqa_scorer=speculative_config.disable_mqa_scorer,
|
||||||
disable_by_batch_size=speculative_config.
|
disable_by_batch_size=speculative_config.disable_by_batch_size,
|
||||||
speculative_disable_by_batch_size,
|
draft_token_acceptance_method=speculative_config.acceptance_method,
|
||||||
draft_token_acceptance_method=speculative_config.
|
|
||||||
draft_token_acceptance_method,
|
|
||||||
typical_acceptance_sampler_posterior_threshold=speculative_config.
|
typical_acceptance_sampler_posterior_threshold=speculative_config.
|
||||||
typical_acceptance_sampler_posterior_threshold,
|
posterior_threshold,
|
||||||
typical_acceptance_sampler_posterior_alpha=speculative_config.
|
typical_acceptance_sampler_posterior_alpha=speculative_config.
|
||||||
typical_acceptance_sampler_posterior_alpha,
|
posterior_alpha,
|
||||||
disable_logprobs=speculative_config.disable_logprobs,
|
disable_logprobs=speculative_config.disable_logprobs,
|
||||||
disable_log_stats=speculative_config.disable_log_stats,
|
disable_log_stats=speculative_config.disable_log_stats,
|
||||||
num_speculative_tokens=speculative_config.num_speculative_tokens,
|
num_speculative_tokens=speculative_config.num_speculative_tokens,
|
||||||
|
@ -151,8 +151,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.use_spec_decode = False
|
self.use_spec_decode = False
|
||||||
if self.speculative_config:
|
if self.speculative_config:
|
||||||
self.use_spec_decode = True
|
self.use_spec_decode = True
|
||||||
# TODO: find a better way to check if we are using ngram.
|
assert self.speculative_config.method == "ngram", \
|
||||||
assert self.speculative_config.ngram_prompt_lookup_min, \
|
|
||||||
"Currently, only ngram spec decode is supported in V1."
|
"Currently, only ngram spec decode is supported in V1."
|
||||||
if get_pp_group().is_last_rank:
|
if get_pp_group().is_last_rank:
|
||||||
self.drafter = NgramProposer()
|
self.drafter = NgramProposer()
|
||||||
@ -160,7 +159,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# This usually takes less than 1 second.
|
# This usually takes less than 1 second.
|
||||||
self.drafter.propose(
|
self.drafter.propose(
|
||||||
np.zeros(1024, dtype=np.int32),
|
np.zeros(1024, dtype=np.int32),
|
||||||
self.speculative_config.ngram_prompt_lookup_min,
|
self.speculative_config.prompt_lookup_min,
|
||||||
self.speculative_config.num_speculative_tokens,
|
self.speculative_config.num_speculative_tokens,
|
||||||
)
|
)
|
||||||
self.rejection_sampler = RejectionSampler()
|
self.rejection_sampler = RejectionSampler()
|
||||||
@ -1155,7 +1154,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids
|
self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids
|
||||||
drafter_output = self.drafter.propose(
|
drafter_output = self.drafter.propose(
|
||||||
self.input_batch.token_ids_cpu[i, :end_idx],
|
self.input_batch.token_ids_cpu[i, :end_idx],
|
||||||
self.speculative_config.ngram_prompt_lookup_min,
|
self.speculative_config.prompt_lookup_min,
|
||||||
self.speculative_config.num_speculative_tokens,
|
self.speculative_config.num_speculative_tokens,
|
||||||
)
|
)
|
||||||
if drafter_output is None or len(drafter_output) == 0:
|
if drafter_output is None or len(drafter_output) == 0:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user