[V1][Spec Decode] Implement Eagle Proposer [1/N] (#15729)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-04-01 12:33:16 -07:00 committed by GitHub
parent a79cc68b3a
commit e75a6301bd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 378 additions and 21 deletions

View File

@ -2152,9 +2152,10 @@ class SpeculativeConfig:
# Replace hf_config for EAGLE draft_model
if self.method == "eagle":
if self.enable_chunked_prefill:
if self.enable_chunked_prefill and not envs.VLLM_USE_V1:
raise ValueError(
"Chunked prefill and EAGLE are not compatible.")
"Chunked prefill and EAGLE are not compatible "
"when using V0.")
from vllm.transformers_utils.configs.eagle import (
EAGLEConfig)

View File

@ -1468,15 +1468,21 @@ class EngineArgs:
# Only Ngram speculative decoding so far.
is_ngram_enabled = False
is_eagle_enabled = False
if self.speculative_config is not None:
# This is supported but experimental (handled below).
if (("method" in self.speculative_config
and self.speculative_config["method"] in ("ngram", "[ngram]"))
or
("model" in self.speculative_config and
self.speculative_config["model"] in ("ngram", "[ngram]"))):
speculative_method = self.speculative_config.get("method")
if speculative_method:
if speculative_method in ("ngram", "[ngram]"):
is_ngram_enabled = True
elif speculative_method == "eagle":
is_eagle_enabled = True
else:
speculative_model = self.speculative_config.get("model")
if speculative_model in ("ngram", "[ngram]"):
is_ngram_enabled = True
if not (is_ngram_enabled or is_eagle_enabled):
# Other speculative decoding methods are not supported yet.
_raise_or_fallback(feature_name="Speculative Decoding",
recommend_to_remove=False)
return False
@ -1523,6 +1529,10 @@ class EngineArgs:
if is_ngram_enabled and _warn_or_fallback("ngram"):
return False
# Eagle is under development, so we don't support it yet.
if is_eagle_enabled and _warn_or_fallback("Eagle"):
return False
# Non-CUDA is supported on V1, but off by default for now.
not_cuda = not current_platform.is_cuda()
if not_cuda and _warn_or_fallback( # noqa: SIM103

View File

@ -0,0 +1,262 @@
# SPDX-License-Identifier: Apache-2.0
import torch
import torch.nn as nn
import triton
import triton.language as tl
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.sample.metadata import SamplingMetadata
class EagleProposer:
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
):
self.vllm_config = vllm_config
self.num_speculative_tokens = (
vllm_config.speculative_config.num_speculative_tokens)
self.block_size = vllm_config.cache_config.block_size
self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs,
device=device)
def propose(
self,
# [num_tokens]
target_token_ids: torch.Tensor,
# [num_tokens]
target_positions: torch.Tensor,
# [num_tokens, hidden_size]
target_hidden_states: torch.Tensor,
# [num_tokens]
target_slot_mapping: torch.Tensor,
# [batch_size]
next_token_ids: torch.Tensor,
# [batch_size + 1] starting with 0
cu_num_tokens: torch.Tensor,
# [batch_size, max_num_blocks_per_req]
block_table: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> tuple[torch.Tensor, torch.Tensor]:
num_tokens = target_token_ids.shape[0]
batch_size = next_token_ids.shape[0]
last_token_indices = cu_num_tokens[1:] - 1
input_ids = torch.empty_like(target_token_ids)
# Shift the input ids by one token.
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
input_ids[:-1] = target_token_ids[1:]
# Replace the last token with the next token.
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
input_ids[last_token_indices] = next_token_ids
seq_lens = target_positions[last_token_indices] + 1
# FIXME(woosuk): The below two ops cause synchronization. Optimize.
max_seq_len = seq_lens.max().item()
max_num_tokens = (cu_num_tokens[1:] - cu_num_tokens[:-1]).max().item()
attn_metadata = FlashAttentionMetadata(
num_actual_tokens=num_tokens,
max_query_len=max_num_tokens,
query_start_loc=cu_num_tokens,
max_seq_len=max_seq_len,
seq_lens=seq_lens,
block_table=block_table,
slot_mapping=target_slot_mapping,
# TODO(woosuk): Support cascade attention.
use_cascade=False,
common_prefix_len=0,
cu_prefix_query_lens=None,
prefix_kv_lens=None,
suffix_kv_lens=None,
)
with set_forward_context(attn_metadata, self.vllm_config):
hidden_states = self.model(
input_ids=input_ids,
hidden_states=target_hidden_states,
positions=target_positions,
)
sample_hidden_states = hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
draft_token_ids, draft_probs = compute_probs_and_sample_next_token(
logits, sampling_metadata)
# Early exit if there is only one draft token to be generated.
if self.num_speculative_tokens == 1:
# [batch_size, 1] and [batch_size, 1, vocab_size]
return draft_token_ids.view(-1, 1), draft_probs.unsqueeze(dim=1)
# Generate the remaining draft tokens.
draft_token_ids_list = [draft_token_ids]
draft_probs_list = [draft_probs]
positions = target_positions[last_token_indices]
hidden_states = sample_hidden_states
attn_metadata.num_actual_tokens = batch_size
attn_metadata.max_query_len = 1
attn_metadata.query_start_loc = self.arange[:batch_size]
for _ in range(self.num_speculative_tokens - 1):
# Update the inputs.
input_ids = draft_token_ids_list[-1]
positions += 1
attn_metadata.max_seq_len += 1
attn_metadata.seq_lens += 1
# Compute the slot mapping.
block_numbers = positions // self.block_size
block_ids = block_table.gather(dim=1,
index=block_numbers.view(-1, 1))
block_ids = block_ids.view(-1)
attn_metadata.slot_mapping = (block_ids * self.block_size +
positions % self.block_size)
# Run the model.
with set_forward_context(attn_metadata, self.vllm_config):
hidden_states = self.model(
input_ids=input_ids,
hidden_states=hidden_states,
positions=positions,
)
logits = self.model.compute_logits(hidden_states, None)
draft_token_ids, probs = compute_probs_and_sample_next_token(
logits, sampling_metadata)
draft_token_ids_list.append(draft_token_ids)
draft_probs_list.append(probs)
# [batch_size, num_speculative_tokens]
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
# [batch_size, num_speculative_tokens, vocab_size]
draft_probs = torch.stack(draft_probs_list, dim=1)
return draft_token_ids, draft_probs
@staticmethod
def prepare_inputs(
# [batch_size + 1]
cu_target_query_lens: torch.Tensor,
# [batch_size]
num_rejected_tokens: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
# cu_target_query_lens: [0, a, a + b, a + b + c]
# num_rejected_tokens: [n1, n2, n3]
# num_tokens_per_req: [a - n1, b - n2, c - n3]
# cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
# token_indices: [0, 1, ..., a - n1 - 1,
# a, a + 1, ..., a + b - n2 - 1,
# a + b, a + b + 1, ..., a + b + c - n3 - 1]
# [0, a, a + b, a + b + c] -> [a, b, c]
query_len_per_req = (cu_target_query_lens[1:] -
cu_target_query_lens[:-1])
# [a, b, c] -> [a - n1, b - n2, c - n3]
num_tokens_per_req = query_len_per_req - num_rejected_tokens
cu_num_tokens = torch.empty_like(cu_target_query_lens)
torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:])
cu_num_tokens[0] = 0
# FIXME(woosuk): Avoid synchronization.
num_tokens = cu_num_tokens[-1].item()
token_indices = torch.empty(
num_tokens,
dtype=torch.int32,
device=cu_num_tokens.device,
)
batch_size = num_rejected_tokens.shape[0]
BLOCK_SIZE = 1024
prepare_input_kernel[(batch_size, )](
token_indices,
cu_target_query_lens,
cu_num_tokens,
BLOCK_SIZE=BLOCK_SIZE,
)
return cu_num_tokens, token_indices
def load_model(self, target_model: nn.Module) -> None:
self.model = DummyEagleModel()
self.model.get_input_embeddings = target_model.get_input_embeddings
self.model.compute_logits = target_model.compute_logits
# FIXME(woosuk): This is a dummy model for testing.
# Remove this once we have a real model.
class DummyEagleModel(nn.Module):
def __init__(self):
super().__init__()
def forward(
self,
input_ids: torch.Tensor,
hidden_states: torch.Tensor,
positions: torch.Tensor,
) -> torch.Tensor:
input_embeddings = self.get_input_embeddings(input_ids)
return hidden_states + input_embeddings # Dummy return.
# FIXME(woosuk): The logic here is duplicated with the main sampling code.
# We should refactor this to reuse the same sampling implementation.
def compute_probs_and_sample_next_token(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> tuple[torch.Tensor, torch.Tensor]:
if sampling_metadata.all_greedy:
# For greedy requests, draft_probs is not used in rejection sampling.
# Therefore, we can just return the logits.
probs = logits
next_token_ids = logits.argmax(dim=-1)
return next_token_ids, probs
is_greedy = sampling_metadata.temperature == -1
temperature = torch.where(is_greedy, 1.0, sampling_metadata.temperature)
logits.div_(temperature.view(-1, 1))
probs = logits.softmax(dim=-1, dtype=torch.float32)
# NOTE(woosuk): Currently, we ignore most of the sampling parameters in
# generating the draft tokens. We only use the temperature. While this
# could degrade the acceptance rate, it does not affect the distribution
# of the generated tokens after rejection sampling.
# TODO(woosuk): Consider seeds.
q = torch.empty_like(probs)
q.exponential_()
next_token_ids = probs.div_(q).argmax(dim=-1).view(-1)
if not sampling_metadata.all_random:
greedy_token_ids = probs.argmax(dim=-1)
next_token_ids = torch.where(
is_greedy,
greedy_token_ids,
next_token_ids,
)
return next_token_ids, probs
@triton.jit
def prepare_input_kernel(
out_ptr,
cu_query_lens_ptr,
cu_num_tokens_ptr,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
# [start_pos, end_pos)
start_pos = tl.load(cu_num_tokens_ptr + pid)
end_pos = tl.load(cu_num_tokens_ptr + pid + 1)
num_tokens = end_pos - start_pos
index_start = tl.load(cu_query_lens_ptr + pid)
indices = index_start + tl.arange(0, BLOCK_SIZE)
num_blocks = tl.cdiv(num_tokens, BLOCK_SIZE)
for i in tl.range(num_blocks):
offset = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
tl.store(
out_ptr + start_pos + offset,
indices,
mask=offset < num_tokens,
)

View File

@ -4,9 +4,14 @@ from typing import Optional
import numpy as np
from numba import jit
from vllm.config import VllmConfig
class NgramProposer:
def __init__(self, vllm_config: VllmConfig):
self.vllm_config = vllm_config
def propose(
self,
context_token_ids: np.ndarray,
@ -50,6 +55,10 @@ class NgramProposer:
return result
return None
def load_model(self, *args, **kwargs):
# No model to load.
pass
@jit(nopython=True)
def _kmp_lps_array(pattern: np.ndarray) -> np.ndarray:

View File

@ -39,9 +39,18 @@ class CachedRequestState:
lora_request: Optional[LoRARequest] = None
def __post_init__(self):
self.num_prompt_tokens = len(self.prompt_token_ids)
@property
def num_tokens(self) -> int:
return len(self.prompt_token_ids) + len(self.output_token_ids)
return self.num_prompt_tokens + len(self.output_token_ids)
def get_token_id(self, idx: int) -> int:
if idx < self.num_prompt_tokens:
return self.prompt_token_ids[idx]
else:
return self.output_token_ids[idx - self.num_prompt_tokens]
class InputBatch:

View File

@ -35,6 +35,7 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
ModelRunnerOutput)
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import RejectionSampler
from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.spec_decode.utils import is_spec_decode_supported
@ -157,18 +158,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.use_spec_decode = False
if self.speculative_config:
self.use_spec_decode = True
assert self.speculative_config.method == "ngram", \
"Currently, only ngram spec decode is supported in V1."
if get_pp_group().is_last_rank:
self.drafter = NgramProposer()
# Trigger Numba JIT compilation for N-gram proposer.
# This usually takes less than 1 second.
self.drafter.propose(
np.zeros(1024, dtype=np.int32),
self.speculative_config.prompt_lookup_min,
self.speculative_config.prompt_lookup_max,
self.speculative_config.num_speculative_tokens,
)
if self.speculative_config.method == "ngram":
self.drafter = NgramProposer(self.vllm_config)
elif self.speculative_config.method == "eagle":
self.drafter = EagleProposer(self.vllm_config,
self.device) # type: ignore
else:
raise ValueError("Unknown speculative decoding method: "
f"{self.speculative_config.method}")
self.rejection_sampler = RejectionSampler()
# Request states.
@ -1144,10 +1142,75 @@ class GPUModelRunner(LoRAModelRunnerMixin):
valid_sampled_token_ids[i].clear()
if not self.use_spec_decode:
# Speculative decoding is not enabled.
spec_token_ids = None
else:
elif self.speculative_config.method == "ngram":
assert isinstance(self.drafter, NgramProposer)
spec_token_ids = self.generate_draft_token_ids(
valid_sampled_token_ids, sampling_metadata)
elif self.speculative_config.method == "eagle":
assert isinstance(self.drafter, EagleProposer)
# TODO(woosuk): Refactor the loop.
next_token_ids: list[int] = []
for i, token_ids in enumerate(valid_sampled_token_ids):
if token_ids:
# Common case.
next_token_id = token_ids[-1]
else:
# Partial prefill (rare case).
# Get the next token id from the request state.
req_id = self.input_batch.req_ids[i]
req_state = self.requests[req_id]
seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])
next_token_id = req_state.get_token_id(seq_len)
next_token_ids.append(next_token_id)
next_token_ids = torch.tensor(next_token_ids,
dtype=torch.int32,
device=self.device)
if spec_decode_metadata is None:
# input_ids can be None for multimodal models.
target_token_ids = self.input_ids[:num_scheduled_tokens]
target_positions = positions
target_hidden_states = hidden_states
target_slot_mapping = attn_metadata.slot_mapping
cu_num_tokens = attn_metadata.query_start_loc
else:
# TODO(woosuk): Refactor this.
num_draft_tokens = spec_decode_metadata.num_draft_tokens
num_rejected_tokens = [
n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0
for i, n in enumerate(num_draft_tokens)
]
num_rejected_tokens = torch.tensor(
num_rejected_tokens,
dtype=torch.int32,
device=self.device,
)
cu_num_tokens, token_indices = self.drafter.prepare_inputs(
attn_metadata.query_start_loc,
num_rejected_tokens,
)
target_token_ids = self.input_ids[token_indices]
target_positions = positions[token_indices]
target_hidden_states = hidden_states[token_indices]
target_slot_mapping = attn_metadata.slot_mapping[token_indices]
draft_token_ids, draft_probs = self.drafter.propose(
target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
target_slot_mapping=target_slot_mapping,
next_token_ids=next_token_ids,
cu_num_tokens=cu_num_tokens,
block_table=attn_metadata.block_table,
sampling_metadata=sampling_metadata,
)
spec_token_ids = draft_token_ids.tolist()
# TODO(woosuk): Cache draft_probs and use it for rejection sampling
# in the next step.
del draft_probs
return ModelRunnerOutput(
req_ids=self.input_batch.req_ids,
@ -1205,6 +1268,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.scheduler_config,
self.lora_config,
self.device)
if hasattr(self, "drafter"):
logger.info("Loading drafter model...")
self.drafter.load_model(self.model)
time_after_load = time.perf_counter()
self.model_memory_usage = m.consumed_memory
logger.info("Model loading took %.4f GiB and %.6f seconds",