[V1][Spec Decode] Optimize Rejection Sampler with Triton Kernels (#14930)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
3a1e648158
commit
99abb8b650
@ -6,20 +6,23 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID, RejectionSampler
|
||||
from vllm.v1.sample.rejection_sampler import (PLACEHOLDER_TOKEN_ID,
|
||||
RejectionSampler)
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
|
||||
DEVICE = "cpu"
|
||||
DEVICE = "cuda"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sampler():
|
||||
def rejection_sampler():
|
||||
return RejectionSampler()
|
||||
|
||||
|
||||
def create_logits_tensor(token_ids: list[list[int]],
|
||||
def create_logits_tensor(output_token_ids: list[list[int]],
|
||||
vocab_size: int = 100) -> torch.Tensor:
|
||||
"""Helper function to create logits tensor that
|
||||
will produce desired token ids on argmax"""
|
||||
token_ids = [tokens[:-1] for tokens in output_token_ids]
|
||||
num_total_tokens = sum(len(tokens) for tokens in token_ids)
|
||||
logits = torch.full((num_total_tokens, vocab_size), -100.0, device=DEVICE)
|
||||
start_loc = 0
|
||||
@ -31,15 +34,22 @@ def create_logits_tensor(token_ids: list[list[int]],
|
||||
|
||||
|
||||
def create_sampling_metadata(
|
||||
all_greedy: bool,
|
||||
generators: Optional[dict[int, Any]] = None) -> SamplingMetadata:
|
||||
all_greedy: bool,
|
||||
temperature: Optional[torch.Tensor] = None,
|
||||
generators: Optional[dict[int, Any]] = None,
|
||||
) -> SamplingMetadata:
|
||||
"""Create a v1 sampling metadata object with all_greedy set
|
||||
to the given value. Either all greedy or all random sampling
|
||||
is used.
|
||||
"""
|
||||
generators = generators or {}
|
||||
if all_greedy:
|
||||
temperature = None
|
||||
else:
|
||||
assert temperature is not None
|
||||
|
||||
return SamplingMetadata(
|
||||
temperature=torch.tensor([]),
|
||||
temperature=temperature,
|
||||
all_greedy=all_greedy,
|
||||
all_random=not all_greedy,
|
||||
top_p=None,
|
||||
@ -61,7 +71,7 @@ def create_sampling_metadata(
|
||||
|
||||
|
||||
########################### Tests for Greedy Sampling ###################
|
||||
def test_perfect_match(sampler):
|
||||
def test_perfect_match(rejection_sampler):
|
||||
"""Test when output tokens perfectly match speculated tokens"""
|
||||
spec_tokens = [[1, 2, 3]]
|
||||
output_tokens = [[1, 2, 3, 4]] # 4 is the bonus token
|
||||
@ -70,15 +80,23 @@ def test_perfect_match(sampler):
|
||||
logits = create_logits_tensor(output_tokens)
|
||||
bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
|
||||
device=logits.device)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
|
||||
device=logits.device)
|
||||
|
||||
output = sampler(spec_tokens, None, bonus_token_tensor, logits, metadata)
|
||||
output = rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
draft_probs=None,
|
||||
target_logits=logits,
|
||||
bonus_token_ids=bonus_token_tensor,
|
||||
sampling_metadata=metadata,
|
||||
)
|
||||
expected = torch.tensor([[1, 2, 3, 4]],
|
||||
dtype=torch.int,
|
||||
device=logits.device)
|
||||
assert torch.equal(output, expected)
|
||||
|
||||
|
||||
def test_early_mismatch(sampler):
|
||||
def test_early_mismatch(rejection_sampler):
|
||||
"""Test when there's an early mismatch in tokens"""
|
||||
spec_tokens = [[1, 2, 3]]
|
||||
output_tokens = [[1, 5, 3, 4]] # Mismatch at position 1
|
||||
@ -87,15 +105,25 @@ def test_early_mismatch(sampler):
|
||||
logits = create_logits_tensor(output_tokens)
|
||||
bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
|
||||
device=logits.device)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
|
||||
device=logits.device)
|
||||
|
||||
output = sampler(spec_tokens, None, bonus_token_tensor, logits, metadata)
|
||||
expected = torch.tensor([[1, 5, INVALID_TOKEN_ID, INVALID_TOKEN_ID]],
|
||||
dtype=torch.int,
|
||||
device=logits.device)
|
||||
output = rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
draft_probs=None,
|
||||
target_logits=logits,
|
||||
bonus_token_ids=bonus_token_tensor,
|
||||
sampling_metadata=metadata,
|
||||
)
|
||||
expected = torch.tensor(
|
||||
[[1, 5, PLACEHOLDER_TOKEN_ID, PLACEHOLDER_TOKEN_ID]],
|
||||
dtype=torch.int,
|
||||
device=logits.device,
|
||||
)
|
||||
assert torch.equal(output, expected)
|
||||
|
||||
|
||||
def test_multiple_sequences(sampler):
|
||||
def test_multiple_sequences(rejection_sampler):
|
||||
"""Test handling multiple sequences of speculated tokens"""
|
||||
spec_tokens = [[1, 2], [3]]
|
||||
output_tokens = [[1, 2, 5], [3,
|
||||
@ -105,15 +133,23 @@ def test_multiple_sequences(sampler):
|
||||
logits = create_logits_tensor(output_tokens)
|
||||
bonus_token_tensor = torch.tensor(
|
||||
[output_tokens[0][-1], output_tokens[1][-1]], device=logits.device)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
|
||||
device=logits.device)
|
||||
|
||||
output = sampler(spec_tokens, None, bonus_token_tensor, logits, metadata)
|
||||
expected = torch.tensor([[1, 2, 5], [3, 4, INVALID_TOKEN_ID]],
|
||||
output = rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
draft_probs=None,
|
||||
target_logits=logits,
|
||||
bonus_token_ids=bonus_token_tensor,
|
||||
sampling_metadata=metadata,
|
||||
)
|
||||
expected = torch.tensor([[1, 2, 5], [3, 4, PLACEHOLDER_TOKEN_ID]],
|
||||
dtype=torch.int,
|
||||
device=logits.device)
|
||||
assert torch.equal(output, expected)
|
||||
|
||||
|
||||
def test_single_token_sequence(sampler):
|
||||
def test_single_token_sequence(rejection_sampler):
|
||||
"""Test handling sequences with single token"""
|
||||
spec_tokens = [[1]]
|
||||
output_tokens = [[1, 2]] # Single token with bonus token 2
|
||||
@ -122,13 +158,21 @@ def test_single_token_sequence(sampler):
|
||||
logits = create_logits_tensor(output_tokens)
|
||||
bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
|
||||
device=logits.device)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
|
||||
device=logits.device)
|
||||
|
||||
output = sampler(spec_tokens, None, bonus_token_tensor, logits, metadata)
|
||||
output = rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
draft_probs=None,
|
||||
target_logits=logits,
|
||||
bonus_token_ids=bonus_token_tensor,
|
||||
sampling_metadata=metadata,
|
||||
)
|
||||
expected = torch.tensor([[1, 2]], dtype=torch.int, device=logits.device)
|
||||
assert torch.equal(output, expected)
|
||||
|
||||
|
||||
def test_empty_sequence(sampler):
|
||||
def test_empty_sequence(rejection_sampler):
|
||||
"""Test handling empty sequence of speculated tokens"""
|
||||
spec_tokens: list[list[int]] = [[]]
|
||||
output_tokens = [[5]] # Just the bonus token
|
||||
@ -137,13 +181,21 @@ def test_empty_sequence(sampler):
|
||||
logits = create_logits_tensor(output_tokens)
|
||||
bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
|
||||
device=logits.device)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
|
||||
device=logits.device)
|
||||
|
||||
output = sampler(spec_tokens, None, bonus_token_tensor, logits, metadata)
|
||||
output = rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
draft_probs=None,
|
||||
target_logits=logits,
|
||||
bonus_token_ids=bonus_token_tensor,
|
||||
sampling_metadata=metadata,
|
||||
)
|
||||
expected = torch.tensor([[5]], dtype=torch.int, device=logits.device)
|
||||
assert torch.equal(output, expected)
|
||||
|
||||
|
||||
def test_multiple_mismatches(sampler):
|
||||
def test_multiple_mismatches(rejection_sampler):
|
||||
"""Test handling multiple sequences with mismatches"""
|
||||
spec_tokens = [[1, 2, 3], [4, 5, 6]]
|
||||
output_tokens = [[1, 2, 7, 6], [4, 8, 6,
|
||||
@ -153,12 +205,22 @@ def test_multiple_mismatches(sampler):
|
||||
logits = create_logits_tensor(output_tokens)
|
||||
bonus_token_tensor = torch.tensor(
|
||||
[output_tokens[0][-1], output_tokens[1][-1]], device=logits.device)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
|
||||
device=logits.device)
|
||||
|
||||
output = sampler(spec_tokens, None, bonus_token_tensor, logits, metadata)
|
||||
expected = torch.tensor([[1, 2, 7, INVALID_TOKEN_ID],
|
||||
[4, 8, INVALID_TOKEN_ID, INVALID_TOKEN_ID]],
|
||||
dtype=torch.int,
|
||||
device=logits.device)
|
||||
output = rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
draft_probs=None,
|
||||
target_logits=logits,
|
||||
bonus_token_ids=bonus_token_tensor,
|
||||
sampling_metadata=metadata,
|
||||
)
|
||||
expected = torch.tensor(
|
||||
[[1, 2, 7, PLACEHOLDER_TOKEN_ID],
|
||||
[4, 8, PLACEHOLDER_TOKEN_ID, PLACEHOLDER_TOKEN_ID]],
|
||||
dtype=torch.int,
|
||||
device=logits.device,
|
||||
)
|
||||
assert torch.equal(output, expected)
|
||||
|
||||
|
||||
@ -166,18 +228,27 @@ def test_multiple_mismatches(sampler):
|
||||
"spec_tokens,output_tokens,expected",
|
||||
[
|
||||
([[1, 2]], [[1, 2, 3]], [[1, 2, 3]]), # Perfect match with bonus
|
||||
([[1]], [[2, 3]], [[2, INVALID_TOKEN_ID]]), # First mismatch
|
||||
([[1]], [[2, 3]], [[2, PLACEHOLDER_TOKEN_ID]]), # First mismatch
|
||||
([[1, 2], [3, 4]], [[1, 5, 6], [3, 4, 7]],
|
||||
[[1, 5, INVALID_TOKEN_ID], [3, 4, 7]]), # Mixed matches
|
||||
[[1, 5, PLACEHOLDER_TOKEN_ID], [3, 4, 7]]), # Mixed matches
|
||||
])
|
||||
def test_parametrized_cases(sampler, spec_tokens, output_tokens, expected):
|
||||
def test_parametrized_cases(rejection_sampler, spec_tokens, output_tokens,
|
||||
expected):
|
||||
"""Parametrized test for various matching scenarios"""
|
||||
metadata = create_sampling_metadata(all_greedy=True)
|
||||
logits = create_logits_tensor(output_tokens)
|
||||
bonus_token_tensor = torch.tensor([tokens[-1] for tokens in output_tokens],
|
||||
device=logits.device)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
|
||||
device=logits.device)
|
||||
|
||||
output = sampler(spec_tokens, None, bonus_token_tensor, logits, metadata)
|
||||
output = rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
draft_probs=None,
|
||||
target_logits=logits,
|
||||
bonus_token_ids=bonus_token_tensor,
|
||||
sampling_metadata=metadata,
|
||||
)
|
||||
expected_tensor = torch.tensor(expected,
|
||||
dtype=torch.int,
|
||||
device=logits.device)
|
||||
@ -190,21 +261,31 @@ def test_parametrized_cases(sampler, spec_tokens, output_tokens, expected):
|
||||
@pytest.mark.parametrize("batch_size", [1, 4, 8])
|
||||
@pytest.mark.parametrize("frac_seeded", [0.0, 0.5])
|
||||
@pytest.mark.parametrize("n_rep", [20])
|
||||
def test_deterministic_when_seeded(sampler, k: int, vocab_size: int,
|
||||
batch_size: int, frac_seeded: float,
|
||||
n_rep: int):
|
||||
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
target_probs = torch.rand(batch_size * (k + 1),
|
||||
vocab_size,
|
||||
dtype=torch.float32)
|
||||
def test_deterministic_when_seeded(
|
||||
rejection_sampler,
|
||||
k: int,
|
||||
vocab_size: int,
|
||||
batch_size: int,
|
||||
frac_seeded: float,
|
||||
n_rep: int,
|
||||
):
|
||||
num_tokens = batch_size * k
|
||||
draft_probs = torch.rand(num_tokens,
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device=DEVICE)
|
||||
draft_probs = F.softmax(draft_probs, dim=-1)
|
||||
target_logits = torch.rand_like(draft_probs)
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
dtype=torch.int64)
|
||||
dtype=torch.int64,
|
||||
device=DEVICE)
|
||||
draft_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k),
|
||||
dtype=torch.int64)
|
||||
dtype=torch.int64,
|
||||
device=DEVICE)
|
||||
|
||||
seeded_mask = torch.rand(batch_size, dtype=torch.float32) <= frac_seeded
|
||||
|
||||
@ -215,10 +296,21 @@ def test_deterministic_when_seeded(sampler, k: int, vocab_size: int,
|
||||
for i in range(batch_size) if seeded_mask[i]
|
||||
}
|
||||
|
||||
temperature = torch.ones(batch_size,
|
||||
dtype=torch.float32,
|
||||
device=DEVICE)
|
||||
sampling_metadata = create_sampling_metadata(all_greedy=False,
|
||||
temperature=temperature,
|
||||
generators=seeded_seqs)
|
||||
rep_result = sampler(draft_token_ids.tolist(), draft_probs,
|
||||
bonus_token_ids, target_probs, sampling_metadata)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
||||
draft_token_ids.tolist(), device=DEVICE)
|
||||
rep_result = rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
draft_probs=draft_probs,
|
||||
target_logits=target_logits,
|
||||
bonus_token_ids=bonus_token_ids,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
|
||||
results.append(rep_result)
|
||||
|
||||
@ -257,10 +349,10 @@ def test_rejection_sampling_approximates_target_distribution():
|
||||
num_reference_probs = 100
|
||||
|
||||
# Prepare draft, target, and reference probability distributions
|
||||
draft_probs, target_probs = (F.softmax(
|
||||
torch.rand(vocab_size, dtype=torch.float32),
|
||||
dim=-1,
|
||||
) for _ in range(2))
|
||||
draft_probs = F.softmax(torch.rand(vocab_size, dtype=torch.float32),
|
||||
dim=-1)
|
||||
target_logits = torch.rand(vocab_size, dtype=torch.float32)
|
||||
target_probs = F.softmax(target_logits, dim=-1)
|
||||
reference_probs = F.softmax(
|
||||
torch.rand(num_reference_probs, vocab_size, dtype=torch.float32),
|
||||
dim=-1,
|
||||
@ -273,7 +365,7 @@ def test_rejection_sampling_approximates_target_distribution():
|
||||
for num_samples in sample_sizes:
|
||||
# Sample using rejection sampling.
|
||||
rej_sample_probs = estimate_rejection_sampling_pdf(
|
||||
draft_probs, target_probs, k, vocab_size, num_samples)
|
||||
draft_probs, target_logits, k, vocab_size, num_samples)
|
||||
rej_sample_probs = rej_sample_probs.to(DEVICE)
|
||||
|
||||
# Average distance from reference probs.
|
||||
@ -313,7 +405,7 @@ def get_ratio_first_to_last(elements: list[float]) -> float:
|
||||
|
||||
def estimate_rejection_sampling_pdf(
|
||||
draft_probs: torch.Tensor,
|
||||
target_probs: torch.Tensor,
|
||||
target_logits: torch.Tensor,
|
||||
k: int,
|
||||
vocab_size: int,
|
||||
num_samples: int,
|
||||
@ -323,35 +415,44 @@ def estimate_rejection_sampling_pdf(
|
||||
|
||||
Args:
|
||||
draft_probs: Draft probability distribution.
|
||||
target_probs: Target probability distribution.
|
||||
target_logits: Target logits.
|
||||
num_samples: Number of samples to draw.
|
||||
|
||||
Returns:
|
||||
Estimated probability distribution of the output tokens.
|
||||
"""
|
||||
sampler = RejectionSampler()
|
||||
# Repeat draft probs num_samples times.
|
||||
rejection_sampler = RejectionSampler()
|
||||
num_tokens = num_samples * k
|
||||
# Repeat draft probs num_samples * k times.
|
||||
draft_probs = draft_probs.reshape(1, 1,
|
||||
vocab_size).repeat(num_samples, k, 1)
|
||||
|
||||
# Repeat target probs num_samples * (k + 1) times.
|
||||
target_probs = target_probs.reshape(1, 1, vocab_size).repeat(
|
||||
num_samples, k + 1, 1).reshape(num_samples * (k + 1), vocab_size)
|
||||
# Repeat target probs num_tokens times.
|
||||
target_logits = target_logits.reshape(1, vocab_size).repeat(num_tokens, 1)
|
||||
|
||||
# Randomly sample draft token ids from draft probs.
|
||||
draft_token_ids = torch.multinomial(draft_probs[:, 0, :],
|
||||
num_samples=k,
|
||||
replacement=True).reshape(
|
||||
num_samples, k)
|
||||
draft_probs = draft_probs.view(num_tokens, vocab_size)
|
||||
|
||||
# Bonus tokens not used but required.
|
||||
bonus_token_ids = torch.zeros((1, 1), dtype=torch.int64,
|
||||
device=DEVICE).repeat(num_samples, 1)
|
||||
|
||||
sampling_metadata = create_sampling_metadata(all_greedy=False)
|
||||
output_token_ids = sampler(draft_token_ids.tolist(), draft_probs,
|
||||
bonus_token_ids, target_probs,
|
||||
sampling_metadata)
|
||||
temperature = torch.ones(num_samples, dtype=torch.float32, device=DEVICE)
|
||||
sampling_metadata = create_sampling_metadata(all_greedy=False,
|
||||
temperature=temperature)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
||||
draft_token_ids.tolist(), device=bonus_token_ids.device)
|
||||
output_token_ids = rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
draft_probs=draft_probs,
|
||||
target_logits=target_logits,
|
||||
bonus_token_ids=bonus_token_ids,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
output_token_ids = output_token_ids[:, :-1].flatten()
|
||||
|
||||
hist = torch.histogram(output_token_ids.to(dtype=torch.float,
|
||||
|
@ -35,7 +35,6 @@ if TYPE_CHECKING:
|
||||
VLLM_TRACE_FUNCTION: int = 0
|
||||
VLLM_ATTENTION_BACKEND: Optional[str] = None
|
||||
VLLM_USE_FLASHINFER_SAMPLER: Optional[bool] = None
|
||||
VLLM_USE_FLASHINFER_REJECTION_SAMPLER: bool = False
|
||||
VLLM_FLASHINFER_FORCE_TENSOR_CORES: bool = False
|
||||
VLLM_PP_LAYER_PARTITION: Optional[str] = None
|
||||
VLLM_CPU_KVCACHE_SPACE: int = 0
|
||||
|
@ -46,7 +46,7 @@ class SamplerOutput:
|
||||
# [num_reqs, max_num_generated_tokens]
|
||||
# Different requests can have different number of generated tokens.
|
||||
# All requests are padded to max_num_generated_tokens.
|
||||
# INVALID_TOKEN_ID (-1 by default) is used for padding.
|
||||
# PLACEHOLDER_TOKEN_ID (-1 by default) is used for padding.
|
||||
sampled_token_ids: torch.Tensor
|
||||
logprobs_tensors: Optional[LogprobsTensors]
|
||||
|
||||
|
30
vllm/v1/sample/ops/utils.py
Normal file
30
vllm/v1/sample/ops/utils.py
Normal file
@ -0,0 +1,30 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def compiled_softmax(
|
||||
logits: torch.Tensor,
|
||||
temperature: Union[float, torch.Tensor] = 1.0,
|
||||
) -> torch.Tensor:
|
||||
"""Faster softmax kernel generated by torch.compile.
|
||||
|
||||
Args:
|
||||
logits: [n, vocab_size]
|
||||
temperature: [n] or float
|
||||
"""
|
||||
# NOTE(woosuk): Avoid recompilation by marking the first dim as dynamic.
|
||||
torch._dynamo.mark_dynamic(logits, index=0)
|
||||
if isinstance(temperature, torch.Tensor):
|
||||
torch._dynamo.mark_dynamic(temperature, index=0)
|
||||
return _softmax(logits, temperature)
|
||||
|
||||
|
||||
@torch.compile
|
||||
def _softmax(
|
||||
logits: torch.Tensor,
|
||||
temperature: Union[float, torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
logits = logits / temperature
|
||||
return torch.softmax(logits, dim=-1, dtype=torch.float32)
|
@ -3,25 +3,32 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.spec_decode.utils import random_sample
|
||||
from vllm.v1.sample.ops.utils import compiled_softmax
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
INVALID_TOKEN_ID = -1
|
||||
|
||||
PLACEHOLDER_TOKEN_ID: tl.constexpr = -1
|
||||
GREEDY_TEMPERATURE: tl.constexpr = -1
|
||||
# Maximum number of speculative draft tokens allowed per request in a single
|
||||
# step. This value is chosen to be large enough to handle typical use cases.
|
||||
MAX_SPEC_LEN = 32
|
||||
|
||||
|
||||
class RejectionSampler(nn.Module):
|
||||
"""
|
||||
The implementation strictly follows the algorithm described in
|
||||
The implementation strictly follows the algorithm described in
|
||||
https://arxiv.org/abs/2211.17192.
|
||||
However, we want to clarify the terminology used in the implementation:
|
||||
accepted tokens: tokens that are accepted based on the relationship
|
||||
accepted tokens: tokens that are accepted based on the relationship
|
||||
between the "raw" draft and target probabilities.
|
||||
recovered tokens: tokens that are sampled based on the adjusted probability
|
||||
distribution, which is derived from both the draft and target
|
||||
distribution, which is derived from both the draft and target
|
||||
probabilities.
|
||||
bonus tokens:
|
||||
If all proposed tokens are accepted, the bonus token is added to the
|
||||
@ -31,48 +38,42 @@ class RejectionSampler(nn.Module):
|
||||
sampling process. For example, we can use top_p, top_k sampling for
|
||||
bonus tokens, while spec decode does not support these sampling
|
||||
strategies.
|
||||
output tokens:
|
||||
Tokens are finally generated with the rejection sampler.
|
||||
output tokens:
|
||||
Tokens are finally generated with the rejection sampler.
|
||||
output tokens = accepted tokens + recovered tokens + bonus tokens
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
draft_token_ids: list[list[int]],
|
||||
metadata: SpecDecodeMetadata,
|
||||
# [num_tokens, vocab_size]
|
||||
draft_probs: Optional[torch.Tensor],
|
||||
bonus_token_ids_tensor: torch.Tensor, # [batch_size, 1]
|
||||
target_probs: torch.Tensor, # [num_total_tokens, vocab_size]
|
||||
# [num_tokens, vocab_size]
|
||||
target_logits: torch.Tensor,
|
||||
# [batch_size, 1]
|
||||
bonus_token_ids: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
'''
|
||||
Args:
|
||||
draft_token_ids (List[List[int]]):
|
||||
A 2D list of token IDs for each request in the batch.
|
||||
Each request might have different number of draft tokens.
|
||||
It may also contain empty lists for requests that have
|
||||
no draft tokens.
|
||||
metadata:
|
||||
Metadata for spec decoding.
|
||||
draft_probs (Optional[torch.Tensor]):
|
||||
Probability distribution for the draft tokens. Shape is
|
||||
[batch_size, max_spec_len, vocab_size]. Can be None if
|
||||
probabilities are not provided, which is the case for
|
||||
ngram spec decode.
|
||||
[num_tokens, vocab_size]. Can be None if probabilities are
|
||||
not provided, which is the case for ngram spec decode.
|
||||
target_logits (torch.Tensor):
|
||||
Target model's logits probability distribution.
|
||||
Shape is [num_tokens, vocab_size]. Here, probabilities from
|
||||
different requests are flattened into a single tensor because
|
||||
this is the shape of the output logits.
|
||||
bonus_token_ids_tensor (torch.Tensor):
|
||||
A tensor containing bonus tokens. Shape is [batch_size, 1].
|
||||
Bonus tokens are added to the end of the sequence if all
|
||||
proposed tokens are accepted. We generate the bonus tokens
|
||||
outside of the rejection sampler with the default sampling
|
||||
strategy. It allows for more flexibility in the sampling
|
||||
A tensor containing bonus tokens. Shape is [batch_size, 1].
|
||||
Bonus tokens are added to the end of the sequence if all
|
||||
proposed tokens are accepted. We generate the bonus tokens
|
||||
outside of the rejection sampler with the default sampling
|
||||
strategy. It allows for more flexibility in the sampling
|
||||
process such as top_p, top_k sampling.
|
||||
target_probs (torch.Tensor):
|
||||
Target model probability distribution.
|
||||
Shape is [num_total_tokens, vocab_size]. num_total_tokens
|
||||
is the total number of tokens from all requests. Here,
|
||||
probabilities from different requests are flattened into
|
||||
a single tensor because this is the shape of the output
|
||||
logits.
|
||||
sampling_metadata (SamplingMetadata):
|
||||
Additional metadata needed for sampling, such as temperature,
|
||||
top-k/top-p parameters, or other relevant information.
|
||||
@ -80,268 +81,481 @@ class RejectionSampler(nn.Module):
|
||||
output_token_ids (torch.Tensor):
|
||||
A tensor containing the final output token IDs.
|
||||
'''
|
||||
assert metadata.max_spec_len <= MAX_SPEC_LEN
|
||||
# [num_tokens, vocab_size]
|
||||
target_probs = compute_probs(
|
||||
target_logits,
|
||||
metadata.cu_num_draft_tokens,
|
||||
sampling_metadata,
|
||||
)
|
||||
|
||||
# NOTE: The following input preparationg can be moved
|
||||
# to the model runner with a persistent manner for better
|
||||
# performance.
|
||||
# Convert draft token IDs to a tensor, split by sample_lens, then pad.
|
||||
draft_token_ids = [
|
||||
torch.tensor(x, dtype=int, device='cpu') for x in draft_token_ids
|
||||
]
|
||||
draft_token_ids_tensor = pad_sequence(draft_token_ids,
|
||||
batch_first=True,
|
||||
padding_value=INVALID_TOKEN_ID)
|
||||
|
||||
# NOTE: CPU <-> GPU synchronization happens here.
|
||||
draft_token_ids_tensor = draft_token_ids_tensor.to(target_probs.device)
|
||||
|
||||
# Create one-hot tensor for draft token ids.
|
||||
# This is used for ngram where we don't have draft_probs.
|
||||
if draft_probs is None and not sampling_metadata.all_greedy:
|
||||
vocab_size = target_probs.size(-1)
|
||||
draft_probs = _create_greedy_token_probs(draft_token_ids_tensor,
|
||||
vocab_size,
|
||||
target_probs.device)
|
||||
sample_lens = [len(x) + 1 for x in draft_token_ids]
|
||||
target_probs = _convert_2d_probs(target_probs, sample_lens)
|
||||
|
||||
return self.forward_native(draft_token_ids_tensor, draft_probs,
|
||||
bonus_token_ids_tensor, target_probs,
|
||||
sampling_metadata)
|
||||
|
||||
# TODO: The following method can be optimized for better performance.
|
||||
def forward_native(
|
||||
self,
|
||||
draft_token_ids_tensor: torch.Tensor,
|
||||
# [batch_size, max_spec_len, vocab_size]
|
||||
draft_probs: Optional[torch.Tensor],
|
||||
bonus_token_ids_tensor: torch.Tensor,
|
||||
# [batch_size, max_spec_len + 1, vocab_size]
|
||||
target_probs: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
# Add 1 to include the 'bonus' token.
|
||||
if sampling_metadata.all_greedy:
|
||||
# Produce a mask that remains 1 (True) until the first
|
||||
# mismatch (cumprod turns 0 after a mismatch).
|
||||
target_token_ids_tensor = target_probs.argmax(dim=-1)
|
||||
accept_mask = (target_token_ids_tensor[:, :-1] ==
|
||||
draft_token_ids_tensor).cumprod(dim=1)
|
||||
|
||||
# Identify valid positions (non-padding).
|
||||
valid_mask = target_token_ids_tensor != INVALID_TOKEN_ID
|
||||
# Generate mask with bonus token.
|
||||
generate_mask = torch.cat([
|
||||
accept_mask,
|
||||
torch.zeros(accept_mask.size(0), 1, device=accept_mask.device)
|
||||
],
|
||||
dim=1).to(torch.bool) & valid_mask
|
||||
zeros_mask = (generate_mask == 0)
|
||||
first_zero_idx = zeros_mask.float().argmax(dim=1)
|
||||
# Figure out which rows actually contain at least one zero.
|
||||
rows_with_zero = zeros_mask.any(dim=1)
|
||||
# Use indexing to set the first zero in each of those rows to 1.
|
||||
generate_mask[rows_with_zero, first_zero_idx[rows_with_zero]] = 1
|
||||
|
||||
output_token_ids = target_token_ids_tensor
|
||||
output_token_ids[~generate_mask] = INVALID_TOKEN_ID
|
||||
else:
|
||||
# Reference: https://arxiv.org/pdf/2211.17192
|
||||
# 1. Extract the probabilities of the draft tokens.
|
||||
# [batch_size, max_spec_len]
|
||||
batch_size = draft_token_ids_tensor.size(0)
|
||||
max_spec_len = draft_token_ids_tensor.size(1)
|
||||
invalid_idx = draft_token_ids_tensor == INVALID_TOKEN_ID
|
||||
draft_token_ids_tensor[invalid_idx] = 0
|
||||
assert draft_probs is not None
|
||||
draft_token_probs = draft_probs.gather(
|
||||
dim=-1, index=draft_token_ids_tensor.unsqueeze(-1)).squeeze(-1)
|
||||
target_token_probs = target_probs.gather(
|
||||
dim=-1, index=draft_token_ids_tensor.unsqueeze(-1)).squeeze(-1)
|
||||
# Force the probabilities of invalid tokens to inf
|
||||
# so that they are not accepted.
|
||||
draft_token_probs[invalid_idx] = float('inf')
|
||||
|
||||
# 2. Generate uniform samples.
|
||||
# [batch_size, max_spec_len + 1]
|
||||
uniform_samples = _create_uniform_samples(
|
||||
sampling_metadata.generators, batch_size, max_spec_len,
|
||||
target_probs.device)
|
||||
|
||||
# 3. Accept or reject the samples.
|
||||
# [batch_size, max_spec_len]
|
||||
# If the draft token probabilities are 0, set them to the smallest
|
||||
# positive normal value representable by float32.
|
||||
safe_draft_probs = torch.where(draft_token_probs > 0,
|
||||
draft_token_probs,
|
||||
torch.finfo(torch.float32).tiny)
|
||||
accepted = uniform_samples <= target_token_probs / safe_draft_probs
|
||||
accept_mask = accepted.cumprod(dim=1)
|
||||
# Set the token ids to the draft token ids if accepted, otherwise
|
||||
# set them to INVALID_TOKEN_ID.
|
||||
accepted_token_ids = (draft_token_ids_tensor * accept_mask +
|
||||
INVALID_TOKEN_ID * (1 - accept_mask))
|
||||
|
||||
# 4. Adjust the distribution for the recovered tokens.
|
||||
# Clamp the bonus probabilities to the smallest positive normal
|
||||
# value representable by float32.
|
||||
bonus_prob = torch.clamp(target_probs[:, :-1, :] - draft_probs,
|
||||
min=torch.finfo(torch.float32).tiny)
|
||||
normalized_bonus_prob = bonus_prob / bonus_prob.sum(dim=-1,
|
||||
keepdim=True)
|
||||
|
||||
# 5. Sample recovered token ids.
|
||||
recovered_token_ids = random_sample(
|
||||
normalized_bonus_prob,
|
||||
sampling_metadata.generators).reshape(batch_size, max_spec_len)
|
||||
|
||||
# 6. Get the final output token ids.
|
||||
# output_token_ids = accepted_token_ids +
|
||||
# recovered_token_ids +
|
||||
# bonus_token_id
|
||||
recovered_bonus_token_ids = torch.cat(
|
||||
[recovered_token_ids, bonus_token_ids_tensor], dim=1)
|
||||
# Generate mask with bonus tokens.
|
||||
generate_mask = torch.cat([
|
||||
accept_mask,
|
||||
torch.zeros(batch_size, 1, device=accept_mask.device)
|
||||
],
|
||||
dim=1).to(torch.bool)
|
||||
zeros_mask = (generate_mask == 0)
|
||||
first_zero_idx = zeros_mask.float().argmax(dim=1)
|
||||
output_token_ids = torch.cat([
|
||||
accepted_token_ids,
|
||||
torch.full((batch_size, 1),
|
||||
fill_value=INVALID_TOKEN_ID,
|
||||
device=accept_mask.device)
|
||||
],
|
||||
dim=1)
|
||||
output_token_ids[torch.arange(batch_size),
|
||||
first_zero_idx] = recovered_bonus_token_ids[
|
||||
torch.arange(batch_size), first_zero_idx]
|
||||
|
||||
output_token_ids = rejection_sample(
|
||||
metadata.draft_token_ids,
|
||||
metadata.num_draft_tokens,
|
||||
metadata.max_spec_len,
|
||||
metadata.cu_num_draft_tokens,
|
||||
draft_probs,
|
||||
target_probs,
|
||||
bonus_token_ids,
|
||||
sampling_metadata,
|
||||
)
|
||||
return output_token_ids
|
||||
|
||||
def compute_probs(self, logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
sample_lens: list[int]) -> torch.Tensor:
|
||||
"""
|
||||
Compute probability distribution from logits based on sampling metadata.
|
||||
|
||||
This function applies temperature scaling to the logits and converts
|
||||
them to probabilities using softmax. Note that division by
|
||||
temperature is not performed inplace to preserve the original logits
|
||||
tensor, which will be used by the original sampler to get bonus tokens.
|
||||
|
||||
Args:
|
||||
logits: Input logits tensor to be converted to probabilities
|
||||
sampling_metadata: Metadata containing sampling parameters such
|
||||
as temperature and whether greedy sampling is used
|
||||
sample_lens: List of sample lengths used for repeating
|
||||
temperature values
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Probability distribution (softmax of scaled logits)
|
||||
if non-greedy sampling is used, otherwise returns the
|
||||
original logits
|
||||
"""
|
||||
if sampling_metadata.all_greedy:
|
||||
return logits
|
||||
assert sampling_metadata.temperature is not None
|
||||
# We should optimize the following code as
|
||||
# it will cause CPU -> GPU synchronization.
|
||||
temperature = torch.repeat_interleave(
|
||||
sampling_metadata.temperature,
|
||||
torch.tensor(sample_lens,
|
||||
device=sampling_metadata.temperature.device))
|
||||
temperature = temperature.unsqueeze(dim=1)
|
||||
logits = logits / temperature
|
||||
return logits.softmax(dim=-1, dtype=torch.float32)
|
||||
@staticmethod
|
||||
def parse_output(
|
||||
output_token_ids: torch.Tensor,
|
||||
vocab_size: int,
|
||||
) -> list[list[int]]:
|
||||
output_token_ids_np = output_token_ids.cpu().numpy()
|
||||
# Create mask for valid tokens.
|
||||
valid_mask = ((output_token_ids_np != PLACEHOLDER_TOKEN_ID) &
|
||||
(output_token_ids_np < vocab_size))
|
||||
outputs = [
|
||||
row[valid_mask[i]].tolist()
|
||||
for i, row in enumerate(output_token_ids_np)
|
||||
]
|
||||
return outputs
|
||||
|
||||
|
||||
def _create_greedy_token_probs(
|
||||
token_ids: torch.Tensor,
|
||||
vocab_size: int,
|
||||
out_device: torch.device,
|
||||
def rejection_sample(
|
||||
# [num_tokens]
|
||||
draft_token_ids: torch.Tensor,
|
||||
# [batch_size]
|
||||
num_draft_tokens: list[int],
|
||||
max_spec_len: int,
|
||||
# [batch_size]
|
||||
cu_num_draft_tokens: torch.Tensor,
|
||||
# [num_tokens, vocab_size]
|
||||
draft_probs: Optional[torch.Tensor],
|
||||
# [num_tokens, vocab_size]
|
||||
target_probs: torch.Tensor,
|
||||
# [batch_size, 1]
|
||||
bonus_token_ids: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
batch_size, num_tokens = token_ids.shape
|
||||
assert draft_token_ids.ndim == 1
|
||||
assert draft_probs is None or draft_probs.ndim == 2
|
||||
assert cu_num_draft_tokens.ndim == 1
|
||||
assert target_probs.ndim == 2
|
||||
|
||||
token_probs = torch.zeros(batch_size,
|
||||
num_tokens,
|
||||
vocab_size,
|
||||
dtype=torch.float,
|
||||
device=out_device)
|
||||
batch_size = len(num_draft_tokens)
|
||||
num_tokens = draft_token_ids.shape[0]
|
||||
vocab_size = target_probs.shape[-1]
|
||||
device = target_probs.device
|
||||
assert draft_token_ids.is_contiguous()
|
||||
assert draft_probs is None or draft_probs.is_contiguous()
|
||||
assert target_probs.is_contiguous()
|
||||
assert bonus_token_ids.is_contiguous()
|
||||
assert target_probs.shape == (num_tokens, vocab_size)
|
||||
|
||||
# Ignore INVALID_TOKEN_ID.
|
||||
valid_mask = (token_ids != INVALID_TOKEN_ID)
|
||||
valid_indices = token_ids.clone()
|
||||
valid_indices[~valid_mask] = 0
|
||||
# Create output buffer.
|
||||
output_token_ids = torch.empty(
|
||||
(batch_size, max_spec_len + 1),
|
||||
dtype=torch.int32, # Consistent with SamplerOutput.sampled_token_ids.
|
||||
device=device,
|
||||
)
|
||||
output_token_ids.fill_(PLACEHOLDER_TOKEN_ID)
|
||||
|
||||
token_probs.scatter_(dim=2,
|
||||
index=valid_indices.unsqueeze(-1),
|
||||
src=valid_mask.unsqueeze(-1).float())
|
||||
if sampling_metadata.all_greedy:
|
||||
is_greedy = None
|
||||
else:
|
||||
is_greedy = sampling_metadata.temperature == GREEDY_TEMPERATURE
|
||||
if not sampling_metadata.all_random:
|
||||
# Rejection sampling for greedy sampling requests.
|
||||
target_argmax = target_probs.argmax(dim=-1)
|
||||
rejection_greedy_sample_kernel[(batch_size, )](
|
||||
output_token_ids,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
target_argmax,
|
||||
bonus_token_ids,
|
||||
is_greedy,
|
||||
max_spec_len,
|
||||
num_warps=1,
|
||||
)
|
||||
if sampling_metadata.all_greedy:
|
||||
return output_token_ids
|
||||
|
||||
return token_probs
|
||||
# Generate uniform probabilities for rejection sampling.
|
||||
# [num_tokens]
|
||||
uniform_probs = generate_uniform_probs(
|
||||
num_tokens,
|
||||
num_draft_tokens,
|
||||
sampling_metadata.generators,
|
||||
device,
|
||||
)
|
||||
|
||||
# Sample recovered tokens for each position.
|
||||
# [num_tokens]
|
||||
recovered_token_ids = sample_recovered_tokens(
|
||||
max_spec_len,
|
||||
num_draft_tokens,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
draft_probs,
|
||||
target_probs,
|
||||
sampling_metadata,
|
||||
device,
|
||||
)
|
||||
|
||||
# Rejection sampling for random sampling requests.
|
||||
rejection_random_sample_kernel[(batch_size, )](
|
||||
output_token_ids,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
draft_probs,
|
||||
target_probs,
|
||||
bonus_token_ids,
|
||||
recovered_token_ids,
|
||||
uniform_probs,
|
||||
is_greedy,
|
||||
max_spec_len,
|
||||
vocab_size,
|
||||
IS_NGRAM=draft_probs is None,
|
||||
num_warps=1,
|
||||
)
|
||||
return output_token_ids
|
||||
|
||||
|
||||
def _convert_2d_probs(
|
||||
probs: torch.Tensor, # [num_total_tokens, vocab_size]
|
||||
sample_lens: list[int]) -> torch.Tensor:
|
||||
def compute_probs(
|
||||
logits: torch.Tensor, # [num_tokens, vocab_size]
|
||||
cu_num_draft_tokens: torch.Tensor, # [batch_size]
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
"""Compute probability distribution from logits based on sampling metadata.
|
||||
|
||||
This function applies temperature scaling to the logits and converts
|
||||
them to probabilities using softmax. For greedy decoding, it returns
|
||||
the original logits.
|
||||
|
||||
Args:
|
||||
logits: Input logits tensor to be converted to probabilities.
|
||||
cu_num_draft_tokens: Cumulative number of draft tokens.
|
||||
sampling_metadata: Metadata containing sampling parameters such as
|
||||
temperature and whether greedy sampling is used.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Probability distribution (softmax of scaled logits)
|
||||
if non-greedy sampling is used, otherwise returns the
|
||||
original logits.
|
||||
"""
|
||||
Converts a 2D tensor of probabilities to a 3D tensor with padding.
|
||||
[num_total_tokens, vocab_size] ->
|
||||
[batch_size, max_spec_len + 1, vocab_size]
|
||||
assert logits.ndim == 2
|
||||
assert cu_num_draft_tokens.ndim == 1
|
||||
if sampling_metadata.all_greedy:
|
||||
return logits
|
||||
|
||||
num_tokens = logits.shape[0]
|
||||
batch_size = cu_num_draft_tokens.shape[0]
|
||||
expanded_temperature = torch.empty(
|
||||
(num_tokens, 1),
|
||||
dtype=torch.float32,
|
||||
device=logits.device,
|
||||
)
|
||||
expand_kernel[(batch_size, )](
|
||||
expanded_temperature,
|
||||
sampling_metadata.temperature,
|
||||
cu_num_draft_tokens,
|
||||
GREEDY_TEMPERATURE, # replace_from
|
||||
1, # replace_to
|
||||
MAX_NUM_TOKENS=MAX_SPEC_LEN,
|
||||
num_warps=1,
|
||||
)
|
||||
output_prob = compiled_softmax(logits, expanded_temperature)
|
||||
return output_prob
|
||||
|
||||
|
||||
def generate_uniform_probs(
|
||||
num_tokens: int,
|
||||
num_draft_tokens: list[int],
|
||||
generators: dict[int, torch.Generator],
|
||||
device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
cumulative_lens = torch.cumsum(torch.tensor(sample_lens,
|
||||
device=probs.device),
|
||||
dim=0)
|
||||
split_indices = cumulative_lens[:-1].tolist() # Exclude last index
|
||||
Generates a batch of uniform random samples, with optional seeding
|
||||
if available.
|
||||
|
||||
# Split into chunks without loops
|
||||
chunks = torch.tensor_split(probs, split_indices, dim=0)
|
||||
This method creates a tensor of shape `(num_tokens, )` filled
|
||||
with uniform random values in the range [0, 1). If `generators` is provided,
|
||||
the requests with their own seeds will use the provided `torch.Generator`
|
||||
for reproducibility. The samples for the other requests will be generated
|
||||
without a seed.
|
||||
|
||||
# Pad all sequences to maximum length
|
||||
padded_probs = pad_sequence(chunks, batch_first=True, padding_value=0.0)
|
||||
return padded_probs
|
||||
|
||||
|
||||
def _create_uniform_samples(seeded_seqs: dict[int, torch.Generator],
|
||||
batch_size: int, k: int,
|
||||
device: torch.device) -> torch.Tensor:
|
||||
Args:
|
||||
num_tokens : int
|
||||
Total number of tokens.
|
||||
num_draft_tokens : List[List[int]]
|
||||
Number of draft tokens per request.
|
||||
generators : Optional[Dict[int, torch.Generator]]
|
||||
A dictionary mapping indices in the batch to
|
||||
`torch.Generator` objects.
|
||||
device : torch.device
|
||||
The device on which to allocate the tensor.
|
||||
Returns:
|
||||
uniform_rand : torch.Tensor
|
||||
A tensor of shape `(num_tokens, )` containing uniform
|
||||
random values in the range [0, 1).
|
||||
"""
|
||||
Generates a batch of uniform random samples, with optional seeding
|
||||
for specific sequences.
|
||||
uniform_probs = torch.rand(
|
||||
(num_tokens, ),
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
start_idx = 0
|
||||
for req_idx, n in enumerate(num_draft_tokens):
|
||||
# Do not generate random numbers for requests with no draft tokens.
|
||||
# This can be important for reproducibility.
|
||||
if n == 0:
|
||||
continue
|
||||
end_idx = start_idx + n
|
||||
generator = generators.get(req_idx)
|
||||
if generator is not None:
|
||||
uniform_probs[start_idx:end_idx].uniform_(generator=generator)
|
||||
start_idx = end_idx
|
||||
return uniform_probs
|
||||
|
||||
This method creates a tensor of shape `(batch_size, k)` filled
|
||||
with uniform random values in the range [0, 1). If `seeded_seqs`
|
||||
is provided, the sequences corresponding to specific indices
|
||||
will be generated using the provided `torch.Generator` for
|
||||
reproducibility. The other sequences will be generated without
|
||||
a seed.
|
||||
|
||||
Args:
|
||||
seeded_seqs : Optional[Dict[int, torch.Generator]]
|
||||
A dictionary mapping indices in the batch to
|
||||
`torch.Generator` objects.
|
||||
batch_size : int
|
||||
The number of sequences to generate.
|
||||
k : int
|
||||
The number of random samples per sequence.
|
||||
device : torch.device
|
||||
The device on which to allocate the tensor.
|
||||
def sample_recovered_tokens(
|
||||
max_spec_len: int,
|
||||
num_draft_tokens: list[int],
|
||||
# [batch_size]
|
||||
cu_num_draft_tokens: torch.Tensor,
|
||||
# [num_tokens]
|
||||
draft_token_ids: torch.Tensor,
|
||||
# [num_tokens, vocab_size]
|
||||
draft_probs: Optional[torch.Tensor],
|
||||
# [num_tokens, vocab_size]
|
||||
target_probs: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
# NOTE(woosuk): Create only one distribution for each request.
|
||||
batch_size = len(num_draft_tokens)
|
||||
vocab_size = target_probs.shape[-1]
|
||||
q = torch.empty(
|
||||
(batch_size, vocab_size),
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
q.exponential_()
|
||||
for i, generator in sampling_metadata.generators.items():
|
||||
# Do not generate random numbers for requests with no draft tokens.
|
||||
# This can be important for reproducibility.
|
||||
if num_draft_tokens[i] > 0:
|
||||
q[i].exponential_(generator=generator)
|
||||
|
||||
Returns:
|
||||
uniform_rand : torch.Tensor
|
||||
A tensor of shape `(batch_size, k)` containing uniform
|
||||
random values in the range [0, 1).
|
||||
"""
|
||||
recovered_token_ids = torch.empty_like(draft_token_ids)
|
||||
sample_recovered_tokens_kernel[(batch_size, max_spec_len)](
|
||||
recovered_token_ids,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
draft_probs,
|
||||
target_probs,
|
||||
q,
|
||||
vocab_size,
|
||||
triton.next_power_of_2(vocab_size),
|
||||
IS_NGRAM=draft_probs is None,
|
||||
)
|
||||
return recovered_token_ids
|
||||
|
||||
uniform_rand = torch.rand(batch_size,
|
||||
k,
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
# Apply seeded generators only where needed
|
||||
if seeded_seqs:
|
||||
for idx, generator in seeded_seqs.items():
|
||||
uniform_rand[idx].uniform_(0, 1, generator=generator)
|
||||
return uniform_rand
|
||||
|
||||
# NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation.
|
||||
@triton.jit(do_not_specialize=["max_spec_len"])
|
||||
def rejection_greedy_sample_kernel(
|
||||
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
|
||||
cu_num_draft_tokens_ptr, # [batch_size]
|
||||
draft_token_ids_ptr, # [num_tokens]
|
||||
target_argmax_ptr, # [num_tokens]
|
||||
bonus_token_ids_ptr, # [batch_size]
|
||||
is_greedy_ptr, # [batch_size] or None
|
||||
max_spec_len,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
# FIXME(woosuk): Because is_greedy_ptr is not None at profiling run,
|
||||
# re-compilation may happen during runtime when is_greedy_ptr is None.
|
||||
if is_greedy_ptr is None:
|
||||
is_greedy = True
|
||||
else:
|
||||
is_greedy = tl.load(is_greedy_ptr + req_idx)
|
||||
if not is_greedy:
|
||||
# Early exit for non-greedy sampling requests.
|
||||
return
|
||||
|
||||
if req_idx == 0:
|
||||
start_idx = 0
|
||||
else:
|
||||
start_idx = tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
|
||||
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
|
||||
num_draft_tokens = end_idx - start_idx
|
||||
|
||||
rejected = False
|
||||
for pos in range(num_draft_tokens):
|
||||
if not rejected:
|
||||
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
|
||||
target_argmax_id = tl.load(target_argmax_ptr + start_idx + pos)
|
||||
tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
|
||||
target_argmax_id)
|
||||
if draft_token_id != target_argmax_id:
|
||||
# Reject.
|
||||
rejected = True
|
||||
|
||||
if not rejected:
|
||||
# If all tokens are accepted, append the bonus token.
|
||||
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
|
||||
tl.store(
|
||||
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
|
||||
num_draft_tokens, bonus_token_id)
|
||||
|
||||
|
||||
# NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation.
|
||||
@triton.jit(do_not_specialize=["max_spec_len"])
|
||||
def rejection_random_sample_kernel(
|
||||
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
|
||||
cu_num_draft_tokens_ptr, # [batch_size]
|
||||
draft_token_ids_ptr, # [num_tokens]
|
||||
draft_probs_ptr, # [num_tokens, vocab_size] or None
|
||||
target_probs_ptr, # [num_tokens, vocab_size]
|
||||
bonus_token_ids_ptr, # [batch_size]
|
||||
recovered_token_ids_ptr, # [num_tokens]
|
||||
uniform_probs_ptr, # [num_tokens]
|
||||
is_greedy_ptr, # [batch_size]
|
||||
max_spec_len,
|
||||
vocab_size,
|
||||
IS_NGRAM: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
is_greedy = tl.load(is_greedy_ptr + req_idx)
|
||||
if is_greedy:
|
||||
# Early exit for greedy sampling requests.
|
||||
return
|
||||
|
||||
if req_idx == 0:
|
||||
start_idx = 0
|
||||
else:
|
||||
start_idx = tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
|
||||
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
|
||||
num_draft_tokens = end_idx - start_idx
|
||||
|
||||
rejected = False
|
||||
for pos in range(num_draft_tokens):
|
||||
if not rejected:
|
||||
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
|
||||
if IS_NGRAM:
|
||||
draft_prob = 1
|
||||
else:
|
||||
draft_prob = tl.load(draft_probs_ptr +
|
||||
(start_idx + pos) * vocab_size +
|
||||
draft_token_id)
|
||||
target_prob = tl.load(target_probs_ptr +
|
||||
(start_idx + pos) * vocab_size +
|
||||
draft_token_id)
|
||||
uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos)
|
||||
# NOTE(woosuk): While the draft probability should never be 0,
|
||||
# we check it to avoid NaNs. If it happens to be 0, we reject.
|
||||
if draft_prob > 0 and target_prob / draft_prob >= uniform_prob:
|
||||
# Accept.
|
||||
token_id = draft_token_id
|
||||
else:
|
||||
# Reject. Use recovered token.
|
||||
rejected = True
|
||||
token_id = tl.load(recovered_token_ids_ptr + start_idx + pos)
|
||||
tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
|
||||
token_id)
|
||||
|
||||
if not rejected:
|
||||
# If all tokens are accepted, append the bonus token.
|
||||
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
|
||||
tl.store(
|
||||
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
|
||||
num_draft_tokens, bonus_token_id)
|
||||
|
||||
|
||||
# NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation.
|
||||
@triton.jit(do_not_specialize=["replace_from", "replace_to"])
|
||||
def expand_kernel(
|
||||
output_ptr, # [num_tokens]
|
||||
input_ptr, # [batch_size]
|
||||
cu_num_tokens_ptr, # [batch_size]
|
||||
replace_from,
|
||||
replace_to,
|
||||
MAX_NUM_TOKENS: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
if req_idx == 0: # noqa: SIM108
|
||||
start_idx = 0
|
||||
else:
|
||||
start_idx = tl.load(cu_num_tokens_ptr + req_idx - 1)
|
||||
end_idx = tl.load(cu_num_tokens_ptr + req_idx)
|
||||
num_tokens = end_idx - start_idx
|
||||
|
||||
src_val = tl.load(input_ptr + req_idx)
|
||||
src_val = tl.where(src_val == replace_from, replace_to, src_val)
|
||||
offset = tl.arange(0, MAX_NUM_TOKENS)
|
||||
tl.store(output_ptr + start_idx + offset,
|
||||
src_val,
|
||||
mask=offset < num_tokens)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def sample_recovered_tokens_kernel(
|
||||
output_token_ids_ptr, # [num_tokens]
|
||||
cu_num_draft_tokens_ptr, # [batch_size]
|
||||
draft_token_ids_ptr, # [num_tokens]
|
||||
draft_probs_ptr, # [num_tokens, vocab_size] or None
|
||||
target_probs_ptr, # [num_tokens, vocab_size]
|
||||
q_ptr, # [batch_size, vocab_size]
|
||||
vocab_size,
|
||||
PADDED_VOCAB_SIZE: tl.constexpr,
|
||||
IS_NGRAM: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
if req_idx == 0:
|
||||
start_idx = 0
|
||||
else:
|
||||
start_idx = tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
|
||||
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
|
||||
num_draft_tokens = end_idx - start_idx
|
||||
|
||||
# Early exit for out-of-range positions.
|
||||
pos = tl.program_id(1)
|
||||
if pos >= num_draft_tokens:
|
||||
return
|
||||
|
||||
vocab_offset = tl.arange(0, PADDED_VOCAB_SIZE)
|
||||
if IS_NGRAM:
|
||||
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
|
||||
orig_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size +
|
||||
draft_token_id)
|
||||
# Temporarily zero out the probability of the draft token.
|
||||
# This is essentially the same as target_prob - draft_prob, except that
|
||||
# n-gram does not have draft_prob. We regard it as 1.
|
||||
tl.store(
|
||||
target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id,
|
||||
0)
|
||||
prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size +
|
||||
vocab_offset,
|
||||
mask=vocab_offset < vocab_size,
|
||||
other=0)
|
||||
else:
|
||||
draft_prob = tl.load(draft_probs_ptr + (start_idx + pos) * vocab_size +
|
||||
vocab_offset,
|
||||
mask=vocab_offset < vocab_size,
|
||||
other=0)
|
||||
target_prob = tl.load(target_probs_ptr +
|
||||
(start_idx + pos) * vocab_size + vocab_offset,
|
||||
mask=vocab_offset < vocab_size,
|
||||
other=0)
|
||||
prob = tl.maximum(target_prob - draft_prob, 0)
|
||||
# NOTE(woosuk): We don't need `prob = prob / tl.sum(prob)` here because
|
||||
# `tl.argmax` will select the maximum value.
|
||||
|
||||
q = tl.load(q_ptr + req_idx * vocab_size + vocab_offset,
|
||||
mask=vocab_offset < vocab_size,
|
||||
other=float("-inf"))
|
||||
recovered_id = tl.argmax(prob / q, axis=-1)
|
||||
tl.store(output_token_ids_ptr + start_idx + pos, recovered_id)
|
||||
|
||||
if IS_NGRAM:
|
||||
# Restore the original probability.
|
||||
tl.store(
|
||||
target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id,
|
||||
orig_prob)
|
||||
|
61
vllm/v1/spec_decode/metadata.py
Normal file
61
vllm/v1/spec_decode/metadata.py
Normal file
@ -0,0 +1,61 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpecDecodeMetadata:
|
||||
|
||||
# [num_tokens]
|
||||
draft_token_ids: torch.Tensor
|
||||
# [batch_size]
|
||||
num_draft_tokens: list[int]
|
||||
# [batch_size]
|
||||
cu_num_draft_tokens: torch.Tensor
|
||||
# [num_tokens]
|
||||
target_logits_indices: torch.Tensor
|
||||
# [batch_size]
|
||||
bonus_logits_indices: torch.Tensor
|
||||
# [num_tokens + batch_size]
|
||||
logits_indices: torch.Tensor
|
||||
|
||||
def __post_init__(self):
|
||||
self.max_spec_len = max(self.num_draft_tokens)
|
||||
|
||||
@classmethod
|
||||
def make_dummy(
|
||||
cls,
|
||||
draft_token_ids: list[list[int]],
|
||||
device: torch.device,
|
||||
) -> "SpecDecodeMetadata":
|
||||
batch_size = len(draft_token_ids)
|
||||
num_draft_tokens = [len(ids) for ids in draft_token_ids]
|
||||
flattened_draft_token_ids = sum(draft_token_ids, [])
|
||||
num_tokens = len(flattened_draft_token_ids)
|
||||
|
||||
draft_token_ids_tensor = torch.tensor(flattened_draft_token_ids,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
cu_num_draft_tokens = np.cumsum(num_draft_tokens, dtype=np.int32)
|
||||
cu_num_draft_tokens_tensor = torch.from_numpy(cu_num_draft_tokens).to(
|
||||
device)
|
||||
|
||||
target_logits_indices = torch.zeros(num_tokens,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
bonus_logits_indices = torch.zeros(batch_size,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
logits_indices = torch.zeros(num_tokens + batch_size,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
return cls(
|
||||
draft_token_ids=draft_token_ids_tensor,
|
||||
num_draft_tokens=num_draft_tokens,
|
||||
cu_num_draft_tokens=cu_num_draft_tokens_tensor,
|
||||
target_logits_indices=target_logits_indices,
|
||||
bonus_logits_indices=bonus_logits_indices,
|
||||
logits_indices=logits_indices,
|
||||
)
|
@ -1,5 +1,4 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import random_sample # noqa
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
|
||||
|
@ -34,7 +34,8 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
|
||||
ModelRunnerOutput)
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID, RejectionSampler
|
||||
from vllm.v1.sample.rejection_sampler import RejectionSampler
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||
from vllm.v1.spec_decode.utils import is_spec_decode_supported
|
||||
from vllm.v1.utils import bind_kv_cache
|
||||
@ -149,7 +150,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.use_spec_decode = False
|
||||
if self.speculative_config:
|
||||
self.use_spec_decode = True
|
||||
self.rejection_sampler = RejectionSampler()
|
||||
# TODO: find a better way to check if we are using ngram.
|
||||
assert self.speculative_config.ngram_prompt_lookup_min, \
|
||||
"Currently, only ngram spec decode is supported in V1."
|
||||
@ -162,6 +162,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.speculative_config.ngram_prompt_lookup_min,
|
||||
self.speculative_config.num_speculative_tokens,
|
||||
)
|
||||
self.rejection_sampler = RejectionSampler()
|
||||
|
||||
# Request states.
|
||||
self.requests: dict[str, CachedRequestState] = {}
|
||||
@ -452,7 +453,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
def _prepare_inputs(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> tuple[FlashAttentionMetadata, torch.Tensor]:
|
||||
) -> tuple[FlashAttentionMetadata, torch.Tensor,
|
||||
Optional[SpecDecodeMetadata]]:
|
||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
assert total_num_scheduled_tokens > 0
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
@ -577,22 +579,33 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
use_spec_decode = len(
|
||||
scheduler_output.scheduled_spec_decode_tokens) > 0
|
||||
if use_spec_decode:
|
||||
logits_indices = self._calc_spec_decode_metadata(
|
||||
scheduler_output, cu_num_tokens)
|
||||
else:
|
||||
if not use_spec_decode:
|
||||
# NOTE(woosuk): Due to chunked prefills, the batch may contain
|
||||
# partial requests. While we should not sample any token
|
||||
# from these partial requests, we do so for simplicity.
|
||||
# We will ignore the sampled tokens from the partial requests.
|
||||
# TODO: Support prompt logprobs.
|
||||
logits_indices = attn_metadata.query_start_loc[1:] - 1
|
||||
spec_decode_metadata = None
|
||||
else:
|
||||
# Get the number of draft tokens for each request.
|
||||
# Iterate over the dictionary rather than all requests since not all
|
||||
# requests have draft tokens.
|
||||
num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
|
||||
for req_id, draft_token_ids in (
|
||||
scheduler_output.scheduled_spec_decode_tokens.items()):
|
||||
req_idx = self.input_batch.req_id_to_index[req_id]
|
||||
num_draft_tokens[req_idx] = len(draft_token_ids)
|
||||
|
||||
spec_decode_metadata = self._calc_spec_decode_metadata(
|
||||
num_draft_tokens, cu_num_tokens)
|
||||
logits_indices = spec_decode_metadata.logits_indices
|
||||
|
||||
# Hot-Swap lora model
|
||||
if self.lora_config:
|
||||
self.set_active_loras(self.input_batch, num_scheduled_tokens)
|
||||
|
||||
return attn_metadata, logits_indices
|
||||
return attn_metadata, logits_indices, spec_decode_metadata
|
||||
|
||||
def _compute_cascade_attn_prefix_len(
|
||||
self,
|
||||
@ -732,49 +745,78 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
def _calc_spec_decode_metadata(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
cu_num_tokens: np.ndarray,
|
||||
) -> torch.Tensor:
|
||||
# Get the number of spec decode tokens for each request.
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
num_spec_decode_tokens = np.empty(num_reqs, dtype=np.int32)
|
||||
for i, req_id in enumerate(self.input_batch.req_ids):
|
||||
num_spec_decode_tokens[i] = len(
|
||||
scheduler_output.scheduled_spec_decode_tokens.get(req_id, ()))
|
||||
num_draft_tokens: np.ndarray,
|
||||
cu_num_scheduled_tokens: np.ndarray,
|
||||
) -> SpecDecodeMetadata:
|
||||
# Inputs:
|
||||
# cu_num_scheduled_tokens: [ 4, 104, 107, 207, 209]
|
||||
# num_draft_tokens: [ 3, 0, 2, 0, 1]
|
||||
# Outputs:
|
||||
# cu_num_draft_tokens: [ 3, 3, 5, 5, 6]
|
||||
# logits_indices: [ 0, 1, 2, 3, 103, 104, 105, 106,
|
||||
# 206, 207, 208]
|
||||
# target_logits_indices: [ 0, 1, 2, 5, 6, 9]
|
||||
# bonus_logits_indices: [ 3, 4, 7, 8, 10]
|
||||
|
||||
# Get spec decode logits indices.
|
||||
# E.g., num_scheduled_tokens: [4, 100, 3, 100, 2]
|
||||
# cu_num_tokens: [4, 104, 107, 207, 209]
|
||||
# num_spec_tokens_list: [3, 0, 2, 0, 1]
|
||||
# num_sampled_tokens: [4, 1, 3, 1, 2]
|
||||
# spec_decode_logits_indices:
|
||||
# [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
|
||||
num_sampled_tokens = num_spec_decode_tokens + 1
|
||||
# logits_start_loc: [0, 103, 104, 206, 207]
|
||||
logits_start_loc = cu_num_tokens - num_sampled_tokens
|
||||
# [0, 103, 104, 206, 207] ->
|
||||
# [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207]
|
||||
logits_start_loc = np.repeat(logits_start_loc, num_sampled_tokens)
|
||||
# The following three lines:
|
||||
# [4, 1, 3, 1, 2] -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
|
||||
# Step 1. [4, 1, 3, 1, 2] -> [4, 5, 8, 9, 11]
|
||||
cu_num_sampled_tokens = np.cumsum(num_sampled_tokens)
|
||||
# Step 2. [4, 5, 8, 9, 11] -> [0, 4, 5, 8, 9]
|
||||
# -> [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9]
|
||||
cumsums_sampled_offsets = np.repeat(
|
||||
cu_num_sampled_tokens - num_sampled_tokens, num_sampled_tokens)
|
||||
# Step 3. [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
# - [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9]
|
||||
# -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
|
||||
total_num_sampled_tokens = num_sampled_tokens.sum()
|
||||
sampled_arange = (self.arange_np[:total_num_sampled_tokens] -
|
||||
cumsums_sampled_offsets)
|
||||
# Compute the logits indices.
|
||||
# [4, 1, 3, 1, 2]
|
||||
num_sampled_tokens = num_draft_tokens + 1
|
||||
# Step 1. [4, 5, 8, 9, 11]
|
||||
cu_num_sampled_tokens = np.cumsum(num_sampled_tokens, dtype=np.int32)
|
||||
total_num_sampled_tokens = cu_num_sampled_tokens[-1]
|
||||
# Step 2. [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9]
|
||||
cumsums_offsets = np.repeat(cu_num_sampled_tokens - num_sampled_tokens,
|
||||
num_sampled_tokens)
|
||||
# Step 3. [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
|
||||
arange = self.arange_np[:total_num_sampled_tokens] - cumsums_offsets
|
||||
# Step 4. [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207]
|
||||
logits_indices = np.repeat(
|
||||
cu_num_scheduled_tokens - num_sampled_tokens, num_sampled_tokens)
|
||||
# Step 5. [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
|
||||
logits_indices += arange
|
||||
|
||||
# [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] ->
|
||||
# [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
|
||||
spec_decode_logits_indices = logits_start_loc + sampled_arange
|
||||
return torch.from_numpy(spec_decode_logits_indices).to(
|
||||
# Compute the bonus logits indices.
|
||||
bonus_logits_indices = cu_num_sampled_tokens - 1
|
||||
|
||||
# Compute the draft logits indices.
|
||||
# [3, 3, 5, 5, 6]
|
||||
cu_num_draft_tokens = np.cumsum(num_draft_tokens, dtype=np.int32)
|
||||
total_num_draft_tokens = cu_num_draft_tokens[-1]
|
||||
# [0, 0, 0, 3, 3, 5]
|
||||
cumsums_offsets = np.repeat(cu_num_draft_tokens - num_draft_tokens,
|
||||
num_draft_tokens)
|
||||
# [0, 1, 2, 0, 1, 0]
|
||||
arange = self.arange_np[:total_num_draft_tokens] - cumsums_offsets
|
||||
# [0, 0, 0, 5, 5, 9]
|
||||
target_logits_indices = np.repeat(
|
||||
cu_num_sampled_tokens - num_sampled_tokens, num_draft_tokens)
|
||||
# [0, 1, 2, 5, 6, 9]
|
||||
target_logits_indices += arange
|
||||
|
||||
# TODO: Optimize the CPU -> GPU copy.
|
||||
cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to(
|
||||
self.device, non_blocking=True)
|
||||
logits_indices = torch.from_numpy(logits_indices).to(self.device,
|
||||
non_blocking=True)
|
||||
target_logits_indices = torch.from_numpy(target_logits_indices).to(
|
||||
self.device, non_blocking=True)
|
||||
bonus_logits_indices = torch.from_numpy(bonus_logits_indices).to(
|
||||
self.device, non_blocking=True)
|
||||
|
||||
# Compute the draft token ids.
|
||||
# draft_token_indices: [ 1, 2, 3, 105, 106, 208]
|
||||
draft_token_ids = self.input_ids[logits_indices]
|
||||
draft_token_ids = draft_token_ids[target_logits_indices + 1]
|
||||
|
||||
metadata = SpecDecodeMetadata(
|
||||
draft_token_ids=draft_token_ids,
|
||||
num_draft_tokens=num_draft_tokens.tolist(),
|
||||
cu_num_draft_tokens=cu_num_draft_tokens,
|
||||
target_logits_indices=target_logits_indices,
|
||||
bonus_logits_indices=bonus_logits_indices,
|
||||
logits_indices=logits_indices,
|
||||
)
|
||||
return metadata
|
||||
|
||||
def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
|
||||
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
|
||||
@ -931,7 +973,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
encoder_outputs = []
|
||||
|
||||
# Prepare the decoder inputs.
|
||||
attn_metadata, logits_indices = self._prepare_inputs(scheduler_output)
|
||||
attn_metadata, logits_indices, spec_decode_metadata = (
|
||||
self._prepare_inputs(scheduler_output))
|
||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
if (self.use_cuda_graph
|
||||
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
|
||||
@ -1006,31 +1049,29 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
# Sample the next token and get logprobs if needed.
|
||||
sampling_metadata = self.input_batch.sampling_metadata
|
||||
if not self.use_spec_decode:
|
||||
if spec_decode_metadata is None:
|
||||
sampler_output = self.model.sample(
|
||||
logits=logits,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
else:
|
||||
draft_token_ids = [
|
||||
scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])
|
||||
for req_id in self.input_batch.req_ids
|
||||
]
|
||||
sample_lens = [len(tokens) + 1 for tokens in draft_token_ids]
|
||||
recover_logits_idx = np.cumsum(sample_lens) - 1
|
||||
target_probs = self.rejection_sampler.compute_probs(
|
||||
logits, sampling_metadata, sample_lens)
|
||||
# TODO(woosuk): Optimize the memory usage.
|
||||
bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
|
||||
sampler_output = self.model.sample(
|
||||
logits=logits[recover_logits_idx, :],
|
||||
logits=bonus_logits,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
bonus_token_ids = sampler_output.sampled_token_ids
|
||||
|
||||
# TODO(woosuk): Optimize the memory usage.
|
||||
target_logits = logits[spec_decode_metadata.target_logits_indices]
|
||||
output_token_ids = self.rejection_sampler(
|
||||
draft_token_ids,
|
||||
spec_decode_metadata,
|
||||
None, # draft_probs
|
||||
target_logits,
|
||||
bonus_token_ids,
|
||||
target_probs,
|
||||
sampling_metadata)
|
||||
sampling_metadata,
|
||||
)
|
||||
sampler_output.sampled_token_ids = output_token_ids
|
||||
|
||||
# TODO(woosuk): The following loop can be slow since it iterates over
|
||||
@ -1066,13 +1107,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
valid_sampled_token_ids = sampled_token_ids.tolist()
|
||||
else:
|
||||
# Includes spec decode tokens.
|
||||
valid_mask = sampled_token_ids != INVALID_TOKEN_ID
|
||||
gen_lens = valid_mask.sum(dim=1).tolist()
|
||||
# TODO(woosuk): Optimize this.
|
||||
valid_sampled_token_ids = [
|
||||
seq.tolist()
|
||||
for seq in sampled_token_ids[valid_mask].split(gen_lens)
|
||||
]
|
||||
valid_sampled_token_ids = self.rejection_sampler.parse_output(
|
||||
sampled_token_ids, self.input_batch.vocab_size)
|
||||
|
||||
if not self.use_spec_decode:
|
||||
spec_token_ids = None
|
||||
@ -1316,6 +1352,33 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
"initializing the engine.") from e
|
||||
else:
|
||||
raise e
|
||||
if self.use_spec_decode:
|
||||
draft_token_ids = [[0] for _ in range(num_reqs)]
|
||||
dummy_spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
||||
draft_token_ids, self.device)
|
||||
|
||||
num_tokens = sum(len(ids) for ids in draft_token_ids)
|
||||
# draft_probs = torch.randn(
|
||||
# num_tokens, logits.shape[-1], device=self.device,
|
||||
# dtype=logits.dtype)
|
||||
draft_probs = None
|
||||
target_logits = torch.randn(num_tokens,
|
||||
logits.shape[-1],
|
||||
device=self.device,
|
||||
dtype=logits.dtype)
|
||||
# NOTE(woosuk): Here, we should use int32 because the sampler uses
|
||||
# int32 for bonus_token_ids. If the dtype mismatches, re-compilation
|
||||
# will occur at runtime.
|
||||
bonus_token_ids = torch.zeros(num_reqs,
|
||||
device=self.device,
|
||||
dtype=torch.int32)
|
||||
self.rejection_sampler(
|
||||
dummy_spec_decode_metadata,
|
||||
draft_probs,
|
||||
target_logits,
|
||||
bonus_token_ids,
|
||||
dummy_metadata,
|
||||
)
|
||||
return sampler_output
|
||||
|
||||
def profile_run(self) -> None:
|
||||
|
Loading…
x
Reference in New Issue
Block a user