275 lines
13 KiB
Markdown
275 lines
13 KiB
Markdown
(spec-decode)=
|
||
|
||
# Speculative Decoding
|
||
|
||
:::{warning}
|
||
Please note that speculative decoding in vLLM is not yet optimized and does
|
||
not usually yield inter-token latency reductions for all prompt datasets or sampling parameters.
|
||
The work to optimize it is ongoing and can be followed here: <gh-issue:4630>
|
||
:::
|
||
|
||
:::{warning}
|
||
Currently, speculative decoding in vLLM is not compatible with pipeline parallelism.
|
||
:::
|
||
|
||
This document shows how to use [Speculative Decoding](https://x.com/karpathy/status/1697318534555336961) with vLLM.
|
||
Speculative decoding is a technique which improves inter-token latency in memory-bound LLM inference.
|
||
|
||
## Speculating with a draft model
|
||
|
||
The following code configures vLLM in an offline mode to use speculative decoding with a draft model, speculating 5 tokens at a time.
|
||
|
||
```python
|
||
from vllm import LLM, SamplingParams
|
||
|
||
prompts = [
|
||
"The future of AI is",
|
||
]
|
||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||
|
||
llm = LLM(
|
||
model="facebook/opt-6.7b",
|
||
tensor_parallel_size=1,
|
||
speculative_config={
|
||
"model": "facebook/opt-125m",
|
||
"num_speculative_tokens": 5,
|
||
},
|
||
)
|
||
outputs = llm.generate(prompts, sampling_params)
|
||
|
||
for output in outputs:
|
||
prompt = output.prompt
|
||
generated_text = output.outputs[0].text
|
||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||
```
|
||
|
||
To perform the same with an online mode launch the server:
|
||
|
||
```bash
|
||
python -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port 8000 --model facebook/opt-6.7b \
|
||
--seed 42 -tp 1 --gpu_memory_utilization 0.8 \
|
||
--speculative_config '{"model": "facebook/opt-125m", "num_speculative_tokens": 5}'
|
||
```
|
||
|
||
:::{warning}
|
||
Note: Please use `--speculative_config` to set all configurations related to speculative decoding. The previous method of specifying the model through `--speculative_model` and adding related parameters (e.g., `--num_speculative_tokens`) separately has been deprecated now.
|
||
:::
|
||
|
||
Then use a client:
|
||
|
||
```python
|
||
from openai import OpenAI
|
||
|
||
# Modify OpenAI's API key and API base to use vLLM's API server.
|
||
openai_api_key = "EMPTY"
|
||
openai_api_base = "http://localhost:8000/v1"
|
||
|
||
client = OpenAI(
|
||
# defaults to os.environ.get("OPENAI_API_KEY")
|
||
api_key=openai_api_key,
|
||
base_url=openai_api_base,
|
||
)
|
||
|
||
models = client.models.list()
|
||
model = models.data[0].id
|
||
|
||
# Completion API
|
||
stream = False
|
||
completion = client.completions.create(
|
||
model=model,
|
||
prompt="The future of AI is",
|
||
echo=False,
|
||
n=1,
|
||
stream=stream,
|
||
)
|
||
|
||
print("Completion results:")
|
||
if stream:
|
||
for c in completion:
|
||
print(c)
|
||
else:
|
||
print(completion)
|
||
```
|
||
|
||
## Speculating by matching n-grams in the prompt
|
||
|
||
The following code configures vLLM to use speculative decoding where proposals are generated by
|
||
matching n-grams in the prompt. For more information read [this thread.](https://x.com/joao_gante/status/1747322413006643259)
|
||
|
||
```python
|
||
from vllm import LLM, SamplingParams
|
||
|
||
prompts = [
|
||
"The future of AI is",
|
||
]
|
||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||
|
||
llm = LLM(
|
||
model="facebook/opt-6.7b",
|
||
tensor_parallel_size=1,
|
||
speculative_config={
|
||
"method": "ngram",
|
||
"num_speculative_tokens": 5,
|
||
"prompt_lookup_max": 4,
|
||
},
|
||
)
|
||
outputs = llm.generate(prompts, sampling_params)
|
||
|
||
for output in outputs:
|
||
prompt = output.prompt
|
||
generated_text = output.outputs[0].text
|
||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||
```
|
||
|
||
## Speculating using MLP speculators
|
||
|
||
The following code configures vLLM to use speculative decoding where proposals are generated by
|
||
draft models that conditioning draft predictions on both context vectors and sampled tokens.
|
||
For more information see [this blog](https://pytorch.org/blog/hitchhikers-guide-speculative-decoding/) or
|
||
[this technical report](https://arxiv.org/abs/2404.19124).
|
||
|
||
```python
|
||
from vllm import LLM, SamplingParams
|
||
|
||
prompts = [
|
||
"The future of AI is",
|
||
]
|
||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||
|
||
llm = LLM(
|
||
model="meta-llama/Meta-Llama-3.1-70B-Instruct",
|
||
tensor_parallel_size=4,
|
||
speculative_config={
|
||
"model": "ibm-ai-platform/llama3-70b-accelerator",
|
||
"draft_tensor_parallel_size": 1,
|
||
},
|
||
)
|
||
outputs = llm.generate(prompts, sampling_params)
|
||
|
||
for output in outputs:
|
||
prompt = output.prompt
|
||
generated_text = output.outputs[0].text
|
||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||
```
|
||
|
||
Note that these speculative models currently need to be run without tensor parallelism, although
|
||
it is possible to run the main model using tensor parallelism (see example above). Since the
|
||
speculative models are relatively small, we still see significant speedups. However, this
|
||
limitation will be fixed in a future release.
|
||
|
||
A variety of speculative models of this type are available on HF hub:
|
||
|
||
- [llama-13b-accelerator](https://huggingface.co/ibm-ai-platform/llama-13b-accelerator)
|
||
- [llama3-8b-accelerator](https://huggingface.co/ibm-ai-platform/llama3-8b-accelerator)
|
||
- [codellama-34b-accelerator](https://huggingface.co/ibm-ai-platform/codellama-34b-accelerator)
|
||
- [llama2-70b-accelerator](https://huggingface.co/ibm-ai-platform/llama2-70b-accelerator)
|
||
- [llama3-70b-accelerator](https://huggingface.co/ibm-ai-platform/llama3-70b-accelerator)
|
||
- [granite-3b-code-instruct-accelerator](https://huggingface.co/ibm-granite/granite-3b-code-instruct-accelerator)
|
||
- [granite-8b-code-instruct-accelerator](https://huggingface.co/ibm-granite/granite-8b-code-instruct-accelerator)
|
||
- [granite-7b-instruct-accelerator](https://huggingface.co/ibm-granite/granite-7b-instruct-accelerator)
|
||
- [granite-20b-code-instruct-accelerator](https://huggingface.co/ibm-granite/granite-20b-code-instruct-accelerator)
|
||
|
||
## Speculating using EAGLE based draft models
|
||
|
||
The following code configures vLLM to use speculative decoding where proposals are generated by
|
||
an [EAGLE (Extrapolation Algorithm for Greater Language-model Efficiency)](https://arxiv.org/pdf/2401.15077) based draft model. A more detailed example for offline mode, including how to extract request level acceptance rate, can be found [here](<gh-file:examples/offline_inference/eagle.py>).
|
||
|
||
```python
|
||
from vllm import LLM, SamplingParams
|
||
|
||
prompts = [
|
||
"The future of AI is",
|
||
]
|
||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||
|
||
llm = LLM(
|
||
model="meta-llama/Meta-Llama-3-8B-Instruct",
|
||
tensor_parallel_size=4,
|
||
speculative_config={
|
||
"model": "yuhuili/EAGLE-LLaMA3-Instruct-8B",
|
||
"draft_tensor_parallel_size": 1,
|
||
},
|
||
)
|
||
|
||
outputs = llm.generate(prompts, sampling_params)
|
||
|
||
for output in outputs:
|
||
prompt = output.prompt
|
||
generated_text = output.outputs[0].text
|
||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||
|
||
```
|
||
|
||
A few important things to consider when using the EAGLE based draft models:
|
||
|
||
1. The EAGLE draft models available in the [HF repository for EAGLE models](https://huggingface.co/yuhuili) should
|
||
be able to be loaded and used directly by vLLM after [PR 12304](https://github.com/vllm-project/vllm/pull/12304).
|
||
If you are using vllm version before [PR 12304](https://github.com/vllm-project/vllm/pull/12304), please use the
|
||
[script](https://gist.github.com/abhigoyal1997/1e7a4109ccb7704fbc67f625e86b2d6d) to convert the speculative model,
|
||
and specify `"model": "path/to/modified/eagle/model"` in `speculative_config`. If weight-loading problems still occur when using the latest version of vLLM, please leave a comment or raise an issue.
|
||
|
||
2. The EAGLE based draft models need to be run without tensor parallelism
|
||
(i.e. draft_tensor_parallel_size is set to 1 in `speculative_config`), although
|
||
it is possible to run the main model using tensor parallelism (see example above).
|
||
|
||
3. When using EAGLE-based speculators with vLLM, the observed speedup is lower than what is
|
||
reported in the reference implementation [here](https://github.com/SafeAILab/EAGLE). This issue is under
|
||
investigation and tracked here: [https://github.com/vllm-project/vllm/issues/9565](https://github.com/vllm-project/vllm/issues/9565).
|
||
|
||
A variety of EAGLE draft models are available on the Hugging Face hub:
|
||
|
||
| Base Model | EAGLE on Hugging Face | # EAGLE Parameters |
|
||
|---------------------------------------------------------------------|-------------------------------------------|--------------------|
|
||
| Vicuna-7B-v1.3 | yuhuili/EAGLE-Vicuna-7B-v1.3 | 0.24B |
|
||
| Vicuna-13B-v1.3 | yuhuili/EAGLE-Vicuna-13B-v1.3 | 0.37B |
|
||
| Vicuna-33B-v1.3 | yuhuili/EAGLE-Vicuna-33B-v1.3 | 0.56B |
|
||
| LLaMA2-Chat 7B | yuhuili/EAGLE-llama2-chat-7B | 0.24B |
|
||
| LLaMA2-Chat 13B | yuhuili/EAGLE-llama2-chat-13B | 0.37B |
|
||
| LLaMA2-Chat 70B | yuhuili/EAGLE-llama2-chat-70B | 0.99B |
|
||
| Mixtral-8x7B-Instruct-v0.1 | yuhuili/EAGLE-mixtral-instruct-8x7B | 0.28B |
|
||
| LLaMA3-Instruct 8B | yuhuili/EAGLE-LLaMA3-Instruct-8B | 0.25B |
|
||
| LLaMA3-Instruct 70B | yuhuili/EAGLE-LLaMA3-Instruct-70B | 0.99B |
|
||
| Qwen2-7B-Instruct | yuhuili/EAGLE-Qwen2-7B-Instruct | 0.26B |
|
||
| Qwen2-72B-Instruct | yuhuili/EAGLE-Qwen2-72B-Instruct | 1.05B |
|
||
|
||
## Lossless guarantees of Speculative Decoding
|
||
|
||
In vLLM, speculative decoding aims to enhance inference efficiency while maintaining accuracy. This section addresses the lossless guarantees of
|
||
speculative decoding, breaking down the guarantees into three key areas:
|
||
|
||
1. **Theoretical Losslessness**
|
||
\- Speculative decoding sampling is theoretically lossless up to the precision limits of hardware numerics. Floating-point errors might
|
||
cause slight variations in output distributions, as discussed
|
||
in [Accelerating Large Language Model Decoding with Speculative Sampling](https://arxiv.org/pdf/2302.01318)
|
||
|
||
2. **Algorithmic Losslessness**
|
||
\- vLLM’s implementation of speculative decoding is algorithmically validated to be lossless. Key validation tests include:
|
||
|
||
> - **Rejection Sampler Convergence**: Ensures that samples from vLLM’s rejection sampler align with the target
|
||
> distribution. [View Test Code](https://github.com/vllm-project/vllm/blob/47b65a550866c7ffbd076ecb74106714838ce7da/tests/samplers/test_rejection_sampler.py#L252)
|
||
> - **Greedy Sampling Equality**: Confirms that greedy sampling with speculative decoding matches greedy sampling
|
||
> without it. This verifies that vLLM's speculative decoding framework, when integrated with the vLLM forward pass and the vLLM rejection sampler,
|
||
> provides a lossless guarantee. Almost all of the tests in <gh-dir:tests/spec_decode/e2e>.
|
||
> verify this property using [this assertion implementation](https://github.com/vllm-project/vllm/blob/b67ae00cdbbe1a58ffc8ff170f0c8d79044a684a/tests/spec_decode/e2e/conftest.py#L291)
|
||
|
||
3. **vLLM Logprob Stability**
|
||
\- vLLM does not currently guarantee stable token log probabilities (logprobs). This can result in different outputs for the
|
||
same request across runs. For more details, see the FAQ section
|
||
titled *Can the output of a prompt vary across runs in vLLM?* in the [FAQs](#faq).
|
||
|
||
While vLLM strives to ensure losslessness in speculative decoding, variations in generated outputs with and without speculative decoding
|
||
can occur due to following factors:
|
||
|
||
- **Floating-Point Precision**: Differences in hardware numerical precision may lead to slight discrepancies in the output distribution.
|
||
- **Batch Size and Numerical Stability**: Changes in batch size may cause variations in logprobs and output probabilities, potentially
|
||
due to non-deterministic behavior in batched operations or numerical instability.
|
||
|
||
For mitigation strategies, please refer to the FAQ entry *Can the output of a prompt vary across runs in vLLM?* in the [FAQs](#faq).
|
||
|
||
## Resources for vLLM contributors
|
||
|
||
- [A Hacker's Guide to Speculative Decoding in vLLM](https://www.youtube.com/watch?v=9wNAgpX6z_4)
|
||
- [What is Lookahead Scheduling in vLLM?](https://docs.google.com/document/d/1Z9TvqzzBPnh5WHcRwjvK2UEeFeq5zMZb5mFE8jR0HCs/edit#heading=h.1fjfb0donq5a)
|
||
- [Information on batch expansion](https://docs.google.com/document/d/1T-JaS2T1NRfdP51qzqpyakoCXxSXTtORppiwaj5asxA/edit#heading=h.kk7dq05lc6q8)
|
||
- [Dynamic speculative decoding](gh-issue:4565)
|