[V1][Spec Decode] Optimize Rejection Sampler with Triton Kernels (#14930)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-03-18 14:31:54 -07:00 committed by GitHub
parent 3a1e648158
commit 99abb8b650
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 875 additions and 408 deletions

View File

@ -6,20 +6,23 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID, RejectionSampler from vllm.v1.sample.rejection_sampler import (PLACEHOLDER_TOKEN_ID,
RejectionSampler)
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
DEVICE = "cpu" DEVICE = "cuda"
@pytest.fixture @pytest.fixture
def sampler(): def rejection_sampler():
return RejectionSampler() return RejectionSampler()
def create_logits_tensor(token_ids: list[list[int]], def create_logits_tensor(output_token_ids: list[list[int]],
vocab_size: int = 100) -> torch.Tensor: vocab_size: int = 100) -> torch.Tensor:
"""Helper function to create logits tensor that """Helper function to create logits tensor that
will produce desired token ids on argmax""" will produce desired token ids on argmax"""
token_ids = [tokens[:-1] for tokens in output_token_ids]
num_total_tokens = sum(len(tokens) for tokens in token_ids) num_total_tokens = sum(len(tokens) for tokens in token_ids)
logits = torch.full((num_total_tokens, vocab_size), -100.0, device=DEVICE) logits = torch.full((num_total_tokens, vocab_size), -100.0, device=DEVICE)
start_loc = 0 start_loc = 0
@ -32,14 +35,21 @@ def create_logits_tensor(token_ids: list[list[int]],
def create_sampling_metadata( def create_sampling_metadata(
all_greedy: bool, all_greedy: bool,
generators: Optional[dict[int, Any]] = None) -> SamplingMetadata: temperature: Optional[torch.Tensor] = None,
generators: Optional[dict[int, Any]] = None,
) -> SamplingMetadata:
"""Create a v1 sampling metadata object with all_greedy set """Create a v1 sampling metadata object with all_greedy set
to the given value. Either all greedy or all random sampling to the given value. Either all greedy or all random sampling
is used. is used.
""" """
generators = generators or {} generators = generators or {}
if all_greedy:
temperature = None
else:
assert temperature is not None
return SamplingMetadata( return SamplingMetadata(
temperature=torch.tensor([]), temperature=temperature,
all_greedy=all_greedy, all_greedy=all_greedy,
all_random=not all_greedy, all_random=not all_greedy,
top_p=None, top_p=None,
@ -61,7 +71,7 @@ def create_sampling_metadata(
########################### Tests for Greedy Sampling ################### ########################### Tests for Greedy Sampling ###################
def test_perfect_match(sampler): def test_perfect_match(rejection_sampler):
"""Test when output tokens perfectly match speculated tokens""" """Test when output tokens perfectly match speculated tokens"""
spec_tokens = [[1, 2, 3]] spec_tokens = [[1, 2, 3]]
output_tokens = [[1, 2, 3, 4]] # 4 is the bonus token output_tokens = [[1, 2, 3, 4]] # 4 is the bonus token
@ -70,15 +80,23 @@ def test_perfect_match(sampler):
logits = create_logits_tensor(output_tokens) logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor([output_tokens[0][-1]], bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
device=logits.device) device=logits.device)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
device=logits.device)
output = sampler(spec_tokens, None, bonus_token_tensor, logits, metadata) output = rejection_sampler(
spec_decode_metadata,
draft_probs=None,
target_logits=logits,
bonus_token_ids=bonus_token_tensor,
sampling_metadata=metadata,
)
expected = torch.tensor([[1, 2, 3, 4]], expected = torch.tensor([[1, 2, 3, 4]],
dtype=torch.int, dtype=torch.int,
device=logits.device) device=logits.device)
assert torch.equal(output, expected) assert torch.equal(output, expected)
def test_early_mismatch(sampler): def test_early_mismatch(rejection_sampler):
"""Test when there's an early mismatch in tokens""" """Test when there's an early mismatch in tokens"""
spec_tokens = [[1, 2, 3]] spec_tokens = [[1, 2, 3]]
output_tokens = [[1, 5, 3, 4]] # Mismatch at position 1 output_tokens = [[1, 5, 3, 4]] # Mismatch at position 1
@ -87,15 +105,25 @@ def test_early_mismatch(sampler):
logits = create_logits_tensor(output_tokens) logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor([output_tokens[0][-1]], bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
device=logits.device) device=logits.device)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
output = sampler(spec_tokens, None, bonus_token_tensor, logits, metadata)
expected = torch.tensor([[1, 5, INVALID_TOKEN_ID, INVALID_TOKEN_ID]],
dtype=torch.int,
device=logits.device) device=logits.device)
output = rejection_sampler(
spec_decode_metadata,
draft_probs=None,
target_logits=logits,
bonus_token_ids=bonus_token_tensor,
sampling_metadata=metadata,
)
expected = torch.tensor(
[[1, 5, PLACEHOLDER_TOKEN_ID, PLACEHOLDER_TOKEN_ID]],
dtype=torch.int,
device=logits.device,
)
assert torch.equal(output, expected) assert torch.equal(output, expected)
def test_multiple_sequences(sampler): def test_multiple_sequences(rejection_sampler):
"""Test handling multiple sequences of speculated tokens""" """Test handling multiple sequences of speculated tokens"""
spec_tokens = [[1, 2], [3]] spec_tokens = [[1, 2], [3]]
output_tokens = [[1, 2, 5], [3, output_tokens = [[1, 2, 5], [3,
@ -105,15 +133,23 @@ def test_multiple_sequences(sampler):
logits = create_logits_tensor(output_tokens) logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor( bonus_token_tensor = torch.tensor(
[output_tokens[0][-1], output_tokens[1][-1]], device=logits.device) [output_tokens[0][-1], output_tokens[1][-1]], device=logits.device)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
device=logits.device)
output = sampler(spec_tokens, None, bonus_token_tensor, logits, metadata) output = rejection_sampler(
expected = torch.tensor([[1, 2, 5], [3, 4, INVALID_TOKEN_ID]], spec_decode_metadata,
draft_probs=None,
target_logits=logits,
bonus_token_ids=bonus_token_tensor,
sampling_metadata=metadata,
)
expected = torch.tensor([[1, 2, 5], [3, 4, PLACEHOLDER_TOKEN_ID]],
dtype=torch.int, dtype=torch.int,
device=logits.device) device=logits.device)
assert torch.equal(output, expected) assert torch.equal(output, expected)
def test_single_token_sequence(sampler): def test_single_token_sequence(rejection_sampler):
"""Test handling sequences with single token""" """Test handling sequences with single token"""
spec_tokens = [[1]] spec_tokens = [[1]]
output_tokens = [[1, 2]] # Single token with bonus token 2 output_tokens = [[1, 2]] # Single token with bonus token 2
@ -122,13 +158,21 @@ def test_single_token_sequence(sampler):
logits = create_logits_tensor(output_tokens) logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor([output_tokens[0][-1]], bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
device=logits.device) device=logits.device)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
device=logits.device)
output = sampler(spec_tokens, None, bonus_token_tensor, logits, metadata) output = rejection_sampler(
spec_decode_metadata,
draft_probs=None,
target_logits=logits,
bonus_token_ids=bonus_token_tensor,
sampling_metadata=metadata,
)
expected = torch.tensor([[1, 2]], dtype=torch.int, device=logits.device) expected = torch.tensor([[1, 2]], dtype=torch.int, device=logits.device)
assert torch.equal(output, expected) assert torch.equal(output, expected)
def test_empty_sequence(sampler): def test_empty_sequence(rejection_sampler):
"""Test handling empty sequence of speculated tokens""" """Test handling empty sequence of speculated tokens"""
spec_tokens: list[list[int]] = [[]] spec_tokens: list[list[int]] = [[]]
output_tokens = [[5]] # Just the bonus token output_tokens = [[5]] # Just the bonus token
@ -137,13 +181,21 @@ def test_empty_sequence(sampler):
logits = create_logits_tensor(output_tokens) logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor([output_tokens[0][-1]], bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
device=logits.device) device=logits.device)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
device=logits.device)
output = sampler(spec_tokens, None, bonus_token_tensor, logits, metadata) output = rejection_sampler(
spec_decode_metadata,
draft_probs=None,
target_logits=logits,
bonus_token_ids=bonus_token_tensor,
sampling_metadata=metadata,
)
expected = torch.tensor([[5]], dtype=torch.int, device=logits.device) expected = torch.tensor([[5]], dtype=torch.int, device=logits.device)
assert torch.equal(output, expected) assert torch.equal(output, expected)
def test_multiple_mismatches(sampler): def test_multiple_mismatches(rejection_sampler):
"""Test handling multiple sequences with mismatches""" """Test handling multiple sequences with mismatches"""
spec_tokens = [[1, 2, 3], [4, 5, 6]] spec_tokens = [[1, 2, 3], [4, 5, 6]]
output_tokens = [[1, 2, 7, 6], [4, 8, 6, output_tokens = [[1, 2, 7, 6], [4, 8, 6,
@ -153,12 +205,22 @@ def test_multiple_mismatches(sampler):
logits = create_logits_tensor(output_tokens) logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor( bonus_token_tensor = torch.tensor(
[output_tokens[0][-1], output_tokens[1][-1]], device=logits.device) [output_tokens[0][-1], output_tokens[1][-1]], device=logits.device)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
output = sampler(spec_tokens, None, bonus_token_tensor, logits, metadata)
expected = torch.tensor([[1, 2, 7, INVALID_TOKEN_ID],
[4, 8, INVALID_TOKEN_ID, INVALID_TOKEN_ID]],
dtype=torch.int,
device=logits.device) device=logits.device)
output = rejection_sampler(
spec_decode_metadata,
draft_probs=None,
target_logits=logits,
bonus_token_ids=bonus_token_tensor,
sampling_metadata=metadata,
)
expected = torch.tensor(
[[1, 2, 7, PLACEHOLDER_TOKEN_ID],
[4, 8, PLACEHOLDER_TOKEN_ID, PLACEHOLDER_TOKEN_ID]],
dtype=torch.int,
device=logits.device,
)
assert torch.equal(output, expected) assert torch.equal(output, expected)
@ -166,18 +228,27 @@ def test_multiple_mismatches(sampler):
"spec_tokens,output_tokens,expected", "spec_tokens,output_tokens,expected",
[ [
([[1, 2]], [[1, 2, 3]], [[1, 2, 3]]), # Perfect match with bonus ([[1, 2]], [[1, 2, 3]], [[1, 2, 3]]), # Perfect match with bonus
([[1]], [[2, 3]], [[2, INVALID_TOKEN_ID]]), # First mismatch ([[1]], [[2, 3]], [[2, PLACEHOLDER_TOKEN_ID]]), # First mismatch
([[1, 2], [3, 4]], [[1, 5, 6], [3, 4, 7]], ([[1, 2], [3, 4]], [[1, 5, 6], [3, 4, 7]],
[[1, 5, INVALID_TOKEN_ID], [3, 4, 7]]), # Mixed matches [[1, 5, PLACEHOLDER_TOKEN_ID], [3, 4, 7]]), # Mixed matches
]) ])
def test_parametrized_cases(sampler, spec_tokens, output_tokens, expected): def test_parametrized_cases(rejection_sampler, spec_tokens, output_tokens,
expected):
"""Parametrized test for various matching scenarios""" """Parametrized test for various matching scenarios"""
metadata = create_sampling_metadata(all_greedy=True) metadata = create_sampling_metadata(all_greedy=True)
logits = create_logits_tensor(output_tokens) logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor([tokens[-1] for tokens in output_tokens], bonus_token_tensor = torch.tensor([tokens[-1] for tokens in output_tokens],
device=logits.device) device=logits.device)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
device=logits.device)
output = sampler(spec_tokens, None, bonus_token_tensor, logits, metadata) output = rejection_sampler(
spec_decode_metadata,
draft_probs=None,
target_logits=logits,
bonus_token_ids=bonus_token_tensor,
sampling_metadata=metadata,
)
expected_tensor = torch.tensor(expected, expected_tensor = torch.tensor(expected,
dtype=torch.int, dtype=torch.int,
device=logits.device) device=logits.device)
@ -190,21 +261,31 @@ def test_parametrized_cases(sampler, spec_tokens, output_tokens, expected):
@pytest.mark.parametrize("batch_size", [1, 4, 8]) @pytest.mark.parametrize("batch_size", [1, 4, 8])
@pytest.mark.parametrize("frac_seeded", [0.0, 0.5]) @pytest.mark.parametrize("frac_seeded", [0.0, 0.5])
@pytest.mark.parametrize("n_rep", [20]) @pytest.mark.parametrize("n_rep", [20])
def test_deterministic_when_seeded(sampler, k: int, vocab_size: int, def test_deterministic_when_seeded(
batch_size: int, frac_seeded: float, rejection_sampler,
n_rep: int): k: int,
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) vocab_size: int,
target_probs = torch.rand(batch_size * (k + 1), batch_size: int,
frac_seeded: float,
n_rep: int,
):
num_tokens = batch_size * k
draft_probs = torch.rand(num_tokens,
vocab_size, vocab_size,
dtype=torch.float32) dtype=torch.float32,
device=DEVICE)
draft_probs = F.softmax(draft_probs, dim=-1)
target_logits = torch.rand_like(draft_probs)
bonus_token_ids = torch.randint(low=0, bonus_token_ids = torch.randint(low=0,
high=vocab_size, high=vocab_size,
size=(batch_size, 1), size=(batch_size, 1),
dtype=torch.int64) dtype=torch.int64,
device=DEVICE)
draft_token_ids = torch.randint(low=0, draft_token_ids = torch.randint(low=0,
high=vocab_size, high=vocab_size,
size=(batch_size, k), size=(batch_size, k),
dtype=torch.int64) dtype=torch.int64,
device=DEVICE)
seeded_mask = torch.rand(batch_size, dtype=torch.float32) <= frac_seeded seeded_mask = torch.rand(batch_size, dtype=torch.float32) <= frac_seeded
@ -215,10 +296,21 @@ def test_deterministic_when_seeded(sampler, k: int, vocab_size: int,
for i in range(batch_size) if seeded_mask[i] for i in range(batch_size) if seeded_mask[i]
} }
temperature = torch.ones(batch_size,
dtype=torch.float32,
device=DEVICE)
sampling_metadata = create_sampling_metadata(all_greedy=False, sampling_metadata = create_sampling_metadata(all_greedy=False,
temperature=temperature,
generators=seeded_seqs) generators=seeded_seqs)
rep_result = sampler(draft_token_ids.tolist(), draft_probs, spec_decode_metadata = SpecDecodeMetadata.make_dummy(
bonus_token_ids, target_probs, sampling_metadata) draft_token_ids.tolist(), device=DEVICE)
rep_result = rejection_sampler(
spec_decode_metadata,
draft_probs=draft_probs,
target_logits=target_logits,
bonus_token_ids=bonus_token_ids,
sampling_metadata=sampling_metadata,
)
results.append(rep_result) results.append(rep_result)
@ -257,10 +349,10 @@ def test_rejection_sampling_approximates_target_distribution():
num_reference_probs = 100 num_reference_probs = 100
# Prepare draft, target, and reference probability distributions # Prepare draft, target, and reference probability distributions
draft_probs, target_probs = (F.softmax( draft_probs = F.softmax(torch.rand(vocab_size, dtype=torch.float32),
torch.rand(vocab_size, dtype=torch.float32), dim=-1)
dim=-1, target_logits = torch.rand(vocab_size, dtype=torch.float32)
) for _ in range(2)) target_probs = F.softmax(target_logits, dim=-1)
reference_probs = F.softmax( reference_probs = F.softmax(
torch.rand(num_reference_probs, vocab_size, dtype=torch.float32), torch.rand(num_reference_probs, vocab_size, dtype=torch.float32),
dim=-1, dim=-1,
@ -273,7 +365,7 @@ def test_rejection_sampling_approximates_target_distribution():
for num_samples in sample_sizes: for num_samples in sample_sizes:
# Sample using rejection sampling. # Sample using rejection sampling.
rej_sample_probs = estimate_rejection_sampling_pdf( rej_sample_probs = estimate_rejection_sampling_pdf(
draft_probs, target_probs, k, vocab_size, num_samples) draft_probs, target_logits, k, vocab_size, num_samples)
rej_sample_probs = rej_sample_probs.to(DEVICE) rej_sample_probs = rej_sample_probs.to(DEVICE)
# Average distance from reference probs. # Average distance from reference probs.
@ -313,7 +405,7 @@ def get_ratio_first_to_last(elements: list[float]) -> float:
def estimate_rejection_sampling_pdf( def estimate_rejection_sampling_pdf(
draft_probs: torch.Tensor, draft_probs: torch.Tensor,
target_probs: torch.Tensor, target_logits: torch.Tensor,
k: int, k: int,
vocab_size: int, vocab_size: int,
num_samples: int, num_samples: int,
@ -323,35 +415,44 @@ def estimate_rejection_sampling_pdf(
Args: Args:
draft_probs: Draft probability distribution. draft_probs: Draft probability distribution.
target_probs: Target probability distribution. target_logits: Target logits.
num_samples: Number of samples to draw. num_samples: Number of samples to draw.
Returns: Returns:
Estimated probability distribution of the output tokens. Estimated probability distribution of the output tokens.
""" """
sampler = RejectionSampler() rejection_sampler = RejectionSampler()
# Repeat draft probs num_samples times. num_tokens = num_samples * k
# Repeat draft probs num_samples * k times.
draft_probs = draft_probs.reshape(1, 1, draft_probs = draft_probs.reshape(1, 1,
vocab_size).repeat(num_samples, k, 1) vocab_size).repeat(num_samples, k, 1)
# Repeat target probs num_samples * (k + 1) times. # Repeat target probs num_tokens times.
target_probs = target_probs.reshape(1, 1, vocab_size).repeat( target_logits = target_logits.reshape(1, vocab_size).repeat(num_tokens, 1)
num_samples, k + 1, 1).reshape(num_samples * (k + 1), vocab_size)
# Randomly sample draft token ids from draft probs. # Randomly sample draft token ids from draft probs.
draft_token_ids = torch.multinomial(draft_probs[:, 0, :], draft_token_ids = torch.multinomial(draft_probs[:, 0, :],
num_samples=k, num_samples=k,
replacement=True).reshape( replacement=True).reshape(
num_samples, k) num_samples, k)
draft_probs = draft_probs.view(num_tokens, vocab_size)
# Bonus tokens not used but required. # Bonus tokens not used but required.
bonus_token_ids = torch.zeros((1, 1), dtype=torch.int64, bonus_token_ids = torch.zeros((1, 1), dtype=torch.int64,
device=DEVICE).repeat(num_samples, 1) device=DEVICE).repeat(num_samples, 1)
sampling_metadata = create_sampling_metadata(all_greedy=False) temperature = torch.ones(num_samples, dtype=torch.float32, device=DEVICE)
output_token_ids = sampler(draft_token_ids.tolist(), draft_probs, sampling_metadata = create_sampling_metadata(all_greedy=False,
bonus_token_ids, target_probs, temperature=temperature)
sampling_metadata) spec_decode_metadata = SpecDecodeMetadata.make_dummy(
draft_token_ids.tolist(), device=bonus_token_ids.device)
output_token_ids = rejection_sampler(
spec_decode_metadata,
draft_probs=draft_probs,
target_logits=target_logits,
bonus_token_ids=bonus_token_ids,
sampling_metadata=sampling_metadata,
)
output_token_ids = output_token_ids[:, :-1].flatten() output_token_ids = output_token_ids[:, :-1].flatten()
hist = torch.histogram(output_token_ids.to(dtype=torch.float, hist = torch.histogram(output_token_ids.to(dtype=torch.float,

View File

@ -35,7 +35,6 @@ if TYPE_CHECKING:
VLLM_TRACE_FUNCTION: int = 0 VLLM_TRACE_FUNCTION: int = 0
VLLM_ATTENTION_BACKEND: Optional[str] = None VLLM_ATTENTION_BACKEND: Optional[str] = None
VLLM_USE_FLASHINFER_SAMPLER: Optional[bool] = None VLLM_USE_FLASHINFER_SAMPLER: Optional[bool] = None
VLLM_USE_FLASHINFER_REJECTION_SAMPLER: bool = False
VLLM_FLASHINFER_FORCE_TENSOR_CORES: bool = False VLLM_FLASHINFER_FORCE_TENSOR_CORES: bool = False
VLLM_PP_LAYER_PARTITION: Optional[str] = None VLLM_PP_LAYER_PARTITION: Optional[str] = None
VLLM_CPU_KVCACHE_SPACE: int = 0 VLLM_CPU_KVCACHE_SPACE: int = 0

View File

@ -46,7 +46,7 @@ class SamplerOutput:
# [num_reqs, max_num_generated_tokens] # [num_reqs, max_num_generated_tokens]
# Different requests can have different number of generated tokens. # Different requests can have different number of generated tokens.
# All requests are padded to max_num_generated_tokens. # All requests are padded to max_num_generated_tokens.
# INVALID_TOKEN_ID (-1 by default) is used for padding. # PLACEHOLDER_TOKEN_ID (-1 by default) is used for padding.
sampled_token_ids: torch.Tensor sampled_token_ids: torch.Tensor
logprobs_tensors: Optional[LogprobsTensors] logprobs_tensors: Optional[LogprobsTensors]

View File

@ -0,0 +1,30 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Union
import torch
def compiled_softmax(
logits: torch.Tensor,
temperature: Union[float, torch.Tensor] = 1.0,
) -> torch.Tensor:
"""Faster softmax kernel generated by torch.compile.
Args:
logits: [n, vocab_size]
temperature: [n] or float
"""
# NOTE(woosuk): Avoid recompilation by marking the first dim as dynamic.
torch._dynamo.mark_dynamic(logits, index=0)
if isinstance(temperature, torch.Tensor):
torch._dynamo.mark_dynamic(temperature, index=0)
return _softmax(logits, temperature)
@torch.compile
def _softmax(
logits: torch.Tensor,
temperature: Union[float, torch.Tensor],
) -> torch.Tensor:
logits = logits / temperature
return torch.softmax(logits, dim=-1, dtype=torch.float32)

View File

@ -3,14 +3,21 @@ from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence import triton
import triton.language as tl
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.utils import random_sample from vllm.v1.sample.ops.utils import compiled_softmax
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
logger = init_logger(__name__) logger = init_logger(__name__)
INVALID_TOKEN_ID = -1
PLACEHOLDER_TOKEN_ID: tl.constexpr = -1
GREEDY_TEMPERATURE: tl.constexpr = -1
# Maximum number of speculative draft tokens allowed per request in a single
# step. This value is chosen to be large enough to handle typical use cases.
MAX_SPEC_LEN = 32
class RejectionSampler(nn.Module): class RejectionSampler(nn.Module):
@ -36,29 +43,30 @@ class RejectionSampler(nn.Module):
output tokens = accepted tokens + recovered tokens + bonus tokens output tokens = accepted tokens + recovered tokens + bonus tokens
""" """
def __init__(self):
super().__init__()
def forward( def forward(
self, self,
draft_token_ids: list[list[int]], metadata: SpecDecodeMetadata,
# [num_tokens, vocab_size]
draft_probs: Optional[torch.Tensor], draft_probs: Optional[torch.Tensor],
bonus_token_ids_tensor: torch.Tensor, # [batch_size, 1] # [num_tokens, vocab_size]
target_probs: torch.Tensor, # [num_total_tokens, vocab_size] target_logits: torch.Tensor,
# [batch_size, 1]
bonus_token_ids: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
''' '''
Args: Args:
draft_token_ids (List[List[int]]): metadata:
A 2D list of token IDs for each request in the batch. Metadata for spec decoding.
Each request might have different number of draft tokens.
It may also contain empty lists for requests that have
no draft tokens.
draft_probs (Optional[torch.Tensor]): draft_probs (Optional[torch.Tensor]):
Probability distribution for the draft tokens. Shape is Probability distribution for the draft tokens. Shape is
[batch_size, max_spec_len, vocab_size]. Can be None if [num_tokens, vocab_size]. Can be None if probabilities are
probabilities are not provided, which is the case for not provided, which is the case for ngram spec decode.
ngram spec decode. target_logits (torch.Tensor):
Target model's logits probability distribution.
Shape is [num_tokens, vocab_size]. Here, probabilities from
different requests are flattened into a single tensor because
this is the shape of the output logits.
bonus_token_ids_tensor (torch.Tensor): bonus_token_ids_tensor (torch.Tensor):
A tensor containing bonus tokens. Shape is [batch_size, 1]. A tensor containing bonus tokens. Shape is [batch_size, 1].
Bonus tokens are added to the end of the sequence if all Bonus tokens are added to the end of the sequence if all
@ -66,13 +74,6 @@ class RejectionSampler(nn.Module):
outside of the rejection sampler with the default sampling outside of the rejection sampler with the default sampling
strategy. It allows for more flexibility in the sampling strategy. It allows for more flexibility in the sampling
process such as top_p, top_k sampling. process such as top_p, top_k sampling.
target_probs (torch.Tensor):
Target model probability distribution.
Shape is [num_total_tokens, vocab_size]. num_total_tokens
is the total number of tokens from all requests. Here,
probabilities from different requests are flattened into
a single tensor because this is the shape of the output
logits.
sampling_metadata (SamplingMetadata): sampling_metadata (SamplingMetadata):
Additional metadata needed for sampling, such as temperature, Additional metadata needed for sampling, such as temperature,
top-k/top-p parameters, or other relevant information. top-k/top-p parameters, or other relevant information.
@ -80,268 +81,481 @@ class RejectionSampler(nn.Module):
output_token_ids (torch.Tensor): output_token_ids (torch.Tensor):
A tensor containing the final output token IDs. A tensor containing the final output token IDs.
''' '''
assert metadata.max_spec_len <= MAX_SPEC_LEN
# [num_tokens, vocab_size]
target_probs = compute_probs(
target_logits,
metadata.cu_num_draft_tokens,
sampling_metadata,
)
# NOTE: The following input preparationg can be moved output_token_ids = rejection_sample(
# to the model runner with a persistent manner for better metadata.draft_token_ids,
# performance. metadata.num_draft_tokens,
# Convert draft token IDs to a tensor, split by sample_lens, then pad. metadata.max_spec_len,
draft_token_ids = [ metadata.cu_num_draft_tokens,
torch.tensor(x, dtype=int, device='cpu') for x in draft_token_ids draft_probs,
] target_probs,
draft_token_ids_tensor = pad_sequence(draft_token_ids, bonus_token_ids,
batch_first=True, sampling_metadata,
padding_value=INVALID_TOKEN_ID) )
# NOTE: CPU <-> GPU synchronization happens here.
draft_token_ids_tensor = draft_token_ids_tensor.to(target_probs.device)
# Create one-hot tensor for draft token ids.
# This is used for ngram where we don't have draft_probs.
if draft_probs is None and not sampling_metadata.all_greedy:
vocab_size = target_probs.size(-1)
draft_probs = _create_greedy_token_probs(draft_token_ids_tensor,
vocab_size,
target_probs.device)
sample_lens = [len(x) + 1 for x in draft_token_ids]
target_probs = _convert_2d_probs(target_probs, sample_lens)
return self.forward_native(draft_token_ids_tensor, draft_probs,
bonus_token_ids_tensor, target_probs,
sampling_metadata)
# TODO: The following method can be optimized for better performance.
def forward_native(
self,
draft_token_ids_tensor: torch.Tensor,
# [batch_size, max_spec_len, vocab_size]
draft_probs: Optional[torch.Tensor],
bonus_token_ids_tensor: torch.Tensor,
# [batch_size, max_spec_len + 1, vocab_size]
target_probs: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
# Add 1 to include the 'bonus' token.
if sampling_metadata.all_greedy:
# Produce a mask that remains 1 (True) until the first
# mismatch (cumprod turns 0 after a mismatch).
target_token_ids_tensor = target_probs.argmax(dim=-1)
accept_mask = (target_token_ids_tensor[:, :-1] ==
draft_token_ids_tensor).cumprod(dim=1)
# Identify valid positions (non-padding).
valid_mask = target_token_ids_tensor != INVALID_TOKEN_ID
# Generate mask with bonus token.
generate_mask = torch.cat([
accept_mask,
torch.zeros(accept_mask.size(0), 1, device=accept_mask.device)
],
dim=1).to(torch.bool) & valid_mask
zeros_mask = (generate_mask == 0)
first_zero_idx = zeros_mask.float().argmax(dim=1)
# Figure out which rows actually contain at least one zero.
rows_with_zero = zeros_mask.any(dim=1)
# Use indexing to set the first zero in each of those rows to 1.
generate_mask[rows_with_zero, first_zero_idx[rows_with_zero]] = 1
output_token_ids = target_token_ids_tensor
output_token_ids[~generate_mask] = INVALID_TOKEN_ID
else:
# Reference: https://arxiv.org/pdf/2211.17192
# 1. Extract the probabilities of the draft tokens.
# [batch_size, max_spec_len]
batch_size = draft_token_ids_tensor.size(0)
max_spec_len = draft_token_ids_tensor.size(1)
invalid_idx = draft_token_ids_tensor == INVALID_TOKEN_ID
draft_token_ids_tensor[invalid_idx] = 0
assert draft_probs is not None
draft_token_probs = draft_probs.gather(
dim=-1, index=draft_token_ids_tensor.unsqueeze(-1)).squeeze(-1)
target_token_probs = target_probs.gather(
dim=-1, index=draft_token_ids_tensor.unsqueeze(-1)).squeeze(-1)
# Force the probabilities of invalid tokens to inf
# so that they are not accepted.
draft_token_probs[invalid_idx] = float('inf')
# 2. Generate uniform samples.
# [batch_size, max_spec_len + 1]
uniform_samples = _create_uniform_samples(
sampling_metadata.generators, batch_size, max_spec_len,
target_probs.device)
# 3. Accept or reject the samples.
# [batch_size, max_spec_len]
# If the draft token probabilities are 0, set them to the smallest
# positive normal value representable by float32.
safe_draft_probs = torch.where(draft_token_probs > 0,
draft_token_probs,
torch.finfo(torch.float32).tiny)
accepted = uniform_samples <= target_token_probs / safe_draft_probs
accept_mask = accepted.cumprod(dim=1)
# Set the token ids to the draft token ids if accepted, otherwise
# set them to INVALID_TOKEN_ID.
accepted_token_ids = (draft_token_ids_tensor * accept_mask +
INVALID_TOKEN_ID * (1 - accept_mask))
# 4. Adjust the distribution for the recovered tokens.
# Clamp the bonus probabilities to the smallest positive normal
# value representable by float32.
bonus_prob = torch.clamp(target_probs[:, :-1, :] - draft_probs,
min=torch.finfo(torch.float32).tiny)
normalized_bonus_prob = bonus_prob / bonus_prob.sum(dim=-1,
keepdim=True)
# 5. Sample recovered token ids.
recovered_token_ids = random_sample(
normalized_bonus_prob,
sampling_metadata.generators).reshape(batch_size, max_spec_len)
# 6. Get the final output token ids.
# output_token_ids = accepted_token_ids +
# recovered_token_ids +
# bonus_token_id
recovered_bonus_token_ids = torch.cat(
[recovered_token_ids, bonus_token_ids_tensor], dim=1)
# Generate mask with bonus tokens.
generate_mask = torch.cat([
accept_mask,
torch.zeros(batch_size, 1, device=accept_mask.device)
],
dim=1).to(torch.bool)
zeros_mask = (generate_mask == 0)
first_zero_idx = zeros_mask.float().argmax(dim=1)
output_token_ids = torch.cat([
accepted_token_ids,
torch.full((batch_size, 1),
fill_value=INVALID_TOKEN_ID,
device=accept_mask.device)
],
dim=1)
output_token_ids[torch.arange(batch_size),
first_zero_idx] = recovered_bonus_token_ids[
torch.arange(batch_size), first_zero_idx]
return output_token_ids return output_token_ids
def compute_probs(self, logits: torch.Tensor, @staticmethod
def parse_output(
output_token_ids: torch.Tensor,
vocab_size: int,
) -> list[list[int]]:
output_token_ids_np = output_token_ids.cpu().numpy()
# Create mask for valid tokens.
valid_mask = ((output_token_ids_np != PLACEHOLDER_TOKEN_ID) &
(output_token_ids_np < vocab_size))
outputs = [
row[valid_mask[i]].tolist()
for i, row in enumerate(output_token_ids_np)
]
return outputs
def rejection_sample(
# [num_tokens]
draft_token_ids: torch.Tensor,
# [batch_size]
num_draft_tokens: list[int],
max_spec_len: int,
# [batch_size]
cu_num_draft_tokens: torch.Tensor,
# [num_tokens, vocab_size]
draft_probs: Optional[torch.Tensor],
# [num_tokens, vocab_size]
target_probs: torch.Tensor,
# [batch_size, 1]
bonus_token_ids: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
sample_lens: list[int]) -> torch.Tensor: ) -> torch.Tensor:
""" assert draft_token_ids.ndim == 1
Compute probability distribution from logits based on sampling metadata. assert draft_probs is None or draft_probs.ndim == 2
assert cu_num_draft_tokens.ndim == 1
assert target_probs.ndim == 2
batch_size = len(num_draft_tokens)
num_tokens = draft_token_ids.shape[0]
vocab_size = target_probs.shape[-1]
device = target_probs.device
assert draft_token_ids.is_contiguous()
assert draft_probs is None or draft_probs.is_contiguous()
assert target_probs.is_contiguous()
assert bonus_token_ids.is_contiguous()
assert target_probs.shape == (num_tokens, vocab_size)
# Create output buffer.
output_token_ids = torch.empty(
(batch_size, max_spec_len + 1),
dtype=torch.int32, # Consistent with SamplerOutput.sampled_token_ids.
device=device,
)
output_token_ids.fill_(PLACEHOLDER_TOKEN_ID)
if sampling_metadata.all_greedy:
is_greedy = None
else:
is_greedy = sampling_metadata.temperature == GREEDY_TEMPERATURE
if not sampling_metadata.all_random:
# Rejection sampling for greedy sampling requests.
target_argmax = target_probs.argmax(dim=-1)
rejection_greedy_sample_kernel[(batch_size, )](
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
target_argmax,
bonus_token_ids,
is_greedy,
max_spec_len,
num_warps=1,
)
if sampling_metadata.all_greedy:
return output_token_ids
# Generate uniform probabilities for rejection sampling.
# [num_tokens]
uniform_probs = generate_uniform_probs(
num_tokens,
num_draft_tokens,
sampling_metadata.generators,
device,
)
# Sample recovered tokens for each position.
# [num_tokens]
recovered_token_ids = sample_recovered_tokens(
max_spec_len,
num_draft_tokens,
cu_num_draft_tokens,
draft_token_ids,
draft_probs,
target_probs,
sampling_metadata,
device,
)
# Rejection sampling for random sampling requests.
rejection_random_sample_kernel[(batch_size, )](
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
draft_probs,
target_probs,
bonus_token_ids,
recovered_token_ids,
uniform_probs,
is_greedy,
max_spec_len,
vocab_size,
IS_NGRAM=draft_probs is None,
num_warps=1,
)
return output_token_ids
def compute_probs(
logits: torch.Tensor, # [num_tokens, vocab_size]
cu_num_draft_tokens: torch.Tensor, # [batch_size]
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
"""Compute probability distribution from logits based on sampling metadata.
This function applies temperature scaling to the logits and converts This function applies temperature scaling to the logits and converts
them to probabilities using softmax. Note that division by them to probabilities using softmax. For greedy decoding, it returns
temperature is not performed inplace to preserve the original logits the original logits.
tensor, which will be used by the original sampler to get bonus tokens.
Args: Args:
logits: Input logits tensor to be converted to probabilities logits: Input logits tensor to be converted to probabilities.
sampling_metadata: Metadata containing sampling parameters such cu_num_draft_tokens: Cumulative number of draft tokens.
as temperature and whether greedy sampling is used sampling_metadata: Metadata containing sampling parameters such as
sample_lens: List of sample lengths used for repeating temperature and whether greedy sampling is used.
temperature values
Returns: Returns:
torch.Tensor: Probability distribution (softmax of scaled logits) torch.Tensor: Probability distribution (softmax of scaled logits)
if non-greedy sampling is used, otherwise returns the if non-greedy sampling is used, otherwise returns the
original logits original logits.
""" """
assert logits.ndim == 2
assert cu_num_draft_tokens.ndim == 1
if sampling_metadata.all_greedy: if sampling_metadata.all_greedy:
return logits return logits
assert sampling_metadata.temperature is not None
# We should optimize the following code as num_tokens = logits.shape[0]
# it will cause CPU -> GPU synchronization. batch_size = cu_num_draft_tokens.shape[0]
temperature = torch.repeat_interleave( expanded_temperature = torch.empty(
(num_tokens, 1),
dtype=torch.float32,
device=logits.device,
)
expand_kernel[(batch_size, )](
expanded_temperature,
sampling_metadata.temperature, sampling_metadata.temperature,
torch.tensor(sample_lens, cu_num_draft_tokens,
device=sampling_metadata.temperature.device)) GREEDY_TEMPERATURE, # replace_from
temperature = temperature.unsqueeze(dim=1) 1, # replace_to
logits = logits / temperature MAX_NUM_TOKENS=MAX_SPEC_LEN,
return logits.softmax(dim=-1, dtype=torch.float32) num_warps=1,
)
output_prob = compiled_softmax(logits, expanded_temperature)
return output_prob
def _create_greedy_token_probs( def generate_uniform_probs(
token_ids: torch.Tensor, num_tokens: int,
vocab_size: int, num_draft_tokens: list[int],
out_device: torch.device, generators: dict[int, torch.Generator],
device: torch.device,
) -> torch.Tensor: ) -> torch.Tensor:
batch_size, num_tokens = token_ids.shape
token_probs = torch.zeros(batch_size,
num_tokens,
vocab_size,
dtype=torch.float,
device=out_device)
# Ignore INVALID_TOKEN_ID.
valid_mask = (token_ids != INVALID_TOKEN_ID)
valid_indices = token_ids.clone()
valid_indices[~valid_mask] = 0
token_probs.scatter_(dim=2,
index=valid_indices.unsqueeze(-1),
src=valid_mask.unsqueeze(-1).float())
return token_probs
def _convert_2d_probs(
probs: torch.Tensor, # [num_total_tokens, vocab_size]
sample_lens: list[int]) -> torch.Tensor:
"""
Converts a 2D tensor of probabilities to a 3D tensor with padding.
[num_total_tokens, vocab_size] ->
[batch_size, max_spec_len + 1, vocab_size]
"""
cumulative_lens = torch.cumsum(torch.tensor(sample_lens,
device=probs.device),
dim=0)
split_indices = cumulative_lens[:-1].tolist() # Exclude last index
# Split into chunks without loops
chunks = torch.tensor_split(probs, split_indices, dim=0)
# Pad all sequences to maximum length
padded_probs = pad_sequence(chunks, batch_first=True, padding_value=0.0)
return padded_probs
def _create_uniform_samples(seeded_seqs: dict[int, torch.Generator],
batch_size: int, k: int,
device: torch.device) -> torch.Tensor:
""" """
Generates a batch of uniform random samples, with optional seeding Generates a batch of uniform random samples, with optional seeding
for specific sequences. if available.
This method creates a tensor of shape `(batch_size, k)` filled This method creates a tensor of shape `(num_tokens, )` filled
with uniform random values in the range [0, 1). If `seeded_seqs` with uniform random values in the range [0, 1). If `generators` is provided,
is provided, the sequences corresponding to specific indices the requests with their own seeds will use the provided `torch.Generator`
will be generated using the provided `torch.Generator` for for reproducibility. The samples for the other requests will be generated
reproducibility. The other sequences will be generated without without a seed.
a seed.
Args: Args:
seeded_seqs : Optional[Dict[int, torch.Generator]] num_tokens : int
Total number of tokens.
num_draft_tokens : List[List[int]]
Number of draft tokens per request.
generators : Optional[Dict[int, torch.Generator]]
A dictionary mapping indices in the batch to A dictionary mapping indices in the batch to
`torch.Generator` objects. `torch.Generator` objects.
batch_size : int
The number of sequences to generate.
k : int
The number of random samples per sequence.
device : torch.device device : torch.device
The device on which to allocate the tensor. The device on which to allocate the tensor.
Returns: Returns:
uniform_rand : torch.Tensor uniform_rand : torch.Tensor
A tensor of shape `(batch_size, k)` containing uniform A tensor of shape `(num_tokens, )` containing uniform
random values in the range [0, 1). random values in the range [0, 1).
""" """
uniform_probs = torch.rand(
uniform_rand = torch.rand(batch_size, (num_tokens, ),
k,
dtype=torch.float32, dtype=torch.float32,
device=device) device=device,
# Apply seeded generators only where needed )
if seeded_seqs: start_idx = 0
for idx, generator in seeded_seqs.items(): for req_idx, n in enumerate(num_draft_tokens):
uniform_rand[idx].uniform_(0, 1, generator=generator) # Do not generate random numbers for requests with no draft tokens.
return uniform_rand # This can be important for reproducibility.
if n == 0:
continue
end_idx = start_idx + n
generator = generators.get(req_idx)
if generator is not None:
uniform_probs[start_idx:end_idx].uniform_(generator=generator)
start_idx = end_idx
return uniform_probs
def sample_recovered_tokens(
max_spec_len: int,
num_draft_tokens: list[int],
# [batch_size]
cu_num_draft_tokens: torch.Tensor,
# [num_tokens]
draft_token_ids: torch.Tensor,
# [num_tokens, vocab_size]
draft_probs: Optional[torch.Tensor],
# [num_tokens, vocab_size]
target_probs: torch.Tensor,
sampling_metadata: SamplingMetadata,
device: torch.device,
) -> torch.Tensor:
# NOTE(woosuk): Create only one distribution for each request.
batch_size = len(num_draft_tokens)
vocab_size = target_probs.shape[-1]
q = torch.empty(
(batch_size, vocab_size),
dtype=torch.float32,
device=device,
)
q.exponential_()
for i, generator in sampling_metadata.generators.items():
# Do not generate random numbers for requests with no draft tokens.
# This can be important for reproducibility.
if num_draft_tokens[i] > 0:
q[i].exponential_(generator=generator)
recovered_token_ids = torch.empty_like(draft_token_ids)
sample_recovered_tokens_kernel[(batch_size, max_spec_len)](
recovered_token_ids,
cu_num_draft_tokens,
draft_token_ids,
draft_probs,
target_probs,
q,
vocab_size,
triton.next_power_of_2(vocab_size),
IS_NGRAM=draft_probs is None,
)
return recovered_token_ids
# NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation.
@triton.jit(do_not_specialize=["max_spec_len"])
def rejection_greedy_sample_kernel(
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
cu_num_draft_tokens_ptr, # [batch_size]
draft_token_ids_ptr, # [num_tokens]
target_argmax_ptr, # [num_tokens]
bonus_token_ids_ptr, # [batch_size]
is_greedy_ptr, # [batch_size] or None
max_spec_len,
):
req_idx = tl.program_id(0)
# FIXME(woosuk): Because is_greedy_ptr is not None at profiling run,
# re-compilation may happen during runtime when is_greedy_ptr is None.
if is_greedy_ptr is None:
is_greedy = True
else:
is_greedy = tl.load(is_greedy_ptr + req_idx)
if not is_greedy:
# Early exit for non-greedy sampling requests.
return
if req_idx == 0:
start_idx = 0
else:
start_idx = tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
num_draft_tokens = end_idx - start_idx
rejected = False
for pos in range(num_draft_tokens):
if not rejected:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
target_argmax_id = tl.load(target_argmax_ptr + start_idx + pos)
tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
target_argmax_id)
if draft_token_id != target_argmax_id:
# Reject.
rejected = True
if not rejected:
# If all tokens are accepted, append the bonus token.
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
tl.store(
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
num_draft_tokens, bonus_token_id)
# NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation.
@triton.jit(do_not_specialize=["max_spec_len"])
def rejection_random_sample_kernel(
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
cu_num_draft_tokens_ptr, # [batch_size]
draft_token_ids_ptr, # [num_tokens]
draft_probs_ptr, # [num_tokens, vocab_size] or None
target_probs_ptr, # [num_tokens, vocab_size]
bonus_token_ids_ptr, # [batch_size]
recovered_token_ids_ptr, # [num_tokens]
uniform_probs_ptr, # [num_tokens]
is_greedy_ptr, # [batch_size]
max_spec_len,
vocab_size,
IS_NGRAM: tl.constexpr,
):
req_idx = tl.program_id(0)
is_greedy = tl.load(is_greedy_ptr + req_idx)
if is_greedy:
# Early exit for greedy sampling requests.
return
if req_idx == 0:
start_idx = 0
else:
start_idx = tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
num_draft_tokens = end_idx - start_idx
rejected = False
for pos in range(num_draft_tokens):
if not rejected:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
if IS_NGRAM:
draft_prob = 1
else:
draft_prob = tl.load(draft_probs_ptr +
(start_idx + pos) * vocab_size +
draft_token_id)
target_prob = tl.load(target_probs_ptr +
(start_idx + pos) * vocab_size +
draft_token_id)
uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos)
# NOTE(woosuk): While the draft probability should never be 0,
# we check it to avoid NaNs. If it happens to be 0, we reject.
if draft_prob > 0 and target_prob / draft_prob >= uniform_prob:
# Accept.
token_id = draft_token_id
else:
# Reject. Use recovered token.
rejected = True
token_id = tl.load(recovered_token_ids_ptr + start_idx + pos)
tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
token_id)
if not rejected:
# If all tokens are accepted, append the bonus token.
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
tl.store(
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
num_draft_tokens, bonus_token_id)
# NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation.
@triton.jit(do_not_specialize=["replace_from", "replace_to"])
def expand_kernel(
output_ptr, # [num_tokens]
input_ptr, # [batch_size]
cu_num_tokens_ptr, # [batch_size]
replace_from,
replace_to,
MAX_NUM_TOKENS: tl.constexpr,
):
req_idx = tl.program_id(0)
if req_idx == 0: # noqa: SIM108
start_idx = 0
else:
start_idx = tl.load(cu_num_tokens_ptr + req_idx - 1)
end_idx = tl.load(cu_num_tokens_ptr + req_idx)
num_tokens = end_idx - start_idx
src_val = tl.load(input_ptr + req_idx)
src_val = tl.where(src_val == replace_from, replace_to, src_val)
offset = tl.arange(0, MAX_NUM_TOKENS)
tl.store(output_ptr + start_idx + offset,
src_val,
mask=offset < num_tokens)
@triton.jit
def sample_recovered_tokens_kernel(
output_token_ids_ptr, # [num_tokens]
cu_num_draft_tokens_ptr, # [batch_size]
draft_token_ids_ptr, # [num_tokens]
draft_probs_ptr, # [num_tokens, vocab_size] or None
target_probs_ptr, # [num_tokens, vocab_size]
q_ptr, # [batch_size, vocab_size]
vocab_size,
PADDED_VOCAB_SIZE: tl.constexpr,
IS_NGRAM: tl.constexpr,
):
req_idx = tl.program_id(0)
if req_idx == 0:
start_idx = 0
else:
start_idx = tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
num_draft_tokens = end_idx - start_idx
# Early exit for out-of-range positions.
pos = tl.program_id(1)
if pos >= num_draft_tokens:
return
vocab_offset = tl.arange(0, PADDED_VOCAB_SIZE)
if IS_NGRAM:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
orig_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size +
draft_token_id)
# Temporarily zero out the probability of the draft token.
# This is essentially the same as target_prob - draft_prob, except that
# n-gram does not have draft_prob. We regard it as 1.
tl.store(
target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id,
0)
prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size +
vocab_offset,
mask=vocab_offset < vocab_size,
other=0)
else:
draft_prob = tl.load(draft_probs_ptr + (start_idx + pos) * vocab_size +
vocab_offset,
mask=vocab_offset < vocab_size,
other=0)
target_prob = tl.load(target_probs_ptr +
(start_idx + pos) * vocab_size + vocab_offset,
mask=vocab_offset < vocab_size,
other=0)
prob = tl.maximum(target_prob - draft_prob, 0)
# NOTE(woosuk): We don't need `prob = prob / tl.sum(prob)` here because
# `tl.argmax` will select the maximum value.
q = tl.load(q_ptr + req_idx * vocab_size + vocab_offset,
mask=vocab_offset < vocab_size,
other=float("-inf"))
recovered_id = tl.argmax(prob / q, axis=-1)
tl.store(output_token_ids_ptr + start_idx + pos, recovered_id)
if IS_NGRAM:
# Restore the original probability.
tl.store(
target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id,
orig_prob)

View File

@ -0,0 +1,61 @@
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
import numpy as np
import torch
@dataclass
class SpecDecodeMetadata:
# [num_tokens]
draft_token_ids: torch.Tensor
# [batch_size]
num_draft_tokens: list[int]
# [batch_size]
cu_num_draft_tokens: torch.Tensor
# [num_tokens]
target_logits_indices: torch.Tensor
# [batch_size]
bonus_logits_indices: torch.Tensor
# [num_tokens + batch_size]
logits_indices: torch.Tensor
def __post_init__(self):
self.max_spec_len = max(self.num_draft_tokens)
@classmethod
def make_dummy(
cls,
draft_token_ids: list[list[int]],
device: torch.device,
) -> "SpecDecodeMetadata":
batch_size = len(draft_token_ids)
num_draft_tokens = [len(ids) for ids in draft_token_ids]
flattened_draft_token_ids = sum(draft_token_ids, [])
num_tokens = len(flattened_draft_token_ids)
draft_token_ids_tensor = torch.tensor(flattened_draft_token_ids,
dtype=torch.int32,
device=device)
cu_num_draft_tokens = np.cumsum(num_draft_tokens, dtype=np.int32)
cu_num_draft_tokens_tensor = torch.from_numpy(cu_num_draft_tokens).to(
device)
target_logits_indices = torch.zeros(num_tokens,
dtype=torch.int32,
device=device)
bonus_logits_indices = torch.zeros(batch_size,
dtype=torch.int32,
device=device)
logits_indices = torch.zeros(num_tokens + batch_size,
dtype=torch.int32,
device=device)
return cls(
draft_token_ids=draft_token_ids_tensor,
num_draft_tokens=num_draft_tokens,
cu_num_draft_tokens=cu_num_draft_tokens_tensor,
target_logits_indices=target_logits_indices,
bonus_logits_indices=bonus_logits_indices,
logits_indices=logits_indices,
)

View File

@ -1,5 +1,4 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from vllm.v1.sample.ops.topk_topp_sampler import random_sample # noqa
from vllm.v1.worker.gpu_input_batch import InputBatch from vllm.v1.worker.gpu_input_batch import InputBatch

View File

@ -34,7 +34,8 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
ModelRunnerOutput) ModelRunnerOutput)
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID, RejectionSampler from vllm.v1.sample.rejection_sampler import RejectionSampler
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.spec_decode.utils import is_spec_decode_supported from vllm.v1.spec_decode.utils import is_spec_decode_supported
from vllm.v1.utils import bind_kv_cache from vllm.v1.utils import bind_kv_cache
@ -149,7 +150,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.use_spec_decode = False self.use_spec_decode = False
if self.speculative_config: if self.speculative_config:
self.use_spec_decode = True self.use_spec_decode = True
self.rejection_sampler = RejectionSampler()
# TODO: find a better way to check if we are using ngram. # TODO: find a better way to check if we are using ngram.
assert self.speculative_config.ngram_prompt_lookup_min, \ assert self.speculative_config.ngram_prompt_lookup_min, \
"Currently, only ngram spec decode is supported in V1." "Currently, only ngram spec decode is supported in V1."
@ -162,6 +162,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.speculative_config.ngram_prompt_lookup_min, self.speculative_config.ngram_prompt_lookup_min,
self.speculative_config.num_speculative_tokens, self.speculative_config.num_speculative_tokens,
) )
self.rejection_sampler = RejectionSampler()
# Request states. # Request states.
self.requests: dict[str, CachedRequestState] = {} self.requests: dict[str, CachedRequestState] = {}
@ -452,7 +453,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def _prepare_inputs( def _prepare_inputs(
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
) -> tuple[FlashAttentionMetadata, torch.Tensor]: ) -> tuple[FlashAttentionMetadata, torch.Tensor,
Optional[SpecDecodeMetadata]]:
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens > 0 assert total_num_scheduled_tokens > 0
num_reqs = self.input_batch.num_reqs num_reqs = self.input_batch.num_reqs
@ -577,22 +579,33 @@ class GPUModelRunner(LoRAModelRunnerMixin):
use_spec_decode = len( use_spec_decode = len(
scheduler_output.scheduled_spec_decode_tokens) > 0 scheduler_output.scheduled_spec_decode_tokens) > 0
if use_spec_decode: if not use_spec_decode:
logits_indices = self._calc_spec_decode_metadata(
scheduler_output, cu_num_tokens)
else:
# NOTE(woosuk): Due to chunked prefills, the batch may contain # NOTE(woosuk): Due to chunked prefills, the batch may contain
# partial requests. While we should not sample any token # partial requests. While we should not sample any token
# from these partial requests, we do so for simplicity. # from these partial requests, we do so for simplicity.
# We will ignore the sampled tokens from the partial requests. # We will ignore the sampled tokens from the partial requests.
# TODO: Support prompt logprobs. # TODO: Support prompt logprobs.
logits_indices = attn_metadata.query_start_loc[1:] - 1 logits_indices = attn_metadata.query_start_loc[1:] - 1
spec_decode_metadata = None
else:
# Get the number of draft tokens for each request.
# Iterate over the dictionary rather than all requests since not all
# requests have draft tokens.
num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
for req_id, draft_token_ids in (
scheduler_output.scheduled_spec_decode_tokens.items()):
req_idx = self.input_batch.req_id_to_index[req_id]
num_draft_tokens[req_idx] = len(draft_token_ids)
spec_decode_metadata = self._calc_spec_decode_metadata(
num_draft_tokens, cu_num_tokens)
logits_indices = spec_decode_metadata.logits_indices
# Hot-Swap lora model # Hot-Swap lora model
if self.lora_config: if self.lora_config:
self.set_active_loras(self.input_batch, num_scheduled_tokens) self.set_active_loras(self.input_batch, num_scheduled_tokens)
return attn_metadata, logits_indices return attn_metadata, logits_indices, spec_decode_metadata
def _compute_cascade_attn_prefix_len( def _compute_cascade_attn_prefix_len(
self, self,
@ -732,49 +745,78 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def _calc_spec_decode_metadata( def _calc_spec_decode_metadata(
self, self,
scheduler_output: "SchedulerOutput", num_draft_tokens: np.ndarray,
cu_num_tokens: np.ndarray, cu_num_scheduled_tokens: np.ndarray,
) -> torch.Tensor: ) -> SpecDecodeMetadata:
# Get the number of spec decode tokens for each request. # Inputs:
num_reqs = self.input_batch.num_reqs # cu_num_scheduled_tokens: [ 4, 104, 107, 207, 209]
num_spec_decode_tokens = np.empty(num_reqs, dtype=np.int32) # num_draft_tokens: [ 3, 0, 2, 0, 1]
for i, req_id in enumerate(self.input_batch.req_ids): # Outputs:
num_spec_decode_tokens[i] = len( # cu_num_draft_tokens: [ 3, 3, 5, 5, 6]
scheduler_output.scheduled_spec_decode_tokens.get(req_id, ())) # logits_indices: [ 0, 1, 2, 3, 103, 104, 105, 106,
# 206, 207, 208]
# target_logits_indices: [ 0, 1, 2, 5, 6, 9]
# bonus_logits_indices: [ 3, 4, 7, 8, 10]
# Get spec decode logits indices. # Compute the logits indices.
# E.g., num_scheduled_tokens: [4, 100, 3, 100, 2] # [4, 1, 3, 1, 2]
# cu_num_tokens: [4, 104, 107, 207, 209] num_sampled_tokens = num_draft_tokens + 1
# num_spec_tokens_list: [3, 0, 2, 0, 1] # Step 1. [4, 5, 8, 9, 11]
# num_sampled_tokens: [4, 1, 3, 1, 2] cu_num_sampled_tokens = np.cumsum(num_sampled_tokens, dtype=np.int32)
# spec_decode_logits_indices: total_num_sampled_tokens = cu_num_sampled_tokens[-1]
# [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208] # Step 2. [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9]
num_sampled_tokens = num_spec_decode_tokens + 1 cumsums_offsets = np.repeat(cu_num_sampled_tokens - num_sampled_tokens,
# logits_start_loc: [0, 103, 104, 206, 207] num_sampled_tokens)
logits_start_loc = cu_num_tokens - num_sampled_tokens # Step 3. [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
# [0, 103, 104, 206, 207] -> arange = self.arange_np[:total_num_sampled_tokens] - cumsums_offsets
# [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] # Step 4. [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207]
logits_start_loc = np.repeat(logits_start_loc, num_sampled_tokens) logits_indices = np.repeat(
# The following three lines: cu_num_scheduled_tokens - num_sampled_tokens, num_sampled_tokens)
# [4, 1, 3, 1, 2] -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1] # Step 5. [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
# Step 1. [4, 1, 3, 1, 2] -> [4, 5, 8, 9, 11] logits_indices += arange
cu_num_sampled_tokens = np.cumsum(num_sampled_tokens)
# Step 2. [4, 5, 8, 9, 11] -> [0, 4, 5, 8, 9]
# -> [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9]
cumsums_sampled_offsets = np.repeat(
cu_num_sampled_tokens - num_sampled_tokens, num_sampled_tokens)
# Step 3. [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
# - [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9]
# -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
total_num_sampled_tokens = num_sampled_tokens.sum()
sampled_arange = (self.arange_np[:total_num_sampled_tokens] -
cumsums_sampled_offsets)
# [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] -> # Compute the bonus logits indices.
# [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208] bonus_logits_indices = cu_num_sampled_tokens - 1
spec_decode_logits_indices = logits_start_loc + sampled_arange
return torch.from_numpy(spec_decode_logits_indices).to( # Compute the draft logits indices.
# [3, 3, 5, 5, 6]
cu_num_draft_tokens = np.cumsum(num_draft_tokens, dtype=np.int32)
total_num_draft_tokens = cu_num_draft_tokens[-1]
# [0, 0, 0, 3, 3, 5]
cumsums_offsets = np.repeat(cu_num_draft_tokens - num_draft_tokens,
num_draft_tokens)
# [0, 1, 2, 0, 1, 0]
arange = self.arange_np[:total_num_draft_tokens] - cumsums_offsets
# [0, 0, 0, 5, 5, 9]
target_logits_indices = np.repeat(
cu_num_sampled_tokens - num_sampled_tokens, num_draft_tokens)
# [0, 1, 2, 5, 6, 9]
target_logits_indices += arange
# TODO: Optimize the CPU -> GPU copy.
cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to(
self.device, non_blocking=True) self.device, non_blocking=True)
logits_indices = torch.from_numpy(logits_indices).to(self.device,
non_blocking=True)
target_logits_indices = torch.from_numpy(target_logits_indices).to(
self.device, non_blocking=True)
bonus_logits_indices = torch.from_numpy(bonus_logits_indices).to(
self.device, non_blocking=True)
# Compute the draft token ids.
# draft_token_indices: [ 1, 2, 3, 105, 106, 208]
draft_token_ids = self.input_ids[logits_indices]
draft_token_ids = draft_token_ids[target_logits_indices + 1]
metadata = SpecDecodeMetadata(
draft_token_ids=draft_token_ids,
num_draft_tokens=num_draft_tokens.tolist(),
cu_num_draft_tokens=cu_num_draft_tokens,
target_logits_indices=target_logits_indices,
bonus_logits_indices=bonus_logits_indices,
logits_indices=logits_indices,
)
return metadata
def _execute_encoder(self, scheduler_output: "SchedulerOutput"): def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
@ -931,7 +973,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
encoder_outputs = [] encoder_outputs = []
# Prepare the decoder inputs. # Prepare the decoder inputs.
attn_metadata, logits_indices = self._prepare_inputs(scheduler_output) attn_metadata, logits_indices, spec_decode_metadata = (
self._prepare_inputs(scheduler_output))
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
if (self.use_cuda_graph if (self.use_cuda_graph
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
@ -1006,31 +1049,29 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Sample the next token and get logprobs if needed. # Sample the next token and get logprobs if needed.
sampling_metadata = self.input_batch.sampling_metadata sampling_metadata = self.input_batch.sampling_metadata
if not self.use_spec_decode: if spec_decode_metadata is None:
sampler_output = self.model.sample( sampler_output = self.model.sample(
logits=logits, logits=logits,
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
) )
else: else:
draft_token_ids = [ # TODO(woosuk): Optimize the memory usage.
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []) bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
for req_id in self.input_batch.req_ids
]
sample_lens = [len(tokens) + 1 for tokens in draft_token_ids]
recover_logits_idx = np.cumsum(sample_lens) - 1
target_probs = self.rejection_sampler.compute_probs(
logits, sampling_metadata, sample_lens)
sampler_output = self.model.sample( sampler_output = self.model.sample(
logits=logits[recover_logits_idx, :], logits=bonus_logits,
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
) )
bonus_token_ids = sampler_output.sampled_token_ids bonus_token_ids = sampler_output.sampled_token_ids
# TODO(woosuk): Optimize the memory usage.
target_logits = logits[spec_decode_metadata.target_logits_indices]
output_token_ids = self.rejection_sampler( output_token_ids = self.rejection_sampler(
draft_token_ids, spec_decode_metadata,
None, # draft_probs None, # draft_probs
target_logits,
bonus_token_ids, bonus_token_ids,
target_probs, sampling_metadata,
sampling_metadata) )
sampler_output.sampled_token_ids = output_token_ids sampler_output.sampled_token_ids = output_token_ids
# TODO(woosuk): The following loop can be slow since it iterates over # TODO(woosuk): The following loop can be slow since it iterates over
@ -1066,13 +1107,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
valid_sampled_token_ids = sampled_token_ids.tolist() valid_sampled_token_ids = sampled_token_ids.tolist()
else: else:
# Includes spec decode tokens. # Includes spec decode tokens.
valid_mask = sampled_token_ids != INVALID_TOKEN_ID valid_sampled_token_ids = self.rejection_sampler.parse_output(
gen_lens = valid_mask.sum(dim=1).tolist() sampled_token_ids, self.input_batch.vocab_size)
# TODO(woosuk): Optimize this.
valid_sampled_token_ids = [
seq.tolist()
for seq in sampled_token_ids[valid_mask].split(gen_lens)
]
if not self.use_spec_decode: if not self.use_spec_decode:
spec_token_ids = None spec_token_ids = None
@ -1316,6 +1352,33 @@ class GPUModelRunner(LoRAModelRunnerMixin):
"initializing the engine.") from e "initializing the engine.") from e
else: else:
raise e raise e
if self.use_spec_decode:
draft_token_ids = [[0] for _ in range(num_reqs)]
dummy_spec_decode_metadata = SpecDecodeMetadata.make_dummy(
draft_token_ids, self.device)
num_tokens = sum(len(ids) for ids in draft_token_ids)
# draft_probs = torch.randn(
# num_tokens, logits.shape[-1], device=self.device,
# dtype=logits.dtype)
draft_probs = None
target_logits = torch.randn(num_tokens,
logits.shape[-1],
device=self.device,
dtype=logits.dtype)
# NOTE(woosuk): Here, we should use int32 because the sampler uses
# int32 for bonus_token_ids. If the dtype mismatches, re-compilation
# will occur at runtime.
bonus_token_ids = torch.zeros(num_reqs,
device=self.device,
dtype=torch.int32)
self.rejection_sampler(
dummy_spec_decode_metadata,
draft_probs,
target_logits,
bonus_token_ids,
dummy_metadata,
)
return sampler_output return sampler_output
def profile_run(self) -> None: def profile_run(self) -> None: