[SpecDec][Misc] Cleanup, remove bonus token logic. (#8701)
This commit is contained in:
parent
5b59532760
commit
c6bd70d772
@ -42,18 +42,13 @@ def mock_causal_accepted_tensor(
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"which_tokens_accepted",
|
"which_tokens_accepted",
|
||||||
["all_tokens_accepted", "no_tokens_accepted", "some_tokens_accepted"])
|
["all_tokens_accepted", "no_tokens_accepted", "some_tokens_accepted"])
|
||||||
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
|
|
||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
@pytest.mark.parametrize("use_flashinfer", [True, False])
|
@pytest.mark.parametrize("use_flashinfer", [True, False])
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_correct_output_format(which_tokens_accepted: str, seed: int,
|
def test_correct_output_format(which_tokens_accepted: str, seed: int,
|
||||||
disable_bonus_tokens: bool, device: str,
|
device: str, use_flashinfer: bool):
|
||||||
use_flashinfer: bool):
|
|
||||||
"""Verify the output has correct format given predetermined accepted matrix.
|
"""Verify the output has correct format given predetermined accepted matrix.
|
||||||
"""
|
"""
|
||||||
if use_flashinfer and disable_bonus_tokens:
|
|
||||||
pytest.skip("Flashinfer rejection sampler must enable bonus token.")
|
|
||||||
|
|
||||||
set_random_seed(seed)
|
set_random_seed(seed)
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
|
|
||||||
@ -88,9 +83,7 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
|
|||||||
size=(batch_size, 1),
|
size=(batch_size, 1),
|
||||||
dtype=torch.int64)
|
dtype=torch.int64)
|
||||||
|
|
||||||
rejection_sampler = RejectionSampler(
|
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
|
||||||
disable_bonus_tokens=disable_bonus_tokens,
|
|
||||||
use_flashinfer=use_flashinfer)
|
|
||||||
rejection_sampler.init_gpu_tensors(device=device)
|
rejection_sampler.init_gpu_tensors(device=device)
|
||||||
output_token_ids = rejection_sampler._create_output( # pylint: disable=protected-access
|
output_token_ids = rejection_sampler._create_output( # pylint: disable=protected-access
|
||||||
accepted,
|
accepted,
|
||||||
@ -100,10 +93,6 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
|
|||||||
)
|
)
|
||||||
|
|
||||||
expected_bonus_token_ids = bonus_token_ids.clone()
|
expected_bonus_token_ids = bonus_token_ids.clone()
|
||||||
# If bonus tokens disabled. Verify they are set to -1.
|
|
||||||
# See https://github.com/vllm-project/vllm/issues/4212
|
|
||||||
if disable_bonus_tokens:
|
|
||||||
expected_bonus_token_ids = expected_bonus_token_ids * 0 - 1
|
|
||||||
|
|
||||||
if which_tokens_accepted == "all_tokens_accepted":
|
if which_tokens_accepted == "all_tokens_accepted":
|
||||||
# Expect all tokens to be equal to draft tokens.
|
# Expect all tokens to be equal to draft tokens.
|
||||||
@ -143,8 +132,7 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
|
|||||||
def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
|
def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
|
||||||
device: str, use_flashinfer: bool):
|
device: str, use_flashinfer: bool):
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
rejection_sampler = RejectionSampler(disable_bonus_tokens=False,
|
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
|
||||||
use_flashinfer=use_flashinfer)
|
|
||||||
rejection_sampler.init_gpu_tensors(device=device)
|
rejection_sampler.init_gpu_tensors(device=device)
|
||||||
|
|
||||||
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||||
@ -177,8 +165,7 @@ def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
|
|||||||
frac_seeded: float, n_rep: int, device: str,
|
frac_seeded: float, n_rep: int, device: str,
|
||||||
use_flashinfer: bool):
|
use_flashinfer: bool):
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
rejection_sampler = RejectionSampler(disable_bonus_tokens=False,
|
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
|
||||||
use_flashinfer=use_flashinfer)
|
|
||||||
rejection_sampler.init_gpu_tensors(device=device)
|
rejection_sampler.init_gpu_tensors(device=device)
|
||||||
|
|
||||||
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||||
@ -251,8 +238,7 @@ def test_compare_nonflashinfer_backend(k: int, vocab_size: int,
|
|||||||
}
|
}
|
||||||
|
|
||||||
for use_flashinfer in [True, False]:
|
for use_flashinfer in [True, False]:
|
||||||
rejection_sampler = RejectionSampler(disable_bonus_tokens=False,
|
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
|
||||||
use_flashinfer=use_flashinfer)
|
|
||||||
rejection_sampler.init_gpu_tensors(device=device)
|
rejection_sampler.init_gpu_tensors(device=device)
|
||||||
# We use seeded sequences to ensure the same tokens are accepted
|
# We use seeded sequences to ensure the same tokens are accepted
|
||||||
# for both flashinfer and nonflashinfer backends.
|
# for both flashinfer and nonflashinfer backends.
|
||||||
@ -282,8 +268,7 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
|
|||||||
vocab_size = 30_000
|
vocab_size = 30_000
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
|
|
||||||
rejection_sampler = RejectionSampler(disable_bonus_tokens=False,
|
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer,
|
||||||
use_flashinfer=use_flashinfer,
|
|
||||||
strict_mode=True)
|
strict_mode=True)
|
||||||
rejection_sampler.init_gpu_tensors(device=device)
|
rejection_sampler.init_gpu_tensors(device=device)
|
||||||
|
|
||||||
@ -359,8 +344,7 @@ def test_rejection_sampling_approximates_target_distribution(
|
|||||||
set_random_seed(seed)
|
set_random_seed(seed)
|
||||||
helper = _CorrectnessTestHelper(
|
helper = _CorrectnessTestHelper(
|
||||||
vocab_size=10,
|
vocab_size=10,
|
||||||
rejection_sampler=RejectionSampler(disable_bonus_tokens=False,
|
rejection_sampler=RejectionSampler(use_flashinfer=use_flashinfer),
|
||||||
use_flashinfer=use_flashinfer),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
draft_probs, target_probs, reference_probs = helper.generate_probs_for_test(
|
draft_probs, target_probs, reference_probs = helper.generate_probs_for_test(
|
||||||
|
@ -55,14 +55,13 @@ def get_draft_token_ids(batch_size: int, k: int, vocab_size: int,
|
|||||||
def get_acceptance_sampler(
|
def get_acceptance_sampler(
|
||||||
posterior_threshold: float = 0.03,
|
posterior_threshold: float = 0.03,
|
||||||
posterior_alpha: float = 0.9,
|
posterior_alpha: float = 0.9,
|
||||||
disable_bonus_tokens: bool = False,
|
|
||||||
strict_mode: bool = False,
|
strict_mode: bool = False,
|
||||||
) -> TypicalAcceptanceSampler:
|
) -> TypicalAcceptanceSampler:
|
||||||
"""
|
"""
|
||||||
Initializes and returns a TypicalAcceptanceSampler.
|
Initializes and returns a TypicalAcceptanceSampler.
|
||||||
"""
|
"""
|
||||||
return TypicalAcceptanceSampler(posterior_threshold, posterior_alpha,
|
return TypicalAcceptanceSampler(posterior_threshold, posterior_alpha,
|
||||||
disable_bonus_tokens, strict_mode)
|
strict_mode)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("k", list(range(1, 6)))
|
@pytest.mark.parametrize("k", list(range(1, 6)))
|
||||||
@ -154,11 +153,10 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("seed", list(range(10)))
|
@pytest.mark.parametrize("seed", list(range(10)))
|
||||||
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
|
|
||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_uniform_target_distribution_accepts_all_tokens(
|
def test_uniform_target_distribution_accepts_all_tokens(
|
||||||
seed: int, disable_bonus_tokens: bool, device: str):
|
seed: int, device: str):
|
||||||
"""
|
"""
|
||||||
Test the TypicalAcceptanceSampler with a uniform target probability
|
Test the TypicalAcceptanceSampler with a uniform target probability
|
||||||
distribution.
|
distribution.
|
||||||
@ -166,17 +164,14 @@ def test_uniform_target_distribution_accepts_all_tokens(
|
|||||||
This test verifies that when provided with a uniform target probability
|
This test verifies that when provided with a uniform target probability
|
||||||
distribution, the TypicalAcceptanceSampler accepts all draft tokens. The
|
distribution, the TypicalAcceptanceSampler accepts all draft tokens. The
|
||||||
entropy of the uniform target distribution being high should lead to all
|
entropy of the uniform target distribution being high should lead to all
|
||||||
draft tokens being accepted. The test also ensures that the behavior
|
draft tokens being accepted.
|
||||||
regarding bonus tokens is consistent with the `disable_bonus_tokens`
|
|
||||||
flag.
|
|
||||||
"""
|
"""
|
||||||
set_random_seed(seed)
|
set_random_seed(seed)
|
||||||
k = 3
|
k = 3
|
||||||
batch_size = 5
|
batch_size = 5
|
||||||
vocab_size = 30_000
|
vocab_size = 30_000
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
typical_acceptance_sampler = get_acceptance_sampler(
|
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
|
||||||
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
|
|
||||||
typical_acceptance_sampler.init_gpu_tensors(device=device)
|
typical_acceptance_sampler.init_gpu_tensors(device=device)
|
||||||
target_with_bonus_probs = torch.rand(batch_size,
|
target_with_bonus_probs = torch.rand(batch_size,
|
||||||
k + 1,
|
k + 1,
|
||||||
@ -200,21 +195,15 @@ def test_uniform_target_distribution_accepts_all_tokens(
|
|||||||
# should lead to all draft tokens being accepted. Verify that.
|
# should lead to all draft tokens being accepted. Verify that.
|
||||||
assert output_token_ids.shape[0] == batch_size
|
assert output_token_ids.shape[0] == batch_size
|
||||||
assert output_token_ids.shape[1] == (k + 1)
|
assert output_token_ids.shape[1] == (k + 1)
|
||||||
if disable_bonus_tokens:
|
|
||||||
assert torch.all(output_token_ids[:, -1] == -1)
|
|
||||||
else:
|
|
||||||
assert torch.all(output_token_ids[:, -1] == bonus_token_ids.squeeze())
|
assert torch.all(output_token_ids[:, -1] == bonus_token_ids.squeeze())
|
||||||
|
|
||||||
assert torch.all(output_token_ids[:, :k] == draft_token_ids)
|
assert torch.all(output_token_ids[:, :k] == draft_token_ids)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("seed", list(range(10)))
|
@pytest.mark.parametrize("seed", list(range(10)))
|
||||||
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
|
|
||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_temperature_zero_target_distribution(seed: int,
|
def test_temperature_zero_target_distribution(seed: int, device: str):
|
||||||
disable_bonus_tokens: bool,
|
|
||||||
device: str):
|
|
||||||
"""
|
"""
|
||||||
Test the TypicalAcceptanceSampler with a zero-temperature target
|
Test the TypicalAcceptanceSampler with a zero-temperature target
|
||||||
probability distribution.
|
probability distribution.
|
||||||
@ -232,8 +221,7 @@ def test_temperature_zero_target_distribution(seed: int,
|
|||||||
vocab_size = 30_000
|
vocab_size = 30_000
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
|
|
||||||
typical_acceptance_sampler = get_acceptance_sampler(
|
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
|
||||||
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
|
|
||||||
typical_acceptance_sampler.init_gpu_tensors(device=device)
|
typical_acceptance_sampler.init_gpu_tensors(device=device)
|
||||||
# Simulate temperature 0 probability distribution for target probabilities
|
# Simulate temperature 0 probability distribution for target probabilities
|
||||||
# and create target probabilities such that only 1 token id has
|
# and create target probabilities such that only 1 token id has
|
||||||
@ -267,11 +255,9 @@ def test_temperature_zero_target_distribution(seed: int,
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("seed", list(range(10)))
|
@pytest.mark.parametrize("seed", list(range(10)))
|
||||||
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
|
|
||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
|
def test_mixed_target_distribution(seed: int, device: str):
|
||||||
device: str):
|
|
||||||
"""
|
"""
|
||||||
Test the TypicalAcceptanceSampler with a mixed target probability
|
Test the TypicalAcceptanceSampler with a mixed target probability
|
||||||
distribution.
|
distribution.
|
||||||
@ -285,16 +271,13 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
|
|||||||
with a probability of 1.0 is accepted, and all other tokens are rejected.
|
with a probability of 1.0 is accepted, and all other tokens are rejected.
|
||||||
- For sequences with a uniform distribution, all draft tokens are
|
- For sequences with a uniform distribution, all draft tokens are
|
||||||
accepted.
|
accepted.
|
||||||
- When `disable_bonus_tokens` is False, the bonus tokens are also accepted
|
|
||||||
for sequences with a uniform distribution.
|
|
||||||
"""
|
"""
|
||||||
set_random_seed(seed)
|
set_random_seed(seed)
|
||||||
k = 3
|
k = 3
|
||||||
batch_size = 4
|
batch_size = 4
|
||||||
vocab_size = 30_000
|
vocab_size = 30_000
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
typical_acceptance_sampler = get_acceptance_sampler(
|
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
|
||||||
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
|
|
||||||
typical_acceptance_sampler.init_gpu_tensors(device=device)
|
typical_acceptance_sampler.init_gpu_tensors(device=device)
|
||||||
# For sequences 0 and 2 set the distribution to a temperature
|
# For sequences 0 and 2 set the distribution to a temperature
|
||||||
# zero distribution. For sequences 1 and 3 set it to a uniform
|
# zero distribution. For sequences 1 and 3 set it to a uniform
|
||||||
@ -328,21 +311,16 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
|
|||||||
0]))
|
0]))
|
||||||
# For sequences 1 and 3 verify that all tokens are accepted since the
|
# For sequences 1 and 3 verify that all tokens are accepted since the
|
||||||
# target probability distribution is uniform. In addition verify that
|
# target probability distribution is uniform. In addition verify that
|
||||||
# if disable_bonus_tokens is false then we also accept the bonus tokens.
|
# we also accept the bonus tokens.
|
||||||
assert torch.all(
|
assert torch.all(
|
||||||
output_token_ids[[1, 3], :-1] == draft_token_ids[[1, 3], :])
|
output_token_ids[[1, 3], :-1] == draft_token_ids[[1, 3], :])
|
||||||
if disable_bonus_tokens:
|
|
||||||
assert torch.all(output_token_ids[[1, 3], -1] == -1)
|
|
||||||
else:
|
|
||||||
assert torch.all(output_token_ids[[1, 3], -1] != -1)
|
assert torch.all(output_token_ids[[1, 3], -1] != -1)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("seed", list(range(10)))
|
@pytest.mark.parametrize("seed", list(range(10)))
|
||||||
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
|
|
||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
|
def test_accept_tokens_partially(seed: int, device: str):
|
||||||
device: str):
|
|
||||||
"""
|
"""
|
||||||
Test the TypicalAcceptanceSampler's behavior when only a subset of draft
|
Test the TypicalAcceptanceSampler's behavior when only a subset of draft
|
||||||
tokens should be accepted.
|
tokens should be accepted.
|
||||||
@ -362,8 +340,7 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
|
|||||||
batch_size = 1
|
batch_size = 1
|
||||||
vocab_size = 30_000
|
vocab_size = 30_000
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
typical_acceptance_sampler = get_acceptance_sampler(
|
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
|
||||||
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
|
|
||||||
typical_acceptance_sampler.init_gpu_tensors(device=device)
|
typical_acceptance_sampler.init_gpu_tensors(device=device)
|
||||||
# Create a temperature zero target probability distribution and ensure
|
# Create a temperature zero target probability distribution and ensure
|
||||||
# all draft token ids correspond to the tokens with 1.0 probability.
|
# all draft token ids correspond to the tokens with 1.0 probability.
|
||||||
@ -384,9 +361,6 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
|
|||||||
assert output_token_ids.shape[0] == batch_size
|
assert output_token_ids.shape[0] == batch_size
|
||||||
assert output_token_ids.shape[1] == (k + 1)
|
assert output_token_ids.shape[1] == (k + 1)
|
||||||
assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids)
|
assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids)
|
||||||
if disable_bonus_tokens:
|
|
||||||
assert torch.all(output_token_ids[:, -1] == -1)
|
|
||||||
else:
|
|
||||||
assert torch.all(output_token_ids[:, -1] == bonus_token_ids)
|
assert torch.all(output_token_ids[:, -1] == bonus_token_ids)
|
||||||
# Next only keep the first 2 draft tokens same as the zero temperature
|
# Next only keep the first 2 draft tokens same as the zero temperature
|
||||||
# tokens. For the remaining 3 choose some other tokens. In the
|
# tokens. For the remaining 3 choose some other tokens. In the
|
||||||
@ -408,12 +382,9 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("seed", list(range(1)))
|
@pytest.mark.parametrize("seed", list(range(1)))
|
||||||
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
|
|
||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_accept_tokens_set_non_default_posteriors(seed: int,
|
def test_accept_tokens_set_non_default_posteriors(seed: int, device: str):
|
||||||
disable_bonus_tokens: bool,
|
|
||||||
device: str):
|
|
||||||
"""
|
"""
|
||||||
Test the TypicalAcceptanceSampler with custom posterior thresholds and
|
Test the TypicalAcceptanceSampler with custom posterior thresholds and
|
||||||
alpha values. This test verifies that by modifying the posterior
|
alpha values. This test verifies that by modifying the posterior
|
||||||
@ -425,8 +396,7 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
|
|||||||
batch_size = 1
|
batch_size = 1
|
||||||
vocab_size = 30_000
|
vocab_size = 30_000
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
typical_acceptance_sampler = get_acceptance_sampler(
|
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
|
||||||
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
|
|
||||||
typical_acceptance_sampler.init_gpu_tensors(device=device)
|
typical_acceptance_sampler.init_gpu_tensors(device=device)
|
||||||
# Simulate temperature 0 probability distribution for target
|
# Simulate temperature 0 probability distribution for target
|
||||||
# probabilities and create target probabilities such that only 1 token
|
# probabilities and create target probabilities such that only 1 token
|
||||||
@ -457,10 +427,7 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
|
|||||||
# now accept even draft tokens with very low probability in the
|
# now accept even draft tokens with very low probability in the
|
||||||
# target distribution. Simulate and verify the same.
|
# target distribution. Simulate and verify the same.
|
||||||
typical_acceptance_sampler = TypicalAcceptanceSampler(
|
typical_acceptance_sampler = TypicalAcceptanceSampler(
|
||||||
strict_mode=True,
|
strict_mode=True, posterior_threshold=0.0, posterior_alpha=0.0)
|
||||||
disable_bonus_tokens=disable_bonus_tokens,
|
|
||||||
posterior_threshold=0.0,
|
|
||||||
posterior_alpha=0.0)
|
|
||||||
typical_acceptance_sampler.init_gpu_tensors(device=device)
|
typical_acceptance_sampler.init_gpu_tensors(device=device)
|
||||||
output_token_ids = typical_acceptance_sampler(
|
output_token_ids = typical_acceptance_sampler(
|
||||||
target_probs,
|
target_probs,
|
||||||
@ -470,18 +437,13 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
|
|||||||
assert output_token_ids.shape[0] == batch_size
|
assert output_token_ids.shape[0] == batch_size
|
||||||
assert output_token_ids.shape[1] == (k + 1)
|
assert output_token_ids.shape[1] == (k + 1)
|
||||||
assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids)
|
assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids)
|
||||||
if disable_bonus_tokens:
|
|
||||||
assert torch.all(output_token_ids[:, -1] == -1)
|
|
||||||
else:
|
|
||||||
assert torch.all(output_token_ids[:, -1] == bonus_token_ids)
|
assert torch.all(output_token_ids[:, -1] == bonus_token_ids)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("seed", list(range(10)))
|
@pytest.mark.parametrize("seed", list(range(10)))
|
||||||
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
|
|
||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_replacement_token_ids(seed: int, disable_bonus_tokens: bool,
|
def test_replacement_token_ids(seed: int, device: str):
|
||||||
device: str):
|
|
||||||
"""
|
"""
|
||||||
Test the TypicalAcceptanceSampler's method for generating
|
Test the TypicalAcceptanceSampler's method for generating
|
||||||
replacement token IDs.
|
replacement token IDs.
|
||||||
@ -497,8 +459,7 @@ def test_replacement_token_ids(seed: int, disable_bonus_tokens: bool,
|
|||||||
batch_size = 5
|
batch_size = 5
|
||||||
vocab_size = 30_000
|
vocab_size = 30_000
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
typical_acceptance_sampler = get_acceptance_sampler(
|
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
|
||||||
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
|
|
||||||
typical_acceptance_sampler.init_gpu_tensors(device=device)
|
typical_acceptance_sampler.init_gpu_tensors(device=device)
|
||||||
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||||
expected_replacement_tokens = -torch.ones(
|
expected_replacement_tokens = -torch.ones(
|
||||||
|
@ -31,7 +31,7 @@ MAIN_MODEL = "JackFram/llama-68m"
|
|||||||
# speculative model
|
# speculative model
|
||||||
SPEC_MODEL = "abhigoyal/vllm-medusa-llama-68m-random"
|
SPEC_MODEL = "abhigoyal/vllm-medusa-llama-68m-random"
|
||||||
|
|
||||||
# max. number of speculative tokens: this corresponds to
|
# max number of speculative tokens: this corresponds to
|
||||||
# num_heads in the config.json of the speculator model.
|
# num_heads in the config.json of the speculator model.
|
||||||
MAX_SPEC_TOKENS = 5
|
MAX_SPEC_TOKENS = 5
|
||||||
|
|
||||||
|
@ -31,15 +31,11 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
disable_bonus_tokens: bool = True,
|
|
||||||
strict_mode: bool = False,
|
strict_mode: bool = False,
|
||||||
use_flashinfer: Optional[bool] = None):
|
use_flashinfer: Optional[bool] = None):
|
||||||
"""Create a rejection sampler.
|
"""Create a rejection sampler.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
disable_bonus_tokens: Whether or not to disable the bonus token.
|
|
||||||
Require when bonus tokens will cause corrupt KV cache for
|
|
||||||
proposal methods that require KV cache.
|
|
||||||
strict_mode: Whether or not to perform shape/device/dtype checks
|
strict_mode: Whether or not to perform shape/device/dtype checks
|
||||||
during sampling. This catches correctness issues but adds
|
during sampling. This catches correctness issues but adds
|
||||||
nontrivial latency.
|
nontrivial latency.
|
||||||
@ -48,8 +44,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
|
|||||||
None, we will use the default value from the environment variable.
|
None, we will use the default value from the environment variable.
|
||||||
This parameter is only used for testing purposes.
|
This parameter is only used for testing purposes.
|
||||||
"""
|
"""
|
||||||
super().__init__(disable_bonus_tokens=disable_bonus_tokens,
|
super().__init__(strict_mode=strict_mode)
|
||||||
strict_mode=strict_mode)
|
|
||||||
if use_flashinfer is None:
|
if use_flashinfer is None:
|
||||||
self.use_flashinfer = envs.VLLM_USE_FLASHINFER_SAMPLER and (
|
self.use_flashinfer = envs.VLLM_USE_FLASHINFER_SAMPLER and (
|
||||||
chain_speculative_sampling is not None)
|
chain_speculative_sampling is not None)
|
||||||
@ -57,8 +52,6 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
|
|||||||
self.use_flashinfer = use_flashinfer
|
self.use_flashinfer = use_flashinfer
|
||||||
|
|
||||||
if self.use_flashinfer:
|
if self.use_flashinfer:
|
||||||
assert not disable_bonus_tokens, \
|
|
||||||
"flashinfer will enable bonus token by default"
|
|
||||||
logger.info("Use flashinfer for rejection sampling.")
|
logger.info("Use flashinfer for rejection sampling.")
|
||||||
else:
|
else:
|
||||||
logger.info("Use pytorch for rejection sampling.")
|
logger.info("Use pytorch for rejection sampling.")
|
||||||
|
@ -11,20 +11,14 @@ class SpecDecodeBaseSampler(nn.Module):
|
|||||||
step.
|
step.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self, strict_mode: bool = False):
|
||||||
disable_bonus_tokens: bool = True,
|
|
||||||
strict_mode: bool = False):
|
|
||||||
"""Base class constructor.
|
"""Base class constructor.
|
||||||
Args:
|
Args:
|
||||||
disable_bonus_tokens: Whether or not to disable the bonus token.
|
|
||||||
Require when bonus tokens will cause corrupt KV cache for
|
|
||||||
proposal methods that require KV cache.
|
|
||||||
strict_mode: Whether or not to perform shape/device/dtype checks
|
strict_mode: Whether or not to perform shape/device/dtype checks
|
||||||
during sampling. This catches correctness issues but adds
|
during sampling. This catches correctness issues but adds
|
||||||
nontrivial latency.
|
nontrivial latency.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._disable_bonus_tokens = disable_bonus_tokens
|
|
||||||
self._strict_mode = strict_mode
|
self._strict_mode = strict_mode
|
||||||
|
|
||||||
# NOTE: A "bonus token" is accepted iff all proposal tokens are
|
# NOTE: A "bonus token" is accepted iff all proposal tokens are
|
||||||
@ -111,13 +105,6 @@ class SpecDecodeBaseSampler(nn.Module):
|
|||||||
output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1,
|
output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1,
|
||||||
bonus_token_ids, -1)
|
bonus_token_ids, -1)
|
||||||
|
|
||||||
# We disable bonus tokens because it causes corrupt KV cache for
|
|
||||||
# proposal methods that require KV cache. We can fix it by "prefilling"
|
|
||||||
# the bonus token in the proposer. The following issue tracks the fix.
|
|
||||||
# https://github.com/vllm-project/vllm/issues/4212
|
|
||||||
if self._disable_bonus_tokens:
|
|
||||||
output_with_bonus_tokens[:, -1] = -1
|
|
||||||
|
|
||||||
# Fill the recovered token ids.
|
# Fill the recovered token ids.
|
||||||
output.mul_(~after_false_mask).add_(
|
output.mul_(~after_false_mask).add_(
|
||||||
substitute_token_ids.mul(after_false_mask))
|
substitute_token_ids.mul(after_false_mask))
|
||||||
|
@ -16,15 +16,11 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
|
|||||||
self,
|
self,
|
||||||
posterior_threshold: float,
|
posterior_threshold: float,
|
||||||
posterior_alpha: float,
|
posterior_alpha: float,
|
||||||
disable_bonus_tokens: bool = False,
|
|
||||||
strict_mode: bool = False,
|
strict_mode: bool = False,
|
||||||
):
|
):
|
||||||
"""Create a Typical Acceptance Sampler.
|
"""Create a Typical Acceptance Sampler.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
disable_bonus_tokens: Whether or not to disable the bonus token.
|
|
||||||
Require when bonus tokens will cause corrupt KV cache for
|
|
||||||
proposal methods that require KV cache.
|
|
||||||
strict_mode: Whether or not to perform shape/device/dtype checks
|
strict_mode: Whether or not to perform shape/device/dtype checks
|
||||||
during sampling. This catches correctness issues but adds
|
during sampling. This catches correctness issues but adds
|
||||||
nontrivial latency.
|
nontrivial latency.
|
||||||
@ -36,8 +32,7 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
|
|||||||
"""
|
"""
|
||||||
self._posterior_threshold = posterior_threshold
|
self._posterior_threshold = posterior_threshold
|
||||||
self._posterior_alpha = posterior_alpha
|
self._posterior_alpha = posterior_alpha
|
||||||
super().__init__(disable_bonus_tokens=disable_bonus_tokens,
|
super().__init__(strict_mode=strict_mode)
|
||||||
strict_mode=strict_mode)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -54,7 +49,7 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
|
|||||||
one token will be emitted.
|
one token will be emitted.
|
||||||
|
|
||||||
In the case where all draft tokens are accepted, the bonus token will be
|
In the case where all draft tokens are accepted, the bonus token will be
|
||||||
accepted conditioned on self._disable_bonus_tokens being false.
|
accepted.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
target_probs: The probability distribution over token ids given
|
target_probs: The probability distribution over token ids given
|
||||||
|
@ -164,11 +164,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
|
|
||||||
spec_decode_sampler: SpecDecodeBaseSampler = None
|
spec_decode_sampler: SpecDecodeBaseSampler = None
|
||||||
if draft_token_acceptance_method == "rejection_sampler":
|
if draft_token_acceptance_method == "rejection_sampler":
|
||||||
spec_decode_sampler = RejectionSampler(
|
spec_decode_sampler = RejectionSampler()
|
||||||
disable_bonus_tokens=False, )
|
|
||||||
elif draft_token_acceptance_method == "typical_acceptance_sampler":
|
elif draft_token_acceptance_method == "typical_acceptance_sampler":
|
||||||
spec_decode_sampler = TypicalAcceptanceSampler(
|
spec_decode_sampler = TypicalAcceptanceSampler(
|
||||||
disable_bonus_tokens=False,
|
|
||||||
posterior_threshold=\
|
posterior_threshold=\
|
||||||
typical_acceptance_sampler_posterior_threshold,
|
typical_acceptance_sampler_posterior_threshold,
|
||||||
posterior_alpha=typical_acceptance_sampler_posterior_alpha,
|
posterior_alpha=typical_acceptance_sampler_posterior_alpha,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user