206 lines
8.4 KiB
Markdown
206 lines
8.4 KiB
Markdown
![]() |
(spec-decode)=
|
|||
|
|
|||
|
# Speculative decoding
|
|||
|
|
|||
|
```{warning}
|
|||
|
Please note that speculative decoding in vLLM is not yet optimized and does
|
|||
|
not usually yield inter-token latency reductions for all prompt datasets or sampling parameters. The work
|
|||
|
to optimize it is ongoing and can be followed in [this issue.](https://github.com/vllm-project/vllm/issues/4630)
|
|||
|
```
|
|||
|
|
|||
|
```{warning}
|
|||
|
Currently, speculative decoding in vLLM is not compatible with pipeline parallelism.
|
|||
|
```
|
|||
|
|
|||
|
This document shows how to use [Speculative Decoding](https://x.com/karpathy/status/1697318534555336961) with vLLM.
|
|||
|
Speculative decoding is a technique which improves inter-token latency in memory-bound LLM inference.
|
|||
|
|
|||
|
## Speculating with a draft model
|
|||
|
|
|||
|
The following code configures vLLM in an offline mode to use speculative decoding with a draft model, speculating 5 tokens at a time.
|
|||
|
|
|||
|
```python
|
|||
|
from vllm import LLM, SamplingParams
|
|||
|
|
|||
|
prompts = [
|
|||
|
"The future of AI is",
|
|||
|
]
|
|||
|
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
|||
|
|
|||
|
llm = LLM(
|
|||
|
model="facebook/opt-6.7b",
|
|||
|
tensor_parallel_size=1,
|
|||
|
speculative_model="facebook/opt-125m",
|
|||
|
num_speculative_tokens=5,
|
|||
|
)
|
|||
|
outputs = llm.generate(prompts, sampling_params)
|
|||
|
|
|||
|
for output in outputs:
|
|||
|
prompt = output.prompt
|
|||
|
generated_text = output.outputs[0].text
|
|||
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
|||
|
```
|
|||
|
|
|||
|
To perform the same with an online mode launch the server:
|
|||
|
|
|||
|
```bash
|
|||
|
python -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port 8000 --model facebook/opt-6.7b \
|
|||
|
--seed 42 -tp 1 --speculative_model facebook/opt-125m --use-v2-block-manager \
|
|||
|
--num_speculative_tokens 5 --gpu_memory_utilization 0.8
|
|||
|
```
|
|||
|
|
|||
|
Then use a client:
|
|||
|
|
|||
|
```python
|
|||
|
from openai import OpenAI
|
|||
|
|
|||
|
# Modify OpenAI's API key and API base to use vLLM's API server.
|
|||
|
openai_api_key = "EMPTY"
|
|||
|
openai_api_base = "http://localhost:8000/v1"
|
|||
|
|
|||
|
client = OpenAI(
|
|||
|
# defaults to os.environ.get("OPENAI_API_KEY")
|
|||
|
api_key=openai_api_key,
|
|||
|
base_url=openai_api_base,
|
|||
|
)
|
|||
|
|
|||
|
models = client.models.list()
|
|||
|
model = models.data[0].id
|
|||
|
|
|||
|
# Completion API
|
|||
|
stream = False
|
|||
|
completion = client.completions.create(
|
|||
|
model=model,
|
|||
|
prompt="The future of AI is",
|
|||
|
echo=False,
|
|||
|
n=1,
|
|||
|
stream=stream,
|
|||
|
)
|
|||
|
|
|||
|
print("Completion results:")
|
|||
|
if stream:
|
|||
|
for c in completion:
|
|||
|
print(c)
|
|||
|
else:
|
|||
|
print(completion)
|
|||
|
```
|
|||
|
|
|||
|
## Speculating by matching n-grams in the prompt
|
|||
|
|
|||
|
The following code configures vLLM to use speculative decoding where proposals are generated by
|
|||
|
matching n-grams in the prompt. For more information read [this thread.](https://x.com/joao_gante/status/1747322413006643259)
|
|||
|
|
|||
|
```python
|
|||
|
from vllm import LLM, SamplingParams
|
|||
|
|
|||
|
prompts = [
|
|||
|
"The future of AI is",
|
|||
|
]
|
|||
|
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
|||
|
|
|||
|
llm = LLM(
|
|||
|
model="facebook/opt-6.7b",
|
|||
|
tensor_parallel_size=1,
|
|||
|
speculative_model="[ngram]",
|
|||
|
num_speculative_tokens=5,
|
|||
|
ngram_prompt_lookup_max=4,
|
|||
|
)
|
|||
|
outputs = llm.generate(prompts, sampling_params)
|
|||
|
|
|||
|
for output in outputs:
|
|||
|
prompt = output.prompt
|
|||
|
generated_text = output.outputs[0].text
|
|||
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
|||
|
```
|
|||
|
|
|||
|
## Speculating using MLP speculators
|
|||
|
|
|||
|
The following code configures vLLM to use speculative decoding where proposals are generated by
|
|||
|
draft models that conditioning draft predictions on both context vectors and sampled tokens.
|
|||
|
For more information see [this blog](https://pytorch.org/blog/hitchhikers-guide-speculative-decoding/) or
|
|||
|
[this technical report](https://arxiv.org/abs/2404.19124).
|
|||
|
|
|||
|
```python
|
|||
|
from vllm import LLM, SamplingParams
|
|||
|
|
|||
|
prompts = [
|
|||
|
"The future of AI is",
|
|||
|
]
|
|||
|
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
|||
|
|
|||
|
llm = LLM(
|
|||
|
model="meta-llama/Meta-Llama-3.1-70B-Instruct",
|
|||
|
tensor_parallel_size=4,
|
|||
|
speculative_model="ibm-fms/llama3-70b-accelerator",
|
|||
|
speculative_draft_tensor_parallel_size=1,
|
|||
|
)
|
|||
|
outputs = llm.generate(prompts, sampling_params)
|
|||
|
|
|||
|
for output in outputs:
|
|||
|
prompt = output.prompt
|
|||
|
generated_text = output.outputs[0].text
|
|||
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
|||
|
```
|
|||
|
|
|||
|
Note that these speculative models currently need to be run without tensor parallelism, although
|
|||
|
it is possible to run the main model using tensor parallelism (see example above). Since the
|
|||
|
speculative models are relatively small, we still see significant speedups. However, this
|
|||
|
limitation will be fixed in a future release.
|
|||
|
|
|||
|
A variety of speculative models of this type are available on HF hub:
|
|||
|
|
|||
|
- [llama-13b-accelerator](https://huggingface.co/ibm-fms/llama-13b-accelerator)
|
|||
|
- [llama3-8b-accelerator](https://huggingface.co/ibm-fms/llama3-8b-accelerator)
|
|||
|
- [codellama-34b-accelerator](https://huggingface.co/ibm-fms/codellama-34b-accelerator)
|
|||
|
- [llama2-70b-accelerator](https://huggingface.co/ibm-fms/llama2-70b-accelerator)
|
|||
|
- [llama3-70b-accelerator](https://huggingface.co/ibm-fms/llama3-70b-accelerator)
|
|||
|
- [granite-3b-code-instruct-accelerator](https://huggingface.co/ibm-granite/granite-3b-code-instruct-accelerator)
|
|||
|
- [granite-8b-code-instruct-accelerator](https://huggingface.co/ibm-granite/granite-8b-code-instruct-accelerator)
|
|||
|
- [granite-7b-instruct-accelerator](https://huggingface.co/ibm-granite/granite-7b-instruct-accelerator)
|
|||
|
- [granite-20b-code-instruct-accelerator](https://huggingface.co/ibm-granite/granite-20b-code-instruct-accelerator)
|
|||
|
|
|||
|
## Lossless guarantees of Speculative Decoding
|
|||
|
|
|||
|
In vLLM, speculative decoding aims to enhance inference efficiency while maintaining accuracy. This section addresses the lossless guarantees of
|
|||
|
speculative decoding, breaking down the guarantees into three key areas:
|
|||
|
|
|||
|
1. **Theoretical Losslessness**
|
|||
|
\- Speculative decoding sampling is theoretically lossless up to the precision limits of hardware numerics. Floating-point errors might
|
|||
|
cause slight variations in output distributions, as discussed
|
|||
|
in [Accelerating Large Language Model Decoding with Speculative Sampling](https://arxiv.org/pdf/2302.01318)
|
|||
|
|
|||
|
2. **Algorithmic Losslessness**
|
|||
|
\- vLLM’s implementation of speculative decoding is algorithmically validated to be lossless. Key validation tests include:
|
|||
|
|
|||
|
> - **Rejection Sampler Convergence**: Ensures that samples from vLLM’s rejection sampler align with the target
|
|||
|
> distribution. [View Test Code](https://github.com/vllm-project/vllm/blob/47b65a550866c7ffbd076ecb74106714838ce7da/tests/samplers/test_rejection_sampler.py#L252)
|
|||
|
> - **Greedy Sampling Equality**: Confirms that greedy sampling with speculative decoding matches greedy sampling
|
|||
|
> without it. This verifies that vLLM's speculative decoding framework, when integrated with the vLLM forward pass and the vLLM rejection sampler,
|
|||
|
> provides a lossless guarantee. Almost all of the tests in [this directory](https://github.com/vllm-project/vllm/tree/b67ae00cdbbe1a58ffc8ff170f0c8d79044a684a/tests/spec_decode/e2e)
|
|||
|
> verify this property using [this assertion implementation](https://github.com/vllm-project/vllm/blob/b67ae00cdbbe1a58ffc8ff170f0c8d79044a684a/tests/spec_decode/e2e/conftest.py#L291)
|
|||
|
|
|||
|
3. **vLLM Logprob Stability**
|
|||
|
\- vLLM does not currently guarantee stable token log probabilities (logprobs). This can result in different outputs for the
|
|||
|
same request across runs. For more details, see the FAQ section
|
|||
|
titled *Can the output of a prompt vary across runs in vLLM?* in the {ref}`FAQs <faq>`.
|
|||
|
|
|||
|
**Conclusion**
|
|||
|
|
|||
|
While vLLM strives to ensure losslessness in speculative decoding, variations in generated outputs with and without speculative decoding
|
|||
|
can occur due to following factors:
|
|||
|
|
|||
|
- **Floating-Point Precision**: Differences in hardware numerical precision may lead to slight discrepancies in the output distribution.
|
|||
|
- **Batch Size and Numerical Stability**: Changes in batch size may cause variations in logprobs and output probabilities, potentially
|
|||
|
due to non-deterministic behavior in batched operations or numerical instability.
|
|||
|
|
|||
|
**Mitigation Strategies**
|
|||
|
|
|||
|
For mitigation strategies, please refer to the FAQ entry *Can the output of a prompt vary across runs in vLLM?* in the {ref}`FAQs <faq>`.
|
|||
|
|
|||
|
## Resources for vLLM contributors
|
|||
|
|
|||
|
- [A Hacker's Guide to Speculative Decoding in vLLM](https://www.youtube.com/watch?v=9wNAgpX6z_4)
|
|||
|
- [What is Lookahead Scheduling in vLLM?](https://docs.google.com/document/d/1Z9TvqzzBPnh5WHcRwjvK2UEeFeq5zMZb5mFE8jR0HCs/edit#heading=h.1fjfb0donq5a)
|
|||
|
- [Information on batch expansion](https://docs.google.com/document/d/1T-JaS2T1NRfdP51qzqpyakoCXxSXTtORppiwaj5asxA/edit#heading=h.kk7dq05lc6q8)
|
|||
|
- [Dynamic speculative decoding](https://github.com/vllm-project/vllm/issues/4565)
|