[Speculative Decoding 2/2 ] Integrate typical acceptance sampler into Spec Decode Worker (#5348)

This commit is contained in:
sroy745 2024-07-01 00:33:05 -07:00 committed by GitHub
parent 614aa51203
commit 80ca1e6a3a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 480 additions and 208 deletions

View File

@ -52,6 +52,19 @@ def get_draft_token_ids(batch_size: int, k: int, vocab_size: int,
return draft_token_ids return draft_token_ids
def get_acceptance_sampler(
posterior_threshold: float = 0.03,
posterior_alpha: float = 0.9,
disable_bonus_tokens: bool = False,
strict_mode: bool = False,
) -> TypicalAcceptanceSampler:
"""
Initializes and returns a TypicalAcceptanceSampler.
"""
return TypicalAcceptanceSampler(posterior_threshold, posterior_alpha,
disable_bonus_tokens, strict_mode)
@pytest.mark.parametrize("k", list(range(1, 6))) @pytest.mark.parametrize("k", list(range(1, 6)))
@pytest.mark.parametrize("vocab_size", [30_000, 50_000]) @pytest.mark.parametrize("vocab_size", [30_000, 50_000])
@pytest.mark.parametrize("batch_size", list(range(1, 32))) @pytest.mark.parametrize("batch_size", list(range(1, 32)))
@ -64,7 +77,7 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
different combinations of k, vocab_size, batch_size and num devices. different combinations of k, vocab_size, batch_size and num devices.
""" """
torch.set_default_device(device) torch.set_default_device(device)
typical_acceptance_sampler = TypicalAcceptanceSampler() typical_acceptance_sampler = get_acceptance_sampler()
typical_acceptance_sampler.init_gpu_tensors(rank=0) typical_acceptance_sampler.init_gpu_tensors(rank=0)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
bonus_token_ids = torch.randint(low=0, bonus_token_ids = torch.randint(low=0,
@ -76,7 +89,10 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
size=(batch_size, k), size=(batch_size, k),
dtype=torch.int64) dtype=torch.int64)
# Verify that sampling succeeds for all cases. # Verify that sampling succeeds for all cases.
typical_acceptance_sampler(target_probs, bonus_token_ids, draft_token_ids) typical_acceptance_sampler(target_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
@pytest.mark.parametrize("above_or_below_vocab_range", ["above", "below"]) @pytest.mark.parametrize("above_or_below_vocab_range", ["above", "below"])
@ -94,7 +110,7 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
batch_size = 5 batch_size = 5
vocab_size = 30_000 vocab_size = 30_000
torch.set_default_device(device) torch.set_default_device(device)
typical_acceptance_sampler = TypicalAcceptanceSampler(strict_mode=True) typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(rank=0) typical_acceptance_sampler.init_gpu_tensors(rank=0)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
bonus_token_ids = torch.randint(low=0, bonus_token_ids = torch.randint(low=0,
@ -125,8 +141,10 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
oob_token_ids[0][0] = rogue_token_id oob_token_ids[0][0] = rogue_token_id
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
typical_acceptance_sampler(target_probs, bonus_token_ids, typical_acceptance_sampler(target_probs,
draft_token_ids) bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
@pytest.mark.parametrize("seed", list(range(10))) @pytest.mark.parametrize("seed", list(range(10)))
@ -151,7 +169,7 @@ def test_uniform_target_distribution_accepts_all_tokens(
batch_size = 5 batch_size = 5
vocab_size = 30_000 vocab_size = 30_000
torch.set_default_device(device) torch.set_default_device(device)
typical_acceptance_sampler = TypicalAcceptanceSampler( typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens) strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler.init_gpu_tensors(rank=0) typical_acceptance_sampler.init_gpu_tensors(rank=0)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
@ -163,9 +181,11 @@ def test_uniform_target_distribution_accepts_all_tokens(
high=vocab_size, high=vocab_size,
size=(batch_size, 1), size=(batch_size, 1),
dtype=torch.int64) dtype=torch.int64)
output_token_ids = typical_acceptance_sampler(target_probs, output_token_ids = typical_acceptance_sampler(
target_probs,
bonus_token_ids, bonus_token_ids,
draft_token_ids) draft_probs=None,
draft_token_ids=draft_token_ids)
# We are using a uniform target probability distribution. # We are using a uniform target probability distribution.
# For a uniform distribution the entropy is very high and it # For a uniform distribution the entropy is very high and it
# should lead to all draft tokens being accepted. Verify that. # should lead to all draft tokens being accepted. Verify that.
@ -203,7 +223,7 @@ def test_temperature_zero_target_distribution(seed: int,
vocab_size = 30_000 vocab_size = 30_000
torch.set_default_device(device) torch.set_default_device(device)
typical_acceptance_sampler = TypicalAcceptanceSampler( typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens) strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler.init_gpu_tensors(rank=0) typical_acceptance_sampler.init_gpu_tensors(rank=0)
# Simulate temperature 0 probability distribution for target probabilities # Simulate temperature 0 probability distribution for target probabilities
@ -224,9 +244,11 @@ def test_temperature_zero_target_distribution(seed: int,
# 1.0 tokens in the target distribution we will reject all of them and # 1.0 tokens in the target distribution we will reject all of them and
# fallback to the greedy sampling for selecting 1 token for each sequence. # fallback to the greedy sampling for selecting 1 token for each sequence.
# Verify the same. # Verify the same.
output_token_ids = typical_acceptance_sampler(target_probs, output_token_ids = typical_acceptance_sampler(
target_probs,
bonus_token_ids, bonus_token_ids,
draft_token_ids) draft_probs=None,
draft_token_ids=draft_token_ids)
assert output_token_ids.shape[0] == batch_size assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1) assert output_token_ids.shape[1] == (k + 1)
assert torch.all(output_token_ids[:, -1] == -1) assert torch.all(output_token_ids[:, -1] == -1)
@ -261,7 +283,7 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
batch_size = 4 batch_size = 4
vocab_size = 30_000 vocab_size = 30_000
torch.set_default_device(device) torch.set_default_device(device)
typical_acceptance_sampler = TypicalAcceptanceSampler( typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens) strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler.init_gpu_tensors(rank=0) typical_acceptance_sampler.init_gpu_tensors(rank=0)
# For sequences 0 and 2 set the distribution to a temperature # For sequences 0 and 2 set the distribution to a temperature
@ -277,9 +299,11 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
high=vocab_size, high=vocab_size,
size=(batch_size, 1), size=(batch_size, 1),
dtype=torch.int64) dtype=torch.int64)
output_token_ids = typical_acceptance_sampler(target_probs, output_token_ids = typical_acceptance_sampler(
target_probs,
bonus_token_ids, bonus_token_ids,
draft_token_ids) draft_probs=None,
draft_token_ids=draft_token_ids)
# verify the shape of output_token_ids # verify the shape of output_token_ids
assert output_token_ids.shape[0] == batch_size assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1) assert output_token_ids.shape[1] == (k + 1)
@ -326,7 +350,7 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
batch_size = 1 batch_size = 1
vocab_size = 30_000 vocab_size = 30_000
torch.set_default_device(device) torch.set_default_device(device)
typical_acceptance_sampler = TypicalAcceptanceSampler( typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens) strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler.init_gpu_tensors(rank=0) typical_acceptance_sampler.init_gpu_tensors(rank=0)
# Create a temperature zero target probability distribution and ensure # Create a temperature zero target probability distribution and ensure
@ -339,9 +363,11 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
high=vocab_size, high=vocab_size,
size=(batch_size, 1), size=(batch_size, 1),
dtype=torch.int64) dtype=torch.int64)
output_token_ids = typical_acceptance_sampler(target_probs, output_token_ids = typical_acceptance_sampler(
target_probs,
bonus_token_ids, bonus_token_ids,
draft_token_ids) draft_probs=None,
draft_token_ids=draft_token_ids)
assert output_token_ids.shape[0] == batch_size assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1) assert output_token_ids.shape[1] == (k + 1)
assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids) assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids)
@ -357,9 +383,11 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
batch_size, k, vocab_size, zero_temperature_token_ids) batch_size, k, vocab_size, zero_temperature_token_ids)
draft_token_ids = torch.cat( draft_token_ids = torch.cat(
(draft_token_ids[:, :2], draft_token_ids_to_replace[:, -3:]), dim=1) (draft_token_ids[:, :2], draft_token_ids_to_replace[:, -3:]), dim=1)
output_token_ids = typical_acceptance_sampler(target_probs, output_token_ids = typical_acceptance_sampler(
target_probs,
bonus_token_ids, bonus_token_ids,
draft_token_ids) draft_probs=None,
draft_token_ids=draft_token_ids)
assert output_token_ids.shape[0] == batch_size assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1) assert output_token_ids.shape[1] == (k + 1)
assert torch.all(output_token_ids[:, :2] == draft_token_ids[:, :2]) assert torch.all(output_token_ids[:, :2] == draft_token_ids[:, :2])
@ -384,7 +412,7 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
batch_size = 1 batch_size = 1
vocab_size = 30_000 vocab_size = 30_000
torch.set_default_device(device) torch.set_default_device(device)
typical_acceptance_sampler = TypicalAcceptanceSampler( typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens) strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler.init_gpu_tensors(rank=0) typical_acceptance_sampler.init_gpu_tensors(rank=0)
# Simulate temperature 0 probability distribution for target # Simulate temperature 0 probability distribution for target
@ -402,9 +430,11 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
high=vocab_size, high=vocab_size,
size=(batch_size, 1), size=(batch_size, 1),
dtype=torch.int64) dtype=torch.int64)
output_token_ids = typical_acceptance_sampler(target_probs, output_token_ids = typical_acceptance_sampler(
target_probs,
bonus_token_ids, bonus_token_ids,
draft_token_ids) draft_probs=None,
draft_token_ids=draft_token_ids)
assert output_token_ids.shape[0] == batch_size assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1) assert output_token_ids.shape[1] == (k + 1)
assert torch.all(output_token_ids[:, 1:-1] == -1) assert torch.all(output_token_ids[:, 1:-1] == -1)
@ -418,9 +448,11 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
posterior_threshold=0.0, posterior_threshold=0.0,
posterior_alpha=0.0) posterior_alpha=0.0)
typical_acceptance_sampler.init_gpu_tensors(rank=0) typical_acceptance_sampler.init_gpu_tensors(rank=0)
output_token_ids = typical_acceptance_sampler(target_probs, output_token_ids = typical_acceptance_sampler(
target_probs,
bonus_token_ids, bonus_token_ids,
draft_token_ids) draft_probs=None,
draft_token_ids=draft_token_ids)
assert output_token_ids.shape[0] == batch_size assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1) assert output_token_ids.shape[1] == (k + 1)
assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids) assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids)
@ -451,7 +483,7 @@ def test_replacement_token_ids(seed: int, disable_bonus_tokens: bool,
batch_size = 5 batch_size = 5
vocab_size = 30_000 vocab_size = 30_000
torch.set_default_device(device) torch.set_default_device(device)
typical_acceptance_sampler = TypicalAcceptanceSampler( typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens) strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler.init_gpu_tensors(rank=0) typical_acceptance_sampler.init_gpu_tensors(rank=0)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)

View File

@ -11,9 +11,15 @@ distribution matches the target model's output distribution (up to hardware
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
equality. This gives us good coverage of temp=0. equality. This gives us good coverage of temp=0.
At temp=0, the TypicalAcceptanceSampler ensures that only the tokens with the
highest probability in the target distribution are accepted. Therefore, we can
expect greedy equality for the TypicalAcceptanceSampler at temp=0.
For temp>0, we rely on unit tests on the rejection sampler to verify that the For temp>0, we rely on unit tests on the rejection sampler to verify that the
output distribution is the same with spec decode vs. no spec decode (this would output distribution is the same with spec decode vs. no spec decode (this would
be prohibitively expensive to run with a real model). be prohibitively expensive to run with a real model). Similarly, for the
TypicalAcceptance sampler also, we rely on unit tests to validate temp>0
test cases.
NOTE: Speculative decoding's distribution equality requires that the measured NOTE: Speculative decoding's distribution equality requires that the measured
distributions of the target model and proposal model be deterministic given the distributions of the target model and proposal model be deterministic given the
@ -611,3 +617,49 @@ def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int,
batch_size, batch_size,
max_output_len=output_len, max_output_len=output_len,
force_output_len=True) force_output_len=True)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-160m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
"test_llm_kwargs",
[
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": k,
"spec_decoding_acceptance_method": "typical_acceptance_sampler"
}
# Try a range of common k.
for k in [1, 2, 3]
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_typical_acceptance_sampling(baseline_llm_generator,
test_llm_generator, batch_size: int,
output_len: int):
"""Verify that speculative decoding produces exact equality to without spec
decode with TypicalAcceptanceSampler as the draft token acceptance
sampling method.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)

View File

@ -3,33 +3,35 @@ from unittest.mock import MagicMock, patch
import pytest import pytest
import torch import torch
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.sequence import ExecuteModelRequest from vllm.sequence import ExecuteModelRequest
from vllm.spec_decode.metrics import AsyncMetricsCollector from vllm.spec_decode.metrics import AsyncMetricsCollector
from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker
from vllm.spec_decode.top1_proposer import Top1Proposer from vllm.spec_decode.top1_proposer import Top1Proposer
from .test_utils import mock_spec_decode_sampler
from .utils import create_batch, mock_worker from .utils import create_batch, mock_worker
@pytest.mark.parametrize('queue_size', [4]) @pytest.mark.parametrize('queue_size', [4])
@pytest.mark.parametrize('batch_size', [1]) @pytest.mark.parametrize('batch_size', [1])
@pytest.mark.parametrize('k', [1]) @pytest.mark.parametrize('k', [1])
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@torch.inference_mode() @torch.inference_mode()
def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int): def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int,
acceptance_sampler_method: str):
"""Verify that speculative tokens are disabled when the batch size """Verify that speculative tokens are disabled when the batch size
exceeds the threshold. exceeds the threshold.
""" """
disable_by_batch_size = 3 disable_by_batch_size = 3
draft_worker = mock_worker(cls=MultiStepWorker) draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker() target_worker = mock_worker()
rejection_sampler = MagicMock(spec=RejectionSampler)
metrics_collector = MagicMock(spec=AsyncMetricsCollector) metrics_collector = MagicMock(spec=AsyncMetricsCollector)
worker = SpecDecodeWorker(proposer_worker=draft_worker, worker = SpecDecodeWorker(proposer_worker=draft_worker,
scorer_worker=target_worker, scorer_worker=target_worker,
rejection_sampler=rejection_sampler, spec_decode_sampler=mock_spec_decode_sampler(
acceptance_sampler_method),
metrics_collector=metrics_collector, metrics_collector=metrics_collector,
disable_by_batch_size=disable_by_batch_size) disable_by_batch_size=disable_by_batch_size)

View File

@ -10,16 +10,16 @@ from vllm.spec_decode.metrics import AsyncMetricsCollector
def test_initial_call_returns_none(): def test_initial_call_returns_none():
"""Expect first call to get metrics to return None. """Expect first call to get metrics to return None.
""" """
rej_sampler = MagicMock() spec_decode_sampler = MagicMock()
rej_sampler.num_accepted_tokens = torch.tensor(0, spec_decode_sampler.num_accepted_tokens = torch.tensor(0,
dtype=torch.long, dtype=torch.long,
device='cuda') device='cuda')
rej_sampler.num_emitted_tokens = torch.tensor(0, spec_decode_sampler.num_emitted_tokens = torch.tensor(0,
dtype=torch.long, dtype=torch.long,
device='cuda') device='cuda')
rej_sampler.num_draft_tokens = 0 spec_decode_sampler.num_draft_tokens = 0
collector = AsyncMetricsCollector(rej_sampler) collector = AsyncMetricsCollector(spec_decode_sampler)
collector.init_gpu_tensors(rank=0) collector.init_gpu_tensors(rank=0)
maybe_metrics = collector.maybe_collect_rejsample_metrics(k=5) maybe_metrics = collector.maybe_collect_rejsample_metrics(k=5)
assert maybe_metrics is None assert maybe_metrics is None
@ -28,14 +28,14 @@ def test_initial_call_returns_none():
def test_second_call_returns_metrics(): def test_second_call_returns_metrics():
"""Expect second call to not return None. """Expect second call to not return None.
""" """
rej_sampler = MagicMock() spec_decode_sampler = MagicMock()
rej_sampler.num_accepted_tokens = torch.tensor(0, spec_decode_sampler.num_accepted_tokens = torch.tensor(0,
dtype=torch.long, dtype=torch.long,
device='cuda') device='cuda')
rej_sampler.num_emitted_tokens = torch.tensor(0, spec_decode_sampler.num_emitted_tokens = torch.tensor(0,
dtype=torch.long, dtype=torch.long,
device='cuda') device='cuda')
rej_sampler.num_draft_tokens = 0 spec_decode_sampler.num_draft_tokens = 0
collect_interval_s = 5.0 collect_interval_s = 5.0
timer = MagicMock() timer = MagicMock()
@ -43,7 +43,7 @@ def test_second_call_returns_metrics():
0.0, collect_interval_s + 0.1, collect_interval_s + 0.2 0.0, collect_interval_s + 0.1, collect_interval_s + 0.2
] ]
collector = AsyncMetricsCollector(rejection_sampler=rej_sampler, collector = AsyncMetricsCollector(spec_decode_sampler=spec_decode_sampler,
timer=timer, timer=timer,
collect_interval_s=collect_interval_s) collect_interval_s=collect_interval_s)
collector.init_gpu_tensors(rank=0) collector.init_gpu_tensors(rank=0)
@ -56,16 +56,16 @@ def test_second_call_returns_metrics():
def test_nonzero_rank_noop(rank): def test_nonzero_rank_noop(rank):
"""Verify nonzero ranks don't collect metrics. """Verify nonzero ranks don't collect metrics.
""" """
rej_sampler = MagicMock() spec_decode_sampler = MagicMock()
rej_sampler.num_accepted_tokens = torch.tensor(0, spec_decode_sampler.num_accepted_tokens = torch.tensor(0,
dtype=torch.long, dtype=torch.long,
device='cuda') device='cuda')
rej_sampler.num_emitted_tokens = torch.tensor(0, spec_decode_sampler.num_emitted_tokens = torch.tensor(0,
dtype=torch.long, dtype=torch.long,
device='cuda') device='cuda')
rej_sampler.num_draft_tokens = 0 spec_decode_sampler.num_draft_tokens = 0
collector = AsyncMetricsCollector(rej_sampler) collector = AsyncMetricsCollector(spec_decode_sampler)
collector.init_gpu_tensors(rank=rank) collector.init_gpu_tensors(rank=rank)
_ = collector.maybe_collect_rejsample_metrics(k=5) _ = collector.maybe_collect_rejsample_metrics(k=5)
metrics = collector.maybe_collect_rejsample_metrics(k=5) metrics = collector.maybe_collect_rejsample_metrics(k=5)
@ -75,14 +75,14 @@ def test_nonzero_rank_noop(rank):
def test_noop_until_time(): def test_noop_until_time():
"""Verify metrics aren't collected until enough time passes. """Verify metrics aren't collected until enough time passes.
""" """
rej_sampler = MagicMock() spec_decode_sampler = MagicMock()
rej_sampler.num_accepted_tokens = torch.tensor(0, spec_decode_sampler.num_accepted_tokens = torch.tensor(0,
dtype=torch.long, dtype=torch.long,
device='cuda') device='cuda')
rej_sampler.num_emitted_tokens = torch.tensor(0, spec_decode_sampler.num_emitted_tokens = torch.tensor(0,
dtype=torch.long, dtype=torch.long,
device='cuda') device='cuda')
rej_sampler.num_draft_tokens = 0 spec_decode_sampler.num_draft_tokens = 0
collect_interval_s = 5.0 collect_interval_s = 5.0
timer = MagicMock() timer = MagicMock()
@ -91,7 +91,7 @@ def test_noop_until_time():
collect_interval_s + 0.1, collect_interval_s + 0.1 collect_interval_s + 0.1, collect_interval_s + 0.1
] ]
collector = AsyncMetricsCollector(rejection_sampler=rej_sampler, collector = AsyncMetricsCollector(spec_decode_sampler=spec_decode_sampler,
timer=timer, timer=timer,
collect_interval_s=collect_interval_s) collect_interval_s=collect_interval_s)
collector.init_gpu_tensors(rank=0) collector.init_gpu_tensors(rank=0)
@ -122,14 +122,14 @@ def test_initial_metrics_has_correct_values(has_data: bool):
max_num_emitted_tokens = AsyncMetricsCollector.get_max_num_emitted_tokens( max_num_emitted_tokens = AsyncMetricsCollector.get_max_num_emitted_tokens(
num_draft_tokens, k) num_draft_tokens, k)
rej_sampler = MagicMock() spec_decode_sampler = MagicMock()
rej_sampler.num_accepted_tokens = torch.tensor(num_accepted_tokens, spec_decode_sampler.num_accepted_tokens = torch.tensor(num_accepted_tokens,
dtype=torch.long, dtype=torch.long,
device='cuda') device='cuda')
rej_sampler.num_emitted_tokens = torch.tensor(num_emitted_tokens, spec_decode_sampler.num_emitted_tokens = torch.tensor(num_emitted_tokens,
dtype=torch.long, dtype=torch.long,
device='cuda') device='cuda')
rej_sampler.num_draft_tokens = num_draft_tokens spec_decode_sampler.num_draft_tokens = num_draft_tokens
collect_interval_s = 5.0 collect_interval_s = 5.0
timer = MagicMock() timer = MagicMock()
@ -137,7 +137,7 @@ def test_initial_metrics_has_correct_values(has_data: bool):
0.0, collect_interval_s + 0.1, collect_interval_s + 0.2 0.0, collect_interval_s + 0.1, collect_interval_s + 0.2
] ]
collector = AsyncMetricsCollector(rejection_sampler=rej_sampler, collector = AsyncMetricsCollector(spec_decode_sampler=spec_decode_sampler,
timer=timer, timer=timer,
collect_interval_s=collect_interval_s) collect_interval_s=collect_interval_s)
collector.init_gpu_tensors(rank=0) collector.init_gpu_tensors(rank=0)

View File

@ -6,7 +6,6 @@ from unittest.mock import MagicMock
import pytest import pytest
import torch import torch
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
from vllm.sequence import ExecuteModelRequest, SamplerOutput, SequenceOutput from vllm.sequence import ExecuteModelRequest, SamplerOutput, SequenceOutput
from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.interfaces import SpeculativeProposals
@ -16,23 +15,26 @@ from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.spec_decode_worker import (SpecDecodeWorker, from vllm.spec_decode.spec_decode_worker import (SpecDecodeWorker,
split_num_cache_blocks_evenly) split_num_cache_blocks_evenly)
from .test_utils import mock_spec_decode_sampler
from .utils import create_batch, create_sampler_output_list, mock_worker from .utils import create_batch, create_sampler_output_list, mock_worker
@pytest.mark.parametrize('k', [1, 2, 6]) @pytest.mark.parametrize('k', [1, 2, 6])
@pytest.mark.parametrize('batch_size', [1, 2, 32]) @pytest.mark.parametrize('batch_size', [1, 2, 32])
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@torch.inference_mode() @torch.inference_mode()
def test_correctly_calls_draft_model(k: int, batch_size: int): def test_correctly_calls_draft_model(k: int, batch_size: int,
acceptance_sampler_method: str):
"""Verify SpecDecodeWorker calls the draft worker with correct """Verify SpecDecodeWorker calls the draft worker with correct
inputs. Everything else is mocked out. inputs. Everything else is mocked out.
""" """
draft_worker = mock_worker(cls=MultiStepWorker) draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker() target_worker = mock_worker()
rejection_sampler = MagicMock(spec=RejectionSampler)
metrics_collector = MagicMock(spec=AsyncMetricsCollector) metrics_collector = MagicMock(spec=AsyncMetricsCollector)
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, worker = SpecDecodeWorker(
metrics_collector) draft_worker, target_worker,
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
exception_secret = 'artificial stop' exception_secret = 'artificial stop'
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret) draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
@ -53,15 +55,16 @@ def test_correctly_calls_draft_model(k: int, batch_size: int):
@pytest.mark.parametrize('k', [1, 2, 6]) @pytest.mark.parametrize('k', [1, 2, 6])
@pytest.mark.parametrize('batch_size', [1, 2, 32]) @pytest.mark.parametrize('batch_size', [1, 2, 32])
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@torch.inference_mode() @torch.inference_mode()
def test_correctly_calls_target_model(k: int, batch_size: int): def test_correctly_calls_target_model(k: int, batch_size: int,
acceptance_sampler_method: str):
"""Verify SpecDecodeWorker calls the target model with correct """Verify SpecDecodeWorker calls the target model with correct
inputs. Everything else is mocked out. inputs. Everything else is mocked out.
""" """
draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False) draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False)
target_worker = mock_worker(use_spec=False) target_worker = mock_worker(use_spec=False)
rejection_sampler = MagicMock(spec=RejectionSampler)
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector) metrics_collector = MagicMock(spec=AsyncMetricsCollector)
draft_worker.device = 'cuda' draft_worker.device = 'cuda'
@ -69,8 +72,9 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
set_random_seed(1) set_random_seed(1)
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, worker = SpecDecodeWorker(
metrics_collector) draft_worker, target_worker,
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
worker.init_device() worker.init_device()
vocab_size = 32_000 vocab_size = 32_000
@ -133,8 +137,11 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
@pytest.mark.parametrize('k', [1, 2, 6]) @pytest.mark.parametrize('k', [1, 2, 6])
@pytest.mark.parametrize('batch_size', [1, 2, 32]) @pytest.mark.parametrize('batch_size', [1, 2, 32])
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@torch.inference_mode() @torch.inference_mode()
def test_correctly_calls_rejection_sampler(k: int, batch_size: int): def test_correctly_calls_spec_decode_sampler(k: int, batch_size: int,
acceptance_sampler_method: str):
"""Verify SpecDecodeWorker calls the rejection sampler with """Verify SpecDecodeWorker calls the rejection sampler with
correct inputs. Everything else is mocked out. correct inputs. Everything else is mocked out.
""" """
@ -144,15 +151,14 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
vocab_size=vocab_size, vocab_size=vocab_size,
use_spec=False) use_spec=False)
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False) target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
rejection_sampler = MagicMock(spec=RejectionSampler) spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector) metrics_collector = MagicMock(spec=AsyncMetricsCollector)
draft_worker.device = 'cuda' draft_worker.device = 'cuda'
target_worker.device = 'cuda' target_worker.device = 'cuda'
set_random_seed(1) set_random_seed(1)
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler,
metrics_collector) metrics_collector)
worker.init_device() worker.init_device()
@ -199,15 +205,16 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
target_worker.execute_model.return_value = [target_output[0]] target_worker.execute_model.return_value = [target_output[0]]
exception_secret = 'artificial stop' exception_secret = 'artificial stop'
rejection_sampler.side_effect = ValueError(exception_secret)
spec_decode_sampler.side_effect = ValueError(exception_secret)
with pytest.raises(ValueError, match=exception_secret): with pytest.raises(ValueError, match=exception_secret):
worker.execute_model(execute_model_req=ExecuteModelRequest( worker.execute_model(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k)) num_lookahead_slots=k))
assert len(rejection_sampler.call_args_list) == 1 assert len(spec_decode_sampler.call_args_list) == 1
_, kwargs = rejection_sampler.call_args_list[0] _, kwargs = spec_decode_sampler.call_args_list[0]
actual = SimpleNamespace(**kwargs) actual = SimpleNamespace(**kwargs)
assert torch.equal(actual.bonus_token_ids, assert torch.equal(actual.bonus_token_ids,
@ -221,8 +228,11 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
@pytest.mark.parametrize('k', [1, 2, 6]) @pytest.mark.parametrize('k', [1, 2, 6])
@pytest.mark.parametrize('batch_size', [1, 2, 32]) @pytest.mark.parametrize('batch_size', [1, 2, 32])
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@torch.inference_mode() @torch.inference_mode()
def test_correctly_formats_output(k: int, batch_size: int): def test_correctly_formats_output(k: int, batch_size: int,
acceptance_sampler_method: str):
"""Verify SpecDecodeWorker formats sampler output correctly. """Verify SpecDecodeWorker formats sampler output correctly.
Everything else is mocked out. Everything else is mocked out.
""" """
@ -232,15 +242,13 @@ def test_correctly_formats_output(k: int, batch_size: int):
vocab_size=vocab_size, vocab_size=vocab_size,
use_spec=False) use_spec=False)
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False) target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
rejection_sampler = MagicMock(spec=RejectionSampler)
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector) metrics_collector = MagicMock(spec=AsyncMetricsCollector)
draft_worker.device = 'cuda' draft_worker.device = 'cuda'
target_worker.device = 'cuda' target_worker.device = 'cuda'
set_random_seed(1) set_random_seed(1)
spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler,
metrics_collector) metrics_collector)
worker.init_device() worker.init_device()
@ -286,24 +294,23 @@ def test_correctly_formats_output(k: int, batch_size: int):
target_worker.execute_model.return_value = [target_output[0]] target_worker.execute_model.return_value = [target_output[0]]
rejection_sampler_output = torch.randint(low=0, spec_decode_sampler_output = torch.randint(low=0,
high=vocab_size, high=vocab_size,
size=(batch_size, k + 1), size=(batch_size, k + 1),
dtype=torch.int64, dtype=torch.int64,
device='cuda') device='cuda')
for i in range(batch_size): for i in range(batch_size):
minimum_accepted_tokens = 1 minimum_accepted_tokens = 1
rejection_sampler_output[i][ spec_decode_sampler_output[i][
-random.randint(minimum_accepted_tokens, k + 1):] = -1 -random.randint(minimum_accepted_tokens, k + 1):] = -1
rejection_sampler.return_value = rejection_sampler_output spec_decode_sampler.return_value = spec_decode_sampler_output
output = worker.execute_model(execute_model_req=ExecuteModelRequest( output = worker.execute_model(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k)) num_lookahead_slots=k))
expected_output = create_sampler_output_list( expected_output = create_sampler_output_list(
token_ids=rejection_sampler_output.transpose(0, 1), token_ids=spec_decode_sampler_output.transpose(0, 1),
probs=[None for _ in range(k + 1)], probs=[None for _ in range(k + 1)],
logprobs=[None for _ in range(k + 1)]) logprobs=[None for _ in range(k + 1)])
@ -350,8 +357,11 @@ def test_correctly_formats_output(k: int, batch_size: int):
@pytest.mark.parametrize('k', [1, 2]) @pytest.mark.parametrize('k', [1, 2])
@pytest.mark.parametrize('batch_size', [1]) @pytest.mark.parametrize('batch_size', [1])
@pytest.mark.parametrize('returns_metrics', [True, False]) @pytest.mark.parametrize('returns_metrics', [True, False])
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@torch.inference_mode() @torch.inference_mode()
def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool): def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool,
acceptance_sampler_method: str):
"""Verify SpecDecodeWorker collects metrics. """Verify SpecDecodeWorker collects metrics.
""" """
vocab_size = 32_000 vocab_size = 32_000
@ -360,15 +370,14 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
vocab_size=vocab_size, vocab_size=vocab_size,
use_spec=False) use_spec=False)
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False) target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
rejection_sampler = MagicMock(spec=RejectionSampler) spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector) metrics_collector = MagicMock(spec=AsyncMetricsCollector)
draft_worker.device = 'cuda' draft_worker.device = 'cuda'
target_worker.device = 'cuda' target_worker.device = 'cuda'
set_random_seed(1) set_random_seed(1)
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler,
metrics_collector) metrics_collector)
worker.init_device() worker.init_device()
@ -414,17 +423,16 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
target_worker.execute_model.return_value = [target_output[0]] target_worker.execute_model.return_value = [target_output[0]]
rejection_sampler_output = torch.randint(low=0, spec_decode_sampler_output = torch.randint(low=0,
high=vocab_size, high=vocab_size,
size=(batch_size, k + 1), size=(batch_size, k + 1),
dtype=torch.int64, dtype=torch.int64,
device='cuda') device='cuda')
for i in range(batch_size): for i in range(batch_size):
minimum_accepted_tokens = 1 minimum_accepted_tokens = 1
rejection_sampler_output[i][ spec_decode_sampler_output[i][
-random.randint(minimum_accepted_tokens, k + 1):] = -1 -random.randint(minimum_accepted_tokens, k + 1):] = -1
spec_decode_sampler.return_value = spec_decode_sampler_output
rejection_sampler.return_value = rejection_sampler_output
mock_rejsample_metrics = MagicMock( mock_rejsample_metrics = MagicMock(
spec=SpecDecodeWorkerMetrics) if returns_metrics else None spec=SpecDecodeWorkerMetrics) if returns_metrics else None
@ -445,15 +453,16 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
@pytest.mark.parametrize('k', [0]) @pytest.mark.parametrize('k', [0])
@pytest.mark.parametrize('batch_size', [1, 2, 32]) @pytest.mark.parametrize('batch_size', [1, 2, 32])
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@torch.inference_mode() @torch.inference_mode()
def test_k_equals_zero(k: int, batch_size: int): def test_k_equals_zero(k: int, batch_size: int,
acceptance_sampler_method: str):
"""Verify that the SpecDecodeWorker calls the draft and target workers """Verify that the SpecDecodeWorker calls the draft and target workers
when k is zero. This happens during prefill. when k is zero. This happens during prefill.
""" """
draft_worker = mock_worker(cls=MultiStepWorker) draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker() target_worker = mock_worker()
rejection_sampler = MagicMock(spec=RejectionSampler)
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector) metrics_collector = MagicMock(spec=AsyncMetricsCollector)
sampler_output = MagicMock(spec=SamplerOutput) sampler_output = MagicMock(spec=SamplerOutput)
@ -465,8 +474,9 @@ def test_k_equals_zero(k: int, batch_size: int):
set_random_seed(1) set_random_seed(1)
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, worker = SpecDecodeWorker(
metrics_collector) draft_worker, target_worker,
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
seq_group_metadata_list, _, _ = create_batch(batch_size, seq_group_metadata_list, _, _ = create_batch(batch_size,
k, k,
@ -487,16 +497,17 @@ def test_k_equals_zero(k: int, batch_size: int):
@pytest.mark.parametrize('k', [0, 5]) @pytest.mark.parametrize('k', [0, 5])
@pytest.mark.parametrize('batch_size', [0]) @pytest.mark.parametrize('batch_size', [0])
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@torch.inference_mode() @torch.inference_mode()
def test_empty_input_batch(k: int, batch_size: int): def test_empty_input_batch(k: int, batch_size: int,
acceptance_sampler_method: str):
"""Verify that the SpecDecodeWorker calls the draft and target workers """Verify that the SpecDecodeWorker calls the draft and target workers
when the input batch is empty. This can happen if the engine communicates when the input batch is empty. This can happen if the engine communicates
to the workers information without scheduling a batch. to the workers information without scheduling a batch.
""" """
draft_worker = mock_worker(cls=MultiStepWorker) draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker() target_worker = mock_worker()
rejection_sampler = MagicMock(spec=RejectionSampler)
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector) metrics_collector = MagicMock(spec=AsyncMetricsCollector)
sampler_output = MagicMock(spec=SamplerOutput) sampler_output = MagicMock(spec=SamplerOutput)
@ -508,8 +519,9 @@ def test_empty_input_batch(k: int, batch_size: int):
set_random_seed(1) set_random_seed(1)
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, worker = SpecDecodeWorker(
metrics_collector) draft_worker, target_worker,
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
seq_group_metadata_list, _, _ = create_batch(batch_size, seq_group_metadata_list, _, _ = create_batch(batch_size,
k, k,
@ -528,18 +540,19 @@ def test_empty_input_batch(k: int, batch_size: int):
target_worker.execute_model.assert_called_once_with(execute_model_req) target_worker.execute_model.assert_called_once_with(execute_model_req)
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
def test_init_device(): def test_init_device(acceptance_sampler_method: str):
"""Verify SpecDecodeWorker invokes proposer/scorer worker init_device, as """Verify SpecDecodeWorker invokes proposer/scorer worker init_device, as
well as other GPU initialization. well as other GPU initialization.
""" """
draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False) draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False)
target_worker = mock_worker(use_spec=False) target_worker = mock_worker(use_spec=False)
rejection_sampler = MagicMock(spec=RejectionSampler) spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector) metrics_collector = MagicMock(spec=AsyncMetricsCollector)
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler,
metrics_collector) metrics_collector)
worker.init_device() worker.init_device()
@ -549,22 +562,23 @@ def test_init_device():
target_worker.init_device.assert_called_once() target_worker.init_device.assert_called_once()
metrics_collector.init_gpu_tensors.assert_called_once() metrics_collector.init_gpu_tensors.assert_called_once()
rejection_sampler.init_gpu_tensors.assert_called_once() spec_decode_sampler.init_gpu_tensors.assert_called_once()
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@torch.inference_mode() @torch.inference_mode()
def test_initialize_cache(): def test_initialize_cache(acceptance_sampler_method):
"""Verify SpecDecodeWorker invokes initialize_cache on proposer/scorer """Verify SpecDecodeWorker invokes initialize_cache on proposer/scorer
workers. workers.
""" """
draft_worker = mock_worker(cls=MultiStepWorker) draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker() target_worker = mock_worker()
rejection_sampler = MagicMock(spec=RejectionSampler)
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector) metrics_collector = MagicMock(spec=AsyncMetricsCollector)
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, worker = SpecDecodeWorker(
metrics_collector) draft_worker, target_worker,
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
kwargs = {"num_gpu_blocks": 1024, "num_cpu_blocks": 1023} kwargs = {"num_gpu_blocks": 1024, "num_cpu_blocks": 1023}
worker.initialize_cache(**kwargs) worker.initialize_cache(**kwargs)
@ -577,19 +591,20 @@ def test_initialize_cache():
@pytest.mark.parametrize('available_cpu_blocks', [500]) @pytest.mark.parametrize('available_cpu_blocks', [500])
@pytest.mark.parametrize('target_cache_block_size_bytes', [2 * 2 * 4096]) @pytest.mark.parametrize('target_cache_block_size_bytes', [2 * 2 * 4096])
@pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096]) @pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096])
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
def test_determine_num_available_blocks(available_gpu_blocks: int, def test_determine_num_available_blocks(available_gpu_blocks: int,
available_cpu_blocks: int, available_cpu_blocks: int,
target_cache_block_size_bytes: int, target_cache_block_size_bytes: int,
draft_kv_size_bytes: int): draft_kv_size_bytes: int,
acceptance_sampler_method: str):
"""Verify SpecDecodeWorker correctly profiles num available GPU blocks. """Verify SpecDecodeWorker correctly profiles num available GPU blocks.
Specifically, it should run profiling in the scorer worker, and then evenly Specifically, it should run profiling in the scorer worker, and then evenly
split the blocks between proposer and scorer worker. split the blocks between proposer and scorer worker.
""" """
draft_worker = mock_worker(cls=MultiStepWorker) draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker() target_worker = mock_worker()
rejection_sampler = MagicMock(spec=RejectionSampler)
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector) metrics_collector = MagicMock(spec=AsyncMetricsCollector)
target_worker.determine_num_available_blocks.return_value = ( target_worker.determine_num_available_blocks.return_value = (
@ -598,8 +613,9 @@ def test_determine_num_available_blocks(available_gpu_blocks: int,
target_cache_block_size_bytes) target_cache_block_size_bytes)
draft_worker.get_cache_block_size_bytes.return_value = draft_kv_size_bytes draft_worker.get_cache_block_size_bytes.return_value = draft_kv_size_bytes
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, worker = SpecDecodeWorker(
metrics_collector) draft_worker, target_worker,
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
num_gpu_blocks, num_cpu_blocks = worker.determine_num_available_blocks() num_gpu_blocks, num_cpu_blocks = worker.determine_num_available_blocks()

View File

@ -1,7 +1,11 @@
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest import pytest
import torch
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.model_executor.layers.typical_acceptance_sampler import (
TypicalAcceptanceSampler)
from vllm.sequence import SequenceGroupMetadata, get_all_seq_ids from vllm.sequence import SequenceGroupMetadata, get_all_seq_ids
from vllm.spec_decode.util import split_batch_by_proposal_len from vllm.spec_decode.util import split_batch_by_proposal_len
@ -109,3 +113,21 @@ def test_all_non_zero_with_zero_filter(fake_sequence_group_metadata):
assert filtered_groups == [] assert filtered_groups == []
assert indices == [] assert indices == []
def mock_spec_decode_sampler(acceptance_sampler_method):
"""
Returns either a RejectionSampler or TypicalAcceptanceSampler
object depending on whether acceptance_sampler_method is
'rejection_sampler' or 'typical_acceptance_sampler' respectively.
"""
if acceptance_sampler_method == "rejection_sampler":
sampler = MagicMock(spec=RejectionSampler)
sampler.token_id_dtype = torch.int64
return sampler
elif acceptance_sampler_method == "typical_acceptance_sampler":
sampler = MagicMock(spec=TypicalAcceptanceSampler)
sampler.token_id_dtype = torch.int64
return sampler
else:
raise ValueError(f"Invalid sampler name {acceptance_sampler_method}")

View File

@ -753,7 +753,6 @@ class SchedulerConfig:
self.chunked_prefill_enabled = enable_chunked_prefill self.chunked_prefill_enabled = enable_chunked_prefill
self.embedding_mode = embedding_mode self.embedding_mode = embedding_mode
self.preemption_mode = preemption_mode self.preemption_mode = preemption_mode
self._verify_args() self._verify_args()
def _verify_args(self) -> None: def _verify_args(self) -> None:
@ -834,6 +833,9 @@ class SpeculativeConfig:
speculative_disable_by_batch_size: Optional[int], speculative_disable_by_batch_size: Optional[int],
ngram_prompt_lookup_max: Optional[int], ngram_prompt_lookup_max: Optional[int],
ngram_prompt_lookup_min: Optional[int], ngram_prompt_lookup_min: Optional[int],
draft_token_acceptance_method: str,
typical_acceptance_sampler_posterior_threshold: Optional[float],
typical_acceptance_sampler_posterior_alpha: Optional[float],
) -> Optional["SpeculativeConfig"]: ) -> Optional["SpeculativeConfig"]:
"""Create a SpeculativeConfig if possible, else return None. """Create a SpeculativeConfig if possible, else return None.
@ -870,6 +872,19 @@ class SpeculativeConfig:
window, if provided. window, if provided.
ngram_prompt_lookup_min (Optional[int]): Min size of ngram token ngram_prompt_lookup_min (Optional[int]): Min size of ngram token
window, if provided. window, if provided.
draft_token_acceptance_method (str): The method to use for
accepting draft tokens. This can take two possible
values 'rejection_sampler' and 'typical_acceptance_sampler'
for RejectionSampler and TypicalAcceptanceSampler
respectively.
typical_acceptance_sampler_posterior_threshold (Optional[float]):
A threshold value that sets a lower bound on the posterior
probability of a token in the target model for it to be
accepted. This threshold is used only when we use the
TypicalAcceptanceSampler for token acceptance.
typical_acceptance_sampler_posterior_alpha (Optional[float]):
A scaling factor for the entropy-based threshold in the
TypicalAcceptanceSampler.
Returns: Returns:
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
@ -984,6 +999,11 @@ class SpeculativeConfig:
"speculative_model unless the draft model config contains an " "speculative_model unless the draft model config contains an "
"n_predict parameter.") "n_predict parameter.")
if typical_acceptance_sampler_posterior_threshold is None:
typical_acceptance_sampler_posterior_threshold = 0.09
if typical_acceptance_sampler_posterior_alpha is None:
typical_acceptance_sampler_posterior_alpha = 0.3
return SpeculativeConfig( return SpeculativeConfig(
draft_model_config, draft_model_config,
draft_parallel_config, draft_parallel_config,
@ -991,6 +1011,11 @@ class SpeculativeConfig:
speculative_disable_by_batch_size, speculative_disable_by_batch_size,
ngram_prompt_lookup_max, ngram_prompt_lookup_max,
ngram_prompt_lookup_min, ngram_prompt_lookup_min,
draft_token_acceptance_method=draft_token_acceptance_method,
typical_acceptance_sampler_posterior_threshold=\
typical_acceptance_sampler_posterior_threshold,
typical_acceptance_sampler_posterior_alpha=\
typical_acceptance_sampler_posterior_alpha,
) )
@staticmethod @staticmethod
@ -1072,6 +1097,9 @@ class SpeculativeConfig:
speculative_disable_by_batch_size: Optional[int], speculative_disable_by_batch_size: Optional[int],
ngram_prompt_lookup_max: Optional[int], ngram_prompt_lookup_max: Optional[int],
ngram_prompt_lookup_min: Optional[int], ngram_prompt_lookup_min: Optional[int],
draft_token_acceptance_method: str,
typical_acceptance_sampler_posterior_threshold: float,
typical_acceptance_sampler_posterior_alpha: float,
): ):
"""Create a SpeculativeConfig object. """Create a SpeculativeConfig object.
@ -1085,6 +1113,19 @@ class SpeculativeConfig:
enqueue requests is larger than this value. enqueue requests is larger than this value.
ngram_prompt_lookup_max: Max size of ngram token window. ngram_prompt_lookup_max: Max size of ngram token window.
ngram_prompt_lookup_min: Min size of ngram token window. ngram_prompt_lookup_min: Min size of ngram token window.
draft_token_acceptance_method (str): The method to use for
accepting draft tokens. This can take two possible
values 'rejection_sampler' and 'typical_acceptance_sampler'
for RejectionSampler and TypicalAcceptanceSampler
respectively.
typical_acceptance_sampler_posterior_threshold (Optional[float]):
A threshold value that sets a lower bound on the posterior
probability of a token in the target model for it to be
accepted. This threshold is used only when we use the
TypicalAcceptanceSampler for token acceptance.
typical_acceptance_sampler_posterior_alpha (Optional[float]):
A scaling factor for the entropy-based threshold in the
TypicalAcceptanceSampler.
""" """
self.draft_model_config = draft_model_config self.draft_model_config = draft_model_config
self.draft_parallel_config = draft_parallel_config self.draft_parallel_config = draft_parallel_config
@ -1093,6 +1134,11 @@ class SpeculativeConfig:
speculative_disable_by_batch_size speculative_disable_by_batch_size
self.ngram_prompt_lookup_max = ngram_prompt_lookup_max or 0 self.ngram_prompt_lookup_max = ngram_prompt_lookup_max or 0
self.ngram_prompt_lookup_min = ngram_prompt_lookup_min or 0 self.ngram_prompt_lookup_min = ngram_prompt_lookup_min or 0
self.draft_token_acceptance_method = draft_token_acceptance_method
self.typical_acceptance_sampler_posterior_threshold = \
typical_acceptance_sampler_posterior_threshold
self.typical_acceptance_sampler_posterior_alpha = \
typical_acceptance_sampler_posterior_alpha
self._verify_args() self._verify_args()
@ -1104,6 +1150,31 @@ class SpeculativeConfig:
if self.draft_model_config: if self.draft_model_config:
self.draft_model_config.verify_with_parallel_config( self.draft_model_config.verify_with_parallel_config(
self.draft_parallel_config) self.draft_parallel_config)
# Validate and set draft token acceptance related settings.
if (self.draft_token_acceptance_method is None):
raise ValueError("draft_token_acceptance_method is not set. "
"Expected values are rejection_sampler or "
"typical_acceptance_sampler.")
if (self.draft_token_acceptance_method != 'rejection_sampler'
and self.draft_token_acceptance_method !=
'typical_acceptance_sampler'):
raise ValueError(
"Expected draft_token_acceptance_method to be either "
"rejection_sampler or typical_acceptance_sampler. Instead it "
f"is {self.draft_token_acceptance_method}")
if (self.typical_acceptance_sampler_posterior_threshold < 0
or self.typical_acceptance_sampler_posterior_alpha < 0):
raise ValueError(
"Expected typical_acceptance_sampler_posterior_threshold "
"and typical_acceptance_sampler_posterior_alpha to be > 0. "
"Instead found "
f"typical_acceptance_sampler_posterior_threshold = "
f"{self.typical_acceptance_sampler_posterior_threshold} and "
f"typical_acceptance_sampler_posterior_alpha = "
f"{self.typical_acceptance_sampler_posterior_alpha}")
@property @property
def num_lookahead_slots(self) -> int: def num_lookahead_slots(self) -> int:

View File

@ -100,7 +100,9 @@ class EngineArgs:
speculative_disable_by_batch_size: Optional[int] = None speculative_disable_by_batch_size: Optional[int] = None
ngram_prompt_lookup_max: Optional[int] = None ngram_prompt_lookup_max: Optional[int] = None
ngram_prompt_lookup_min: Optional[int] = None ngram_prompt_lookup_min: Optional[int] = None
spec_decoding_acceptance_method: str = 'rejection_sampler'
typical_acceptance_sampler_posterior_threshold: Optional[float] = None
typical_acceptance_sampler_posterior_alpha: Optional[float] = None
qlora_adapter_name_or_path: Optional[str] = None qlora_adapter_name_or_path: Optional[str] = None
otlp_traces_endpoint: Optional[str] = None otlp_traces_endpoint: Optional[str] = None
@ -577,6 +579,38 @@ class EngineArgs:
help='Min size of window for ngram prompt lookup in speculative ' help='Min size of window for ngram prompt lookup in speculative '
'decoding.') 'decoding.')
parser.add_argument(
'--spec-decoding-acceptance-method',
type=str,
default=EngineArgs.spec_decoding_acceptance_method,
choices=['rejection_sampler', 'typical_acceptance_sampler'],
help='Specify the acceptance method to use during draft token '
'verification in speculative decoding. Two types of acceptance '
'routines are supported: '
'1) RejectionSampler which does not allow changing the '
'acceptance rate of draft tokens, '
'2) TypicalAcceptanceSampler which is configurable, allowing for '
'a higher acceptance rate at the cost of lower quality, '
'and vice versa.')
parser.add_argument(
'--typical-acceptance-sampler-posterior-threshold',
type=float,
default=EngineArgs.typical_acceptance_sampler_posterior_threshold,
help='Set the lower bound threshold for the posterior '
'probability of a token to be accepted. This threshold is '
'used by the TypicalAcceptanceSampler to make sampling decisions '
'during speculative decoding. Defaults to 0.09')
parser.add_argument(
'--typical-acceptance-sampler-posterior-alpha',
type=float,
default=EngineArgs.typical_acceptance_sampler_posterior_alpha,
help='A scaling factor for the entropy-based threshold for token '
'acceptance in the TypicalAcceptanceSampler. Typically defaults '
'to sqrt of --typical-acceptance-sampler-posterior-threshold '
'i.e. 0.3')
parser.add_argument('--model-loader-extra-config', parser.add_argument('--model-loader-extra-config',
type=nullable_str, type=nullable_str,
default=EngineArgs.model_loader_extra_config, default=EngineArgs.model_loader_extra_config,
@ -737,6 +771,12 @@ class EngineArgs:
use_v2_block_manager=self.use_v2_block_manager, use_v2_block_manager=self.use_v2_block_manager,
ngram_prompt_lookup_max=self.ngram_prompt_lookup_max, ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
ngram_prompt_lookup_min=self.ngram_prompt_lookup_min, ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
draft_token_acceptance_method=\
self.spec_decoding_acceptance_method,
typical_acceptance_sampler_posterior_threshold=self.
typical_acceptance_sampler_posterior_threshold,
typical_acceptance_sampler_posterior_alpha=self.
typical_acceptance_sampler_posterior_alpha,
) )
scheduler_config = SchedulerConfig( scheduler_config = SchedulerConfig(

View File

@ -3,13 +3,12 @@ from typing import Tuple
import torch import torch
import torch.jit import torch.jit
import torch.nn as nn
from vllm.model_executor.layers.spec_decode_base_sampler import ( from vllm.model_executor.layers.spec_decode_base_sampler import (
SpecDecodeBaseSampler) SpecDecodeBaseSampler)
class RejectionSampler(SpecDecodeBaseSampler, nn.Module): class RejectionSampler(SpecDecodeBaseSampler):
"""Apply modified rejection sampling as described in "Accelerating Large """Apply modified rejection sampling as described in "Accelerating Large
Language Model Decoding with Speculative Sampling" Language Model Decoding with Speculative Sampling"
https://arxiv.org/pdf/2302.01318.pdf. https://arxiv.org/pdf/2302.01318.pdf.
@ -28,8 +27,8 @@ class RejectionSampler(SpecDecodeBaseSampler, nn.Module):
during sampling. This catches correctness issues but adds during sampling. This catches correctness issues but adds
nontrivial latency. nontrivial latency.
""" """
SpecDecodeBaseSampler.__init__(self, disable_bonus_tokens, strict_mode) super().__init__(disable_bonus_tokens=disable_bonus_tokens,
nn.Module.__init__(self) strict_mode=strict_mode)
def forward( def forward(
self, self,
@ -78,11 +77,12 @@ class RejectionSampler(SpecDecodeBaseSampler, nn.Module):
self._raise_if_incorrect_input(target_probs, bonus_token_ids, self._raise_if_incorrect_input(target_probs, bonus_token_ids,
draft_probs, draft_token_ids) draft_probs, draft_token_ids)
accepted, recovered_token_ids = self._batch_modified_rejection_sampling( accepted, recovered_token_ids = (
self._batch_modified_rejection_sampling(
target_probs, target_probs,
draft_probs, draft_probs,
draft_token_ids, draft_token_ids,
) ))
output_token_ids = self._create_output( output_token_ids = self._create_output(
accepted, accepted,

View File

@ -1,9 +1,12 @@
from abc import abstractmethod
from typing import Optional from typing import Optional
import torch import torch
import torch.jit
import torch.nn as nn
class SpecDecodeBaseSampler(): class SpecDecodeBaseSampler(nn.Module):
"""Base class for samplers used for Speculative Decoding verification """Base class for samplers used for Speculative Decoding verification
step. step.
""" """
@ -51,6 +54,16 @@ class SpecDecodeBaseSampler():
def token_id_dtype(self): def token_id_dtype(self):
return torch.int64 return torch.int64
@abstractmethod
def forward(
self,
target_probs: torch.Tensor,
bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
) -> torch.Tensor:
raise NotImplementedError
def _create_output( def _create_output(
self, self,
accepted: torch.Tensor, # [batch_size, k] accepted: torch.Tensor, # [batch_size, k]

View File

@ -1,12 +1,11 @@
import torch import torch
import torch.jit import torch.jit
import torch.nn as nn
from vllm.model_executor.layers.spec_decode_base_sampler import ( from vllm.model_executor.layers.spec_decode_base_sampler import (
SpecDecodeBaseSampler) SpecDecodeBaseSampler)
class TypicalAcceptanceSampler(SpecDecodeBaseSampler, nn.Module): class TypicalAcceptanceSampler(SpecDecodeBaseSampler):
"""Apply typical acceptance sampling as described in section 3.3.1 in """Apply typical acceptance sampling as described in section 3.3.1 in
"MEDUSA: Simple LLM Inference Acceleration Framework with "MEDUSA: Simple LLM Inference Acceleration Framework with
Multiple Decoding Heads" Multiple Decoding Heads"
@ -15,10 +14,10 @@ class TypicalAcceptanceSampler(SpecDecodeBaseSampler, nn.Module):
def __init__( def __init__(
self, self,
posterior_threshold: float,
posterior_alpha: float,
disable_bonus_tokens: bool = False, disable_bonus_tokens: bool = False,
strict_mode: bool = False, strict_mode: bool = False,
posterior_threshold: float = 0.09,
posterior_alpha: float = 0.3,
): ):
"""Create a Typical Acceptance Sampler. """Create a Typical Acceptance Sampler.
@ -31,23 +30,20 @@ class TypicalAcceptanceSampler(SpecDecodeBaseSampler, nn.Module):
nontrivial latency. nontrivial latency.
posterior_threshold : A threshold value that sets a lower bound posterior_threshold : A threshold value that sets a lower bound
on the posterior probability of a token in target model for it on the posterior probability of a token in target model for it
to be accepted. Default is 0.09 to be accepted.
posterior_alpha : A scaling factor for the entropy-based posterior_alpha : A scaling factor for the entropy-based
threshold in typical acceptance sampling. Typically defaults to threshold in typical acceptance sampling.
sqrt of posterior_threshold and is set to 0.3.
""" """
SpecDecodeBaseSampler.__init__(
self,
disable_bonus_tokens=disable_bonus_tokens,
strict_mode=strict_mode)
nn.Module.__init__(self)
self._posterior_threshold = posterior_threshold self._posterior_threshold = posterior_threshold
self._posterior_alpha = posterior_alpha self._posterior_alpha = posterior_alpha
super().__init__(disable_bonus_tokens=disable_bonus_tokens,
strict_mode=strict_mode)
def forward( def forward(
self, self,
target_probs: torch.Tensor, target_probs: torch.Tensor,
bonus_token_ids: torch.Tensor, bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor, draft_token_ids: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
"""Sample token ids using typical acceptance sampling. This accepts """Sample token ids using typical acceptance sampling. This accepts
@ -69,6 +65,8 @@ class TypicalAcceptanceSampler(SpecDecodeBaseSampler, nn.Module):
speculative tokens in a sequence are accepted. speculative tokens in a sequence are accepted.
shape = [batch_size, num_bonus_tokens] shape = [batch_size, num_bonus_tokens]
draft_probs: This parameter is unused by the acceptance sampler.
draft_token_ids: The token ids that were sampled from the draft draft_token_ids: The token ids that were sampled from the draft
probabilities. probabilities.
shape = [batch_size, num_speculative_tokens] shape = [batch_size, num_speculative_tokens]

View File

@ -4,7 +4,8 @@ from typing import Callable, Optional
import torch import torch
from vllm.model_executor.layers.rejection_sampler import RejectionSampler from vllm.model_executor.layers.spec_decode_base_sampler import (
SpecDecodeBaseSampler)
from vllm.utils import is_pin_memory_available from vllm.utils import is_pin_memory_available
@ -46,15 +47,15 @@ Timer = Callable[[], float]
class AsyncMetricsCollector: class AsyncMetricsCollector:
"""Class which copies rejection sampler metrics from the device to CPU on a """Class which copies rejection/typical-acceptance sampler metrics
non-default Torch stream. from the device to CPU on a non-default Torch stream.
""" """
def __init__(self, def __init__(self,
rejection_sampler: RejectionSampler, spec_decode_sampler: SpecDecodeBaseSampler,
timer: Optional[Timer] = None, timer: Optional[Timer] = None,
collect_interval_s: float = 5.0): collect_interval_s: float = 5.0):
self._rejection_sampler = rejection_sampler self.spec_decode_sampler = spec_decode_sampler
self._timer = time.time if timer is None else timer self._timer = time.time if timer is None else timer
self._rank: Optional[int] = None self._rank: Optional[int] = None
@ -95,7 +96,7 @@ class AsyncMetricsCollector:
return None return None
def _should_collect_rejsample_metrics(self, now: float) -> bool: def _should_collect_rejsample_metrics(self, now: float) -> bool:
"""Return whether or not this iteration should print rejection sampling """Return whether or not this iteration should print sampling
metrics. metrics.
""" """
if self._rank != 0: if self._rank != 0:
@ -107,8 +108,8 @@ class AsyncMetricsCollector:
return True return True
def _copy_rejsample_metrics_async(self) -> torch.cuda.Event: def _copy_rejsample_metrics_async(self) -> torch.cuda.Event:
"""Copy rejection sampling metrics (number of accepted tokens, etc) to """Copy rejection/typical-acceptance sampling metrics
CPU asynchronously. (number of accepted tokens, etc) to CPU asynchronously.
Returns a CUDA event recording when the copy is complete. Returns a CUDA event recording when the copy is complete.
""" """
@ -117,13 +118,14 @@ class AsyncMetricsCollector:
with torch.cuda.stream(self._copy_stream): with torch.cuda.stream(self._copy_stream):
self._aggregate_num_accepted_tokens.copy_( self._aggregate_num_accepted_tokens.copy_(
self._rejection_sampler.num_accepted_tokens, non_blocking=True) self.spec_decode_sampler.num_accepted_tokens,
non_blocking=True)
self._aggregate_num_emitted_tokens.copy_( self._aggregate_num_emitted_tokens.copy_(
self._rejection_sampler.num_emitted_tokens, non_blocking=True) self.spec_decode_sampler.num_emitted_tokens, non_blocking=True)
# Number of draft tokens is calculated on CPU, so no copy is # Number of draft tokens is calculated on CPU, so no copy is
# required. # required.
self._aggregate_num_draft_tokens = ( self._aggregate_num_draft_tokens = (
self._rejection_sampler.num_draft_tokens) self.spec_decode_sampler.num_draft_tokens)
aggregate_metrics_ready = torch.cuda.Event() aggregate_metrics_ready = torch.cuda.Event()
aggregate_metrics_ready.record(self._copy_stream) aggregate_metrics_ready.record(self._copy_stream)

View File

@ -7,6 +7,10 @@ from vllm.config import ParallelConfig, SpeculativeConfig
from vllm.distributed.communication_op import broadcast_tensor_dict from vllm.distributed.communication_op import broadcast_tensor_dict
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.rejection_sampler import RejectionSampler from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.model_executor.layers.spec_decode_base_sampler import (
SpecDecodeBaseSampler)
from vllm.model_executor.layers.typical_acceptance_sampler import (
TypicalAcceptanceSampler)
from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest, from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest,
HiddenStates, SamplerOutput, SequenceGroupMetadata, HiddenStates, SamplerOutput, SequenceGroupMetadata,
get_all_seq_ids) get_all_seq_ids)
@ -56,7 +60,12 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
draft_worker_kwargs=draft_worker_kwargs, draft_worker_kwargs=draft_worker_kwargs,
disable_by_batch_size=speculative_config. disable_by_batch_size=speculative_config.
speculative_disable_by_batch_size, speculative_disable_by_batch_size,
) draft_token_acceptance_method=speculative_config.
draft_token_acceptance_method,
typical_acceptance_sampler_posterior_threshold=speculative_config.
typical_acceptance_sampler_posterior_threshold,
typical_acceptance_sampler_posterior_alpha=speculative_config.
typical_acceptance_sampler_posterior_alpha)
return spec_decode_worker return spec_decode_worker
@ -78,8 +87,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
welcome!). welcome!).
* Only top-1 proposal and scoring are implemented. Tree-attention is left as * Only top-1 proposal and scoring are implemented. Tree-attention is left as
future work. future work.
* Only lossless rejection sampling is supported. Contributions adding lossy
verification routines are welcome (e.g. Medusa's typical acceptance).
* All sequences in a batch must have the same proposal length, or zero. This * All sequences in a batch must have the same proposal length, or zero. This
can be improved by having per-sequence speculation in the future. can be improved by having per-sequence speculation in the future.
* The scoring forward pass is done without an MQA kernel, which is * The scoring forward pass is done without an MQA kernel, which is
@ -95,6 +102,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
scorer_worker: Worker, scorer_worker: Worker,
draft_worker_kwargs: Dict[str, Any], draft_worker_kwargs: Dict[str, Any],
disable_by_batch_size: Optional[int], disable_by_batch_size: Optional[int],
draft_token_acceptance_method: str,
typical_acceptance_sampler_posterior_threshold: float,
typical_acceptance_sampler_posterior_alpha: float,
) -> "SpecDecodeWorker": ) -> "SpecDecodeWorker":
ngram_prompt_lookup_max = ( ngram_prompt_lookup_max = (
@ -127,17 +137,30 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
logger.info("Configuring SpecDecodeWorker with proposer=%s", logger.info("Configuring SpecDecodeWorker with proposer=%s",
type(proposer_worker)) type(proposer_worker))
spec_decode_sampler: SpecDecodeBaseSampler = None
if draft_token_acceptance_method == "rejection_sampler":
spec_decode_sampler = RejectionSampler(
disable_bonus_tokens=disable_bonus_tokens, )
elif draft_token_acceptance_method == "typical_acceptance_sampler":
spec_decode_sampler = TypicalAcceptanceSampler(
disable_bonus_tokens=disable_bonus_tokens,
posterior_threshold=\
typical_acceptance_sampler_posterior_threshold,
posterior_alpha=typical_acceptance_sampler_posterior_alpha,
)
logger.info("Configuring SpecDecodeWorker with sampler=%s",
type(spec_decode_sampler))
return SpecDecodeWorker(proposer_worker, return SpecDecodeWorker(proposer_worker,
scorer_worker, scorer_worker,
disable_by_batch_size=disable_by_batch_size, disable_by_batch_size=disable_by_batch_size,
rejection_sampler=RejectionSampler( spec_decode_sampler=spec_decode_sampler)
disable_bonus_tokens=disable_bonus_tokens))
def __init__( def __init__(
self, self,
proposer_worker: ProposerWorkerBase, proposer_worker: ProposerWorkerBase,
scorer_worker: WorkerBase, scorer_worker: WorkerBase,
rejection_sampler: RejectionSampler, spec_decode_sampler: SpecDecodeBaseSampler,
metrics_collector: Optional[AsyncMetricsCollector] = None, metrics_collector: Optional[AsyncMetricsCollector] = None,
disable_by_batch_size: Optional[int] = None, disable_by_batch_size: Optional[int] = None,
): ):
@ -150,8 +173,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
scorer_worker: A worker that produces probabilities of speculative scorer_worker: A worker that produces probabilities of speculative
tokens according to some base model. Typically a vanilla vLLM tokens according to some base model. Typically a vanilla vLLM
Worker. Worker.
rejection_sampler: A Torch module used to perform modified rejection spec_decode_sampler: A Torch module used to perform acceptance
sampling for speculative decoding. sampling of the draft tokens in the verification step of
speculative decoding. Currently we support two different
types of sampler namely RejectionSampler and
TypicalAcceptanceSampler. 'spec_decode_sampler' is either an
instance of RejectionSampler or TypicalAcceptanceSampler.
disable_by_batch_size: If the batch size is larger than this, disable_by_batch_size: If the batch size is larger than this,
disable speculative decoding for new incoming requests. disable speculative decoding for new incoming requests.
metrics_collector: Helper class for collecting metrics; can be set metrics_collector: Helper class for collecting metrics; can be set
@ -160,15 +187,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self.proposer_worker = proposer_worker self.proposer_worker = proposer_worker
self.scorer_worker = scorer_worker self.scorer_worker = scorer_worker
self.disable_by_batch_size = disable_by_batch_size or float("inf") self.disable_by_batch_size = disable_by_batch_size or float("inf")
self.rejection_sampler = rejection_sampler self.spec_decode_sampler = spec_decode_sampler
self._metrics = AsyncMetricsCollector( self._metrics = AsyncMetricsCollector(
rejection_sampler self.spec_decode_sampler
) if metrics_collector is None else metrics_collector ) if metrics_collector is None else metrics_collector
self.probs_dtype = self.spec_decode_sampler.probs_dtype
self.probs_dtype = self.rejection_sampler.probs_dtype self.token_id_dtype = self.spec_decode_sampler.token_id_dtype
self.token_id_dtype = self.rejection_sampler.token_id_dtype
# Lazy initiazliation. # Lazy initiazliation.
self.scorer: SpeculativeScorer self.scorer: SpeculativeScorer
@ -189,7 +213,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self.proposer_worker.load_model() self.proposer_worker.load_model()
self._metrics.init_gpu_tensors(self.rank) self._metrics.init_gpu_tensors(self.rank)
self.rejection_sampler.init_gpu_tensors(self.rank) self.spec_decode_sampler.init_gpu_tensors(self.rank)
self.scorer = BatchExpansionTop1Scorer( self.scorer = BatchExpansionTop1Scorer(
scorer_worker=self.scorer_worker, scorer_worker=self.scorer_worker,
device=self.device, device=self.device,
@ -203,7 +228,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
def _configure_model_sampler_for_spec_decode(self): def _configure_model_sampler_for_spec_decode(self):
"""Configure model sampler to emit GPU tensors. This allows spec decode """Configure model sampler to emit GPU tensors. This allows spec decode
to keep data on device without transferring to CPU and serializing, to keep data on device without transferring to CPU and serializing,
which significantly reduces overhead of rejection sampling. which significantly reduces overhead of sampling during verification.
NOTE(cade): This breaks abstraction boundaries pretty badly. The better NOTE(cade): This breaks abstraction boundaries pretty badly. The better
design is to have the "move to CPU and serialize" sampling decision be design is to have the "move to CPU and serialize" sampling decision be
@ -481,7 +506,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# Get proposed tokens. # Get proposed tokens.
proposal_token_ids = proposals.proposal_token_ids[spec_indices] proposal_token_ids = proposals.proposal_token_ids[spec_indices]
accepted_token_ids = self.rejection_sampler( accepted_token_ids = self.spec_decode_sampler(
target_probs=proposal_verifier_probs, target_probs=proposal_verifier_probs,
bonus_token_ids=bonus_token_ids, bonus_token_ids=bonus_token_ids,
draft_probs=proposal_probs, draft_probs=proposal_probs,
@ -496,7 +521,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
accepted_token_ids = torch.cat( accepted_token_ids = torch.cat(
[accepted_token_ids, non_spec_token_ids]) [accepted_token_ids, non_spec_token_ids])
logprobs = proposal_scores.logprobs logprobs = proposal_scores.logprobs
# Rearrange so that results are in the order of the original seq group # Rearrange so that results are in the order of the original seq group
# metadata. # metadata.
accepted_token_ids[original_indices] = accepted_token_ids.clone() accepted_token_ids[original_indices] = accepted_token_ids.clone()