[SpecDec][Misc] Cleanup, remove bonus token logic. (#8701)

This commit is contained in:
Lily Liu 2024-09-22 12:34:14 -07:00 committed by GitHub
parent 5b59532760
commit c6bd70d772
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 33 additions and 115 deletions

View File

@ -42,18 +42,13 @@ def mock_causal_accepted_tensor(
@pytest.mark.parametrize(
"which_tokens_accepted",
["all_tokens_accepted", "no_tokens_accepted", "some_tokens_accepted"])
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("use_flashinfer", [True, False])
@torch.inference_mode()
def test_correct_output_format(which_tokens_accepted: str, seed: int,
disable_bonus_tokens: bool, device: str,
use_flashinfer: bool):
device: str, use_flashinfer: bool):
"""Verify the output has correct format given predetermined accepted matrix.
"""
if use_flashinfer and disable_bonus_tokens:
pytest.skip("Flashinfer rejection sampler must enable bonus token.")
set_random_seed(seed)
torch.set_default_device(device)
@ -88,9 +83,7 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
size=(batch_size, 1),
dtype=torch.int64)
rejection_sampler = RejectionSampler(
disable_bonus_tokens=disable_bonus_tokens,
use_flashinfer=use_flashinfer)
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
rejection_sampler.init_gpu_tensors(device=device)
output_token_ids = rejection_sampler._create_output( # pylint: disable=protected-access
accepted,
@ -100,10 +93,6 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
)
expected_bonus_token_ids = bonus_token_ids.clone()
# If bonus tokens disabled. Verify they are set to -1.
# See https://github.com/vllm-project/vllm/issues/4212
if disable_bonus_tokens:
expected_bonus_token_ids = expected_bonus_token_ids * 0 - 1
if which_tokens_accepted == "all_tokens_accepted":
# Expect all tokens to be equal to draft tokens.
@ -143,8 +132,7 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
device: str, use_flashinfer: bool):
torch.set_default_device(device)
rejection_sampler = RejectionSampler(disable_bonus_tokens=False,
use_flashinfer=use_flashinfer)
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
rejection_sampler.init_gpu_tensors(device=device)
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
@ -177,8 +165,7 @@ def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
frac_seeded: float, n_rep: int, device: str,
use_flashinfer: bool):
torch.set_default_device(device)
rejection_sampler = RejectionSampler(disable_bonus_tokens=False,
use_flashinfer=use_flashinfer)
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
rejection_sampler.init_gpu_tensors(device=device)
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
@ -251,8 +238,7 @@ def test_compare_nonflashinfer_backend(k: int, vocab_size: int,
}
for use_flashinfer in [True, False]:
rejection_sampler = RejectionSampler(disable_bonus_tokens=False,
use_flashinfer=use_flashinfer)
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
rejection_sampler.init_gpu_tensors(device=device)
# We use seeded sequences to ensure the same tokens are accepted
# for both flashinfer and nonflashinfer backends.
@ -282,8 +268,7 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
vocab_size = 30_000
torch.set_default_device(device)
rejection_sampler = RejectionSampler(disable_bonus_tokens=False,
use_flashinfer=use_flashinfer,
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer,
strict_mode=True)
rejection_sampler.init_gpu_tensors(device=device)
@ -359,8 +344,7 @@ def test_rejection_sampling_approximates_target_distribution(
set_random_seed(seed)
helper = _CorrectnessTestHelper(
vocab_size=10,
rejection_sampler=RejectionSampler(disable_bonus_tokens=False,
use_flashinfer=use_flashinfer),
rejection_sampler=RejectionSampler(use_flashinfer=use_flashinfer),
)
draft_probs, target_probs, reference_probs = helper.generate_probs_for_test(

View File

@ -55,14 +55,13 @@ def get_draft_token_ids(batch_size: int, k: int, vocab_size: int,
def get_acceptance_sampler(
posterior_threshold: float = 0.03,
posterior_alpha: float = 0.9,
disable_bonus_tokens: bool = False,
strict_mode: bool = False,
) -> TypicalAcceptanceSampler:
"""
Initializes and returns a TypicalAcceptanceSampler.
"""
return TypicalAcceptanceSampler(posterior_threshold, posterior_alpha,
disable_bonus_tokens, strict_mode)
strict_mode)
@pytest.mark.parametrize("k", list(range(1, 6)))
@ -154,11 +153,10 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
@pytest.mark.parametrize("seed", list(range(10)))
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_uniform_target_distribution_accepts_all_tokens(
seed: int, disable_bonus_tokens: bool, device: str):
seed: int, device: str):
"""
Test the TypicalAcceptanceSampler with a uniform target probability
distribution.
@ -166,17 +164,14 @@ def test_uniform_target_distribution_accepts_all_tokens(
This test verifies that when provided with a uniform target probability
distribution, the TypicalAcceptanceSampler accepts all draft tokens. The
entropy of the uniform target distribution being high should lead to all
draft tokens being accepted. The test also ensures that the behavior
regarding bonus tokens is consistent with the `disable_bonus_tokens`
flag.
draft tokens being accepted.
"""
set_random_seed(seed)
k = 3
batch_size = 5
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(device=device)
target_with_bonus_probs = torch.rand(batch_size,
k + 1,
@ -200,21 +195,15 @@ def test_uniform_target_distribution_accepts_all_tokens(
# should lead to all draft tokens being accepted. Verify that.
assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1)
if disable_bonus_tokens:
assert torch.all(output_token_ids[:, -1] == -1)
else:
assert torch.all(output_token_ids[:, -1] == bonus_token_ids.squeeze())
assert torch.all(output_token_ids[:, :k] == draft_token_ids)
@pytest.mark.parametrize("seed", list(range(10)))
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_temperature_zero_target_distribution(seed: int,
disable_bonus_tokens: bool,
device: str):
def test_temperature_zero_target_distribution(seed: int, device: str):
"""
Test the TypicalAcceptanceSampler with a zero-temperature target
probability distribution.
@ -232,8 +221,7 @@ def test_temperature_zero_target_distribution(seed: int,
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(device=device)
# Simulate temperature 0 probability distribution for target probabilities
# and create target probabilities such that only 1 token id has
@ -267,11 +255,9 @@ def test_temperature_zero_target_distribution(seed: int,
@pytest.mark.parametrize("seed", list(range(10)))
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
device: str):
def test_mixed_target_distribution(seed: int, device: str):
"""
Test the TypicalAcceptanceSampler with a mixed target probability
distribution.
@ -285,16 +271,13 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
with a probability of 1.0 is accepted, and all other tokens are rejected.
- For sequences with a uniform distribution, all draft tokens are
accepted.
- When `disable_bonus_tokens` is False, the bonus tokens are also accepted
for sequences with a uniform distribution.
"""
set_random_seed(seed)
k = 3
batch_size = 4
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(device=device)
# For sequences 0 and 2 set the distribution to a temperature
# zero distribution. For sequences 1 and 3 set it to a uniform
@ -328,21 +311,16 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
0]))
# For sequences 1 and 3 verify that all tokens are accepted since the
# target probability distribution is uniform. In addition verify that
# if disable_bonus_tokens is false then we also accept the bonus tokens.
# we also accept the bonus tokens.
assert torch.all(
output_token_ids[[1, 3], :-1] == draft_token_ids[[1, 3], :])
if disable_bonus_tokens:
assert torch.all(output_token_ids[[1, 3], -1] == -1)
else:
assert torch.all(output_token_ids[[1, 3], -1] != -1)
@pytest.mark.parametrize("seed", list(range(10)))
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
device: str):
def test_accept_tokens_partially(seed: int, device: str):
"""
Test the TypicalAcceptanceSampler's behavior when only a subset of draft
tokens should be accepted.
@ -362,8 +340,7 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
batch_size = 1
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(device=device)
# Create a temperature zero target probability distribution and ensure
# all draft token ids correspond to the tokens with 1.0 probability.
@ -384,9 +361,6 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1)
assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids)
if disable_bonus_tokens:
assert torch.all(output_token_ids[:, -1] == -1)
else:
assert torch.all(output_token_ids[:, -1] == bonus_token_ids)
# Next only keep the first 2 draft tokens same as the zero temperature
# tokens. For the remaining 3 choose some other tokens. In the
@ -408,12 +382,9 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
@pytest.mark.parametrize("seed", list(range(1)))
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_accept_tokens_set_non_default_posteriors(seed: int,
disable_bonus_tokens: bool,
device: str):
def test_accept_tokens_set_non_default_posteriors(seed: int, device: str):
"""
Test the TypicalAcceptanceSampler with custom posterior thresholds and
alpha values. This test verifies that by modifying the posterior
@ -425,8 +396,7 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
batch_size = 1
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(device=device)
# Simulate temperature 0 probability distribution for target
# probabilities and create target probabilities such that only 1 token
@ -457,10 +427,7 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
# now accept even draft tokens with very low probability in the
# target distribution. Simulate and verify the same.
typical_acceptance_sampler = TypicalAcceptanceSampler(
strict_mode=True,
disable_bonus_tokens=disable_bonus_tokens,
posterior_threshold=0.0,
posterior_alpha=0.0)
strict_mode=True, posterior_threshold=0.0, posterior_alpha=0.0)
typical_acceptance_sampler.init_gpu_tensors(device=device)
output_token_ids = typical_acceptance_sampler(
target_probs,
@ -470,18 +437,13 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1)
assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids)
if disable_bonus_tokens:
assert torch.all(output_token_ids[:, -1] == -1)
else:
assert torch.all(output_token_ids[:, -1] == bonus_token_ids)
@pytest.mark.parametrize("seed", list(range(10)))
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_replacement_token_ids(seed: int, disable_bonus_tokens: bool,
device: str):
def test_replacement_token_ids(seed: int, device: str):
"""
Test the TypicalAcceptanceSampler's method for generating
replacement token IDs.
@ -497,8 +459,7 @@ def test_replacement_token_ids(seed: int, disable_bonus_tokens: bool,
batch_size = 5
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(device=device)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
expected_replacement_tokens = -torch.ones(

View File

@ -31,7 +31,7 @@ MAIN_MODEL = "JackFram/llama-68m"
# speculative model
SPEC_MODEL = "abhigoyal/vllm-medusa-llama-68m-random"
# max. number of speculative tokens: this corresponds to
# max number of speculative tokens: this corresponds to
# num_heads in the config.json of the speculator model.
MAX_SPEC_TOKENS = 5

View File

@ -31,15 +31,11 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
"""
def __init__(self,
disable_bonus_tokens: bool = True,
strict_mode: bool = False,
use_flashinfer: Optional[bool] = None):
"""Create a rejection sampler.
Args:
disable_bonus_tokens: Whether or not to disable the bonus token.
Require when bonus tokens will cause corrupt KV cache for
proposal methods that require KV cache.
strict_mode: Whether or not to perform shape/device/dtype checks
during sampling. This catches correctness issues but adds
nontrivial latency.
@ -48,8 +44,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
None, we will use the default value from the environment variable.
This parameter is only used for testing purposes.
"""
super().__init__(disable_bonus_tokens=disable_bonus_tokens,
strict_mode=strict_mode)
super().__init__(strict_mode=strict_mode)
if use_flashinfer is None:
self.use_flashinfer = envs.VLLM_USE_FLASHINFER_SAMPLER and (
chain_speculative_sampling is not None)
@ -57,8 +52,6 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
self.use_flashinfer = use_flashinfer
if self.use_flashinfer:
assert not disable_bonus_tokens, \
"flashinfer will enable bonus token by default"
logger.info("Use flashinfer for rejection sampling.")
else:
logger.info("Use pytorch for rejection sampling.")

View File

@ -11,20 +11,14 @@ class SpecDecodeBaseSampler(nn.Module):
step.
"""
def __init__(self,
disable_bonus_tokens: bool = True,
strict_mode: bool = False):
def __init__(self, strict_mode: bool = False):
"""Base class constructor.
Args:
disable_bonus_tokens: Whether or not to disable the bonus token.
Require when bonus tokens will cause corrupt KV cache for
proposal methods that require KV cache.
strict_mode: Whether or not to perform shape/device/dtype checks
during sampling. This catches correctness issues but adds
nontrivial latency.
"""
super().__init__()
self._disable_bonus_tokens = disable_bonus_tokens
self._strict_mode = strict_mode
# NOTE: A "bonus token" is accepted iff all proposal tokens are
@ -111,13 +105,6 @@ class SpecDecodeBaseSampler(nn.Module):
output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1,
bonus_token_ids, -1)
# We disable bonus tokens because it causes corrupt KV cache for
# proposal methods that require KV cache. We can fix it by "prefilling"
# the bonus token in the proposer. The following issue tracks the fix.
# https://github.com/vllm-project/vllm/issues/4212
if self._disable_bonus_tokens:
output_with_bonus_tokens[:, -1] = -1
# Fill the recovered token ids.
output.mul_(~after_false_mask).add_(
substitute_token_ids.mul(after_false_mask))

View File

@ -16,15 +16,11 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
self,
posterior_threshold: float,
posterior_alpha: float,
disable_bonus_tokens: bool = False,
strict_mode: bool = False,
):
"""Create a Typical Acceptance Sampler.
Args:
disable_bonus_tokens: Whether or not to disable the bonus token.
Require when bonus tokens will cause corrupt KV cache for
proposal methods that require KV cache.
strict_mode: Whether or not to perform shape/device/dtype checks
during sampling. This catches correctness issues but adds
nontrivial latency.
@ -36,8 +32,7 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
"""
self._posterior_threshold = posterior_threshold
self._posterior_alpha = posterior_alpha
super().__init__(disable_bonus_tokens=disable_bonus_tokens,
strict_mode=strict_mode)
super().__init__(strict_mode=strict_mode)
def forward(
self,
@ -54,7 +49,7 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
one token will be emitted.
In the case where all draft tokens are accepted, the bonus token will be
accepted conditioned on self._disable_bonus_tokens being false.
accepted.
Args:
target_probs: The probability distribution over token ids given

View File

@ -164,11 +164,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
spec_decode_sampler: SpecDecodeBaseSampler = None
if draft_token_acceptance_method == "rejection_sampler":
spec_decode_sampler = RejectionSampler(
disable_bonus_tokens=False, )
spec_decode_sampler = RejectionSampler()
elif draft_token_acceptance_method == "typical_acceptance_sampler":
spec_decode_sampler = TypicalAcceptanceSampler(
disable_bonus_tokens=False,
posterior_threshold=\
typical_acceptance_sampler_posterior_threshold,
posterior_alpha=typical_acceptance_sampler_posterior_alpha,