[SpecDecode][Kernel] Flashinfer Rejection Sampling (#7244)
This commit is contained in:
parent
f8d60145b4
commit
e6a26ed037
@ -162,7 +162,7 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
. /etc/environment && \
|
||||
python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.4/flashinfer-0.1.4+cu121torch2.4-cp${PYTHON_VERSION_STR}-cp${PYTHON_VERSION_STR}-linux_x86_64.whl
|
||||
python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.6/flashinfer-0.1.6+cu121torch2.4-cp${PYTHON_VERSION_STR}-cp${PYTHON_VERSION_STR}-linux_x86_64.whl
|
||||
#################### vLLM installation IMAGE ####################
|
||||
|
||||
|
||||
|
@ -44,12 +44,16 @@ def mock_causal_accepted_tensor(
|
||||
["all_tokens_accepted", "no_tokens_accepted", "some_tokens_accepted"])
|
||||
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("use_flashinfer", [True, False])
|
||||
@torch.inference_mode()
|
||||
def test_correct_output_format(which_tokens_accepted: str,
|
||||
disable_bonus_tokens: bool, seed: int,
|
||||
device: str):
|
||||
def test_correct_output_format(which_tokens_accepted: str, seed: int,
|
||||
disable_bonus_tokens: bool, device: str,
|
||||
use_flashinfer: bool):
|
||||
"""Verify the output has correct format given predetermined accepted matrix.
|
||||
"""
|
||||
if use_flashinfer and disable_bonus_tokens:
|
||||
pytest.skip("Flashinfer rejection sampler must enable bonus token.")
|
||||
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
|
||||
@ -85,7 +89,8 @@ def test_correct_output_format(which_tokens_accepted: str,
|
||||
dtype=torch.int64)
|
||||
|
||||
rejection_sampler = RejectionSampler(
|
||||
disable_bonus_tokens=disable_bonus_tokens)
|
||||
disable_bonus_tokens=disable_bonus_tokens,
|
||||
use_flashinfer=use_flashinfer)
|
||||
rejection_sampler.init_gpu_tensors(device=device)
|
||||
output_token_ids = rejection_sampler._create_output( # pylint: disable=protected-access
|
||||
accepted,
|
||||
@ -133,15 +138,20 @@ def test_correct_output_format(which_tokens_accepted: str,
|
||||
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
|
||||
@pytest.mark.parametrize("batch_size", list(range(1, 32)))
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("use_flashinfer", [True, False])
|
||||
@torch.inference_mode()
|
||||
def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
|
||||
device: str):
|
||||
device: str, use_flashinfer: bool):
|
||||
torch.set_default_device(device)
|
||||
rejection_sampler = RejectionSampler()
|
||||
rejection_sampler = RejectionSampler(disable_bonus_tokens=False,
|
||||
use_flashinfer=use_flashinfer)
|
||||
rejection_sampler.init_gpu_tensors(device=device)
|
||||
|
||||
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
target_probs = torch.rand(batch_size,
|
||||
k + 1,
|
||||
vocab_size,
|
||||
dtype=torch.float32)
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
@ -161,16 +171,21 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
|
||||
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128])
|
||||
@pytest.mark.parametrize("n_rep", [100])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("use_flashinfer", [True, False])
|
||||
@torch.inference_mode()
|
||||
def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
|
||||
frac_seeded: float, n_rep: int,
|
||||
device: str):
|
||||
frac_seeded: float, n_rep: int, device: str,
|
||||
use_flashinfer: bool):
|
||||
torch.set_default_device(device)
|
||||
rejection_sampler = RejectionSampler()
|
||||
rejection_sampler = RejectionSampler(disable_bonus_tokens=False,
|
||||
use_flashinfer=use_flashinfer)
|
||||
rejection_sampler.init_gpu_tensors(device=device)
|
||||
|
||||
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
target_probs = torch.rand(batch_size,
|
||||
k + 1,
|
||||
vocab_size,
|
||||
dtype=torch.float32)
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
@ -198,23 +213,85 @@ def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
|
||||
assert torch.equal(results[j][i], results[0][i])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("k", [1, 3, 6])
|
||||
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
|
||||
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_compare_nonflashinfer_backend(k: int, vocab_size: int,
|
||||
batch_size: int, device: str):
|
||||
"""
|
||||
Test the flashinfer and nonflashinfer backend generate
|
||||
the same output metrics.
|
||||
"""
|
||||
torch.set_default_device(device)
|
||||
torch.manual_seed(0)
|
||||
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
target_probs = torch.rand(batch_size,
|
||||
k + 1,
|
||||
vocab_size,
|
||||
dtype=torch.float32)
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
dtype=torch.int64)
|
||||
draft_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k),
|
||||
dtype=torch.int64)
|
||||
|
||||
num_accepted_tokens = []
|
||||
num_emitted_tokens = []
|
||||
num_draft_tokens = []
|
||||
|
||||
def get_seeded_seqs():
|
||||
return {
|
||||
i: torch.Generator(device=device).manual_seed(i)
|
||||
for i in range(batch_size)
|
||||
}
|
||||
|
||||
for use_flashinfer in [True, False]:
|
||||
rejection_sampler = RejectionSampler(disable_bonus_tokens=False,
|
||||
use_flashinfer=use_flashinfer)
|
||||
rejection_sampler.init_gpu_tensors(device=device)
|
||||
# We use seeded sequences to ensure the same tokens are accepted
|
||||
# for both flashinfer and nonflashinfer backends.
|
||||
seeded_seqs = get_seeded_seqs()
|
||||
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
|
||||
draft_token_ids, seeded_seqs)
|
||||
num_accepted_tokens.append(rejection_sampler.num_accepted_tokens)
|
||||
num_emitted_tokens.append(rejection_sampler.num_emitted_tokens)
|
||||
num_draft_tokens.append(rejection_sampler.num_draft_tokens)
|
||||
|
||||
assert num_accepted_tokens[0] == num_accepted_tokens[1]
|
||||
assert num_emitted_tokens[0] == num_emitted_tokens[1]
|
||||
assert num_draft_tokens[0] == num_draft_tokens[1]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("above_or_below_vocab_range", ["above", "below"])
|
||||
@pytest.mark.parametrize("which_token_ids",
|
||||
["bonus_token_ids", "draft_token_ids"])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("use_flashinfer", [True, False])
|
||||
@torch.inference_mode()
|
||||
def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
|
||||
which_token_ids: str, device: str):
|
||||
which_token_ids: str, device: str,
|
||||
use_flashinfer: bool):
|
||||
k = 3
|
||||
batch_size = 5
|
||||
vocab_size = 30_000
|
||||
torch.set_default_device(device)
|
||||
|
||||
rejection_sampler = RejectionSampler(strict_mode=True)
|
||||
rejection_sampler = RejectionSampler(disable_bonus_tokens=False,
|
||||
use_flashinfer=use_flashinfer,
|
||||
strict_mode=True)
|
||||
rejection_sampler.init_gpu_tensors(device=device)
|
||||
|
||||
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
target_probs = torch.rand(batch_size,
|
||||
k + 1,
|
||||
vocab_size,
|
||||
dtype=torch.float32)
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
@ -248,9 +325,10 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
|
||||
|
||||
@pytest.mark.parametrize("draft_and_target_probs_equal", [True, False])
|
||||
@pytest.mark.parametrize("seed", list(range(5)))
|
||||
@pytest.mark.parametrize("use_flashinfer", [True, False])
|
||||
@torch.inference_mode()
|
||||
def test_rejection_sampling_approximates_target_distribution(
|
||||
seed: int, draft_and_target_probs_equal: bool):
|
||||
seed: int, draft_and_target_probs_equal: bool, use_flashinfer: bool):
|
||||
"""Verify rejection sampling approximates target distribution,
|
||||
despite sampling from a potentially distinct draft distribution.
|
||||
|
||||
@ -279,10 +357,10 @@ def test_rejection_sampling_approximates_target_distribution(
|
||||
"""
|
||||
torch.set_default_device("cpu")
|
||||
set_random_seed(seed)
|
||||
|
||||
helper = _CorrectnessTestHelper(
|
||||
vocab_size=10,
|
||||
rejection_sampler=RejectionSampler(),
|
||||
rejection_sampler=RejectionSampler(disable_bonus_tokens=False,
|
||||
use_flashinfer=use_flashinfer),
|
||||
)
|
||||
|
||||
draft_probs, target_probs, reference_probs = helper.generate_probs_for_test(
|
||||
@ -398,10 +476,10 @@ class _CorrectnessTestHelper:
|
||||
draft_probs = draft_probs.reshape(1, self.k, self.vocab_size).repeat(
|
||||
num_samples, 1, 1)
|
||||
|
||||
# Repeat target probs num_samples * k times.
|
||||
# Repeat target probs num_samples * (k + 1) times.
|
||||
# Rejection sampler requires bonus token probs, but they aren't used.
|
||||
target_probs = target_probs.reshape(1, 1, self.vocab_size).repeat(
|
||||
num_samples, self.k, 1)
|
||||
num_samples, self.k + 1, 1)
|
||||
|
||||
# Randomly sample draft token ids from draft probs.
|
||||
draft_token_ids = torch.multinomial(draft_probs[:, 0, :],
|
||||
|
@ -79,7 +79,10 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
|
||||
torch.set_default_device(device)
|
||||
typical_acceptance_sampler = get_acceptance_sampler()
|
||||
typical_acceptance_sampler.init_gpu_tensors(device=device)
|
||||
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
target_with_bonus_probs = torch.rand(batch_size,
|
||||
k + 1,
|
||||
vocab_size,
|
||||
dtype=torch.float32)
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
@ -89,7 +92,7 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
|
||||
size=(batch_size, k),
|
||||
dtype=torch.int64)
|
||||
# Verify that sampling succeeds for all cases.
|
||||
typical_acceptance_sampler(target_probs,
|
||||
typical_acceptance_sampler(target_with_bonus_probs,
|
||||
bonus_token_ids,
|
||||
draft_probs=None,
|
||||
draft_token_ids=draft_token_ids)
|
||||
@ -112,7 +115,10 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
|
||||
torch.set_default_device(device)
|
||||
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
|
||||
typical_acceptance_sampler.init_gpu_tensors(device=device)
|
||||
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
target_with_bonus_probs = torch.rand(batch_size,
|
||||
k + 1,
|
||||
vocab_size,
|
||||
dtype=torch.float32)
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
@ -141,7 +147,7 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
|
||||
oob_token_ids[0][0] = rogue_token_id
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
typical_acceptance_sampler(target_probs,
|
||||
typical_acceptance_sampler(target_with_bonus_probs,
|
||||
bonus_token_ids,
|
||||
draft_probs=None,
|
||||
draft_token_ids=draft_token_ids)
|
||||
@ -172,7 +178,10 @@ def test_uniform_target_distribution_accepts_all_tokens(
|
||||
typical_acceptance_sampler = get_acceptance_sampler(
|
||||
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
|
||||
typical_acceptance_sampler.init_gpu_tensors(device=device)
|
||||
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
target_with_bonus_probs = torch.rand(batch_size,
|
||||
k + 1,
|
||||
vocab_size,
|
||||
dtype=torch.float32)
|
||||
draft_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k),
|
||||
@ -182,7 +191,7 @@ def test_uniform_target_distribution_accepts_all_tokens(
|
||||
size=(batch_size, 1),
|
||||
dtype=torch.int64)
|
||||
output_token_ids = typical_acceptance_sampler(
|
||||
target_probs,
|
||||
target_with_bonus_probs,
|
||||
bonus_token_ids,
|
||||
draft_probs=None,
|
||||
draft_token_ids=draft_token_ids)
|
||||
@ -229,8 +238,9 @@ def test_temperature_zero_target_distribution(seed: int,
|
||||
# Simulate temperature 0 probability distribution for target probabilities
|
||||
# and create target probabilities such that only 1 token id has
|
||||
# probability 1.0
|
||||
target_probs, zero_temperature_token_ids = get_zero_temperature_prob_dist(
|
||||
batch_size, k, vocab_size)
|
||||
target_with_bonus_probs, zero_temperature_token_ids = \
|
||||
get_zero_temperature_prob_dist(batch_size, k + 1, vocab_size)
|
||||
zero_temperature_token_ids = zero_temperature_token_ids[:, :-1]
|
||||
# Populate draft_token_ids such that they exclude the token_ids
|
||||
# with probability = 1.0
|
||||
draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size,
|
||||
@ -245,7 +255,7 @@ def test_temperature_zero_target_distribution(seed: int,
|
||||
# fallback to the greedy sampling for selecting 1 token for each sequence.
|
||||
# Verify the same.
|
||||
output_token_ids = typical_acceptance_sampler(
|
||||
target_probs,
|
||||
target_with_bonus_probs,
|
||||
bonus_token_ids,
|
||||
draft_probs=None,
|
||||
draft_token_ids=draft_token_ids)
|
||||
@ -289,8 +299,10 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
|
||||
# For sequences 0 and 2 set the distribution to a temperature
|
||||
# zero distribution. For sequences 1 and 3 set it to a uniform
|
||||
# distribution.
|
||||
target_probs, zero_temperature_token_ids = (get_zero_temperature_prob_dist(
|
||||
batch_size, k, vocab_size))
|
||||
target_with_bonus_probs, zero_temperature_token_ids = \
|
||||
get_zero_temperature_prob_dist(batch_size, k + 1, vocab_size)
|
||||
zero_temperature_token_ids = zero_temperature_token_ids[:, :-1]
|
||||
target_probs = target_with_bonus_probs[:, :-1]
|
||||
draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size,
|
||||
zero_temperature_token_ids)
|
||||
uniform_probs = torch.rand(2, k, vocab_size, dtype=torch.float32)
|
||||
@ -300,7 +312,7 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
|
||||
size=(batch_size, 1),
|
||||
dtype=torch.int64)
|
||||
output_token_ids = typical_acceptance_sampler(
|
||||
target_probs,
|
||||
target_with_bonus_probs,
|
||||
bonus_token_ids,
|
||||
draft_probs=None,
|
||||
draft_token_ids=draft_token_ids)
|
||||
@ -356,15 +368,16 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
|
||||
# Create a temperature zero target probability distribution and ensure
|
||||
# all draft token ids correspond to the tokens with 1.0 probability.
|
||||
# Verify that all of them are accepted.
|
||||
target_probs, zero_temperature_token_ids = (get_zero_temperature_prob_dist(
|
||||
batch_size, k, vocab_size))
|
||||
target_with_bonus_probs, zero_temperature_token_ids = \
|
||||
get_zero_temperature_prob_dist(batch_size, k + 1, vocab_size)
|
||||
zero_temperature_token_ids = zero_temperature_token_ids[:, :-1]
|
||||
draft_token_ids = zero_temperature_token_ids
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
dtype=torch.int64)
|
||||
output_token_ids = typical_acceptance_sampler(
|
||||
target_probs,
|
||||
target_with_bonus_probs,
|
||||
bonus_token_ids,
|
||||
draft_probs=None,
|
||||
draft_token_ids=draft_token_ids)
|
||||
@ -384,7 +397,7 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
|
||||
draft_token_ids = torch.cat(
|
||||
(draft_token_ids[:, :2], draft_token_ids_to_replace[:, -3:]), dim=1)
|
||||
output_token_ids = typical_acceptance_sampler(
|
||||
target_probs,
|
||||
target_with_bonus_probs,
|
||||
bonus_token_ids,
|
||||
draft_probs=None,
|
||||
draft_token_ids=draft_token_ids)
|
||||
@ -421,8 +434,9 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
|
||||
# 0.00001. Populate draft_token_ids such that they exclude the token_ids
|
||||
# with probability = 1.0. Without any changes to the posterior thresholds
|
||||
# none of the draft tokens are accepted.
|
||||
target_probs, zero_temperature_token_ids = (get_zero_temperature_prob_dist(
|
||||
batch_size, k, vocab_size))
|
||||
target_probs, zero_temperature_token_ids = get_zero_temperature_prob_dist(
|
||||
batch_size, k + 1, vocab_size)
|
||||
zero_temperature_token_ids = zero_temperature_token_ids[:, :-1]
|
||||
target_probs[target_probs == 0] = 0.00001
|
||||
draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size,
|
||||
zero_temperature_token_ids)
|
||||
|
@ -230,9 +230,8 @@ def test_correctly_calls_spec_decode_sampler(k: int, batch_size: int,
|
||||
|
||||
assert torch.equal(actual.bonus_token_ids,
|
||||
target_token_ids.reshape(batch_size, k + 1)[:, -1:])
|
||||
assert torch.equal(
|
||||
actual.target_probs,
|
||||
target_token_probs.reshape(batch_size, k + 1, -1)[:, :-1])
|
||||
assert torch.equal(actual.target_with_bonus_probs,
|
||||
target_token_probs.reshape(batch_size, k + 1, -1))
|
||||
assert torch.equal(actual.draft_token_ids, proposal_token_ids)
|
||||
assert torch.equal(actual.draft_probs, proposal_probs)
|
||||
|
||||
|
@ -31,6 +31,7 @@ if TYPE_CHECKING:
|
||||
VLLM_TRACE_FUNCTION: int = 0
|
||||
VLLM_ATTENTION_BACKEND: Optional[str] = None
|
||||
VLLM_USE_FLASHINFER_SAMPLER: bool = False
|
||||
VLLM_USE_FLASHINFER_REJECTION_SAMPLER: bool = False
|
||||
VLLM_PP_LAYER_PARTITION: Optional[str] = None
|
||||
VLLM_CPU_KVCACHE_SPACE: int = 0
|
||||
VLLM_CPU_OMP_THREADS_BIND: str = ""
|
||||
|
@ -1,12 +1,28 @@
|
||||
from functools import cached_property
|
||||
from importlib.util import find_spec
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.jit
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.spec_decode_base_sampler import (
|
||||
SpecDecodeStochasticBaseSampler)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if find_spec("flashinfer"):
|
||||
"""
|
||||
Consider utilizing the FlashInfer rejection sampling kernel initially,
|
||||
as it employs a dedicated kernel rather than relying on
|
||||
Torch tensor operations. This design choice helps to fuse operations,
|
||||
reduce memory I/O, and consequently enhances performance.
|
||||
"""
|
||||
from flashinfer.sampling import chain_speculative_sampling
|
||||
else:
|
||||
chain_speculative_sampling = None
|
||||
|
||||
|
||||
class RejectionSampler(SpecDecodeStochasticBaseSampler):
|
||||
"""Apply modified rejection sampling as described in "Accelerating Large
|
||||
@ -16,7 +32,8 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
|
||||
|
||||
def __init__(self,
|
||||
disable_bonus_tokens: bool = True,
|
||||
strict_mode: bool = False):
|
||||
strict_mode: bool = False,
|
||||
use_flashinfer: Optional[bool] = None):
|
||||
"""Create a rejection sampler.
|
||||
|
||||
Args:
|
||||
@ -26,13 +43,29 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
|
||||
strict_mode: Whether or not to perform shape/device/dtype checks
|
||||
during sampling. This catches correctness issues but adds
|
||||
nontrivial latency.
|
||||
use_falshinfer: We will use this parameter to determine whether
|
||||
to use the FlashInfer rejection sampling kernel or not. If it's
|
||||
None, we will use the default value from the environment variable.
|
||||
This parameter is only used for testing purposes.
|
||||
"""
|
||||
super().__init__(disable_bonus_tokens=disable_bonus_tokens,
|
||||
strict_mode=strict_mode)
|
||||
if use_flashinfer is None:
|
||||
self.use_flashinfer = envs.VLLM_USE_FLASHINFER_SAMPLER and (
|
||||
chain_speculative_sampling is not None)
|
||||
else:
|
||||
self.use_flashinfer = use_flashinfer
|
||||
|
||||
if self.use_flashinfer:
|
||||
assert not disable_bonus_tokens, \
|
||||
"flashinfer will enable bonus token by default"
|
||||
logger.info("Use flashinfer for rejection sampling.")
|
||||
else:
|
||||
logger.info("Use pytorch for rejection sampling.")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
target_probs: torch.Tensor,
|
||||
target_with_bonus_probs: torch.Tensor,
|
||||
bonus_token_ids: torch.Tensor,
|
||||
draft_probs: torch.Tensor,
|
||||
draft_token_ids: torch.Tensor,
|
||||
@ -50,9 +83,9 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
|
||||
sequence.
|
||||
|
||||
Args:
|
||||
target_probs: The probability distribution over token ids given
|
||||
context according to the target model.
|
||||
shape = [batch_size, num_speculative_tokens, vocab_size]
|
||||
target_with_bonus_probs: The probability distribution
|
||||
over token ids given context according to the target model.
|
||||
shape = [batch_size, num_speculative_tokens + 1, vocab_size]
|
||||
|
||||
bonus_token_ids: The "bonus" token ids that are accepted iff all
|
||||
speculative tokens in a sequence are accepted.
|
||||
@ -78,23 +111,52 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
|
||||
# Only perform shape/dtype/device checking in strict mode, as it adds
|
||||
# overhead.
|
||||
if self._strict_mode:
|
||||
self._raise_if_incorrect_input(target_probs, draft_token_ids,
|
||||
bonus_token_ids, draft_probs)
|
||||
self._raise_if_incorrect_input(target_with_bonus_probs,
|
||||
draft_token_ids, bonus_token_ids,
|
||||
draft_probs)
|
||||
|
||||
accepted, recovered_token_ids = (
|
||||
self._batch_modified_rejection_sampling(
|
||||
target_probs,
|
||||
draft_probs,
|
||||
batch_size, k, _ = draft_probs.shape
|
||||
|
||||
# batch_size = 0 when all requests in the batch are
|
||||
# non_spec requests. In this case, output_token_ids is
|
||||
# just an empty tensor.
|
||||
if batch_size == 0:
|
||||
return torch.empty(0, k + 1, device=draft_probs.device, dtype=int)
|
||||
|
||||
# If use Flashinfer chain_speculative_sampling kernel
|
||||
# for rejection sampling
|
||||
if self.use_flashinfer:
|
||||
batch_size, k, _ = draft_probs.shape
|
||||
uniform_samples = self._create_uniform_samples(
|
||||
seeded_seqs, batch_size, k, draft_probs.device)
|
||||
output_token_ids, accepted_token_num, emitted_token_num \
|
||||
= chain_speculative_sampling(
|
||||
draft_probs, draft_token_ids, uniform_samples,
|
||||
target_with_bonus_probs)
|
||||
|
||||
# num_emitted_tokens returned by flashinfer
|
||||
# does not include the bonus token
|
||||
# Flashinfer stops at the first token that violates
|
||||
# the condition p >= q and does not include recovery/bonus token.
|
||||
# Therefore, we need to add batch_size here.
|
||||
self.num_accepted_tokens += accepted_token_num.sum()
|
||||
self.num_emitted_tokens += emitted_token_num.sum() + batch_size
|
||||
self.num_draft_tokens += batch_size * k
|
||||
else:
|
||||
accepted, recovered_token_ids = (
|
||||
self._batch_modified_rejection_sampling(
|
||||
target_with_bonus_probs[:, :-1],
|
||||
draft_probs,
|
||||
draft_token_ids,
|
||||
seeded_seqs,
|
||||
))
|
||||
|
||||
output_token_ids = self._create_output(
|
||||
accepted,
|
||||
recovered_token_ids,
|
||||
draft_token_ids,
|
||||
seeded_seqs,
|
||||
))
|
||||
|
||||
output_token_ids = self._create_output(
|
||||
accepted,
|
||||
recovered_token_ids,
|
||||
draft_token_ids,
|
||||
bonus_token_ids,
|
||||
)
|
||||
bonus_token_ids,
|
||||
)
|
||||
|
||||
return output_token_ids
|
||||
|
||||
@ -135,6 +197,63 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
|
||||
|
||||
return accepted, recovered_token_ids
|
||||
|
||||
def _create_uniform_samples(self,
|
||||
seeded_seqs: Optional[Dict[int,
|
||||
torch.Generator]],
|
||||
batch_size: int, k: int,
|
||||
device: torch.device) -> torch.Tensor:
|
||||
"""
|
||||
Generates a batch of uniform random samples, with optional seeding
|
||||
for specific sequences.
|
||||
|
||||
This method creates a tensor of shape `(batch_size, k + 1)` filled
|
||||
with uniform random values in the range [0, 1). If `seeded_seqs`
|
||||
is provided, the sequences corresponding to specific indices
|
||||
will be generated using the provided `torch.Generator` for
|
||||
reproducibility. The other sequences will be generated without
|
||||
a seed.
|
||||
|
||||
Args:
|
||||
seeded_seqs : Optional[Dict[int, torch.Generator]]
|
||||
A dictionary mapping indices in the batch to
|
||||
`torch.Generator` objects. If `None`, all samples are
|
||||
generated without a seed.
|
||||
batch_size : int
|
||||
The number of sequences to generate.
|
||||
k : int
|
||||
The number of random samples per sequence.
|
||||
device : torch.device
|
||||
The device on which to allocate the tensor.
|
||||
|
||||
Returns:
|
||||
uniform_rand : torch.Tensor
|
||||
A tensor of shape `(batch_size, k + 1)` containing uniform
|
||||
random values in the range [0, 1).
|
||||
"""
|
||||
if not seeded_seqs:
|
||||
return torch.rand(batch_size, k + 1, device=device)
|
||||
|
||||
uniform_rand = torch.empty(batch_size, k + 1, device=device)
|
||||
|
||||
non_seeded_indices = []
|
||||
for idx in range(batch_size):
|
||||
generator = seeded_seqs.get(idx)
|
||||
if generator is None:
|
||||
non_seeded_indices.append(idx)
|
||||
else:
|
||||
uniform_rand[idx, :] = torch.rand(1,
|
||||
k + 1,
|
||||
dtype=self.probs_dtype,
|
||||
device=device,
|
||||
generator=generator)
|
||||
if non_seeded_indices:
|
||||
uniform_rand[non_seeded_indices, :] = torch.rand(
|
||||
len(non_seeded_indices),
|
||||
k + 1,
|
||||
dtype=self.probs_dtype,
|
||||
device=device)
|
||||
return uniform_rand
|
||||
|
||||
def _get_accepted(
|
||||
self,
|
||||
target_probs: torch.Tensor, # [batch_size, k, vocab_size]
|
||||
@ -175,29 +294,8 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
|
||||
selected_target_probs = target_probs[batch_indices, probs_indicies,
|
||||
draft_token_ids]
|
||||
|
||||
if not seeded_seqs:
|
||||
uniform_rand = torch.rand_like(selected_target_probs)
|
||||
else:
|
||||
uniform_rand = torch.empty_like(selected_target_probs)
|
||||
|
||||
non_seeded_indices = []
|
||||
for idx in range(batch_size):
|
||||
generator = seeded_seqs.get(idx)
|
||||
if generator is None:
|
||||
non_seeded_indices.append(idx)
|
||||
else:
|
||||
uniform_rand[idx, :] = torch.rand(
|
||||
1,
|
||||
k,
|
||||
dtype=self.probs_dtype,
|
||||
device=target_probs.device,
|
||||
generator=generator)
|
||||
if non_seeded_indices:
|
||||
uniform_rand[non_seeded_indices, :] = torch.rand(
|
||||
len(non_seeded_indices),
|
||||
k,
|
||||
dtype=self.probs_dtype,
|
||||
device=target_probs.device)
|
||||
uniform_rand = self._create_uniform_samples(seeded_seqs, batch_size,
|
||||
k - 1, target_probs.device)
|
||||
|
||||
capped_ratio = torch.minimum(
|
||||
selected_target_probs / selected_draft_probs,
|
||||
|
@ -130,29 +130,35 @@ class SpecDecodeBaseSampler(nn.Module):
|
||||
|
||||
def _raise_if_incorrect_input(
|
||||
self,
|
||||
target_probs: torch.Tensor,
|
||||
target_with_bonus_probs: torch.Tensor,
|
||||
draft_token_ids: torch.Tensor,
|
||||
bonus_token_ids: torch.Tensor,
|
||||
draft_probs: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
self._raise_if_incorrect_shape(target_probs, draft_token_ids,
|
||||
bonus_token_ids, draft_probs)
|
||||
self._raise_if_incorrect_dtype(target_probs, draft_token_ids,
|
||||
bonus_token_ids, draft_probs)
|
||||
self._raise_if_inconsistent_device(target_probs, draft_token_ids,
|
||||
bonus_token_ids, draft_probs)
|
||||
self._raise_if_out_of_bounds_vocab(target_probs.shape[-1],
|
||||
self._raise_if_incorrect_shape(target_with_bonus_probs,
|
||||
draft_token_ids, bonus_token_ids,
|
||||
draft_probs)
|
||||
self._raise_if_incorrect_dtype(target_with_bonus_probs,
|
||||
draft_token_ids, bonus_token_ids,
|
||||
draft_probs)
|
||||
self._raise_if_inconsistent_device(target_with_bonus_probs,
|
||||
draft_token_ids, bonus_token_ids,
|
||||
draft_probs)
|
||||
self._raise_if_out_of_bounds_vocab(target_with_bonus_probs.shape[-1],
|
||||
draft_token_ids, bonus_token_ids)
|
||||
|
||||
def _raise_if_incorrect_shape(
|
||||
self,
|
||||
target_probs: torch.Tensor,
|
||||
target_with_bonus_probs: torch.Tensor,
|
||||
draft_token_ids: torch.Tensor,
|
||||
bonus_token_ids: torch.Tensor,
|
||||
draft_probs: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
(target_batch_size, num_target_probs,
|
||||
target_vocab_size) = target_probs.shape
|
||||
target_vocab_size) = target_with_bonus_probs.shape
|
||||
|
||||
# Does not count the extra token
|
||||
num_target_probs -= 1
|
||||
|
||||
# validate the shape of draft token ids.
|
||||
draft_token_ids_batch_size, num_draft_token_ids = draft_token_ids.shape
|
||||
@ -175,12 +181,12 @@ class SpecDecodeBaseSampler(nn.Module):
|
||||
|
||||
def _raise_if_incorrect_dtype(
|
||||
self,
|
||||
target_probs: torch.Tensor,
|
||||
target_with_bonus_probs: torch.Tensor,
|
||||
draft_token_ids: torch.Tensor,
|
||||
bonus_token_ids: torch.Tensor,
|
||||
draft_probs: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
assert target_probs.dtype == self.probs_dtype
|
||||
assert target_with_bonus_probs.dtype == self.probs_dtype
|
||||
assert draft_token_ids.dtype == self.token_id_dtype
|
||||
assert bonus_token_ids.dtype == self.token_id_dtype
|
||||
if draft_probs is not None:
|
||||
@ -188,15 +194,16 @@ class SpecDecodeBaseSampler(nn.Module):
|
||||
|
||||
def _raise_if_inconsistent_device(
|
||||
self,
|
||||
target_probs: torch.Tensor,
|
||||
target_with_bonus_probs: torch.Tensor,
|
||||
draft_token_ids: torch.Tensor,
|
||||
bonus_token_ids: torch.Tensor,
|
||||
draft_probs: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
devices = [
|
||||
t.device for t in
|
||||
[target_probs, bonus_token_ids, draft_probs, draft_token_ids]
|
||||
if t is not None
|
||||
t.device for t in [
|
||||
target_with_bonus_probs, bonus_token_ids, draft_probs,
|
||||
draft_token_ids
|
||||
] if t is not None
|
||||
]
|
||||
assert all([devices[0] == device for device in devices])
|
||||
|
||||
@ -220,7 +227,7 @@ class SpecDecodeDeterministicBaseSampler(SpecDecodeBaseSampler):
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
target_probs: torch.Tensor,
|
||||
target_with_bonus_probs: torch.Tensor,
|
||||
bonus_token_ids: torch.Tensor,
|
||||
draft_probs: torch.Tensor,
|
||||
draft_token_ids: torch.Tensor,
|
||||
@ -236,7 +243,7 @@ class SpecDecodeStochasticBaseSampler(SpecDecodeBaseSampler):
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
target_probs: torch.Tensor,
|
||||
target_with_bonus_probs: torch.Tensor,
|
||||
bonus_token_ids: torch.Tensor,
|
||||
draft_probs: torch.Tensor,
|
||||
draft_token_ids: torch.Tensor,
|
||||
|
@ -41,7 +41,7 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
target_probs: torch.Tensor,
|
||||
target_with_bonus_probs: torch.Tensor,
|
||||
bonus_token_ids: torch.Tensor,
|
||||
draft_probs: torch.Tensor,
|
||||
draft_token_ids: torch.Tensor,
|
||||
@ -80,8 +80,9 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
|
||||
# Only perform shape/dtype/device checking in strict mode, as it adds
|
||||
# overhead.
|
||||
if self._strict_mode:
|
||||
self._raise_if_incorrect_input(target_probs, draft_token_ids,
|
||||
bonus_token_ids)
|
||||
self._raise_if_incorrect_input(target_with_bonus_probs,
|
||||
draft_token_ids, bonus_token_ids)
|
||||
target_probs = target_with_bonus_probs[:, :-1]
|
||||
accepted = self._evaluate_accepted_tokens(target_probs,
|
||||
draft_token_ids)
|
||||
recovered_token_ids = self._replacement_token_ids(target_probs)
|
||||
|
@ -625,8 +625,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
seq_group_metadata_list, proposal_lens_list)
|
||||
original_indices = spec_indices + non_spec_indices
|
||||
|
||||
# Get probabilities of target model, excluding bonus token.
|
||||
proposal_verifier_probs = proposal_scores.probs[spec_indices, :-1]
|
||||
# Get probabilities of target model, including bonus tokens.
|
||||
proposal_verifier_probs = proposal_scores.probs[spec_indices]
|
||||
|
||||
# Get non-speculative sampled tokens from target model.
|
||||
non_spec_token_ids = proposal_scores.token_ids[non_spec_indices]
|
||||
@ -651,13 +651,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
}
|
||||
|
||||
accepted_token_ids = self.spec_decode_sampler(
|
||||
target_probs=proposal_verifier_probs,
|
||||
target_with_bonus_probs=proposal_verifier_probs,
|
||||
bonus_token_ids=bonus_token_ids,
|
||||
draft_probs=proposal_probs,
|
||||
draft_token_ids=proposal_token_ids,
|
||||
**sampler_extra_kwargs,
|
||||
)
|
||||
|
||||
# Append output tokens from non-speculative sequences to
|
||||
# the accepted token ids tensor.
|
||||
non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len +
|
||||
|
Loading…
x
Reference in New Issue
Block a user