2024-12-23 17:35:38 -05:00
(spec-decode)=
2025-01-07 11:20:01 +08:00
# Speculative Decoding
2024-12-23 17:35:38 -05:00
2025-01-29 03:38:29 +00:00
:::{warning}
2024-12-23 17:35:38 -05:00
Please note that speculative decoding in vLLM is not yet optimized and does
2024-12-26 06:49:26 +08:00
not usually yield inter-token latency reductions for all prompt datasets or sampling parameters.
The work to optimize it is ongoing and can be followed here: < gh-issue:4630 >
2025-01-29 03:38:29 +00:00
:::
2024-12-23 17:35:38 -05:00
2025-01-29 03:38:29 +00:00
:::{warning}
2024-12-23 17:35:38 -05:00
Currently, speculative decoding in vLLM is not compatible with pipeline parallelism.
2025-01-29 03:38:29 +00:00
:::
2024-12-23 17:35:38 -05:00
This document shows how to use [Speculative Decoding ](https://x.com/karpathy/status/1697318534555336961 ) with vLLM.
Speculative decoding is a technique which improves inter-token latency in memory-bound LLM inference.
## Speculating with a draft model
The following code configures vLLM in an offline mode to use speculative decoding with a draft model, speculating 5 tokens at a time.
```python
from vllm import LLM, SamplingParams
prompts = [
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM(
model="facebook/opt-6.7b",
tensor_parallel_size=1,
speculative_model="facebook/opt-125m",
num_speculative_tokens=5,
)
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
To perform the same with an online mode launch the server:
```bash
python -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port 8000 --model facebook/opt-6.7b \
--seed 42 -tp 1 --speculative_model facebook/opt-125m --use-v2-block-manager \
--num_speculative_tokens 5 --gpu_memory_utilization 0.8
```
Then use a client:
```python
from openai import OpenAI
# Modify OpenAI's API key and API base to use vLLM's API server.
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"
client = OpenAI(
# defaults to os.environ.get("OPENAI_API_KEY")
api_key=openai_api_key,
base_url=openai_api_base,
)
models = client.models.list()
model = models.data[0].id
# Completion API
stream = False
completion = client.completions.create(
model=model,
prompt="The future of AI is",
echo=False,
n=1,
stream=stream,
)
print("Completion results:")
if stream:
for c in completion:
print(c)
else:
print(completion)
```
## Speculating by matching n-grams in the prompt
The following code configures vLLM to use speculative decoding where proposals are generated by
matching n-grams in the prompt. For more information read [this thread. ](https://x.com/joao_gante/status/1747322413006643259 )
```python
from vllm import LLM, SamplingParams
prompts = [
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM(
model="facebook/opt-6.7b",
tensor_parallel_size=1,
speculative_model="[ngram]",
num_speculative_tokens=5,
ngram_prompt_lookup_max=4,
)
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
## Speculating using MLP speculators
The following code configures vLLM to use speculative decoding where proposals are generated by
draft models that conditioning draft predictions on both context vectors and sampled tokens.
For more information see [this blog ](https://pytorch.org/blog/hitchhikers-guide-speculative-decoding/ ) or
[this technical report ](https://arxiv.org/abs/2404.19124 ).
```python
from vllm import LLM, SamplingParams
prompts = [
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM(
model="meta-llama/Meta-Llama-3.1-70B-Instruct",
tensor_parallel_size=4,
2025-02-04 02:05:04 -05:00
speculative_model="ibm-ai-platform/llama3-70b-accelerator",
2024-12-23 17:35:38 -05:00
speculative_draft_tensor_parallel_size=1,
)
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
Note that these speculative models currently need to be run without tensor parallelism, although
it is possible to run the main model using tensor parallelism (see example above). Since the
speculative models are relatively small, we still see significant speedups. However, this
limitation will be fixed in a future release.
A variety of speculative models of this type are available on HF hub:
2025-02-04 02:05:04 -05:00
- [llama-13b-accelerator ](https://huggingface.co/ibm-ai-platform/llama-13b-accelerator )
- [llama3-8b-accelerator ](https://huggingface.co/ibm-ai-platform/llama3-8b-accelerator )
- [codellama-34b-accelerator ](https://huggingface.co/ibm-ai-platform/codellama-34b-accelerator )
- [llama2-70b-accelerator ](https://huggingface.co/ibm-ai-platform/llama2-70b-accelerator )
- [llama3-70b-accelerator ](https://huggingface.co/ibm-ai-platform/llama3-70b-accelerator )
2024-12-23 17:35:38 -05:00
- [granite-3b-code-instruct-accelerator ](https://huggingface.co/ibm-granite/granite-3b-code-instruct-accelerator )
- [granite-8b-code-instruct-accelerator ](https://huggingface.co/ibm-granite/granite-8b-code-instruct-accelerator )
- [granite-7b-instruct-accelerator ](https://huggingface.co/ibm-granite/granite-7b-instruct-accelerator )
- [granite-20b-code-instruct-accelerator ](https://huggingface.co/ibm-granite/granite-20b-code-instruct-accelerator )
2025-01-07 11:19:12 -08:00
## Speculating using EAGLE based draft models
The following code configures vLLM to use speculative decoding where proposals are generated by
an [EAGLE (Extrapolation Algorithm for Greater Language-model Efficiency) ](https://arxiv.org/pdf/2401.15077 ) based draft model.
```python
from vllm import LLM, SamplingParams
prompts = [
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM(
model="meta-llama/Meta-Llama-3-8B-Instruct",
tensor_parallel_size=4,
speculative_model="path/to/modified/eagle/model",
speculative_draft_tensor_parallel_size=1,
)
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
A few important things to consider when using the EAGLE based draft models:
1. The EAGLE draft models available in the [HF repository for EAGLE models ](https://huggingface.co/yuhuili ) cannot be
used directly with vLLM due to differences in the expected layer names and model definition.
2025-01-12 03:17:13 -05:00
To use these models with vLLM, use the [following script ](https://gist.github.com/abhigoyal1997/1e7a4109ccb7704fbc67f625e86b2d6d )
2025-01-07 11:19:12 -08:00
to convert them. Note that this script does not modify the model's weights.
In the above example, use the script to first convert
2025-01-12 03:17:13 -05:00
the [yuhuili/EAGLE-LLaMA3-Instruct-8B ](https://huggingface.co/yuhuili/EAGLE-LLaMA3-Instruct-8B ) model
2025-01-07 11:19:12 -08:00
and then use the converted checkpoint as the draft model in vLLM.
2. The EAGLE based draft models need to be run without tensor parallelism
(i.e. speculative_draft_tensor_parallel_size is set to 1), although
it is possible to run the main model using tensor parallelism (see example above).
3. When using EAGLE-based speculators with vLLM, the observed speedup is lower than what is
reported in the reference implementation [here ](https://github.com/SafeAILab/EAGLE ). This issue is under
investigation and tracked here: [https://github.com/vllm-project/vllm/issues/9565 ](https://github.com/vllm-project/vllm/issues/9565 ).
A variety of EAGLE draft models are available on the Hugging Face hub:
| Base Model | EAGLE on Hugging Face | # EAGLE Parameters |
|---------------------------------------------------------------------|-------------------------------------------|--------------------|
| Vicuna-7B-v1.3 | yuhuili/EAGLE-Vicuna-7B-v1.3 | 0.24B |
| Vicuna-13B-v1.3 | yuhuili/EAGLE-Vicuna-13B-v1.3 | 0.37B |
| Vicuna-33B-v1.3 | yuhuili/EAGLE-Vicuna-33B-v1.3 | 0.56B |
| LLaMA2-Chat 7B | yuhuili/EAGLE-llama2-chat-7B | 0.24B |
| LLaMA2-Chat 13B | yuhuili/EAGLE-llama2-chat-13B | 0.37B |
| LLaMA2-Chat 70B | yuhuili/EAGLE-llama2-chat-70B | 0.99B |
| Mixtral-8x7B-Instruct-v0.1 | yuhuili/EAGLE-mixtral-instruct-8x7B | 0.28B |
| LLaMA3-Instruct 8B | yuhuili/EAGLE-LLaMA3-Instruct-8B | 0.25B |
| LLaMA3-Instruct 70B | yuhuili/EAGLE-LLaMA3-Instruct-70B | 0.99B |
| Qwen2-7B-Instruct | yuhuili/EAGLE-Qwen2-7B-Instruct | 0.26B |
| Qwen2-72B-Instruct | yuhuili/EAGLE-Qwen2-72B-Instruct | 1.05B |
2024-12-23 17:35:38 -05:00
## Lossless guarantees of Speculative Decoding
In vLLM, speculative decoding aims to enhance inference efficiency while maintaining accuracy. This section addresses the lossless guarantees of
speculative decoding, breaking down the guarantees into three key areas:
1. **Theoretical Losslessness**
\- Speculative decoding sampling is theoretically lossless up to the precision limits of hardware numerics. Floating-point errors might
cause slight variations in output distributions, as discussed
in [Accelerating Large Language Model Decoding with Speculative Sampling ](https://arxiv.org/pdf/2302.01318 )
2. **Algorithmic Losslessness**
\- vLLM’ s implementation of speculative decoding is algorithmically validated to be lossless. Key validation tests include:
> - **Rejection Sampler Convergence**: Ensures that samples from vLLM’ s rejection sampler align with the target
> distribution. [View Test Code](https://github.com/vllm-project/vllm/blob/47b65a550866c7ffbd076ecb74106714838ce7da/tests/samplers/test_rejection_sampler.py#L252)
> - **Greedy Sampling Equality**: Confirms that greedy sampling with speculative decoding matches greedy sampling
> without it. This verifies that vLLM's speculative decoding framework, when integrated with the vLLM forward pass and the vLLM rejection sampler,
2024-12-26 06:49:26 +08:00
> provides a lossless guarantee. Almost all of the tests in <gh-dir:tests/spec_decode/e2e>.
2024-12-23 17:35:38 -05:00
> verify this property using [this assertion implementation](https://github.com/vllm-project/vllm/blob/b67ae00cdbbe1a58ffc8ff170f0c8d79044a684a/tests/spec_decode/e2e/conftest.py#L291)
3. **vLLM Logprob Stability**
\- vLLM does not currently guarantee stable token log probabilities (logprobs). This can result in different outputs for the
same request across runs. For more details, see the FAQ section
2025-01-06 10:18:33 +08:00
titled *Can the output of a prompt vary across runs in vLLM?* in the [FAQs ](#faq ).
2024-12-23 17:35:38 -05:00
While vLLM strives to ensure losslessness in speculative decoding, variations in generated outputs with and without speculative decoding
can occur due to following factors:
- **Floating-Point Precision**: Differences in hardware numerical precision may lead to slight discrepancies in the output distribution.
- **Batch Size and Numerical Stability**: Changes in batch size may cause variations in logprobs and output probabilities, potentially
due to non-deterministic behavior in batched operations or numerical instability.
2025-01-06 10:18:33 +08:00
For mitigation strategies, please refer to the FAQ entry *Can the output of a prompt vary across runs in vLLM?* in the [FAQs ](#faq ).
2024-12-23 17:35:38 -05:00
## Resources for vLLM contributors
- [A Hacker's Guide to Speculative Decoding in vLLM ](https://www.youtube.com/watch?v=9wNAgpX6z_4 )
- [What is Lookahead Scheduling in vLLM? ](https://docs.google.com/document/d/1Z9TvqzzBPnh5WHcRwjvK2UEeFeq5zMZb5mFE8jR0HCs/edit#heading=h.1fjfb0donq5a )
- [Information on batch expansion ](https://docs.google.com/document/d/1T-JaS2T1NRfdP51qzqpyakoCXxSXTtORppiwaj5asxA/edit#heading=h.kk7dq05lc6q8 )
2024-12-26 06:49:26 +08:00
- [Dynamic speculative decoding ](gh-issue:4565 )