[V1][Usage] Refactor speculative decoding configuration and tests (#14434)
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
This commit is contained in:
parent
0661cfef7a
commit
50c9636d87
@ -30,8 +30,10 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
llm = LLM(
|
||||
model="facebook/opt-6.7b",
|
||||
tensor_parallel_size=1,
|
||||
speculative_model="facebook/opt-125m",
|
||||
num_speculative_tokens=5,
|
||||
speculative_config={
|
||||
"model": "facebook/opt-125m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
)
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
@ -45,10 +47,14 @@ To perform the same with an online mode launch the server:
|
||||
|
||||
```bash
|
||||
python -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port 8000 --model facebook/opt-6.7b \
|
||||
--seed 42 -tp 1 --speculative_model facebook/opt-125m \
|
||||
--num_speculative_tokens 5 --gpu_memory_utilization 0.8
|
||||
--seed 42 -tp 1 --gpu_memory_utilization 0.8 \
|
||||
--speculative_config '{"model": "facebook/opt-125m", "num_speculative_tokens": 5}'
|
||||
```
|
||||
|
||||
:::{warning}
|
||||
Note: Please use `--speculative_config` to set all configurations related to speculative decoding. The previous method of specifying the model through `--speculative_model` and adding related parameters (e.g., `--num_speculative_tokens`) separately will be deprecated in the next release.
|
||||
:::
|
||||
|
||||
Then use a client:
|
||||
|
||||
```python
|
||||
@ -101,9 +107,11 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
llm = LLM(
|
||||
model="facebook/opt-6.7b",
|
||||
tensor_parallel_size=1,
|
||||
speculative_model="[ngram]",
|
||||
num_speculative_tokens=5,
|
||||
ngram_prompt_lookup_max=4,
|
||||
speculative_config={
|
||||
"method": "ngram",
|
||||
"num_speculative_tokens": 5,
|
||||
"prompt_lookup_max": 4,
|
||||
},
|
||||
)
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
@ -131,8 +139,10 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
llm = LLM(
|
||||
model="meta-llama/Meta-Llama-3.1-70B-Instruct",
|
||||
tensor_parallel_size=4,
|
||||
speculative_model="ibm-ai-platform/llama3-70b-accelerator",
|
||||
speculative_draft_tensor_parallel_size=1,
|
||||
speculative_config={
|
||||
"model": "ibm-ai-platform/llama3-70b-accelerator",
|
||||
"draft_tensor_parallel_size": 1,
|
||||
},
|
||||
)
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
@ -175,8 +185,10 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
llm = LLM(
|
||||
model="meta-llama/Meta-Llama-3-8B-Instruct",
|
||||
tensor_parallel_size=4,
|
||||
speculative_model="yuhuili/EAGLE-LLaMA3-Instruct-8B",
|
||||
speculative_draft_tensor_parallel_size=1,
|
||||
speculative_config={
|
||||
"model": "yuhuili/EAGLE-LLaMA3-Instruct-8B",
|
||||
"draft_tensor_parallel_size": 1,
|
||||
},
|
||||
)
|
||||
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
@ -194,11 +206,10 @@ A few important things to consider when using the EAGLE based draft models:
|
||||
be able to be loaded and used directly by vLLM after [PR 12304](https://github.com/vllm-project/vllm/pull/12304).
|
||||
If you are using vllm version before [PR 12304](https://github.com/vllm-project/vllm/pull/12304), please use the
|
||||
[script](https://gist.github.com/abhigoyal1997/1e7a4109ccb7704fbc67f625e86b2d6d) to convert the speculative model,
|
||||
and specify `speculative_model="path/to/modified/eagle/model"`. If weight-loading problems still occur when using
|
||||
the latest version of vLLM, please leave a comment or raise an issue.
|
||||
and specify `"model": "path/to/modified/eagle/model"` in `speculative_config`. If weight-loading problems still occur when using the latest version of vLLM, please leave a comment or raise an issue.
|
||||
|
||||
2. The EAGLE based draft models need to be run without tensor parallelism
|
||||
(i.e. speculative_draft_tensor_parallel_size is set to 1), although
|
||||
(i.e. draft_tensor_parallel_size is set to 1 in `speculative_config`), although
|
||||
it is possible to run the main model using tensor parallelism (see example above).
|
||||
|
||||
3. When using EAGLE-based speculators with vLLM, the observed speedup is lower than what is
|
||||
|
@ -50,7 +50,9 @@ if __name__ == "__main__":
|
||||
# Create an LLM with spec decoding
|
||||
llm = LLM(
|
||||
model="meta-llama/Llama-2-13b-chat-hf",
|
||||
speculative_model="ibm-ai-platform/llama-13b-accelerator",
|
||||
speculative_config={
|
||||
"model": "ibm-ai-platform/llama-13b-accelerator",
|
||||
},
|
||||
)
|
||||
|
||||
print("With speculation")
|
||||
|
@ -56,7 +56,7 @@ def test_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
def maybe_assert_ngram_worker(llm):
|
||||
# Verify the proposer worker is ngram if ngram is specified.
|
||||
if (llm.llm_engine.speculative_config is not None
|
||||
and llm.llm_engine.speculative_config.ngram_prompt_lookup_max > 0):
|
||||
and llm.llm_engine.speculative_config.method == "ngram"):
|
||||
from vllm.spec_decode.ngram_worker import NGramWorker
|
||||
assert isinstance(
|
||||
llm.llm_engine.model_executor.driver_worker.proposer_worker,
|
||||
|
@ -7,28 +7,39 @@ from vllm import SamplingParams
|
||||
from .conftest import get_output_from_llm_generator
|
||||
|
||||
|
||||
@pytest.mark.parametrize("common_llm_kwargs", [{
|
||||
"model": "meta-llama/Llama-3.2-1B-Instruct",
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
}])
|
||||
@pytest.mark.parametrize("common_llm_kwargs",
|
||||
[{
|
||||
"model": "meta-llama/Llama-3.2-1B-Instruct",
|
||||
}])
|
||||
@pytest.mark.parametrize(
|
||||
"per_test_common_llm_kwargs",
|
||||
[
|
||||
{
|
||||
# Speculative max model len > overridden max model len should raise.
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"max_model_len": 129,
|
||||
},
|
||||
"max_model_len": 128,
|
||||
"speculative_max_model_len": 129,
|
||||
},
|
||||
{
|
||||
# Speculative max model len > draft max model len should raise.
|
||||
# https://huggingface.co/JackFram/llama-68m/blob/3b606af5198a0b26762d589a3ee3d26ee6fa6c85/config.json#L12
|
||||
"speculative_max_model_len": 2048 + 1,
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"max_model_len": 2048 + 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
# Speculative max model len > target max model len should raise.
|
||||
# https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct/blob/9213176726f574b556790deb65791e0c5aa438b6/config.json#L18
|
||||
"speculative_max_model_len": 131072 + 1,
|
||||
# https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct/blob/9213176726f574b556790deb65791e0c5aa438b6/config.json#L18
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"max_model_len": 131072 + 1,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{}])
|
||||
|
@ -57,8 +57,10 @@ PRECISION = "float32"
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
@ -95,18 +97,19 @@ def test_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": SPEC_MODEL,
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"disable_logprobs_during_spec_decoding": False,
|
||||
"disable_logprobs": False,
|
||||
},
|
||||
{
|
||||
"speculative_model": SPEC_MODEL,
|
||||
}, {
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"disable_logprobs_during_spec_decoding": True,
|
||||
"disable_logprobs": True,
|
||||
},
|
||||
])
|
||||
}])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
128,
|
||||
])
|
||||
@ -119,18 +122,19 @@ def test_eagle_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
||||
batch_size: int, output_len: int, seed: int,
|
||||
logprobs: int):
|
||||
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
output_len,
|
||||
seed,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs=logprobs,
|
||||
disable_logprobs=test_llm_kwargs[
|
||||
'disable_logprobs_during_spec_decoding'])
|
||||
run_equality_correctness_test(
|
||||
vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
output_len,
|
||||
seed,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs=logprobs,
|
||||
disable_logprobs=test_llm_kwargs["speculative_config"]
|
||||
["disable_logprobs"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -151,8 +155,10 @@ def test_eagle_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
@ -193,8 +199,10 @@ def test_eagle_e2e_greedy_correctness_cuda_graph(
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
@ -236,8 +244,10 @@ def test_eagle_e2e_greedy_correctness_with_preemption(
|
||||
"test_llm_kwargs",
|
||||
[
|
||||
{
|
||||
"speculative_model": SPEC_MODEL,
|
||||
"num_speculative_tokens": k,
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": k,
|
||||
},
|
||||
}
|
||||
# Try a range of num. speculative tokens
|
||||
for k in range(1, 1 + MAX_SPEC_TOKENS)
|
||||
@ -277,12 +287,13 @@ def test_eagle_different_k(vllm_runner, common_llm_kwargs,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs",
|
||||
[{
|
||||
"speculative_model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"speculative_disable_by_batch_size": 4
|
||||
}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"disable_by_batch_size": 4,
|
||||
},
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
@ -324,8 +335,10 @@ def test_eagle_disable_queue(vllm_runner, common_llm_kwargs,
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "yuhuili/EAGLE-llama2-chat-7B",
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"speculative_config": {
|
||||
"model": "yuhuili/EAGLE-llama2-chat-7B",
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
@ -372,8 +385,10 @@ def test_llama2_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "yuhuili/EAGLE-LLaMA3-Instruct-8B",
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"speculative_config": {
|
||||
"model": "yuhuili/EAGLE-LLaMA3-Instruct-8B",
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
@ -420,8 +435,10 @@ def test_llama3_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "yuhuili/EAGLE-Qwen2-7B-Instruct",
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"speculative_config": {
|
||||
"model": "yuhuili/EAGLE-Qwen2-7B-Instruct",
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -23,8 +23,10 @@ MAIN_MODEL = "JackFram/llama-68m"
|
||||
[
|
||||
{
|
||||
# Identical models.
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@ -57,26 +59,33 @@ def test_spec_decode_cuda_graph(vllm_runner, common_llm_kwargs,
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [])
|
||||
@pytest.mark.parametrize(
|
||||
"test_llm_kwargs",
|
||||
[
|
||||
# Explicitly specify draft model quantization
|
||||
{
|
||||
"speculative_model_quantization": "gptq",
|
||||
"speculative_config": {
|
||||
"model": "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit",
|
||||
"num_speculative_tokens": 5,
|
||||
"quantization": "gptq",
|
||||
},
|
||||
},
|
||||
# Explicitly specify GPTQ-based draft model to use marlin quantization
|
||||
{
|
||||
"speculative_model_quantization": "marlin",
|
||||
"speculative_config": {
|
||||
"model": "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit",
|
||||
"num_speculative_tokens": 5,
|
||||
"quantization": "marlin",
|
||||
},
|
||||
},
|
||||
# Not explicitly specify draft model quantization
|
||||
{
|
||||
"speculative_model_quantization": None,
|
||||
"speculative_config": {
|
||||
"model": "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit",
|
||||
"num_speculative_tokens": 5,
|
||||
"quantization": None,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@ -107,15 +116,16 @@ def test_speculative_model_quantization_config(vllm_runner, common_llm_kwargs,
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 3,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs",
|
||||
[{
|
||||
"speculative_disable_mqa_scorer": True,
|
||||
}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 3,
|
||||
"disable_mqa_scorer": True,
|
||||
},
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
@ -127,7 +137,7 @@ def test_speculative_model_quantization_config(vllm_runner, common_llm_kwargs,
|
||||
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
|
||||
output_len: int, seed: int):
|
||||
"""Verify that ngram speculative decoding generates the same output
|
||||
"""Verify that speculative decoding generates the same output
|
||||
with batch expansion scorer and mqa scorer.
|
||||
"""
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
|
@ -27,18 +27,19 @@ from .conftest import run_equality_correctness_test_tp
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [[]])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
[
|
||||
"--speculative-model",
|
||||
"JackFram/llama-68m",
|
||||
"--num-speculative-tokens",
|
||||
"3",
|
||||
"--speculative_config",
|
||||
str({
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 3,
|
||||
}),
|
||||
],
|
||||
[
|
||||
"--speculative-model",
|
||||
"[ngram]",
|
||||
"--num-speculative-tokens",
|
||||
"5",
|
||||
"--ngram-prompt-lookup-max",
|
||||
"3",
|
||||
"--speculative_config",
|
||||
str({
|
||||
"model": "ngram",
|
||||
"num_speculative_tokens": 5,
|
||||
"prompt_lookup_max": 3,
|
||||
}),
|
||||
],
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [2])
|
||||
@ -83,23 +84,24 @@ def test_target_model_tp_gt_1(common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
]])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [[]])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [[]])
|
||||
@pytest.mark.parametrize("model, test_llm_kwargs",
|
||||
[("JackFram/llama-68m", [
|
||||
"--speculative-model",
|
||||
"JackFram/llama-68m",
|
||||
"--num_speculative-tokens",
|
||||
"5",
|
||||
"--speculative-draft-tensor-parallel-size",
|
||||
"1",
|
||||
]),
|
||||
("ibm-granite/granite-3b-code-instruct", [
|
||||
"--speculative-model",
|
||||
"ibm-granite/granite-3b-code-instruct",
|
||||
"--num_speculative-tokens",
|
||||
"5",
|
||||
"--speculative-draft-tensor-parallel-size",
|
||||
"1",
|
||||
])])
|
||||
@pytest.mark.parametrize(
|
||||
"model, test_llm_kwargs",
|
||||
[("JackFram/llama-68m", [
|
||||
"--speculative_config",
|
||||
str({
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"draft_tensor_parallel_size": 1,
|
||||
}),
|
||||
]),
|
||||
("ibm-granite/granite-3b-code-instruct", [
|
||||
"--speculative_config",
|
||||
str({
|
||||
"model": "ibm-granite/granite-3b-code-instruct",
|
||||
"num_speculative_tokens": 5,
|
||||
"draft_tensor_parallel_size": 1,
|
||||
}),
|
||||
])])
|
||||
@pytest.mark.parametrize("batch_size", [2])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_draft_model_tp_lt_target_model_tp2(model, common_llm_kwargs,
|
||||
@ -144,18 +146,19 @@ def test_draft_model_tp_lt_target_model_tp2(model, common_llm_kwargs,
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [[]])
|
||||
@pytest.mark.parametrize("model, test_llm_kwargs",
|
||||
[("JackFram/llama-68m", [
|
||||
"--speculative-model",
|
||||
"JackFram/llama-68m",
|
||||
"--num_speculative-tokens",
|
||||
"3",
|
||||
"--speculative_config",
|
||||
str({
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 3,
|
||||
}),
|
||||
]),
|
||||
("JackFram/llama-68m", [
|
||||
"--speculative-model",
|
||||
"JackFram/llama-68m",
|
||||
"--num_speculative-tokens",
|
||||
"3",
|
||||
"--speculative-draft-tensor-parallel-size",
|
||||
"1",
|
||||
"--speculative_config",
|
||||
str({
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 3,
|
||||
"draft_tensor_parallel_size": 1,
|
||||
}),
|
||||
])])
|
||||
@pytest.mark.parametrize("logprobs", [None, 2])
|
||||
@pytest.mark.parametrize("batch_size", [2])
|
||||
|
@ -24,12 +24,7 @@ SPEC_MODEL = "JackFram/llama-68m"
|
||||
"4",
|
||||
]])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
||||
[
|
||||
"--speculative-model",
|
||||
f"{SPEC_MODEL}",
|
||||
"--num-speculative-tokens",
|
||||
"5",
|
||||
],
|
||||
[],
|
||||
])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [[]])
|
||||
@pytest.mark.parametrize(
|
||||
@ -37,8 +32,12 @@ SPEC_MODEL = "JackFram/llama-68m"
|
||||
[
|
||||
#TODO(wooyeon): add spec_draft_dp=2 case
|
||||
[
|
||||
"--speculative-draft-tensor-parallel-size",
|
||||
"1",
|
||||
"--speculative_config",
|
||||
str({
|
||||
"model": f"{SPEC_MODEL}",
|
||||
"num_speculative_tokens": 5,
|
||||
"draft_tensor_parallel_size": 1,
|
||||
}),
|
||||
],
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [2])
|
||||
@ -78,15 +77,14 @@ def test_draft_model_tp_lt_target_model_tp4(common_llm_kwargs,
|
||||
"test_llm_kwargs",
|
||||
[
|
||||
[
|
||||
"--speculative-model",
|
||||
f"{SPEC_MODEL}",
|
||||
"--num-speculative-tokens",
|
||||
"5",
|
||||
|
||||
# Artificially limit the draft model max model len; this forces vLLM
|
||||
# to skip speculation once the sequences grow beyond 32-k tokens.
|
||||
"--speculative-max-model-len",
|
||||
"32",
|
||||
"--speculative_config",
|
||||
str({
|
||||
"model": f"{SPEC_MODEL}",
|
||||
"num_speculative_tokens": 5,
|
||||
"max_model_len": 32,
|
||||
}),
|
||||
],
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
|
@ -20,16 +20,19 @@ from .conftest import run_equality_correctness_test
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs",
|
||||
[{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 3,
|
||||
"disable_logprobs_during_spec_decoding": False,
|
||||
}, {
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 3,
|
||||
"disable_logprobs_during_spec_decoding": True,
|
||||
}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 3,
|
||||
"disable_logprobs": False,
|
||||
},
|
||||
}, {
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 3,
|
||||
"disable_logprobs": True,
|
||||
},
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
@ -48,19 +51,20 @@ def test_logprobs_equality(vllm_runner, common_llm_kwargs,
|
||||
as well as with and without chunked prefill.
|
||||
"""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, common_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
output_len,
|
||||
seed,
|
||||
temperature=0.0,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs=logprobs,
|
||||
disable_logprobs=test_llm_kwargs[
|
||||
'disable_logprobs_during_spec_decoding'])
|
||||
run_equality_correctness_test(
|
||||
vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
output_len,
|
||||
seed,
|
||||
temperature=0.0,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs=logprobs,
|
||||
disable_logprobs=test_llm_kwargs["speculative_config"]
|
||||
["disable_logprobs"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -73,16 +77,19 @@ def test_logprobs_equality(vllm_runner, common_llm_kwargs,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs",
|
||||
[{
|
||||
"speculative_model": "JackFram/llama-160m",
|
||||
"num_speculative_tokens": 3,
|
||||
"disable_logprobs_during_spec_decoding": False,
|
||||
}, {
|
||||
"speculative_model": "JackFram/llama-160m",
|
||||
"num_speculative_tokens": 6,
|
||||
"disable_logprobs_during_spec_decoding": False,
|
||||
}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-160m",
|
||||
"num_speculative_tokens": 3,
|
||||
"disable_logprobs": False,
|
||||
},
|
||||
}, {
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-160m",
|
||||
"num_speculative_tokens": 6,
|
||||
"disable_logprobs": False,
|
||||
},
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
@ -98,18 +105,19 @@ def test_logprobs_different_k(vllm_runner, common_llm_kwargs,
|
||||
output_len: int, seed: int, logprobs: int):
|
||||
"""Veriy logprob greedy equality with different speculation lens.
|
||||
"""
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
output_len,
|
||||
seed,
|
||||
temperature=0.0,
|
||||
logprobs=logprobs,
|
||||
disable_logprobs=test_llm_kwargs[
|
||||
'disable_logprobs_during_spec_decoding'])
|
||||
run_equality_correctness_test(
|
||||
vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
output_len,
|
||||
seed,
|
||||
temperature=0.0,
|
||||
logprobs=logprobs,
|
||||
disable_logprobs=test_llm_kwargs["speculative_config"]
|
||||
["disable_logprobs"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -125,13 +133,15 @@ def test_logprobs_different_k(vllm_runner, common_llm_kwargs,
|
||||
@pytest.mark.parametrize(
|
||||
"test_llm_kwargs",
|
||||
[{
|
||||
"speculative_model": "JackFram/llama-160m",
|
||||
"num_speculative_tokens": 3,
|
||||
"disable_logprobs_during_spec_decoding": False,
|
||||
|
||||
# Artificially limit the draft model max model len; this forces vLLM
|
||||
# to skip speculation once the sequences grow beyond 32-k tokens.
|
||||
"speculative_max_model_len": 32,
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-160m",
|
||||
"num_speculative_tokens": 3,
|
||||
"disable_logprobs": False,
|
||||
# Artificially limit the draft model max model len; this forces
|
||||
# vLLM to skip speculation once the sequences grow beyond 32-k
|
||||
# tokens.
|
||||
"max_model_len": 32,
|
||||
},
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize(
|
||||
@ -149,18 +159,19 @@ def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs,
|
||||
seed: int, logprobs: int):
|
||||
"""Verify logprobs greedy equality when some sequences skip speculation.
|
||||
"""
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
output_len,
|
||||
seed,
|
||||
temperature=0.0,
|
||||
logprobs=logprobs,
|
||||
disable_logprobs=test_llm_kwargs[
|
||||
'disable_logprobs_during_spec_decoding'])
|
||||
run_equality_correctness_test(
|
||||
vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
output_len,
|
||||
seed,
|
||||
temperature=0.0,
|
||||
logprobs=logprobs,
|
||||
disable_logprobs=test_llm_kwargs["speculative_config"]
|
||||
["disable_logprobs"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -173,12 +184,13 @@ def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs",
|
||||
[{
|
||||
"speculative_model": "JackFram/llama-160m",
|
||||
"num_speculative_tokens": 3,
|
||||
"disable_logprobs_during_spec_decoding": False,
|
||||
}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-160m",
|
||||
"num_speculative_tokens": 3,
|
||||
"disable_logprobs": False,
|
||||
},
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
@ -248,12 +260,13 @@ def test_logprobs_temp_1(vllm_runner, common_llm_kwargs,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs",
|
||||
[{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 3,
|
||||
"disable_logprobs_during_spec_decoding": True,
|
||||
}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 3,
|
||||
"disable_logprobs": True,
|
||||
},
|
||||
}])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("batch_size", [4])
|
||||
@pytest.mark.parametrize(
|
||||
@ -270,15 +283,16 @@ def test_logprobs_disabled(vllm_runner, common_llm_kwargs,
|
||||
"""Check the behavior when logprobs are disabled.
|
||||
Token choices should match with the base model.
|
||||
"""
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
output_len,
|
||||
seed,
|
||||
temperature=0.0,
|
||||
logprobs=logprobs,
|
||||
disable_logprobs=test_llm_kwargs[
|
||||
'disable_logprobs_during_spec_decoding'])
|
||||
run_equality_correctness_test(
|
||||
vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
output_len,
|
||||
seed,
|
||||
temperature=0.0,
|
||||
logprobs=logprobs,
|
||||
disable_logprobs=test_llm_kwargs["speculative_config"]
|
||||
["disable_logprobs"])
|
||||
|
@ -60,8 +60,10 @@ PRECISION = "float32"
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
@ -107,14 +109,18 @@ def test_medusa_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"disable_logprobs_during_spec_decoding": False,
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"disable_logprobs": False,
|
||||
},
|
||||
},
|
||||
{
|
||||
"speculative_model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"disable_logprobs_during_spec_decoding": True,
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"disable_logprobs": True,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
@ -132,19 +138,20 @@ def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
||||
prefill_chunk_size: int):
|
||||
"""Verify greedy equality with different batch size."""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs=logprobs,
|
||||
disable_logprobs=test_llm_kwargs[
|
||||
'disable_logprobs_during_spec_decoding'])
|
||||
run_equality_correctness_test(
|
||||
vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs=logprobs,
|
||||
disable_logprobs=test_llm_kwargs["speculative_config"]
|
||||
["disable_logprobs"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -165,8 +172,10 @@ def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
@ -214,8 +223,10 @@ def test_medusa_e2e_greedy_correctness_cuda_graph(
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
@ -264,8 +275,10 @@ def test_medusa_e2e_greedy_correctness_with_preemption(
|
||||
"test_llm_kwargs",
|
||||
[
|
||||
{
|
||||
"speculative_model": SPEC_MODEL,
|
||||
"num_speculative_tokens": k,
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": k,
|
||||
},
|
||||
}
|
||||
# Try a range of num. speculative tokens
|
||||
for k in range(1, 1 + MAX_SPEC_TOKENS)
|
||||
@ -312,12 +325,13 @@ def test_medusa_different_k(vllm_runner, common_llm_kwargs,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs",
|
||||
[{
|
||||
"speculative_model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"speculative_disable_by_batch_size": 4
|
||||
}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"disable_by_batch_size": 4,
|
||||
},
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
@ -359,16 +373,17 @@ def test_medusa_disable_queue(vllm_runner, common_llm_kwargs,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
"speculative_model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"speculative_disable_by_batch_size": 4
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs",
|
||||
[{
|
||||
"speculative_disable_mqa_scorer": True,
|
||||
}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"disable_by_batch_size": 4,
|
||||
"disable_mqa_scorer": True,
|
||||
},
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
|
@ -62,7 +62,9 @@ PRECISION = "float32"
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": SPEC_MODEL,
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
@ -108,12 +110,16 @@ def test_mlp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": SPEC_MODEL,
|
||||
"disable_logprobs_during_spec_decoding": False,
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"disable_logprobs": False,
|
||||
},
|
||||
},
|
||||
{
|
||||
"speculative_model": SPEC_MODEL,
|
||||
"disable_logprobs_during_spec_decoding": True,
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"disable_logprobs": True,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [8])
|
||||
@ -133,19 +139,20 @@ def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
||||
# up sampling different tokens at the tail (ie top tokens don't change).
|
||||
# TL;DR: sd+cp == org+cp but sd+cp != org..is this expected?
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, baseline_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs=logprobs,
|
||||
disable_logprobs=test_llm_kwargs[
|
||||
'disable_logprobs_during_spec_decoding'])
|
||||
run_equality_correctness_test(
|
||||
vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs=logprobs,
|
||||
disable_logprobs=test_llm_kwargs["speculative_config"]
|
||||
["disable_logprobs"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -167,7 +174,9 @@ def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": SPEC_MODEL,
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [2048])
|
||||
@ -209,8 +218,10 @@ def test_mlp_e2e_acceptance_rate(vllm_runner, common_llm_kwargs,
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
|
||||
# Speculative model
|
||||
"speculative_model": SPEC_MODEL,
|
||||
# Speculative config
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
},
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{"seed": 1}])
|
||||
@ -274,7 +285,9 @@ def test_mlp_e2e_seeded_correctness(vllm_runner, common_llm_kwargs,
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": SPEC_MODEL,
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
@ -326,7 +339,9 @@ def test_mlp_e2e_greedy_correctness_with_preemption(
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": SPEC_MODEL,
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
@ -382,8 +397,10 @@ def test_mlp_e2e_greedy_correctness_with_padding(
|
||||
"test_llm_kwargs",
|
||||
[
|
||||
{
|
||||
"speculative_model": SPEC_MODEL,
|
||||
"num_speculative_tokens": k,
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": k,
|
||||
},
|
||||
}
|
||||
# Try a range of num. speculative tokens
|
||||
for k in range(1, 1 + MAX_SPEC_TOKENS)
|
||||
@ -430,11 +447,12 @@ def test_mlp_different_k(vllm_runner, common_llm_kwargs,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs",
|
||||
[{
|
||||
"speculative_model": SPEC_MODEL,
|
||||
"speculative_disable_by_batch_size": 4
|
||||
}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"disable_by_batch_size": 4,
|
||||
},
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
@ -475,14 +493,15 @@ def test_mlp_disable_queue(vllm_runner, common_llm_kwargs,
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
"speculative_model": SPEC_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs",
|
||||
[{
|
||||
"speculative_disable_mqa_scorer": True,
|
||||
}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"disable_mqa_scorer": True,
|
||||
},
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
|
@ -57,7 +57,9 @@ PRECISION = "bfloat16"
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"speculative_config": {
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
@ -99,12 +101,16 @@ def test_mtp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"disable_logprobs_during_spec_decoding": False,
|
||||
"speculative_config": {
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"disable_logprobs": False,
|
||||
},
|
||||
},
|
||||
{
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"disable_logprobs_during_spec_decoding": True,
|
||||
"speculative_config": {
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"disable_logprobs": True,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
@ -119,18 +125,19 @@ def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
||||
batch_size: int, output_len: int, seed: int,
|
||||
logprobs: int):
|
||||
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
output_len,
|
||||
seed,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs=logprobs,
|
||||
disable_logprobs=test_llm_kwargs[
|
||||
'disable_logprobs_during_spec_decoding'])
|
||||
run_equality_correctness_test(
|
||||
vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
output_len,
|
||||
seed,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs=logprobs,
|
||||
disable_logprobs=test_llm_kwargs["speculative_config"]
|
||||
["disable_logprobs"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -152,7 +159,9 @@ def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"speculative_config": {
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
@ -198,7 +207,9 @@ def test_mtp_e2e_greedy_correctness_cuda_graph(vllm_runner, common_llm_kwargs,
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"speculative_config": {
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
@ -243,7 +254,9 @@ def test_mtp_e2e_greedy_correctness_with_preemption(
|
||||
"test_llm_kwargs",
|
||||
[
|
||||
{
|
||||
"num_speculative_tokens": k,
|
||||
"speculative_config": {
|
||||
"num_speculative_tokens": k,
|
||||
},
|
||||
}
|
||||
# Try a range of num. speculative tokens
|
||||
for k in range(1, 1 + MAX_SPEC_TOKENS)
|
||||
@ -286,11 +299,12 @@ def test_mtp_different_k(vllm_runner, common_llm_kwargs,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs",
|
||||
[{
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"speculative_disable_by_batch_size": 4
|
||||
}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_config": {
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"disable_by_batch_size": 4
|
||||
},
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
|
@ -61,15 +61,19 @@ from .conftest import (get_output_from_llm_generator,
|
||||
"per_test_common_llm_kwargs",
|
||||
[
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
"enable_chunked_prefill": False,
|
||||
},
|
||||
{
|
||||
# Chunked prefill enabled with small value
|
||||
# to make sure we get mixed batches.
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4
|
||||
@ -148,20 +152,23 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs",
|
||||
[{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"enable_chunked_prefill": False,
|
||||
"disable_logprobs_during_spec_decoding": False
|
||||
}, {
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 3,
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4,
|
||||
"disable_logprobs_during_spec_decoding": False
|
||||
}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"disable_logprobs": False,
|
||||
},
|
||||
"enable_chunked_prefill": False,
|
||||
}, {
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 3,
|
||||
"disable_logprobs": False,
|
||||
},
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4,
|
||||
}])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
@ -184,7 +191,7 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
|
||||
whether all speculative tokens are accepted.
|
||||
"""
|
||||
ensure_all_accepted = per_test_common_llm_kwargs.get(
|
||||
"model_name") == test_llm_kwargs.get("speculative_model")
|
||||
"model_name") == test_llm_kwargs.get("speculative_config")["model"]
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
@ -224,13 +231,17 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
"enable_chunked_prefill": False,
|
||||
},
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4
|
||||
@ -283,13 +294,17 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs(
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
"enable_chunked_prefill": False,
|
||||
},
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4
|
||||
@ -336,13 +351,17 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len(
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
"enable_chunked_prefill": False,
|
||||
},
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4
|
||||
@ -391,13 +410,17 @@ def test_spec_decode_e2e_greedy_correctness_real_model_bs1(
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
"enable_chunked_prefill": False,
|
||||
},
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4
|
||||
@ -449,13 +472,17 @@ def test_spec_decode_e2e_greedy_correctness_real_model_large_bs(
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
"enable_chunked_prefill": False,
|
||||
},
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4
|
||||
@ -514,13 +541,17 @@ def test_spec_decode_e2e_greedy_correctness_with_preemption(
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
"enable_chunked_prefill": False,
|
||||
},
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4
|
||||
@ -567,21 +598,25 @@ def test_spec_decode_different_block_size(vllm_runner, common_llm_kwargs,
|
||||
"test_llm_kwargs",
|
||||
[
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
|
||||
# Artificially limit the draft model max model len; this forces vLLM
|
||||
# to skip speculation once the sequences grow beyond 32-k tokens.
|
||||
"speculative_max_model_len": 32,
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"max_model_len": 32,
|
||||
},
|
||||
"enable_chunked_prefill": False,
|
||||
},
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"max_model_len": 32,
|
||||
},
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4,
|
||||
"speculative_max_model_len": 32,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@ -627,15 +662,19 @@ def test_skip_speculation(vllm_runner, common_llm_kwargs,
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"speculative_disable_by_batch_size": 2,
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"disable_by_batch_size": 2,
|
||||
},
|
||||
"enable_chunked_prefill": False,
|
||||
},
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"speculative_disable_by_batch_size": 2,
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"disable_by_batch_size": 2,
|
||||
},
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4,
|
||||
@ -676,15 +715,19 @@ def test_disable_speculation(vllm_runner, common_llm_kwargs,
|
||||
"test_llm_kwargs",
|
||||
[
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": k,
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": k,
|
||||
},
|
||||
"enable_chunked_prefill": False,
|
||||
}
|
||||
# Try a range of common k, as well as large speculation.
|
||||
for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63]
|
||||
] + [{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": k,
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": k,
|
||||
},
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4,
|
||||
@ -729,17 +772,21 @@ def test_many_k(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
"test_llm_kwargs",
|
||||
[
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": k,
|
||||
"spec_decoding_acceptance_method": "typical_acceptance_sampler",
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": k,
|
||||
"acceptance_method": "typical_acceptance_sampler",
|
||||
},
|
||||
"enable_chunked_prefill": False
|
||||
}
|
||||
# Try a range of common k.
|
||||
for k in [1, 2, 3]
|
||||
] + [{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": k,
|
||||
"spec_decoding_acceptance_method": "typical_acceptance_sampler",
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": k,
|
||||
"acceptance_method": "typical_acceptance_sampler",
|
||||
},
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4
|
||||
|
@ -48,16 +48,20 @@ from .conftest import run_equality_correctness_test
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "[ngram]",
|
||||
"num_speculative_tokens": 5,
|
||||
"ngram_prompt_lookup_max": 3,
|
||||
"speculative_disable_mqa_scorer": False,
|
||||
"speculative_config": {
|
||||
"method": "ngram",
|
||||
"num_speculative_tokens": 5,
|
||||
"prompt_lookup_max": 3,
|
||||
"disable_mqa_scorer": False,
|
||||
},
|
||||
},
|
||||
{
|
||||
"speculative_model": "[ngram]",
|
||||
"num_speculative_tokens": 5,
|
||||
"ngram_prompt_lookup_max": 3,
|
||||
"speculative_disable_mqa_scorer": True,
|
||||
"speculative_config": {
|
||||
"method": "ngram",
|
||||
"num_speculative_tokens": 5,
|
||||
"prompt_lookup_max": 3,
|
||||
"disable_mqa_scorer": True,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
@ -101,16 +105,20 @@ def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "[ngram]",
|
||||
"num_speculative_tokens": 5,
|
||||
"ngram_prompt_lookup_max": 3,
|
||||
"disable_logprobs_during_spec_decoding": False,
|
||||
"speculative_config": {
|
||||
"method": "ngram",
|
||||
"num_speculative_tokens": 5,
|
||||
"prompt_lookup_max": 3,
|
||||
"disable_logprobs": False,
|
||||
},
|
||||
},
|
||||
{
|
||||
"speculative_model": "[ngram]",
|
||||
"num_speculative_tokens": 5,
|
||||
"ngram_prompt_lookup_max": 3,
|
||||
"disable_logprobs_during_spec_decoding": True,
|
||||
"speculative_config": {
|
||||
"method": "ngram",
|
||||
"num_speculative_tokens": 5,
|
||||
"prompt_lookup_max": 3,
|
||||
"disable_logprobs": True,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
@ -125,19 +133,20 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
||||
batch_size: int, output_len: int, seed: int,
|
||||
logprobs: int):
|
||||
"""Verify greedy equality on a tiny model with different batch size."""
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs=logprobs,
|
||||
disable_logprobs=test_llm_kwargs[
|
||||
'disable_logprobs_during_spec_decoding'])
|
||||
run_equality_correctness_test(
|
||||
vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs=logprobs,
|
||||
disable_logprobs=test_llm_kwargs["speculative_config"]
|
||||
["disable_logprobs"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -159,17 +168,21 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "[ngram]",
|
||||
"num_speculative_tokens": 5,
|
||||
"ngram_prompt_lookup_max": 3,
|
||||
"speculative_config": {
|
||||
"method": "ngram",
|
||||
"num_speculative_tokens": 5,
|
||||
"prompt_lookup_max": 3,
|
||||
},
|
||||
"enable_chunked_prefill": False,
|
||||
},
|
||||
{
|
||||
"speculative_model": "[ngram]",
|
||||
"num_speculative_tokens": 5,
|
||||
"ngram_prompt_lookup_max": 3,
|
||||
"speculative_config": {
|
||||
"method": "ngram",
|
||||
"num_speculative_tokens": 5,
|
||||
"prompt_lookup_max": 3,
|
||||
"disable_mqa_scorer": True,
|
||||
},
|
||||
"enable_chunked_prefill": True,
|
||||
"speculative_disable_mqa_scorer": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4
|
||||
},
|
||||
@ -214,17 +227,21 @@ def test_ngram_e2e_greedy_correctness_with_preemption(
|
||||
"test_llm_kwargs",
|
||||
[
|
||||
{
|
||||
"speculative_model": "[ngram]",
|
||||
"num_speculative_tokens": k,
|
||||
"ngram_prompt_lookup_max": 3,
|
||||
"speculative_config": {
|
||||
"method": "ngram",
|
||||
"num_speculative_tokens": k,
|
||||
"prompt_lookup_max": 3,
|
||||
},
|
||||
}
|
||||
# Try a range of common k, as well as large speculation.
|
||||
for k in [1, 3, 5]
|
||||
] + [
|
||||
{
|
||||
"speculative_model": "[ngram]",
|
||||
"num_speculative_tokens": k,
|
||||
"ngram_prompt_lookup_max": 1,
|
||||
"speculative_config": {
|
||||
"method": "ngram",
|
||||
"num_speculative_tokens": k,
|
||||
"prompt_lookup_max": 1,
|
||||
},
|
||||
}
|
||||
# Try a range of common k, as well as large speculation.
|
||||
for k in [1, 3, 5]
|
||||
@ -243,7 +260,7 @@ def test_ngram_different_k(vllm_runner, common_llm_kwargs,
|
||||
seed: int):
|
||||
"""Verify that ngram speculative decoding produces exact equality
|
||||
to without spec decode with many different values of k and
|
||||
different ngram_prompt_lookup_max.
|
||||
different ngram prompt_lookup_max.
|
||||
"""
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
@ -266,22 +283,25 @@ def test_ngram_different_k(vllm_runner, common_llm_kwargs,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs",
|
||||
[{
|
||||
"speculative_model": "[ngram]",
|
||||
"num_speculative_tokens": 5,
|
||||
"ngram_prompt_lookup_max": 3,
|
||||
"speculative_disable_by_batch_size": 4
|
||||
}, {
|
||||
"speculative_model": "[ngram]",
|
||||
"num_speculative_tokens": 5,
|
||||
"ngram_prompt_lookup_max": 3,
|
||||
"speculative_disable_by_batch_size": 4,
|
||||
"enable_chunked_prefill": True,
|
||||
"speculative_disable_mqa_scorer": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4
|
||||
}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_config": {
|
||||
"method": "ngram",
|
||||
"num_speculative_tokens": 5,
|
||||
"prompt_lookup_max": 3,
|
||||
"disable_by_batch_size": 4
|
||||
},
|
||||
}, {
|
||||
"speculative_config": {
|
||||
"method": "ngram",
|
||||
"num_speculative_tokens": 5,
|
||||
"prompt_lookup_max": 3,
|
||||
"disable_by_batch_size": 4,
|
||||
"disable_mqa_scorer": True,
|
||||
},
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
@ -296,7 +316,7 @@ def test_ngram_disable_queue(vllm_runner, common_llm_kwargs,
|
||||
seed: int):
|
||||
"""Verify that ngram speculative decoding produces exact equality
|
||||
to without spec decode with many different values of k and
|
||||
different ngram_prompt_lookup_max.
|
||||
different ngram prompt_lookup_max.
|
||||
"""
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
@ -316,18 +336,17 @@ def test_ngram_disable_queue(vllm_runner, common_llm_kwargs,
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"speculative_model": "[ngram]",
|
||||
"num_speculative_tokens": 5,
|
||||
"ngram_prompt_lookup_max": 3,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs",
|
||||
[{
|
||||
"speculative_disable_mqa_scorer": True,
|
||||
}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_config": {
|
||||
"method": "ngram",
|
||||
"num_speculative_tokens": 5,
|
||||
"prompt_lookup_max": 3,
|
||||
"disable_mqa_scorer": True,
|
||||
},
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
|
@ -19,11 +19,11 @@ SPEC_MODEL = "JackFram/llama-160m"
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# speculative model
|
||||
"speculative_model": "JackFram/llama-160m",
|
||||
|
||||
# num speculative tokens
|
||||
"num_speculative_tokens": 3,
|
||||
# speculative config
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-160m",
|
||||
"num_speculative_tokens": 3,
|
||||
},
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{"seed": 1}])
|
||||
|
@ -70,12 +70,16 @@ def test_ngram_correctness(
|
||||
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
||||
del ref_llm
|
||||
|
||||
spec_llm = LLM(model=model_name,
|
||||
speculative_model='[ngram]',
|
||||
ngram_prompt_lookup_max=5,
|
||||
ngram_prompt_lookup_min=3,
|
||||
num_speculative_tokens=3,
|
||||
max_model_len=1024)
|
||||
spec_llm = LLM(
|
||||
model=model_name,
|
||||
speculative_config={
|
||||
"method": "ngram",
|
||||
"prompt_lookup_max": 5,
|
||||
"prompt_lookup_min": 3,
|
||||
"num_speculative_tokens": 3,
|
||||
},
|
||||
max_model_len=1024,
|
||||
)
|
||||
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
|
||||
matches = 0
|
||||
misses = 0
|
||||
|
596
vllm/config.py
596
vllm/config.py
@ -1810,12 +1810,139 @@ class DeviceConfig:
|
||||
self.device = torch.device(self.device_type)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpeculativeConfig:
|
||||
"""Configuration for speculative decoding.
|
||||
|
||||
The configuration is currently specialized to draft-model speculative
|
||||
decoding with top-1 proposals.
|
||||
"""
|
||||
Configuration for speculative decoding.
|
||||
Configurable parameters include:
|
||||
- General Speculative Decoding Control:
|
||||
- num_speculative_tokens (int): The number of speculative
|
||||
tokens, if provided. It will default to the number in the draft
|
||||
model config if present, otherwise, it is required.
|
||||
- model (Optional[str]): The name of the draft model, eagle head,
|
||||
or additional weights, if provided.
|
||||
- method (Optional[str]): The name of the speculative method to use.
|
||||
If users provide and set the `model` param, the speculative method
|
||||
type will be detected automatically if possible, if `model` param
|
||||
is not provided, the method name must be provided.
|
||||
- Possible values:
|
||||
- ngram
|
||||
Related additional configuration:
|
||||
- prompt_lookup_max (Optional[int]):
|
||||
Maximum size of ngram token window when using Ngram
|
||||
proposer, required when method is set to ngram.
|
||||
- prompt_lookup_min (Optional[int]):
|
||||
Minimum size of ngram token window when using Ngram
|
||||
proposer, if provided. Defaults to 1.
|
||||
- eagle
|
||||
- medusa
|
||||
- mlp_speculator
|
||||
- draft_model
|
||||
- acceptance_method (str): The method to use for accepting draft
|
||||
tokens. This can take two possible values: 'rejection_sampler' and
|
||||
'typical_acceptance_sampler' for RejectionSampler and
|
||||
TypicalAcceptanceSampler respectively. If not specified, it
|
||||
defaults to 'rejection_sampler'.
|
||||
- Possible values:
|
||||
- rejection_sampler
|
||||
- typical_acceptance_sampler
|
||||
Related additional configuration:
|
||||
- posterior_threshold (Optional[float]):
|
||||
A threshold value that sets a lower bound on the
|
||||
posterior probability of a token in the target model
|
||||
for it to be accepted. This threshold is used only
|
||||
when we use the TypicalAcceptanceSampler for token
|
||||
acceptance.
|
||||
- posterior_alpha (Optional[float]):
|
||||
Scaling factor for entropy-based threshold, applied
|
||||
when using TypicalAcceptanceSampler.
|
||||
- draft_tensor_parallel_size (Optional[int]): The degree of the tensor
|
||||
parallelism for the draft model. Can only be 1 or the same as the
|
||||
target model's tensor parallel size.
|
||||
- disable_logprobs (bool): If set to True, token log probabilities are
|
||||
not returned during speculative decoding. If set to False, token
|
||||
log probabilities are returned according to the log probability
|
||||
settings in SamplingParams. If not specified, it defaults to True.
|
||||
|
||||
- Draft Model Configuration:
|
||||
- quantization (Optional[str]): Quantization method that was used to
|
||||
quantize the draft model weights. If None, we assume the
|
||||
model weights are not quantized. Note that it only takes effect
|
||||
when using the draft model-based speculative method.
|
||||
- max_model_len (Optional[int]): The maximum model length of the
|
||||
draft model. Used when testing the ability to skip
|
||||
speculation for some sequences.
|
||||
- revision: The specific model version to use for the draft model. It
|
||||
can be a branch name, a tag name, or a commit id. If unspecified,
|
||||
will use the default version.
|
||||
- code_revision: The specific revision to use for the draft model code
|
||||
on Hugging Face Hub. It can be a branch name, a tag name, or a
|
||||
commit id. If unspecified, will use the default version.
|
||||
|
||||
- Advanced Control:
|
||||
- disable_mqa_scorer (bool): Disable the MQA scorer and fall back to
|
||||
batch expansion for scoring proposals. If not specified, it
|
||||
defaults to False.
|
||||
- disable_by_batch_size (Optional[int]): Disable speculative decoding
|
||||
for new incoming requests when the number of enqueued requests is
|
||||
larger than this value, if provided.
|
||||
|
||||
Although the parameters above are structured hierarchically, there is no
|
||||
need to nest them during configuration.
|
||||
|
||||
Non-configurable internal parameters include:
|
||||
- Model Configuration:
|
||||
- target_model_config (ModelConfig): The configuration of the target
|
||||
model.
|
||||
- draft_model_config (ModelConfig): The configuration of the draft
|
||||
model initialized internal.
|
||||
- Parallelism Configuration:
|
||||
- target_parallel_config (ParallelConfig): The parallel configuration
|
||||
for the target model.
|
||||
- draft_parallel_config (ParallelConfig): The parallel configuration
|
||||
for the draft model initialized internal.
|
||||
- Execution Control:
|
||||
- enable_chunked_prefill (bool): Whether vLLM is configured to use
|
||||
chunked prefill or not. Used for raising an error since it's not
|
||||
yet compatible with speculative decode.
|
||||
- disable_log_stats (bool): Whether to disable the periodic printing of
|
||||
stage times in speculative decoding.
|
||||
"""
|
||||
# speculative configs from cli args
|
||||
num_speculative_tokens: int = field(default=None,
|
||||
init=True) # type: ignore
|
||||
method: Optional[str] = None
|
||||
acceptance_method: str = "rejection_sampler"
|
||||
draft_tensor_parallel_size: Optional[int] = None
|
||||
disable_logprobs: bool = True
|
||||
|
||||
model: Optional[str] = None
|
||||
quantization: Optional[str] = None
|
||||
max_model_len: Optional[int] = None
|
||||
revision: Optional[str] = None
|
||||
code_revision: Optional[str] = None
|
||||
|
||||
disable_mqa_scorer: bool = False
|
||||
disable_by_batch_size: Optional[int] = None
|
||||
prompt_lookup_max: Optional[int] = None
|
||||
prompt_lookup_min: Optional[int] = None
|
||||
posterior_threshold: Optional[float] = None
|
||||
posterior_alpha: Optional[float] = None
|
||||
|
||||
# required configuration params passed from engine
|
||||
target_model_config: ModelConfig = field(default=None,
|
||||
init=True) # type: ignore
|
||||
target_parallel_config: ParallelConfig = field(default=None,
|
||||
init=True) # type: ignore
|
||||
enable_chunked_prefill: bool = field(default=None,
|
||||
init=True) # type: ignore
|
||||
disable_log_stats: bool = field(default=None, init=True) # type: ignore
|
||||
|
||||
# params generated in the post-init stage
|
||||
draft_model_config: ModelConfig = field(default=None,
|
||||
init=True) # type: ignore
|
||||
draft_parallel_config: ParallelConfig = field(default=None,
|
||||
init=True) # type: ignore
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
@ -1835,6 +1962,11 @@ class SpeculativeConfig:
|
||||
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
|
||||
return hash_str
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, dict_value: dict) -> "SpeculativeConfig":
|
||||
"""Parse the CLI value for the speculative config."""
|
||||
return cls(**dict_value)
|
||||
|
||||
@staticmethod
|
||||
def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
|
||||
if hf_config.model_type == "deepseek_v3":
|
||||
@ -1847,230 +1979,160 @@ class SpeculativeConfig:
|
||||
})
|
||||
return hf_config
|
||||
|
||||
@staticmethod
|
||||
def maybe_create_spec_config(
|
||||
target_model_config: ModelConfig,
|
||||
target_parallel_config: ParallelConfig,
|
||||
target_dtype: str,
|
||||
speculative_model: Optional[str],
|
||||
speculative_model_quantization: Optional[str],
|
||||
speculative_draft_tensor_parallel_size: Optional[int],
|
||||
num_speculative_tokens: Optional[int],
|
||||
speculative_disable_mqa_scorer: Optional[bool],
|
||||
speculative_max_model_len: Optional[int],
|
||||
enable_chunked_prefill: bool,
|
||||
disable_log_stats: bool,
|
||||
speculative_disable_by_batch_size: Optional[int],
|
||||
ngram_prompt_lookup_max: Optional[int],
|
||||
ngram_prompt_lookup_min: Optional[int],
|
||||
draft_token_acceptance_method: str,
|
||||
typical_acceptance_sampler_posterior_threshold: Optional[float],
|
||||
typical_acceptance_sampler_posterior_alpha: Optional[float],
|
||||
disable_logprobs: Optional[bool],
|
||||
) -> Optional["SpeculativeConfig"]:
|
||||
"""Create a SpeculativeConfig if possible, else return None.
|
||||
def __post_init__(self):
|
||||
|
||||
This function attempts to create a SpeculativeConfig object based on the
|
||||
provided parameters. If the necessary conditions are met, it returns an
|
||||
instance of SpeculativeConfig. Otherwise, it returns None.
|
||||
# Note: After next release, the method parameter will be used to
|
||||
# specify the speculative method, which helps to extend the
|
||||
# configuration of non-model-based proposers, and the model parameter
|
||||
# will be used when the draft model or head is needed.
|
||||
# If users do not specify the method, the speculative method will
|
||||
# be detected automatically if possible. If the speculative method can
|
||||
# not be detected, it will be considered as the draft-model-based
|
||||
# method by default.
|
||||
|
||||
Args:
|
||||
target_model_config (ModelConfig): The configuration of the target
|
||||
model.
|
||||
target_parallel_config (ParallelConfig): The parallel configuration
|
||||
for the target model.
|
||||
target_dtype (str): The data type used for the target model.
|
||||
speculative_model (Optional[str]): The name of the speculative
|
||||
model, if provided.
|
||||
speculative_model_quantization (Optional[str]): Quantization method
|
||||
that was used to quantize the speculative model weights. If
|
||||
None, we assume the model weights are not quantized.
|
||||
speculative_draft_tensor_parallel_size (Optional[int]): The degree
|
||||
of the tensor parallelism for the draft model.
|
||||
num_speculative_tokens (Optional[int]): The number of speculative
|
||||
tokens, if provided. Will default to the number in the draft
|
||||
model config if present, otherwise is required.
|
||||
speculative_disable_mqa_scorer (Optional[bool]): Disable the MQA
|
||||
scorer for the speculative model and fall back to batch
|
||||
expansion for scoring.
|
||||
speculative_max_model_len (Optional[int]): The maximum model len of
|
||||
the speculative model. Used when testing the ability to skip
|
||||
speculation for some sequences.
|
||||
enable_chunked_prefill (bool): Whether vLLM is configured to use
|
||||
chunked prefill or not. Used for raising an error since its not
|
||||
yet compatible with spec decode.
|
||||
speculative_disable_by_batch_size (Optional[int]): Disable
|
||||
speculative decoding for new incoming requests when the number
|
||||
of enqueue requests is larger than this value, if provided.
|
||||
ngram_prompt_lookup_max (Optional[int]): Max size of ngram token
|
||||
window, if provided.
|
||||
ngram_prompt_lookup_min (Optional[int]): Min size of ngram token
|
||||
window, if provided.
|
||||
draft_token_acceptance_method (str): The method to use for
|
||||
accepting draft tokens. This can take two possible
|
||||
values 'rejection_sampler' and 'typical_acceptance_sampler'
|
||||
for RejectionSampler and TypicalAcceptanceSampler
|
||||
respectively.
|
||||
typical_acceptance_sampler_posterior_threshold (Optional[float]):
|
||||
A threshold value that sets a lower bound on the posterior
|
||||
probability of a token in the target model for it to be
|
||||
accepted. This threshold is used only when we use the
|
||||
TypicalAcceptanceSampler for token acceptance.
|
||||
typical_acceptance_sampler_posterior_alpha (Optional[float]):
|
||||
A scaling factor for the entropy-based threshold in the
|
||||
TypicalAcceptanceSampler.
|
||||
disable_logprobs (Optional[bool]): If set to True, token log
|
||||
probabilities are not returned during speculative decoding.
|
||||
If set to False, token log probabilities are returned
|
||||
according to the log probability settings in SamplingParams.
|
||||
If not specified, it defaults to True.
|
||||
|
||||
Returns:
|
||||
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
|
||||
the necessary conditions are met, else None.
|
||||
"""
|
||||
if speculative_model is None:
|
||||
if num_speculative_tokens is not None:
|
||||
if target_model_config.hf_text_config.model_type \
|
||||
if self.model is None and self.num_speculative_tokens is not None:
|
||||
# TODO(Shangming): Refactor mtp configuration logic when supporting
|
||||
# mtp acceleration for more models besides deepseek_v3
|
||||
if self.target_model_config.hf_text_config.model_type \
|
||||
== "deepseek_v3":
|
||||
# use the draft model from the same model:
|
||||
speculative_model = target_model_config.model
|
||||
else:
|
||||
raise ValueError(
|
||||
"num_speculative_tokens was provided without "
|
||||
"speculative_model.")
|
||||
# use the draft model from the same model:
|
||||
self.model = self.target_model_config.model
|
||||
elif self.method in ("ngram", "[ngram]"):
|
||||
self.model = "ngram"
|
||||
else:
|
||||
return None
|
||||
raise ValueError("num_speculative_tokens was provided without "
|
||||
"speculative model.")
|
||||
|
||||
if (speculative_disable_by_batch_size is not None
|
||||
and speculative_disable_by_batch_size < 2):
|
||||
raise ValueError("Expect the batch size threshold of disabling "
|
||||
"speculative decoding is > 1, but got "
|
||||
f"{speculative_disable_by_batch_size=}")
|
||||
if (enable_chunked_prefill and speculative_model == "eagle"):
|
||||
raise ValueError("Chunked prefill and EAGLE are not compatible.")
|
||||
# TODO: The user should be able to specify revision/max model len
|
||||
# for the draft model. It is not currently supported.
|
||||
draft_revision = None
|
||||
draft_code_revision = None
|
||||
draft_quantization = speculative_model_quantization
|
||||
# Automatically configure the ngram method during configuration
|
||||
# refactoring to ensure a smooth transition.
|
||||
if self.method is None and (self.model is not None
|
||||
and self.model in ("ngram", "[ngram]")):
|
||||
self.method = "ngram"
|
||||
|
||||
if speculative_model == "[ngram]":
|
||||
if ngram_prompt_lookup_min is None:
|
||||
ngram_prompt_lookup_min = 1
|
||||
if ngram_prompt_lookup_max is None or ngram_prompt_lookup_max < 1:
|
||||
raise ValueError(f"{ngram_prompt_lookup_max=} must be > 0")
|
||||
if ngram_prompt_lookup_min < 1:
|
||||
raise ValueError(f"{ngram_prompt_lookup_min=} must be > 0")
|
||||
if ngram_prompt_lookup_min > ngram_prompt_lookup_max:
|
||||
raise ValueError(f"{ngram_prompt_lookup_min=} cannot be "
|
||||
f"larger than {ngram_prompt_lookup_max=}")
|
||||
if self.method in ("ngram", "[ngram]"):
|
||||
# Unified to "ngram" internally
|
||||
self.method = "ngram"
|
||||
if self.prompt_lookup_min is None:
|
||||
self.prompt_lookup_min = 1
|
||||
if self.prompt_lookup_max is None or self.prompt_lookup_max < 1:
|
||||
raise ValueError("prompt_lookup_max="
|
||||
f"{self.prompt_lookup_max} must be > 0")
|
||||
if self.prompt_lookup_min < 1:
|
||||
raise ValueError("prompt_lookup_min="
|
||||
f"{self.prompt_lookup_min} must be > 0")
|
||||
if self.prompt_lookup_min > self.prompt_lookup_max:
|
||||
raise ValueError(f"prompt_lookup_min={self.prompt_lookup_min} "
|
||||
"cannot be larger than prompt_lookup_max="
|
||||
f"{self.prompt_lookup_max}")
|
||||
|
||||
# TODO: current we still need extract vocab_size from target model
|
||||
# config, in future, we may try refactor it out, and set
|
||||
# draft related config as None here.
|
||||
draft_model_config = target_model_config
|
||||
draft_parallel_config = target_parallel_config
|
||||
self.draft_model_config = self.target_model_config
|
||||
self.draft_parallel_config = self.target_parallel_config
|
||||
else:
|
||||
ngram_prompt_lookup_max = 0
|
||||
ngram_prompt_lookup_min = 0
|
||||
draft_model_config = ModelConfig(
|
||||
model=speculative_model,
|
||||
task="draft",
|
||||
tokenizer=target_model_config.tokenizer,
|
||||
tokenizer_mode=target_model_config.tokenizer_mode,
|
||||
trust_remote_code=target_model_config.trust_remote_code,
|
||||
allowed_local_media_path=target_model_config.
|
||||
allowed_local_media_path,
|
||||
dtype=target_model_config.dtype,
|
||||
seed=target_model_config.seed,
|
||||
revision=draft_revision,
|
||||
code_revision=draft_code_revision,
|
||||
tokenizer_revision=target_model_config.tokenizer_revision,
|
||||
max_model_len=None,
|
||||
spec_target_max_model_len=target_model_config.max_model_len,
|
||||
quantization=draft_quantization,
|
||||
enforce_eager=target_model_config.enforce_eager,
|
||||
max_seq_len_to_capture=target_model_config.
|
||||
max_seq_len_to_capture,
|
||||
max_logprobs=target_model_config.max_logprobs,
|
||||
hf_overrides=SpeculativeConfig.hf_config_override,
|
||||
)
|
||||
self.prompt_lookup_max = 0
|
||||
self.prompt_lookup_min = 0
|
||||
|
||||
draft_hf_config = draft_model_config.hf_config
|
||||
if self.model is not None:
|
||||
self.draft_model_config = ModelConfig(
|
||||
model=self.model,
|
||||
task="draft",
|
||||
tokenizer=self.target_model_config.tokenizer,
|
||||
tokenizer_mode=self.target_model_config.tokenizer_mode,
|
||||
trust_remote_code=self.target_model_config.
|
||||
trust_remote_code,
|
||||
allowed_local_media_path=self.target_model_config.
|
||||
allowed_local_media_path,
|
||||
dtype=self.target_model_config.dtype,
|
||||
seed=self.target_model_config.seed,
|
||||
revision=self.revision,
|
||||
code_revision=self.code_revision,
|
||||
tokenizer_revision=self.target_model_config.
|
||||
tokenizer_revision,
|
||||
max_model_len=None,
|
||||
spec_target_max_model_len=self.target_model_config.
|
||||
max_model_len,
|
||||
quantization=self.quantization,
|
||||
enforce_eager=self.target_model_config.enforce_eager,
|
||||
max_seq_len_to_capture=self.target_model_config.
|
||||
max_seq_len_to_capture,
|
||||
max_logprobs=self.target_model_config.max_logprobs,
|
||||
hf_overrides=SpeculativeConfig.hf_config_override,
|
||||
)
|
||||
|
||||
# Detect EAGLE prefix to replace hf_config for EAGLE draft_model
|
||||
if "eagle-" in draft_model_config.model.lower():
|
||||
from vllm.transformers_utils.configs.eagle import EAGLEConfig
|
||||
if isinstance(draft_model_config.hf_config, EAGLEConfig):
|
||||
pass
|
||||
# Automatically detect the method
|
||||
if "eagle-" in self.draft_model_config.model.lower():
|
||||
self.method = "eagle"
|
||||
elif self.draft_model_config.hf_config.model_type == "medusa":
|
||||
self.method = "medusa"
|
||||
elif (self.draft_model_config.hf_config.model_type ==
|
||||
"mlp_speculator"):
|
||||
self.method = "mlp_speculator"
|
||||
else:
|
||||
eagle_config = EAGLEConfig(draft_model_config.hf_config)
|
||||
draft_model_config.hf_config = eagle_config
|
||||
self.method = "draft_model"
|
||||
|
||||
if (num_speculative_tokens is not None
|
||||
and hasattr(draft_hf_config, "num_lookahead_tokens")):
|
||||
draft_hf_config.num_lookahead_tokens = num_speculative_tokens
|
||||
n_predict = getattr(draft_hf_config, "n_predict", None)
|
||||
if n_predict is not None:
|
||||
if num_speculative_tokens is None:
|
||||
# Default to max value defined in draft model config.
|
||||
num_speculative_tokens = n_predict
|
||||
elif num_speculative_tokens > n_predict and \
|
||||
num_speculative_tokens % n_predict != 0:
|
||||
# Ensure divisibility for MTP module reuse.
|
||||
raise ValueError(
|
||||
f"{num_speculative_tokens=} must be divisible by "
|
||||
f"{n_predict=}")
|
||||
# Replace hf_config for EAGLE draft_model
|
||||
if self.method == "eagle":
|
||||
if self.enable_chunked_prefill:
|
||||
raise ValueError(
|
||||
"Chunked prefill and EAGLE are not compatible.")
|
||||
|
||||
speculative_draft_tensor_parallel_size = \
|
||||
SpeculativeConfig._verify_and_get_draft_model_tensor_parallel_size(
|
||||
target_parallel_config,
|
||||
speculative_draft_tensor_parallel_size,
|
||||
draft_hf_config
|
||||
)
|
||||
from vllm.transformers_utils.configs.eagle import (
|
||||
EAGLEConfig)
|
||||
if isinstance(self.draft_model_config.hf_config,
|
||||
EAGLEConfig):
|
||||
pass
|
||||
else:
|
||||
eagle_config = EAGLEConfig(
|
||||
self.draft_model_config.hf_config)
|
||||
self.draft_model_config.hf_config = eagle_config
|
||||
|
||||
draft_model_config.max_model_len = (
|
||||
SpeculativeConfig._maybe_override_draft_max_model_len(
|
||||
speculative_max_model_len,
|
||||
draft_model_config.max_model_len,
|
||||
target_model_config.max_model_len,
|
||||
))
|
||||
if (self.num_speculative_tokens is not None
|
||||
and hasattr(self.draft_model_config.hf_config,
|
||||
"num_lookahead_tokens")):
|
||||
self.draft_model_config.hf_config.num_lookahead_tokens = \
|
||||
self.num_speculative_tokens
|
||||
|
||||
draft_parallel_config = (
|
||||
SpeculativeConfig.create_draft_parallel_config(
|
||||
target_parallel_config,
|
||||
speculative_draft_tensor_parallel_size, draft_hf_config))
|
||||
n_predict = getattr(self.draft_model_config.hf_config,
|
||||
"n_predict", None)
|
||||
if n_predict is not None:
|
||||
if self.num_speculative_tokens is None:
|
||||
# Default to max value defined in draft model config.
|
||||
self.num_speculative_tokens = n_predict
|
||||
elif self.num_speculative_tokens > n_predict and \
|
||||
self.num_speculative_tokens % n_predict != 0:
|
||||
# Ensure divisibility for MTP module reuse.
|
||||
raise ValueError(
|
||||
f"num_speculative_tokens:{self.num_speculative_tokens}"
|
||||
f" must be divisible by {n_predict=}")
|
||||
|
||||
if num_speculative_tokens is None:
|
||||
raise ValueError(
|
||||
"num_speculative_tokens must be provided with "
|
||||
"speculative_model unless the draft model config contains an "
|
||||
"n_predict parameter.")
|
||||
self.draft_tensor_parallel_size = \
|
||||
SpeculativeConfig._verify_and_get_draft_tp(
|
||||
self.target_parallel_config,
|
||||
self.draft_tensor_parallel_size,
|
||||
self.draft_model_config.hf_config
|
||||
)
|
||||
|
||||
if typical_acceptance_sampler_posterior_threshold is None:
|
||||
typical_acceptance_sampler_posterior_threshold = 0.09
|
||||
if typical_acceptance_sampler_posterior_alpha is None:
|
||||
typical_acceptance_sampler_posterior_alpha = 0.3
|
||||
if disable_logprobs is None:
|
||||
disable_logprobs = True
|
||||
self.draft_model_config.max_model_len = (
|
||||
SpeculativeConfig._maybe_override_draft_max_model_len(
|
||||
self.max_model_len,
|
||||
self.draft_model_config.max_model_len,
|
||||
self.target_model_config.max_model_len,
|
||||
))
|
||||
|
||||
return SpeculativeConfig(
|
||||
draft_model_config,
|
||||
draft_parallel_config,
|
||||
num_speculative_tokens,
|
||||
speculative_disable_mqa_scorer,
|
||||
speculative_disable_by_batch_size,
|
||||
ngram_prompt_lookup_max,
|
||||
ngram_prompt_lookup_min,
|
||||
draft_token_acceptance_method=draft_token_acceptance_method,
|
||||
typical_acceptance_sampler_posterior_threshold=\
|
||||
typical_acceptance_sampler_posterior_threshold,
|
||||
typical_acceptance_sampler_posterior_alpha=\
|
||||
typical_acceptance_sampler_posterior_alpha,
|
||||
disable_logprobs=disable_logprobs,
|
||||
disable_log_stats=disable_log_stats,
|
||||
)
|
||||
self.draft_parallel_config = (
|
||||
SpeculativeConfig.create_draft_parallel_config(
|
||||
self.target_parallel_config,
|
||||
self.draft_tensor_parallel_size))
|
||||
|
||||
if self.acceptance_method == "typical_acceptance_sampler":
|
||||
if self.posterior_threshold is None:
|
||||
self.posterior_threshold = 0.09
|
||||
if self.posterior_alpha is None:
|
||||
self.posterior_alpha = 0.3
|
||||
|
||||
self._verify_args()
|
||||
|
||||
@staticmethod
|
||||
def _maybe_override_draft_max_model_len(
|
||||
@ -2108,7 +2170,7 @@ class SpeculativeConfig:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _verify_and_get_draft_model_tensor_parallel_size(
|
||||
def _verify_and_get_draft_tp(
|
||||
target_parallel_config: ParallelConfig,
|
||||
speculative_draft_tensor_parallel_size: Optional[int],
|
||||
draft_hf_config: PretrainedConfig) -> int:
|
||||
@ -2140,7 +2202,6 @@ class SpeculativeConfig:
|
||||
def create_draft_parallel_config(
|
||||
target_parallel_config: ParallelConfig,
|
||||
speculative_draft_tensor_parallel_size: int,
|
||||
draft_hf_config: PretrainedConfig,
|
||||
) -> ParallelConfig:
|
||||
"""Create a parallel config for use by the draft worker.
|
||||
|
||||
@ -2164,74 +2225,13 @@ class SpeculativeConfig:
|
||||
|
||||
return draft_parallel_config
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
draft_model_config: ModelConfig,
|
||||
draft_parallel_config: ParallelConfig,
|
||||
num_speculative_tokens: int,
|
||||
speculative_disable_mqa_scorer: Optional[bool],
|
||||
speculative_disable_by_batch_size: Optional[int],
|
||||
ngram_prompt_lookup_max: Optional[int],
|
||||
ngram_prompt_lookup_min: Optional[int],
|
||||
draft_token_acceptance_method: str,
|
||||
typical_acceptance_sampler_posterior_threshold: float,
|
||||
typical_acceptance_sampler_posterior_alpha: float,
|
||||
disable_logprobs: bool,
|
||||
disable_log_stats: bool,
|
||||
):
|
||||
"""Create a SpeculativeConfig object.
|
||||
|
||||
Args:
|
||||
draft_model_config: ModelConfig for the draft model.
|
||||
draft_parallel_config: ParallelConfig for the draft model.
|
||||
num_speculative_tokens: The number of tokens to sample from the
|
||||
draft model before scoring with the target model.
|
||||
speculative_disable_by_batch_size: Disable speculative
|
||||
decoding for new incoming requests when the number of
|
||||
enqueue requests is larger than this value.
|
||||
ngram_prompt_lookup_max: Max size of ngram token window.
|
||||
ngram_prompt_lookup_min: Min size of ngram token window.
|
||||
draft_token_acceptance_method (str): The method to use for
|
||||
accepting draft tokens. This can take two possible
|
||||
values 'rejection_sampler' and 'typical_acceptance_sampler'
|
||||
for RejectionSampler and TypicalAcceptanceSampler
|
||||
respectively.
|
||||
typical_acceptance_sampler_posterior_threshold (Optional[float]):
|
||||
A threshold value that sets a lower bound on the posterior
|
||||
probability of a token in the target model for it to be
|
||||
accepted. This threshold is used only when we use the
|
||||
TypicalAcceptanceSampler for token acceptance.
|
||||
typical_acceptance_sampler_posterior_alpha (Optional[float]):
|
||||
A scaling factor for the entropy-based threshold in the
|
||||
TypicalAcceptanceSampler.
|
||||
disable_logprobs: If set to True, token log probabilities will not
|
||||
be returned even if requested by sampling parameters. This
|
||||
reduces latency by skipping logprob calculation in proposal
|
||||
sampling, target sampling, and after accepted tokens are
|
||||
determined. If set to False, log probabilities will be
|
||||
returned.
|
||||
disable_log_stats: Whether to disable periodic printing of stage
|
||||
times in speculative decoding.
|
||||
"""
|
||||
self.draft_model_config = draft_model_config
|
||||
self.draft_parallel_config = draft_parallel_config
|
||||
self.num_speculative_tokens = num_speculative_tokens
|
||||
self.speculative_disable_mqa_scorer = speculative_disable_mqa_scorer
|
||||
self.speculative_disable_by_batch_size = \
|
||||
speculative_disable_by_batch_size
|
||||
self.ngram_prompt_lookup_max = ngram_prompt_lookup_max or 0
|
||||
self.ngram_prompt_lookup_min = ngram_prompt_lookup_min or 0
|
||||
self.draft_token_acceptance_method = draft_token_acceptance_method
|
||||
self.typical_acceptance_sampler_posterior_threshold = \
|
||||
typical_acceptance_sampler_posterior_threshold
|
||||
self.typical_acceptance_sampler_posterior_alpha = \
|
||||
typical_acceptance_sampler_posterior_alpha
|
||||
self.disable_logprobs = disable_logprobs
|
||||
self.disable_log_stats = disable_log_stats
|
||||
|
||||
self._verify_args()
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
if self.num_speculative_tokens is None:
|
||||
raise ValueError(
|
||||
"num_speculative_tokens must be provided with "
|
||||
"speculative model unless the draft model config contains an "
|
||||
"n_predict parameter.")
|
||||
|
||||
if self.num_speculative_tokens <= 0:
|
||||
raise ValueError("Expected num_speculative_tokens to be greater "
|
||||
f"than zero ({self.num_speculative_tokens}).")
|
||||
@ -2241,29 +2241,34 @@ class SpeculativeConfig:
|
||||
self.draft_parallel_config)
|
||||
# Validate and set draft token acceptance related settings.
|
||||
|
||||
if (self.draft_token_acceptance_method is None):
|
||||
raise ValueError("draft_token_acceptance_method is not set. "
|
||||
if self.acceptance_method is None:
|
||||
raise ValueError("acceptance_method is not set. "
|
||||
"Expected values are rejection_sampler or "
|
||||
"typical_acceptance_sampler.")
|
||||
|
||||
if (self.draft_token_acceptance_method != 'rejection_sampler'
|
||||
and self.draft_token_acceptance_method
|
||||
!= 'typical_acceptance_sampler'):
|
||||
if (self.acceptance_method != 'rejection_sampler'
|
||||
and self.acceptance_method != 'typical_acceptance_sampler'):
|
||||
raise ValueError(
|
||||
"Expected draft_token_acceptance_method to be either "
|
||||
"Expected acceptance_method to be either "
|
||||
"rejection_sampler or typical_acceptance_sampler. Instead it "
|
||||
f"is {self.draft_token_acceptance_method}")
|
||||
f"is {self.acceptance_method}")
|
||||
|
||||
if (self.typical_acceptance_sampler_posterior_threshold < 0
|
||||
or self.typical_acceptance_sampler_posterior_alpha < 0):
|
||||
if self.acceptance_method == "typical_acceptance_sampler" and (
|
||||
(self.posterior_threshold is not None
|
||||
and self.posterior_threshold < 0) or
|
||||
(self.posterior_alpha is not None and self.posterior_alpha < 0)):
|
||||
raise ValueError(
|
||||
"Expected typical_acceptance_sampler_posterior_threshold "
|
||||
"and typical_acceptance_sampler_posterior_alpha to be > 0. "
|
||||
"Instead found "
|
||||
f"typical_acceptance_sampler_posterior_threshold = "
|
||||
f"{self.typical_acceptance_sampler_posterior_threshold} and "
|
||||
f"typical_acceptance_sampler_posterior_alpha = "
|
||||
f"{self.typical_acceptance_sampler_posterior_alpha}")
|
||||
"Expected the posterior_threshold and posterior_alpha of "
|
||||
"typical_acceptance_sampler to be > 0. "
|
||||
"Instead found posterior_threshold = "
|
||||
f"{self.posterior_threshold} and posterior_alpha = "
|
||||
f"{self.posterior_alpha}")
|
||||
|
||||
if (self.disable_by_batch_size is not None
|
||||
and self.disable_by_batch_size < 2):
|
||||
raise ValueError("Expect the batch size threshold of disabling "
|
||||
"speculative decoding is > 1, but got "
|
||||
f"{self.disable_by_batch_size=}")
|
||||
|
||||
@property
|
||||
def num_lookahead_slots(self) -> int:
|
||||
@ -2276,8 +2281,8 @@ class SpeculativeConfig:
|
||||
return self.num_speculative_tokens
|
||||
|
||||
def __repr__(self) -> str:
|
||||
if self.ngram_prompt_lookup_max > 0:
|
||||
draft_model = "[ngram]"
|
||||
if self.prompt_lookup_max is not None and self.prompt_lookup_max > 0:
|
||||
draft_model = "ngram"
|
||||
else:
|
||||
draft_model = self.draft_model_config.model
|
||||
num_spec_tokens = self.num_speculative_tokens
|
||||
@ -3285,7 +3290,8 @@ class VllmConfig:
|
||||
init=True) # type: ignore
|
||||
load_config: LoadConfig = field(default=None, init=True) # type: ignore
|
||||
lora_config: Optional[LoRAConfig] = None
|
||||
speculative_config: Optional[SpeculativeConfig] = None
|
||||
speculative_config: SpeculativeConfig = field(default=None,
|
||||
init=True) # type: ignore
|
||||
decoding_config: Optional[DecodingConfig] = None
|
||||
observability_config: Optional[ObservabilityConfig] = None
|
||||
prompt_adapter_config: Optional[PromptAdapterConfig] = None
|
||||
|
@ -177,7 +177,10 @@ class EngineArgs:
|
||||
|
||||
guided_decoding_backend: str = 'xgrammar'
|
||||
logits_processor_pattern: Optional[str] = None
|
||||
# Speculative decoding configuration.
|
||||
|
||||
speculative_config: Optional[Union[str, Dict[str, Any]]] = None
|
||||
|
||||
# TODO(Shangming): Deprecate these out-of-date params after next release
|
||||
speculative_model: Optional[str] = None
|
||||
speculative_model_quantization: Optional[str] = None
|
||||
speculative_draft_tensor_parallel_size: Optional[int] = None
|
||||
@ -190,9 +193,9 @@ class EngineArgs:
|
||||
spec_decoding_acceptance_method: str = 'rejection_sampler'
|
||||
typical_acceptance_sampler_posterior_threshold: Optional[float] = None
|
||||
typical_acceptance_sampler_posterior_alpha: Optional[float] = None
|
||||
qlora_adapter_name_or_path: Optional[str] = None
|
||||
disable_logprobs_during_spec_decoding: Optional[bool] = None
|
||||
|
||||
qlora_adapter_name_or_path: Optional[str] = None
|
||||
show_hidden_metrics_for_version: Optional[str] = None
|
||||
otlp_traces_endpoint: Optional[str] = None
|
||||
collect_detailed_traces: Optional[str] = None
|
||||
@ -780,7 +783,11 @@ class EngineArgs:
|
||||
const="True",
|
||||
help='If set, the prefill requests can be chunked based on the '
|
||||
'max_num_batched_tokens.')
|
||||
|
||||
parser.add_argument('--speculative-config',
|
||||
type=nullable_str,
|
||||
default=None,
|
||||
help='The configurations for speculative decoding.'
|
||||
' Should be a JSON string.')
|
||||
parser.add_argument(
|
||||
'--speculative-model',
|
||||
type=nullable_str,
|
||||
@ -1192,6 +1199,82 @@ class EngineArgs:
|
||||
use_tqdm_on_load=self.use_tqdm_on_load,
|
||||
)
|
||||
|
||||
def create_speculative_config(
|
||||
self,
|
||||
target_model_config: ModelConfig,
|
||||
target_parallel_config: ParallelConfig,
|
||||
enable_chunked_prefill: bool,
|
||||
disable_log_stats: bool,
|
||||
) -> Optional["SpeculativeConfig"]:
|
||||
"""Initializes and returns a SpeculativeConfig object based on
|
||||
`speculative_config`.
|
||||
|
||||
This function utilizes `speculative_config` to create a
|
||||
SpeculativeConfig object. The `speculative_config` can either be
|
||||
provided as a JSON string input via CLI arguments or directly as a
|
||||
dictionary from the engine. If `speculative_config` is not set, this
|
||||
function will attempt to construct a configuration dictionary using
|
||||
certain parameters, which are scheduled for deprecation in the next
|
||||
release. Note that in next releases, `speculative_config` must be
|
||||
provided, and the deprecated standalone speculative-related parameters
|
||||
will be removed.
|
||||
"""
|
||||
if self.speculative_config is None:
|
||||
if (self.speculative_model is None
|
||||
and self.num_speculative_tokens is None):
|
||||
return None
|
||||
|
||||
# TODO(Shangming): Deprecate this way of setting SpeculativeConfig,
|
||||
# only allow '--speculative-config' after next release
|
||||
logger.warning_once(
|
||||
"Please use '--speculative-config' to set all configurations "
|
||||
"related to speculative decoding. The current method of "
|
||||
"specifying the model through '--speculative-model' and "
|
||||
"adding related parameters (e.g., '--num-speculative-tokens') "
|
||||
"separately will be deprecated in the next release.")
|
||||
|
||||
spec_config_dict = {
|
||||
"model": self.speculative_model,
|
||||
"quantization": self.speculative_model_quantization,
|
||||
"max_model_len": self.speculative_max_model_len,
|
||||
"draft_tensor_parallel_size":
|
||||
self.speculative_draft_tensor_parallel_size,
|
||||
"num_speculative_tokens": self.num_speculative_tokens,
|
||||
"disable_mqa_scorer": self.speculative_disable_mqa_scorer,
|
||||
"disable_by_batch_size":
|
||||
self.speculative_disable_by_batch_size,
|
||||
"prompt_lookup_max": self.ngram_prompt_lookup_max,
|
||||
"prompt_lookup_min": self.ngram_prompt_lookup_min,
|
||||
"acceptance_method": self.spec_decoding_acceptance_method,
|
||||
"posterior_threshold":
|
||||
self.typical_acceptance_sampler_posterior_threshold,
|
||||
"posterior_alpha":
|
||||
self.typical_acceptance_sampler_posterior_alpha,
|
||||
"disable_logprobs": self.disable_logprobs_during_spec_decoding,
|
||||
}
|
||||
|
||||
self.speculative_config = spec_config_dict
|
||||
else:
|
||||
if isinstance(self.speculative_config, str):
|
||||
import ast
|
||||
self.speculative_config = ast.literal_eval(
|
||||
self.speculative_config)
|
||||
# Note(Shangming): These parameters are not obtained from the cli arg
|
||||
# '--speculative-config' and must be passed in when creating the engine
|
||||
# config.
|
||||
|
||||
assert isinstance(self.speculative_config, dict)
|
||||
self.speculative_config.update({
|
||||
"target_model_config": target_model_config,
|
||||
"target_parallel_config": target_parallel_config,
|
||||
"enable_chunked_prefill": enable_chunked_prefill,
|
||||
"disable_log_stats": disable_log_stats,
|
||||
})
|
||||
speculative_config = SpeculativeConfig.from_dict(
|
||||
self.speculative_config)
|
||||
|
||||
return speculative_config
|
||||
|
||||
def create_engine_config(
|
||||
self,
|
||||
usage_context: Optional[UsageContext] = None,
|
||||
@ -1238,6 +1321,8 @@ class EngineArgs:
|
||||
else:
|
||||
self._set_default_args_v0(model_config)
|
||||
|
||||
assert self.enable_chunked_prefill is not None
|
||||
|
||||
cache_config = CacheConfig(
|
||||
block_size=self.block_size,
|
||||
gpu_memory_utilization=self.gpu_memory_utilization,
|
||||
@ -1280,31 +1365,11 @@ class EngineArgs:
|
||||
worker_extension_cls=self.worker_extension_cls,
|
||||
)
|
||||
|
||||
speculative_config = SpeculativeConfig.maybe_create_spec_config(
|
||||
speculative_config = self.create_speculative_config(
|
||||
target_model_config=model_config,
|
||||
target_parallel_config=parallel_config,
|
||||
target_dtype=self.dtype,
|
||||
speculative_model=self.speculative_model,
|
||||
speculative_model_quantization = \
|
||||
self.speculative_model_quantization,
|
||||
speculative_draft_tensor_parallel_size = \
|
||||
self.speculative_draft_tensor_parallel_size,
|
||||
num_speculative_tokens=self.num_speculative_tokens,
|
||||
speculative_disable_mqa_scorer=self.speculative_disable_mqa_scorer,
|
||||
speculative_disable_by_batch_size=self.
|
||||
speculative_disable_by_batch_size,
|
||||
speculative_max_model_len=self.speculative_max_model_len,
|
||||
enable_chunked_prefill=self.enable_chunked_prefill,
|
||||
disable_log_stats=self.disable_log_stats,
|
||||
ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
|
||||
ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
|
||||
draft_token_acceptance_method=\
|
||||
self.spec_decoding_acceptance_method,
|
||||
typical_acceptance_sampler_posterior_threshold=self.
|
||||
typical_acceptance_sampler_posterior_threshold,
|
||||
typical_acceptance_sampler_posterior_alpha=self.
|
||||
typical_acceptance_sampler_posterior_alpha,
|
||||
disable_logprobs=self.disable_logprobs_during_spec_decoding,
|
||||
)
|
||||
|
||||
# Reminder: Please update docs/source/features/compatibility_matrix.md
|
||||
@ -1569,7 +1634,7 @@ class EngineArgs:
|
||||
if (self.speculative_model is not None
|
||||
or self.num_speculative_tokens is not None):
|
||||
# This is supported but experimental (handled below).
|
||||
if self.speculative_model == "[ngram]":
|
||||
if self.speculative_model in ("ngram", "[ngram]"):
|
||||
pass
|
||||
else:
|
||||
_raise_or_fallback(feature_name="Speculative Decoding",
|
||||
@ -1617,7 +1682,8 @@ class EngineArgs:
|
||||
return False
|
||||
|
||||
# ngram is supported on V1, but off by default for now.
|
||||
if self.speculative_model == "[ngram]" and _warn_or_fallback("ngram"):
|
||||
if self.speculative_model in (
|
||||
"ngram", "[ngram]") and _warn_or_fallback("ngram"):
|
||||
return False
|
||||
|
||||
# Non-CUDA is supported on V1, but off by default for now.
|
||||
|
@ -92,22 +92,20 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
|
||||
# Override draft-model specific worker args.
|
||||
draft_worker_kwargs.update(
|
||||
vllm_config=draft_worker_config,
|
||||
ngram_prompt_lookup_max=speculative_config.ngram_prompt_lookup_max,
|
||||
ngram_prompt_lookup_min=speculative_config.ngram_prompt_lookup_min,
|
||||
ngram_prompt_lookup_max=speculative_config.prompt_lookup_max,
|
||||
ngram_prompt_lookup_min=speculative_config.prompt_lookup_min,
|
||||
)
|
||||
|
||||
spec_decode_worker = SpecDecodeWorker.create_worker(
|
||||
scorer_worker=target_worker,
|
||||
draft_worker_kwargs=draft_worker_kwargs,
|
||||
disable_mqa_scorer=speculative_config.speculative_disable_mqa_scorer,
|
||||
disable_by_batch_size=speculative_config.
|
||||
speculative_disable_by_batch_size,
|
||||
draft_token_acceptance_method=speculative_config.
|
||||
draft_token_acceptance_method,
|
||||
disable_mqa_scorer=speculative_config.disable_mqa_scorer,
|
||||
disable_by_batch_size=speculative_config.disable_by_batch_size,
|
||||
draft_token_acceptance_method=speculative_config.acceptance_method,
|
||||
typical_acceptance_sampler_posterior_threshold=speculative_config.
|
||||
typical_acceptance_sampler_posterior_threshold,
|
||||
posterior_threshold,
|
||||
typical_acceptance_sampler_posterior_alpha=speculative_config.
|
||||
typical_acceptance_sampler_posterior_alpha,
|
||||
posterior_alpha,
|
||||
disable_logprobs=speculative_config.disable_logprobs,
|
||||
disable_log_stats=speculative_config.disable_log_stats,
|
||||
num_speculative_tokens=speculative_config.num_speculative_tokens,
|
||||
|
@ -151,8 +151,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.use_spec_decode = False
|
||||
if self.speculative_config:
|
||||
self.use_spec_decode = True
|
||||
# TODO: find a better way to check if we are using ngram.
|
||||
assert self.speculative_config.ngram_prompt_lookup_min, \
|
||||
assert self.speculative_config.method == "ngram", \
|
||||
"Currently, only ngram spec decode is supported in V1."
|
||||
if get_pp_group().is_last_rank:
|
||||
self.drafter = NgramProposer()
|
||||
@ -160,7 +159,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# This usually takes less than 1 second.
|
||||
self.drafter.propose(
|
||||
np.zeros(1024, dtype=np.int32),
|
||||
self.speculative_config.ngram_prompt_lookup_min,
|
||||
self.speculative_config.prompt_lookup_min,
|
||||
self.speculative_config.num_speculative_tokens,
|
||||
)
|
||||
self.rejection_sampler = RejectionSampler()
|
||||
@ -1155,7 +1154,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids
|
||||
drafter_output = self.drafter.propose(
|
||||
self.input_batch.token_ids_cpu[i, :end_idx],
|
||||
self.speculative_config.ngram_prompt_lookup_min,
|
||||
self.speculative_config.prompt_lookup_min,
|
||||
self.speculative_config.num_speculative_tokens,
|
||||
)
|
||||
if drafter_output is None or len(drafter_output) == 0:
|
||||
|
Loading…
x
Reference in New Issue
Block a user