.. _spec_decode: Speculative decoding in vLLM ============================ .. warning:: Please note that speculative decoding in vLLM is not yet optimized and does not usually yield inter-token latency reductions for all prompt datasets or sampling parameters. The work to optimize it is ongoing and can be followed in `this issue. `_ This document shows how to use `Speculative Decoding `_ with vLLM. Speculative decoding is a technique which improves inter-token latency in memory-bound LLM inference. Speculating with a draft model ------------------------------ The following code configures vLLM in an offline mode to use speculative decoding with a draft model, speculating 5 tokens at a time. .. code-block:: python from vllm import LLM, SamplingParams prompts = [ "The future of AI is", ] sampling_params = SamplingParams(temperature=0.8, top_p=0.95) llm = LLM( model="facebook/opt-6.7b", tensor_parallel_size=1, speculative_model="facebook/opt-125m", num_speculative_tokens=5, use_v2_block_manager=True, ) outputs = llm.generate(prompts, sampling_params) for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") To perform the same with an online mode launch the server: .. code-block:: bash python -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port 8000 --model facebook/opt-6.7b \ --seed 42 -tp 1 --speculative_model facebook/opt-125m --use-v2-block-manager \ --num_speculative_tokens 5 --gpu_memory_utilization 0.8 Then use a client: .. code-block:: python from openai import OpenAI # Modify OpenAI's API key and API base to use vLLM's API server. openai_api_key = "EMPTY" openai_api_base = "http://localhost:8000/v1" client = OpenAI( # defaults to os.environ.get("OPENAI_API_KEY") api_key=openai_api_key, base_url=openai_api_base, ) models = client.models.list() model = models.data[0].id # Completion API stream = False completion = client.completions.create( model=model, prompt="The future of AI is", echo=False, n=1, stream=stream, ) print("Completion results:") if stream: for c in completion: print(c) else: print(completion) Speculating by matching n-grams in the prompt --------------------------------------------- The following code configures vLLM to use speculative decoding where proposals are generated by matching n-grams in the prompt. For more information read `this thread. `_ .. code-block:: python from vllm import LLM, SamplingParams prompts = [ "The future of AI is", ] sampling_params = SamplingParams(temperature=0.8, top_p=0.95) llm = LLM( model="facebook/opt-6.7b", tensor_parallel_size=1, speculative_model="[ngram]", num_speculative_tokens=5, ngram_prompt_lookup_max=4, use_v2_block_manager=True, ) outputs = llm.generate(prompts, sampling_params) for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") Speculating using MLP speculators --------------------------------- The following code configures vLLM to use speculative decoding where proposals are generated by draft models that conditioning draft predictions on both context vectors and sampled tokens. For more information see `this blog `_ or `this technical report `_. .. code-block:: python from vllm import LLM, SamplingParams prompts = [ "The future of AI is", ] sampling_params = SamplingParams(temperature=0.8, top_p=0.95) llm = LLM( model="meta-llama/Meta-Llama-3.1-70B-Instruct", tensor_parallel_size=4, speculative_model="ibm-fms/llama3-70b-accelerator", speculative_draft_tensor_parallel_size=1, use_v2_block_manager=True, ) outputs = llm.generate(prompts, sampling_params) for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") Note that these speculative models currently need to be run without tensor parallelism, although it is possible to run the main model using tensor parallelism (see example above). Since the speculative models are relatively small, we still see significant speedups. However, this limitation will be fixed in a future release. A variety of speculative models of this type are available on HF hub: * `llama-13b-accelerator `_ * `llama3-8b-accelerator `_ * `codellama-34b-accelerator `_ * `llama2-70b-accelerator `_ * `llama3-70b-accelerator `_ * `granite-3b-code-instruct-accelerator `_ * `granite-8b-code-instruct-accelerator `_ * `granite-7b-instruct-accelerator `_ * `granite-20b-code-instruct-accelerator `_ Resources for vLLM contributors ------------------------------- * `A Hacker's Guide to Speculative Decoding in vLLM `_ * `What is Lookahead Scheduling in vLLM? `_ * `Information on batch expansion `_ * `Dynamic speculative decoding `_