[SpecDecode][Kernel] Flashinfer Rejection Sampling (#7244)

This commit is contained in:
Lily Liu 2024-09-01 21:23:29 -07:00 committed by GitHub
parent f8d60145b4
commit e6a26ed037
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 306 additions and 109 deletions

View File

@ -162,7 +162,7 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
RUN --mount=type=cache,target=/root/.cache/pip \
. /etc/environment && \
python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.4/flashinfer-0.1.4+cu121torch2.4-cp${PYTHON_VERSION_STR}-cp${PYTHON_VERSION_STR}-linux_x86_64.whl
python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.6/flashinfer-0.1.6+cu121torch2.4-cp${PYTHON_VERSION_STR}-cp${PYTHON_VERSION_STR}-linux_x86_64.whl
#################### vLLM installation IMAGE ####################

View File

@ -44,12 +44,16 @@ def mock_causal_accepted_tensor(
["all_tokens_accepted", "no_tokens_accepted", "some_tokens_accepted"])
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("use_flashinfer", [True, False])
@torch.inference_mode()
def test_correct_output_format(which_tokens_accepted: str,
disable_bonus_tokens: bool, seed: int,
device: str):
def test_correct_output_format(which_tokens_accepted: str, seed: int,
disable_bonus_tokens: bool, device: str,
use_flashinfer: bool):
"""Verify the output has correct format given predetermined accepted matrix.
"""
if use_flashinfer and disable_bonus_tokens:
pytest.skip("Flashinfer rejection sampler must enable bonus token.")
set_random_seed(seed)
torch.set_default_device(device)
@ -85,7 +89,8 @@ def test_correct_output_format(which_tokens_accepted: str,
dtype=torch.int64)
rejection_sampler = RejectionSampler(
disable_bonus_tokens=disable_bonus_tokens)
disable_bonus_tokens=disable_bonus_tokens,
use_flashinfer=use_flashinfer)
rejection_sampler.init_gpu_tensors(device=device)
output_token_ids = rejection_sampler._create_output( # pylint: disable=protected-access
accepted,
@ -133,15 +138,20 @@ def test_correct_output_format(which_tokens_accepted: str,
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
@pytest.mark.parametrize("batch_size", list(range(1, 32)))
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("use_flashinfer", [True, False])
@torch.inference_mode()
def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
device: str):
device: str, use_flashinfer: bool):
torch.set_default_device(device)
rejection_sampler = RejectionSampler()
rejection_sampler = RejectionSampler(disable_bonus_tokens=False,
use_flashinfer=use_flashinfer)
rejection_sampler.init_gpu_tensors(device=device)
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
target_probs = torch.rand(batch_size,
k + 1,
vocab_size,
dtype=torch.float32)
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
@ -161,16 +171,21 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128])
@pytest.mark.parametrize("n_rep", [100])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("use_flashinfer", [True, False])
@torch.inference_mode()
def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
frac_seeded: float, n_rep: int,
device: str):
frac_seeded: float, n_rep: int, device: str,
use_flashinfer: bool):
torch.set_default_device(device)
rejection_sampler = RejectionSampler()
rejection_sampler = RejectionSampler(disable_bonus_tokens=False,
use_flashinfer=use_flashinfer)
rejection_sampler.init_gpu_tensors(device=device)
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
target_probs = torch.rand(batch_size,
k + 1,
vocab_size,
dtype=torch.float32)
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
@ -198,23 +213,85 @@ def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
assert torch.equal(results[j][i], results[0][i])
@pytest.mark.parametrize("k", [1, 3, 6])
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_compare_nonflashinfer_backend(k: int, vocab_size: int,
batch_size: int, device: str):
"""
Test the flashinfer and nonflashinfer backend generate
the same output metrics.
"""
torch.set_default_device(device)
torch.manual_seed(0)
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
target_probs = torch.rand(batch_size,
k + 1,
vocab_size,
dtype=torch.float32)
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64)
draft_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64)
num_accepted_tokens = []
num_emitted_tokens = []
num_draft_tokens = []
def get_seeded_seqs():
return {
i: torch.Generator(device=device).manual_seed(i)
for i in range(batch_size)
}
for use_flashinfer in [True, False]:
rejection_sampler = RejectionSampler(disable_bonus_tokens=False,
use_flashinfer=use_flashinfer)
rejection_sampler.init_gpu_tensors(device=device)
# We use seeded sequences to ensure the same tokens are accepted
# for both flashinfer and nonflashinfer backends.
seeded_seqs = get_seeded_seqs()
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
draft_token_ids, seeded_seqs)
num_accepted_tokens.append(rejection_sampler.num_accepted_tokens)
num_emitted_tokens.append(rejection_sampler.num_emitted_tokens)
num_draft_tokens.append(rejection_sampler.num_draft_tokens)
assert num_accepted_tokens[0] == num_accepted_tokens[1]
assert num_emitted_tokens[0] == num_emitted_tokens[1]
assert num_draft_tokens[0] == num_draft_tokens[1]
@pytest.mark.parametrize("above_or_below_vocab_range", ["above", "below"])
@pytest.mark.parametrize("which_token_ids",
["bonus_token_ids", "draft_token_ids"])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("use_flashinfer", [True, False])
@torch.inference_mode()
def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
which_token_ids: str, device: str):
which_token_ids: str, device: str,
use_flashinfer: bool):
k = 3
batch_size = 5
vocab_size = 30_000
torch.set_default_device(device)
rejection_sampler = RejectionSampler(strict_mode=True)
rejection_sampler = RejectionSampler(disable_bonus_tokens=False,
use_flashinfer=use_flashinfer,
strict_mode=True)
rejection_sampler.init_gpu_tensors(device=device)
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
target_probs = torch.rand(batch_size,
k + 1,
vocab_size,
dtype=torch.float32)
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
@ -248,9 +325,10 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
@pytest.mark.parametrize("draft_and_target_probs_equal", [True, False])
@pytest.mark.parametrize("seed", list(range(5)))
@pytest.mark.parametrize("use_flashinfer", [True, False])
@torch.inference_mode()
def test_rejection_sampling_approximates_target_distribution(
seed: int, draft_and_target_probs_equal: bool):
seed: int, draft_and_target_probs_equal: bool, use_flashinfer: bool):
"""Verify rejection sampling approximates target distribution,
despite sampling from a potentially distinct draft distribution.
@ -279,10 +357,10 @@ def test_rejection_sampling_approximates_target_distribution(
"""
torch.set_default_device("cpu")
set_random_seed(seed)
helper = _CorrectnessTestHelper(
vocab_size=10,
rejection_sampler=RejectionSampler(),
rejection_sampler=RejectionSampler(disable_bonus_tokens=False,
use_flashinfer=use_flashinfer),
)
draft_probs, target_probs, reference_probs = helper.generate_probs_for_test(
@ -398,10 +476,10 @@ class _CorrectnessTestHelper:
draft_probs = draft_probs.reshape(1, self.k, self.vocab_size).repeat(
num_samples, 1, 1)
# Repeat target probs num_samples * k times.
# Repeat target probs num_samples * (k + 1) times.
# Rejection sampler requires bonus token probs, but they aren't used.
target_probs = target_probs.reshape(1, 1, self.vocab_size).repeat(
num_samples, self.k, 1)
num_samples, self.k + 1, 1)
# Randomly sample draft token ids from draft probs.
draft_token_ids = torch.multinomial(draft_probs[:, 0, :],

View File

@ -79,7 +79,10 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler()
typical_acceptance_sampler.init_gpu_tensors(device=device)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
target_with_bonus_probs = torch.rand(batch_size,
k + 1,
vocab_size,
dtype=torch.float32)
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
@ -89,7 +92,7 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
size=(batch_size, k),
dtype=torch.int64)
# Verify that sampling succeeds for all cases.
typical_acceptance_sampler(target_probs,
typical_acceptance_sampler(target_with_bonus_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
@ -112,7 +115,10 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(device=device)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
target_with_bonus_probs = torch.rand(batch_size,
k + 1,
vocab_size,
dtype=torch.float32)
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
@ -141,7 +147,7 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
oob_token_ids[0][0] = rogue_token_id
with pytest.raises(AssertionError):
typical_acceptance_sampler(target_probs,
typical_acceptance_sampler(target_with_bonus_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
@ -172,7 +178,10 @@ def test_uniform_target_distribution_accepts_all_tokens(
typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler.init_gpu_tensors(device=device)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
target_with_bonus_probs = torch.rand(batch_size,
k + 1,
vocab_size,
dtype=torch.float32)
draft_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
@ -182,7 +191,7 @@ def test_uniform_target_distribution_accepts_all_tokens(
size=(batch_size, 1),
dtype=torch.int64)
output_token_ids = typical_acceptance_sampler(
target_probs,
target_with_bonus_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
@ -229,8 +238,9 @@ def test_temperature_zero_target_distribution(seed: int,
# Simulate temperature 0 probability distribution for target probabilities
# and create target probabilities such that only 1 token id has
# probability 1.0
target_probs, zero_temperature_token_ids = get_zero_temperature_prob_dist(
batch_size, k, vocab_size)
target_with_bonus_probs, zero_temperature_token_ids = \
get_zero_temperature_prob_dist(batch_size, k + 1, vocab_size)
zero_temperature_token_ids = zero_temperature_token_ids[:, :-1]
# Populate draft_token_ids such that they exclude the token_ids
# with probability = 1.0
draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size,
@ -245,7 +255,7 @@ def test_temperature_zero_target_distribution(seed: int,
# fallback to the greedy sampling for selecting 1 token for each sequence.
# Verify the same.
output_token_ids = typical_acceptance_sampler(
target_probs,
target_with_bonus_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
@ -289,8 +299,10 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
# For sequences 0 and 2 set the distribution to a temperature
# zero distribution. For sequences 1 and 3 set it to a uniform
# distribution.
target_probs, zero_temperature_token_ids = (get_zero_temperature_prob_dist(
batch_size, k, vocab_size))
target_with_bonus_probs, zero_temperature_token_ids = \
get_zero_temperature_prob_dist(batch_size, k + 1, vocab_size)
zero_temperature_token_ids = zero_temperature_token_ids[:, :-1]
target_probs = target_with_bonus_probs[:, :-1]
draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size,
zero_temperature_token_ids)
uniform_probs = torch.rand(2, k, vocab_size, dtype=torch.float32)
@ -300,7 +312,7 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
size=(batch_size, 1),
dtype=torch.int64)
output_token_ids = typical_acceptance_sampler(
target_probs,
target_with_bonus_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
@ -356,15 +368,16 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
# Create a temperature zero target probability distribution and ensure
# all draft token ids correspond to the tokens with 1.0 probability.
# Verify that all of them are accepted.
target_probs, zero_temperature_token_ids = (get_zero_temperature_prob_dist(
batch_size, k, vocab_size))
target_with_bonus_probs, zero_temperature_token_ids = \
get_zero_temperature_prob_dist(batch_size, k + 1, vocab_size)
zero_temperature_token_ids = zero_temperature_token_ids[:, :-1]
draft_token_ids = zero_temperature_token_ids
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64)
output_token_ids = typical_acceptance_sampler(
target_probs,
target_with_bonus_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
@ -384,7 +397,7 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
draft_token_ids = torch.cat(
(draft_token_ids[:, :2], draft_token_ids_to_replace[:, -3:]), dim=1)
output_token_ids = typical_acceptance_sampler(
target_probs,
target_with_bonus_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
@ -421,8 +434,9 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
# 0.00001. Populate draft_token_ids such that they exclude the token_ids
# with probability = 1.0. Without any changes to the posterior thresholds
# none of the draft tokens are accepted.
target_probs, zero_temperature_token_ids = (get_zero_temperature_prob_dist(
batch_size, k, vocab_size))
target_probs, zero_temperature_token_ids = get_zero_temperature_prob_dist(
batch_size, k + 1, vocab_size)
zero_temperature_token_ids = zero_temperature_token_ids[:, :-1]
target_probs[target_probs == 0] = 0.00001
draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size,
zero_temperature_token_ids)

View File

@ -230,9 +230,8 @@ def test_correctly_calls_spec_decode_sampler(k: int, batch_size: int,
assert torch.equal(actual.bonus_token_ids,
target_token_ids.reshape(batch_size, k + 1)[:, -1:])
assert torch.equal(
actual.target_probs,
target_token_probs.reshape(batch_size, k + 1, -1)[:, :-1])
assert torch.equal(actual.target_with_bonus_probs,
target_token_probs.reshape(batch_size, k + 1, -1))
assert torch.equal(actual.draft_token_ids, proposal_token_ids)
assert torch.equal(actual.draft_probs, proposal_probs)

View File

@ -31,6 +31,7 @@ if TYPE_CHECKING:
VLLM_TRACE_FUNCTION: int = 0
VLLM_ATTENTION_BACKEND: Optional[str] = None
VLLM_USE_FLASHINFER_SAMPLER: bool = False
VLLM_USE_FLASHINFER_REJECTION_SAMPLER: bool = False
VLLM_PP_LAYER_PARTITION: Optional[str] = None
VLLM_CPU_KVCACHE_SPACE: int = 0
VLLM_CPU_OMP_THREADS_BIND: str = ""

View File

@ -1,12 +1,28 @@
from functools import cached_property
from importlib.util import find_spec
from typing import Dict, List, Optional, Tuple
import torch
import torch.jit
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.model_executor.layers.spec_decode_base_sampler import (
SpecDecodeStochasticBaseSampler)
logger = init_logger(__name__)
if find_spec("flashinfer"):
"""
Consider utilizing the FlashInfer rejection sampling kernel initially,
as it employs a dedicated kernel rather than relying on
Torch tensor operations. This design choice helps to fuse operations,
reduce memory I/O, and consequently enhances performance.
"""
from flashinfer.sampling import chain_speculative_sampling
else:
chain_speculative_sampling = None
class RejectionSampler(SpecDecodeStochasticBaseSampler):
"""Apply modified rejection sampling as described in "Accelerating Large
@ -16,7 +32,8 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
def __init__(self,
disable_bonus_tokens: bool = True,
strict_mode: bool = False):
strict_mode: bool = False,
use_flashinfer: Optional[bool] = None):
"""Create a rejection sampler.
Args:
@ -26,13 +43,29 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
strict_mode: Whether or not to perform shape/device/dtype checks
during sampling. This catches correctness issues but adds
nontrivial latency.
use_falshinfer: We will use this parameter to determine whether
to use the FlashInfer rejection sampling kernel or not. If it's
None, we will use the default value from the environment variable.
This parameter is only used for testing purposes.
"""
super().__init__(disable_bonus_tokens=disable_bonus_tokens,
strict_mode=strict_mode)
if use_flashinfer is None:
self.use_flashinfer = envs.VLLM_USE_FLASHINFER_SAMPLER and (
chain_speculative_sampling is not None)
else:
self.use_flashinfer = use_flashinfer
if self.use_flashinfer:
assert not disable_bonus_tokens, \
"flashinfer will enable bonus token by default"
logger.info("Use flashinfer for rejection sampling.")
else:
logger.info("Use pytorch for rejection sampling.")
def forward(
self,
target_probs: torch.Tensor,
target_with_bonus_probs: torch.Tensor,
bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
@ -50,9 +83,9 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
sequence.
Args:
target_probs: The probability distribution over token ids given
context according to the target model.
shape = [batch_size, num_speculative_tokens, vocab_size]
target_with_bonus_probs: The probability distribution
over token ids given context according to the target model.
shape = [batch_size, num_speculative_tokens + 1, vocab_size]
bonus_token_ids: The "bonus" token ids that are accepted iff all
speculative tokens in a sequence are accepted.
@ -78,23 +111,52 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
# Only perform shape/dtype/device checking in strict mode, as it adds
# overhead.
if self._strict_mode:
self._raise_if_incorrect_input(target_probs, draft_token_ids,
bonus_token_ids, draft_probs)
self._raise_if_incorrect_input(target_with_bonus_probs,
draft_token_ids, bonus_token_ids,
draft_probs)
accepted, recovered_token_ids = (
self._batch_modified_rejection_sampling(
target_probs,
draft_probs,
batch_size, k, _ = draft_probs.shape
# batch_size = 0 when all requests in the batch are
# non_spec requests. In this case, output_token_ids is
# just an empty tensor.
if batch_size == 0:
return torch.empty(0, k + 1, device=draft_probs.device, dtype=int)
# If use Flashinfer chain_speculative_sampling kernel
# for rejection sampling
if self.use_flashinfer:
batch_size, k, _ = draft_probs.shape
uniform_samples = self._create_uniform_samples(
seeded_seqs, batch_size, k, draft_probs.device)
output_token_ids, accepted_token_num, emitted_token_num \
= chain_speculative_sampling(
draft_probs, draft_token_ids, uniform_samples,
target_with_bonus_probs)
# num_emitted_tokens returned by flashinfer
# does not include the bonus token
# Flashinfer stops at the first token that violates
# the condition p >= q and does not include recovery/bonus token.
# Therefore, we need to add batch_size here.
self.num_accepted_tokens += accepted_token_num.sum()
self.num_emitted_tokens += emitted_token_num.sum() + batch_size
self.num_draft_tokens += batch_size * k
else:
accepted, recovered_token_ids = (
self._batch_modified_rejection_sampling(
target_with_bonus_probs[:, :-1],
draft_probs,
draft_token_ids,
seeded_seqs,
))
output_token_ids = self._create_output(
accepted,
recovered_token_ids,
draft_token_ids,
seeded_seqs,
))
output_token_ids = self._create_output(
accepted,
recovered_token_ids,
draft_token_ids,
bonus_token_ids,
)
bonus_token_ids,
)
return output_token_ids
@ -135,6 +197,63 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
return accepted, recovered_token_ids
def _create_uniform_samples(self,
seeded_seqs: Optional[Dict[int,
torch.Generator]],
batch_size: int, k: int,
device: torch.device) -> torch.Tensor:
"""
Generates a batch of uniform random samples, with optional seeding
for specific sequences.
This method creates a tensor of shape `(batch_size, k + 1)` filled
with uniform random values in the range [0, 1). If `seeded_seqs`
is provided, the sequences corresponding to specific indices
will be generated using the provided `torch.Generator` for
reproducibility. The other sequences will be generated without
a seed.
Args:
seeded_seqs : Optional[Dict[int, torch.Generator]]
A dictionary mapping indices in the batch to
`torch.Generator` objects. If `None`, all samples are
generated without a seed.
batch_size : int
The number of sequences to generate.
k : int
The number of random samples per sequence.
device : torch.device
The device on which to allocate the tensor.
Returns:
uniform_rand : torch.Tensor
A tensor of shape `(batch_size, k + 1)` containing uniform
random values in the range [0, 1).
"""
if not seeded_seqs:
return torch.rand(batch_size, k + 1, device=device)
uniform_rand = torch.empty(batch_size, k + 1, device=device)
non_seeded_indices = []
for idx in range(batch_size):
generator = seeded_seqs.get(idx)
if generator is None:
non_seeded_indices.append(idx)
else:
uniform_rand[idx, :] = torch.rand(1,
k + 1,
dtype=self.probs_dtype,
device=device,
generator=generator)
if non_seeded_indices:
uniform_rand[non_seeded_indices, :] = torch.rand(
len(non_seeded_indices),
k + 1,
dtype=self.probs_dtype,
device=device)
return uniform_rand
def _get_accepted(
self,
target_probs: torch.Tensor, # [batch_size, k, vocab_size]
@ -175,29 +294,8 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
selected_target_probs = target_probs[batch_indices, probs_indicies,
draft_token_ids]
if not seeded_seqs:
uniform_rand = torch.rand_like(selected_target_probs)
else:
uniform_rand = torch.empty_like(selected_target_probs)
non_seeded_indices = []
for idx in range(batch_size):
generator = seeded_seqs.get(idx)
if generator is None:
non_seeded_indices.append(idx)
else:
uniform_rand[idx, :] = torch.rand(
1,
k,
dtype=self.probs_dtype,
device=target_probs.device,
generator=generator)
if non_seeded_indices:
uniform_rand[non_seeded_indices, :] = torch.rand(
len(non_seeded_indices),
k,
dtype=self.probs_dtype,
device=target_probs.device)
uniform_rand = self._create_uniform_samples(seeded_seqs, batch_size,
k - 1, target_probs.device)
capped_ratio = torch.minimum(
selected_target_probs / selected_draft_probs,

View File

@ -130,29 +130,35 @@ class SpecDecodeBaseSampler(nn.Module):
def _raise_if_incorrect_input(
self,
target_probs: torch.Tensor,
target_with_bonus_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
bonus_token_ids: torch.Tensor,
draft_probs: Optional[torch.Tensor] = None,
) -> None:
self._raise_if_incorrect_shape(target_probs, draft_token_ids,
bonus_token_ids, draft_probs)
self._raise_if_incorrect_dtype(target_probs, draft_token_ids,
bonus_token_ids, draft_probs)
self._raise_if_inconsistent_device(target_probs, draft_token_ids,
bonus_token_ids, draft_probs)
self._raise_if_out_of_bounds_vocab(target_probs.shape[-1],
self._raise_if_incorrect_shape(target_with_bonus_probs,
draft_token_ids, bonus_token_ids,
draft_probs)
self._raise_if_incorrect_dtype(target_with_bonus_probs,
draft_token_ids, bonus_token_ids,
draft_probs)
self._raise_if_inconsistent_device(target_with_bonus_probs,
draft_token_ids, bonus_token_ids,
draft_probs)
self._raise_if_out_of_bounds_vocab(target_with_bonus_probs.shape[-1],
draft_token_ids, bonus_token_ids)
def _raise_if_incorrect_shape(
self,
target_probs: torch.Tensor,
target_with_bonus_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
bonus_token_ids: torch.Tensor,
draft_probs: Optional[torch.Tensor] = None,
) -> None:
(target_batch_size, num_target_probs,
target_vocab_size) = target_probs.shape
target_vocab_size) = target_with_bonus_probs.shape
# Does not count the extra token
num_target_probs -= 1
# validate the shape of draft token ids.
draft_token_ids_batch_size, num_draft_token_ids = draft_token_ids.shape
@ -175,12 +181,12 @@ class SpecDecodeBaseSampler(nn.Module):
def _raise_if_incorrect_dtype(
self,
target_probs: torch.Tensor,
target_with_bonus_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
bonus_token_ids: torch.Tensor,
draft_probs: Optional[torch.Tensor] = None,
) -> None:
assert target_probs.dtype == self.probs_dtype
assert target_with_bonus_probs.dtype == self.probs_dtype
assert draft_token_ids.dtype == self.token_id_dtype
assert bonus_token_ids.dtype == self.token_id_dtype
if draft_probs is not None:
@ -188,15 +194,16 @@ class SpecDecodeBaseSampler(nn.Module):
def _raise_if_inconsistent_device(
self,
target_probs: torch.Tensor,
target_with_bonus_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
bonus_token_ids: torch.Tensor,
draft_probs: Optional[torch.Tensor] = None,
) -> None:
devices = [
t.device for t in
[target_probs, bonus_token_ids, draft_probs, draft_token_ids]
if t is not None
t.device for t in [
target_with_bonus_probs, bonus_token_ids, draft_probs,
draft_token_ids
] if t is not None
]
assert all([devices[0] == device for device in devices])
@ -220,7 +227,7 @@ class SpecDecodeDeterministicBaseSampler(SpecDecodeBaseSampler):
@abstractmethod
def forward(
self,
target_probs: torch.Tensor,
target_with_bonus_probs: torch.Tensor,
bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
@ -236,7 +243,7 @@ class SpecDecodeStochasticBaseSampler(SpecDecodeBaseSampler):
@abstractmethod
def forward(
self,
target_probs: torch.Tensor,
target_with_bonus_probs: torch.Tensor,
bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor,

View File

@ -41,7 +41,7 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
def forward(
self,
target_probs: torch.Tensor,
target_with_bonus_probs: torch.Tensor,
bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
@ -80,8 +80,9 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
# Only perform shape/dtype/device checking in strict mode, as it adds
# overhead.
if self._strict_mode:
self._raise_if_incorrect_input(target_probs, draft_token_ids,
bonus_token_ids)
self._raise_if_incorrect_input(target_with_bonus_probs,
draft_token_ids, bonus_token_ids)
target_probs = target_with_bonus_probs[:, :-1]
accepted = self._evaluate_accepted_tokens(target_probs,
draft_token_ids)
recovered_token_ids = self._replacement_token_ids(target_probs)

View File

@ -625,8 +625,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
seq_group_metadata_list, proposal_lens_list)
original_indices = spec_indices + non_spec_indices
# Get probabilities of target model, excluding bonus token.
proposal_verifier_probs = proposal_scores.probs[spec_indices, :-1]
# Get probabilities of target model, including bonus tokens.
proposal_verifier_probs = proposal_scores.probs[spec_indices]
# Get non-speculative sampled tokens from target model.
non_spec_token_ids = proposal_scores.token_ids[non_spec_indices]
@ -651,13 +651,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
}
accepted_token_ids = self.spec_decode_sampler(
target_probs=proposal_verifier_probs,
target_with_bonus_probs=proposal_verifier_probs,
bonus_token_ids=bonus_token_ids,
draft_probs=proposal_probs,
draft_token_ids=proposal_token_ids,
**sampler_extra_kwargs,
)
# Append output tokens from non-speculative sequences to
# the accepted token ids tensor.
non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len +