[Speculative Decoding 2/2 ] Integrate typical acceptance sampler into Spec Decode Worker (#5348)
This commit is contained in:
parent
614aa51203
commit
80ca1e6a3a
@ -52,6 +52,19 @@ def get_draft_token_ids(batch_size: int, k: int, vocab_size: int,
|
||||
return draft_token_ids
|
||||
|
||||
|
||||
def get_acceptance_sampler(
|
||||
posterior_threshold: float = 0.03,
|
||||
posterior_alpha: float = 0.9,
|
||||
disable_bonus_tokens: bool = False,
|
||||
strict_mode: bool = False,
|
||||
) -> TypicalAcceptanceSampler:
|
||||
"""
|
||||
Initializes and returns a TypicalAcceptanceSampler.
|
||||
"""
|
||||
return TypicalAcceptanceSampler(posterior_threshold, posterior_alpha,
|
||||
disable_bonus_tokens, strict_mode)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("k", list(range(1, 6)))
|
||||
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
|
||||
@pytest.mark.parametrize("batch_size", list(range(1, 32)))
|
||||
@ -64,7 +77,7 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
|
||||
different combinations of k, vocab_size, batch_size and num devices.
|
||||
"""
|
||||
torch.set_default_device(device)
|
||||
typical_acceptance_sampler = TypicalAcceptanceSampler()
|
||||
typical_acceptance_sampler = get_acceptance_sampler()
|
||||
typical_acceptance_sampler.init_gpu_tensors(rank=0)
|
||||
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
@ -76,7 +89,10 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
|
||||
size=(batch_size, k),
|
||||
dtype=torch.int64)
|
||||
# Verify that sampling succeeds for all cases.
|
||||
typical_acceptance_sampler(target_probs, bonus_token_ids, draft_token_ids)
|
||||
typical_acceptance_sampler(target_probs,
|
||||
bonus_token_ids,
|
||||
draft_probs=None,
|
||||
draft_token_ids=draft_token_ids)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("above_or_below_vocab_range", ["above", "below"])
|
||||
@ -94,7 +110,7 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
|
||||
batch_size = 5
|
||||
vocab_size = 30_000
|
||||
torch.set_default_device(device)
|
||||
typical_acceptance_sampler = TypicalAcceptanceSampler(strict_mode=True)
|
||||
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
|
||||
typical_acceptance_sampler.init_gpu_tensors(rank=0)
|
||||
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
@ -125,8 +141,10 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
|
||||
oob_token_ids[0][0] = rogue_token_id
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
typical_acceptance_sampler(target_probs, bonus_token_ids,
|
||||
draft_token_ids)
|
||||
typical_acceptance_sampler(target_probs,
|
||||
bonus_token_ids,
|
||||
draft_probs=None,
|
||||
draft_token_ids=draft_token_ids)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", list(range(10)))
|
||||
@ -151,7 +169,7 @@ def test_uniform_target_distribution_accepts_all_tokens(
|
||||
batch_size = 5
|
||||
vocab_size = 30_000
|
||||
torch.set_default_device(device)
|
||||
typical_acceptance_sampler = TypicalAcceptanceSampler(
|
||||
typical_acceptance_sampler = get_acceptance_sampler(
|
||||
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
|
||||
typical_acceptance_sampler.init_gpu_tensors(rank=0)
|
||||
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
@ -163,9 +181,11 @@ def test_uniform_target_distribution_accepts_all_tokens(
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
dtype=torch.int64)
|
||||
output_token_ids = typical_acceptance_sampler(target_probs,
|
||||
bonus_token_ids,
|
||||
draft_token_ids)
|
||||
output_token_ids = typical_acceptance_sampler(
|
||||
target_probs,
|
||||
bonus_token_ids,
|
||||
draft_probs=None,
|
||||
draft_token_ids=draft_token_ids)
|
||||
# We are using a uniform target probability distribution.
|
||||
# For a uniform distribution the entropy is very high and it
|
||||
# should lead to all draft tokens being accepted. Verify that.
|
||||
@ -203,7 +223,7 @@ def test_temperature_zero_target_distribution(seed: int,
|
||||
vocab_size = 30_000
|
||||
torch.set_default_device(device)
|
||||
|
||||
typical_acceptance_sampler = TypicalAcceptanceSampler(
|
||||
typical_acceptance_sampler = get_acceptance_sampler(
|
||||
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
|
||||
typical_acceptance_sampler.init_gpu_tensors(rank=0)
|
||||
# Simulate temperature 0 probability distribution for target probabilities
|
||||
@ -224,9 +244,11 @@ def test_temperature_zero_target_distribution(seed: int,
|
||||
# 1.0 tokens in the target distribution we will reject all of them and
|
||||
# fallback to the greedy sampling for selecting 1 token for each sequence.
|
||||
# Verify the same.
|
||||
output_token_ids = typical_acceptance_sampler(target_probs,
|
||||
bonus_token_ids,
|
||||
draft_token_ids)
|
||||
output_token_ids = typical_acceptance_sampler(
|
||||
target_probs,
|
||||
bonus_token_ids,
|
||||
draft_probs=None,
|
||||
draft_token_ids=draft_token_ids)
|
||||
assert output_token_ids.shape[0] == batch_size
|
||||
assert output_token_ids.shape[1] == (k + 1)
|
||||
assert torch.all(output_token_ids[:, -1] == -1)
|
||||
@ -261,7 +283,7 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
|
||||
batch_size = 4
|
||||
vocab_size = 30_000
|
||||
torch.set_default_device(device)
|
||||
typical_acceptance_sampler = TypicalAcceptanceSampler(
|
||||
typical_acceptance_sampler = get_acceptance_sampler(
|
||||
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
|
||||
typical_acceptance_sampler.init_gpu_tensors(rank=0)
|
||||
# For sequences 0 and 2 set the distribution to a temperature
|
||||
@ -277,9 +299,11 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
dtype=torch.int64)
|
||||
output_token_ids = typical_acceptance_sampler(target_probs,
|
||||
bonus_token_ids,
|
||||
draft_token_ids)
|
||||
output_token_ids = typical_acceptance_sampler(
|
||||
target_probs,
|
||||
bonus_token_ids,
|
||||
draft_probs=None,
|
||||
draft_token_ids=draft_token_ids)
|
||||
# verify the shape of output_token_ids
|
||||
assert output_token_ids.shape[0] == batch_size
|
||||
assert output_token_ids.shape[1] == (k + 1)
|
||||
@ -326,7 +350,7 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
|
||||
batch_size = 1
|
||||
vocab_size = 30_000
|
||||
torch.set_default_device(device)
|
||||
typical_acceptance_sampler = TypicalAcceptanceSampler(
|
||||
typical_acceptance_sampler = get_acceptance_sampler(
|
||||
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
|
||||
typical_acceptance_sampler.init_gpu_tensors(rank=0)
|
||||
# Create a temperature zero target probability distribution and ensure
|
||||
@ -339,9 +363,11 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
dtype=torch.int64)
|
||||
output_token_ids = typical_acceptance_sampler(target_probs,
|
||||
bonus_token_ids,
|
||||
draft_token_ids)
|
||||
output_token_ids = typical_acceptance_sampler(
|
||||
target_probs,
|
||||
bonus_token_ids,
|
||||
draft_probs=None,
|
||||
draft_token_ids=draft_token_ids)
|
||||
assert output_token_ids.shape[0] == batch_size
|
||||
assert output_token_ids.shape[1] == (k + 1)
|
||||
assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids)
|
||||
@ -357,9 +383,11 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
|
||||
batch_size, k, vocab_size, zero_temperature_token_ids)
|
||||
draft_token_ids = torch.cat(
|
||||
(draft_token_ids[:, :2], draft_token_ids_to_replace[:, -3:]), dim=1)
|
||||
output_token_ids = typical_acceptance_sampler(target_probs,
|
||||
bonus_token_ids,
|
||||
draft_token_ids)
|
||||
output_token_ids = typical_acceptance_sampler(
|
||||
target_probs,
|
||||
bonus_token_ids,
|
||||
draft_probs=None,
|
||||
draft_token_ids=draft_token_ids)
|
||||
assert output_token_ids.shape[0] == batch_size
|
||||
assert output_token_ids.shape[1] == (k + 1)
|
||||
assert torch.all(output_token_ids[:, :2] == draft_token_ids[:, :2])
|
||||
@ -384,7 +412,7 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
|
||||
batch_size = 1
|
||||
vocab_size = 30_000
|
||||
torch.set_default_device(device)
|
||||
typical_acceptance_sampler = TypicalAcceptanceSampler(
|
||||
typical_acceptance_sampler = get_acceptance_sampler(
|
||||
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
|
||||
typical_acceptance_sampler.init_gpu_tensors(rank=0)
|
||||
# Simulate temperature 0 probability distribution for target
|
||||
@ -402,9 +430,11 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
dtype=torch.int64)
|
||||
output_token_ids = typical_acceptance_sampler(target_probs,
|
||||
bonus_token_ids,
|
||||
draft_token_ids)
|
||||
output_token_ids = typical_acceptance_sampler(
|
||||
target_probs,
|
||||
bonus_token_ids,
|
||||
draft_probs=None,
|
||||
draft_token_ids=draft_token_ids)
|
||||
assert output_token_ids.shape[0] == batch_size
|
||||
assert output_token_ids.shape[1] == (k + 1)
|
||||
assert torch.all(output_token_ids[:, 1:-1] == -1)
|
||||
@ -418,9 +448,11 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
|
||||
posterior_threshold=0.0,
|
||||
posterior_alpha=0.0)
|
||||
typical_acceptance_sampler.init_gpu_tensors(rank=0)
|
||||
output_token_ids = typical_acceptance_sampler(target_probs,
|
||||
bonus_token_ids,
|
||||
draft_token_ids)
|
||||
output_token_ids = typical_acceptance_sampler(
|
||||
target_probs,
|
||||
bonus_token_ids,
|
||||
draft_probs=None,
|
||||
draft_token_ids=draft_token_ids)
|
||||
assert output_token_ids.shape[0] == batch_size
|
||||
assert output_token_ids.shape[1] == (k + 1)
|
||||
assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids)
|
||||
@ -451,7 +483,7 @@ def test_replacement_token_ids(seed: int, disable_bonus_tokens: bool,
|
||||
batch_size = 5
|
||||
vocab_size = 30_000
|
||||
torch.set_default_device(device)
|
||||
typical_acceptance_sampler = TypicalAcceptanceSampler(
|
||||
typical_acceptance_sampler = get_acceptance_sampler(
|
||||
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
|
||||
typical_acceptance_sampler.init_gpu_tensors(rank=0)
|
||||
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
|
@ -11,9 +11,15 @@ distribution matches the target model's output distribution (up to hardware
|
||||
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
|
||||
equality. This gives us good coverage of temp=0.
|
||||
|
||||
At temp=0, the TypicalAcceptanceSampler ensures that only the tokens with the
|
||||
highest probability in the target distribution are accepted. Therefore, we can
|
||||
expect greedy equality for the TypicalAcceptanceSampler at temp=0.
|
||||
|
||||
For temp>0, we rely on unit tests on the rejection sampler to verify that the
|
||||
output distribution is the same with spec decode vs. no spec decode (this would
|
||||
be prohibitively expensive to run with a real model).
|
||||
be prohibitively expensive to run with a real model). Similarly, for the
|
||||
TypicalAcceptance sampler also, we rely on unit tests to validate temp>0
|
||||
test cases.
|
||||
|
||||
NOTE: Speculative decoding's distribution equality requires that the measured
|
||||
distributions of the target model and proposal model be deterministic given the
|
||||
@ -611,3 +617,49 @@ def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model": "JackFram/llama-160m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize(
|
||||
"test_llm_kwargs",
|
||||
[
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": k,
|
||||
"spec_decoding_acceptance_method": "typical_acceptance_sampler"
|
||||
}
|
||||
# Try a range of common k.
|
||||
for k in [1, 2, 3]
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_typical_acceptance_sampling(baseline_llm_generator,
|
||||
test_llm_generator, batch_size: int,
|
||||
output_len: int):
|
||||
"""Verify that speculative decoding produces exact equality to without spec
|
||||
decode with TypicalAcceptanceSampler as the draft token acceptance
|
||||
sampling method.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
|
@ -3,33 +3,35 @@ from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.spec_decode.metrics import AsyncMetricsCollector
|
||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker
|
||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||
|
||||
from .test_utils import mock_spec_decode_sampler
|
||||
from .utils import create_batch, mock_worker
|
||||
|
||||
|
||||
@pytest.mark.parametrize('queue_size', [4])
|
||||
@pytest.mark.parametrize('batch_size', [1])
|
||||
@pytest.mark.parametrize('k', [1])
|
||||
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||
["rejection_sampler", "typical_acceptance_sampler"])
|
||||
@torch.inference_mode()
|
||||
def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int):
|
||||
def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int,
|
||||
acceptance_sampler_method: str):
|
||||
"""Verify that speculative tokens are disabled when the batch size
|
||||
exceeds the threshold.
|
||||
"""
|
||||
disable_by_batch_size = 3
|
||||
|
||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||
target_worker = mock_worker()
|
||||
rejection_sampler = MagicMock(spec=RejectionSampler)
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
worker = SpecDecodeWorker(proposer_worker=draft_worker,
|
||||
scorer_worker=target_worker,
|
||||
rejection_sampler=rejection_sampler,
|
||||
spec_decode_sampler=mock_spec_decode_sampler(
|
||||
acceptance_sampler_method),
|
||||
metrics_collector=metrics_collector,
|
||||
disable_by_batch_size=disable_by_batch_size)
|
||||
|
||||
|
@ -10,16 +10,16 @@ from vllm.spec_decode.metrics import AsyncMetricsCollector
|
||||
def test_initial_call_returns_none():
|
||||
"""Expect first call to get metrics to return None.
|
||||
"""
|
||||
rej_sampler = MagicMock()
|
||||
rej_sampler.num_accepted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
rej_sampler.num_emitted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
rej_sampler.num_draft_tokens = 0
|
||||
spec_decode_sampler = MagicMock()
|
||||
spec_decode_sampler.num_accepted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
spec_decode_sampler.num_emitted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
spec_decode_sampler.num_draft_tokens = 0
|
||||
|
||||
collector = AsyncMetricsCollector(rej_sampler)
|
||||
collector = AsyncMetricsCollector(spec_decode_sampler)
|
||||
collector.init_gpu_tensors(rank=0)
|
||||
maybe_metrics = collector.maybe_collect_rejsample_metrics(k=5)
|
||||
assert maybe_metrics is None
|
||||
@ -28,14 +28,14 @@ def test_initial_call_returns_none():
|
||||
def test_second_call_returns_metrics():
|
||||
"""Expect second call to not return None.
|
||||
"""
|
||||
rej_sampler = MagicMock()
|
||||
rej_sampler.num_accepted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
rej_sampler.num_emitted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
rej_sampler.num_draft_tokens = 0
|
||||
spec_decode_sampler = MagicMock()
|
||||
spec_decode_sampler.num_accepted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
spec_decode_sampler.num_emitted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
spec_decode_sampler.num_draft_tokens = 0
|
||||
|
||||
collect_interval_s = 5.0
|
||||
timer = MagicMock()
|
||||
@ -43,7 +43,7 @@ def test_second_call_returns_metrics():
|
||||
0.0, collect_interval_s + 0.1, collect_interval_s + 0.2
|
||||
]
|
||||
|
||||
collector = AsyncMetricsCollector(rejection_sampler=rej_sampler,
|
||||
collector = AsyncMetricsCollector(spec_decode_sampler=spec_decode_sampler,
|
||||
timer=timer,
|
||||
collect_interval_s=collect_interval_s)
|
||||
collector.init_gpu_tensors(rank=0)
|
||||
@ -56,16 +56,16 @@ def test_second_call_returns_metrics():
|
||||
def test_nonzero_rank_noop(rank):
|
||||
"""Verify nonzero ranks don't collect metrics.
|
||||
"""
|
||||
rej_sampler = MagicMock()
|
||||
rej_sampler.num_accepted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
rej_sampler.num_emitted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
rej_sampler.num_draft_tokens = 0
|
||||
spec_decode_sampler = MagicMock()
|
||||
spec_decode_sampler.num_accepted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
spec_decode_sampler.num_emitted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
spec_decode_sampler.num_draft_tokens = 0
|
||||
|
||||
collector = AsyncMetricsCollector(rej_sampler)
|
||||
collector = AsyncMetricsCollector(spec_decode_sampler)
|
||||
collector.init_gpu_tensors(rank=rank)
|
||||
_ = collector.maybe_collect_rejsample_metrics(k=5)
|
||||
metrics = collector.maybe_collect_rejsample_metrics(k=5)
|
||||
@ -75,14 +75,14 @@ def test_nonzero_rank_noop(rank):
|
||||
def test_noop_until_time():
|
||||
"""Verify metrics aren't collected until enough time passes.
|
||||
"""
|
||||
rej_sampler = MagicMock()
|
||||
rej_sampler.num_accepted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
rej_sampler.num_emitted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
rej_sampler.num_draft_tokens = 0
|
||||
spec_decode_sampler = MagicMock()
|
||||
spec_decode_sampler.num_accepted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
spec_decode_sampler.num_emitted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
spec_decode_sampler.num_draft_tokens = 0
|
||||
|
||||
collect_interval_s = 5.0
|
||||
timer = MagicMock()
|
||||
@ -91,7 +91,7 @@ def test_noop_until_time():
|
||||
collect_interval_s + 0.1, collect_interval_s + 0.1
|
||||
]
|
||||
|
||||
collector = AsyncMetricsCollector(rejection_sampler=rej_sampler,
|
||||
collector = AsyncMetricsCollector(spec_decode_sampler=spec_decode_sampler,
|
||||
timer=timer,
|
||||
collect_interval_s=collect_interval_s)
|
||||
collector.init_gpu_tensors(rank=0)
|
||||
@ -122,14 +122,14 @@ def test_initial_metrics_has_correct_values(has_data: bool):
|
||||
max_num_emitted_tokens = AsyncMetricsCollector.get_max_num_emitted_tokens(
|
||||
num_draft_tokens, k)
|
||||
|
||||
rej_sampler = MagicMock()
|
||||
rej_sampler.num_accepted_tokens = torch.tensor(num_accepted_tokens,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
rej_sampler.num_emitted_tokens = torch.tensor(num_emitted_tokens,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
rej_sampler.num_draft_tokens = num_draft_tokens
|
||||
spec_decode_sampler = MagicMock()
|
||||
spec_decode_sampler.num_accepted_tokens = torch.tensor(num_accepted_tokens,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
spec_decode_sampler.num_emitted_tokens = torch.tensor(num_emitted_tokens,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
spec_decode_sampler.num_draft_tokens = num_draft_tokens
|
||||
|
||||
collect_interval_s = 5.0
|
||||
timer = MagicMock()
|
||||
@ -137,7 +137,7 @@ def test_initial_metrics_has_correct_values(has_data: bool):
|
||||
0.0, collect_interval_s + 0.1, collect_interval_s + 0.2
|
||||
]
|
||||
|
||||
collector = AsyncMetricsCollector(rejection_sampler=rej_sampler,
|
||||
collector = AsyncMetricsCollector(spec_decode_sampler=spec_decode_sampler,
|
||||
timer=timer,
|
||||
collect_interval_s=collect_interval_s)
|
||||
collector.init_gpu_tensors(rank=0)
|
||||
|
@ -6,7 +6,6 @@ from unittest.mock import MagicMock
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput, SequenceOutput
|
||||
from vllm.spec_decode.interfaces import SpeculativeProposals
|
||||
@ -16,23 +15,26 @@ from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
from vllm.spec_decode.spec_decode_worker import (SpecDecodeWorker,
|
||||
split_num_cache_blocks_evenly)
|
||||
|
||||
from .test_utils import mock_spec_decode_sampler
|
||||
from .utils import create_batch, create_sampler_output_list, mock_worker
|
||||
|
||||
|
||||
@pytest.mark.parametrize('k', [1, 2, 6])
|
||||
@pytest.mark.parametrize('batch_size', [1, 2, 32])
|
||||
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||
["rejection_sampler", "typical_acceptance_sampler"])
|
||||
@torch.inference_mode()
|
||||
def test_correctly_calls_draft_model(k: int, batch_size: int):
|
||||
def test_correctly_calls_draft_model(k: int, batch_size: int,
|
||||
acceptance_sampler_method: str):
|
||||
"""Verify SpecDecodeWorker calls the draft worker with correct
|
||||
inputs. Everything else is mocked out.
|
||||
"""
|
||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||
target_worker = mock_worker()
|
||||
rejection_sampler = MagicMock(spec=RejectionSampler)
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
||||
metrics_collector)
|
||||
|
||||
worker = SpecDecodeWorker(
|
||||
draft_worker, target_worker,
|
||||
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
|
||||
exception_secret = 'artificial stop'
|
||||
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
|
||||
|
||||
@ -53,15 +55,16 @@ def test_correctly_calls_draft_model(k: int, batch_size: int):
|
||||
|
||||
@pytest.mark.parametrize('k', [1, 2, 6])
|
||||
@pytest.mark.parametrize('batch_size', [1, 2, 32])
|
||||
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||
["rejection_sampler", "typical_acceptance_sampler"])
|
||||
@torch.inference_mode()
|
||||
def test_correctly_calls_target_model(k: int, batch_size: int):
|
||||
def test_correctly_calls_target_model(k: int, batch_size: int,
|
||||
acceptance_sampler_method: str):
|
||||
"""Verify SpecDecodeWorker calls the target model with correct
|
||||
inputs. Everything else is mocked out.
|
||||
"""
|
||||
draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False)
|
||||
target_worker = mock_worker(use_spec=False)
|
||||
rejection_sampler = MagicMock(spec=RejectionSampler)
|
||||
rejection_sampler.token_id_dtype = torch.int64
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
|
||||
draft_worker.device = 'cuda'
|
||||
@ -69,8 +72,9 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
|
||||
|
||||
set_random_seed(1)
|
||||
|
||||
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
||||
metrics_collector)
|
||||
worker = SpecDecodeWorker(
|
||||
draft_worker, target_worker,
|
||||
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
|
||||
worker.init_device()
|
||||
|
||||
vocab_size = 32_000
|
||||
@ -133,8 +137,11 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
|
||||
|
||||
@pytest.mark.parametrize('k', [1, 2, 6])
|
||||
@pytest.mark.parametrize('batch_size', [1, 2, 32])
|
||||
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||
["rejection_sampler", "typical_acceptance_sampler"])
|
||||
@torch.inference_mode()
|
||||
def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
|
||||
def test_correctly_calls_spec_decode_sampler(k: int, batch_size: int,
|
||||
acceptance_sampler_method: str):
|
||||
"""Verify SpecDecodeWorker calls the rejection sampler with
|
||||
correct inputs. Everything else is mocked out.
|
||||
"""
|
||||
@ -144,15 +151,14 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
|
||||
vocab_size=vocab_size,
|
||||
use_spec=False)
|
||||
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
|
||||
rejection_sampler = MagicMock(spec=RejectionSampler)
|
||||
rejection_sampler.token_id_dtype = torch.int64
|
||||
spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
draft_worker.device = 'cuda'
|
||||
target_worker.device = 'cuda'
|
||||
|
||||
set_random_seed(1)
|
||||
|
||||
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
||||
worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler,
|
||||
metrics_collector)
|
||||
worker.init_device()
|
||||
|
||||
@ -199,15 +205,16 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
|
||||
target_worker.execute_model.return_value = [target_output[0]]
|
||||
|
||||
exception_secret = 'artificial stop'
|
||||
rejection_sampler.side_effect = ValueError(exception_secret)
|
||||
|
||||
spec_decode_sampler.side_effect = ValueError(exception_secret)
|
||||
|
||||
with pytest.raises(ValueError, match=exception_secret):
|
||||
worker.execute_model(execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k))
|
||||
|
||||
assert len(rejection_sampler.call_args_list) == 1
|
||||
_, kwargs = rejection_sampler.call_args_list[0]
|
||||
assert len(spec_decode_sampler.call_args_list) == 1
|
||||
_, kwargs = spec_decode_sampler.call_args_list[0]
|
||||
actual = SimpleNamespace(**kwargs)
|
||||
|
||||
assert torch.equal(actual.bonus_token_ids,
|
||||
@ -221,8 +228,11 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
|
||||
|
||||
@pytest.mark.parametrize('k', [1, 2, 6])
|
||||
@pytest.mark.parametrize('batch_size', [1, 2, 32])
|
||||
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||
["rejection_sampler", "typical_acceptance_sampler"])
|
||||
@torch.inference_mode()
|
||||
def test_correctly_formats_output(k: int, batch_size: int):
|
||||
def test_correctly_formats_output(k: int, batch_size: int,
|
||||
acceptance_sampler_method: str):
|
||||
"""Verify SpecDecodeWorker formats sampler output correctly.
|
||||
Everything else is mocked out.
|
||||
"""
|
||||
@ -232,15 +242,13 @@ def test_correctly_formats_output(k: int, batch_size: int):
|
||||
vocab_size=vocab_size,
|
||||
use_spec=False)
|
||||
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
|
||||
rejection_sampler = MagicMock(spec=RejectionSampler)
|
||||
rejection_sampler.token_id_dtype = torch.int64
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
draft_worker.device = 'cuda'
|
||||
target_worker.device = 'cuda'
|
||||
|
||||
set_random_seed(1)
|
||||
|
||||
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
||||
spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
|
||||
worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler,
|
||||
metrics_collector)
|
||||
worker.init_device()
|
||||
|
||||
@ -286,24 +294,23 @@ def test_correctly_formats_output(k: int, batch_size: int):
|
||||
|
||||
target_worker.execute_model.return_value = [target_output[0]]
|
||||
|
||||
rejection_sampler_output = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k + 1),
|
||||
dtype=torch.int64,
|
||||
device='cuda')
|
||||
spec_decode_sampler_output = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k + 1),
|
||||
dtype=torch.int64,
|
||||
device='cuda')
|
||||
for i in range(batch_size):
|
||||
minimum_accepted_tokens = 1
|
||||
rejection_sampler_output[i][
|
||||
spec_decode_sampler_output[i][
|
||||
-random.randint(minimum_accepted_tokens, k + 1):] = -1
|
||||
|
||||
rejection_sampler.return_value = rejection_sampler_output
|
||||
|
||||
spec_decode_sampler.return_value = spec_decode_sampler_output
|
||||
output = worker.execute_model(execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k))
|
||||
|
||||
expected_output = create_sampler_output_list(
|
||||
token_ids=rejection_sampler_output.transpose(0, 1),
|
||||
token_ids=spec_decode_sampler_output.transpose(0, 1),
|
||||
probs=[None for _ in range(k + 1)],
|
||||
logprobs=[None for _ in range(k + 1)])
|
||||
|
||||
@ -350,8 +357,11 @@ def test_correctly_formats_output(k: int, batch_size: int):
|
||||
@pytest.mark.parametrize('k', [1, 2])
|
||||
@pytest.mark.parametrize('batch_size', [1])
|
||||
@pytest.mark.parametrize('returns_metrics', [True, False])
|
||||
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||
["rejection_sampler", "typical_acceptance_sampler"])
|
||||
@torch.inference_mode()
|
||||
def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
|
||||
def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool,
|
||||
acceptance_sampler_method: str):
|
||||
"""Verify SpecDecodeWorker collects metrics.
|
||||
"""
|
||||
vocab_size = 32_000
|
||||
@ -360,15 +370,14 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
|
||||
vocab_size=vocab_size,
|
||||
use_spec=False)
|
||||
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
|
||||
rejection_sampler = MagicMock(spec=RejectionSampler)
|
||||
rejection_sampler.token_id_dtype = torch.int64
|
||||
spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
draft_worker.device = 'cuda'
|
||||
target_worker.device = 'cuda'
|
||||
|
||||
set_random_seed(1)
|
||||
|
||||
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
||||
worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler,
|
||||
metrics_collector)
|
||||
worker.init_device()
|
||||
|
||||
@ -414,17 +423,16 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
|
||||
|
||||
target_worker.execute_model.return_value = [target_output[0]]
|
||||
|
||||
rejection_sampler_output = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k + 1),
|
||||
dtype=torch.int64,
|
||||
device='cuda')
|
||||
spec_decode_sampler_output = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k + 1),
|
||||
dtype=torch.int64,
|
||||
device='cuda')
|
||||
for i in range(batch_size):
|
||||
minimum_accepted_tokens = 1
|
||||
rejection_sampler_output[i][
|
||||
spec_decode_sampler_output[i][
|
||||
-random.randint(minimum_accepted_tokens, k + 1):] = -1
|
||||
|
||||
rejection_sampler.return_value = rejection_sampler_output
|
||||
spec_decode_sampler.return_value = spec_decode_sampler_output
|
||||
|
||||
mock_rejsample_metrics = MagicMock(
|
||||
spec=SpecDecodeWorkerMetrics) if returns_metrics else None
|
||||
@ -445,15 +453,16 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
|
||||
|
||||
@pytest.mark.parametrize('k', [0])
|
||||
@pytest.mark.parametrize('batch_size', [1, 2, 32])
|
||||
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||
["rejection_sampler", "typical_acceptance_sampler"])
|
||||
@torch.inference_mode()
|
||||
def test_k_equals_zero(k: int, batch_size: int):
|
||||
def test_k_equals_zero(k: int, batch_size: int,
|
||||
acceptance_sampler_method: str):
|
||||
"""Verify that the SpecDecodeWorker calls the draft and target workers
|
||||
when k is zero. This happens during prefill.
|
||||
"""
|
||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||
target_worker = mock_worker()
|
||||
rejection_sampler = MagicMock(spec=RejectionSampler)
|
||||
rejection_sampler.token_id_dtype = torch.int64
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
|
||||
sampler_output = MagicMock(spec=SamplerOutput)
|
||||
@ -465,8 +474,9 @@ def test_k_equals_zero(k: int, batch_size: int):
|
||||
|
||||
set_random_seed(1)
|
||||
|
||||
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
||||
metrics_collector)
|
||||
worker = SpecDecodeWorker(
|
||||
draft_worker, target_worker,
|
||||
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
|
||||
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size,
|
||||
k,
|
||||
@ -487,16 +497,17 @@ def test_k_equals_zero(k: int, batch_size: int):
|
||||
|
||||
@pytest.mark.parametrize('k', [0, 5])
|
||||
@pytest.mark.parametrize('batch_size', [0])
|
||||
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||
["rejection_sampler", "typical_acceptance_sampler"])
|
||||
@torch.inference_mode()
|
||||
def test_empty_input_batch(k: int, batch_size: int):
|
||||
def test_empty_input_batch(k: int, batch_size: int,
|
||||
acceptance_sampler_method: str):
|
||||
"""Verify that the SpecDecodeWorker calls the draft and target workers
|
||||
when the input batch is empty. This can happen if the engine communicates
|
||||
to the workers information without scheduling a batch.
|
||||
"""
|
||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||
target_worker = mock_worker()
|
||||
rejection_sampler = MagicMock(spec=RejectionSampler)
|
||||
rejection_sampler.token_id_dtype = torch.int64
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
|
||||
sampler_output = MagicMock(spec=SamplerOutput)
|
||||
@ -508,8 +519,9 @@ def test_empty_input_batch(k: int, batch_size: int):
|
||||
|
||||
set_random_seed(1)
|
||||
|
||||
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
||||
metrics_collector)
|
||||
worker = SpecDecodeWorker(
|
||||
draft_worker, target_worker,
|
||||
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
|
||||
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size,
|
||||
k,
|
||||
@ -528,18 +540,19 @@ def test_empty_input_batch(k: int, batch_size: int):
|
||||
target_worker.execute_model.assert_called_once_with(execute_model_req)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||
["rejection_sampler", "typical_acceptance_sampler"])
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_init_device():
|
||||
def test_init_device(acceptance_sampler_method: str):
|
||||
"""Verify SpecDecodeWorker invokes proposer/scorer worker init_device, as
|
||||
well as other GPU initialization.
|
||||
"""
|
||||
draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False)
|
||||
target_worker = mock_worker(use_spec=False)
|
||||
rejection_sampler = MagicMock(spec=RejectionSampler)
|
||||
rejection_sampler.token_id_dtype = torch.int64
|
||||
spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
|
||||
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
||||
worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler,
|
||||
metrics_collector)
|
||||
|
||||
worker.init_device()
|
||||
@ -549,22 +562,23 @@ def test_init_device():
|
||||
target_worker.init_device.assert_called_once()
|
||||
|
||||
metrics_collector.init_gpu_tensors.assert_called_once()
|
||||
rejection_sampler.init_gpu_tensors.assert_called_once()
|
||||
spec_decode_sampler.init_gpu_tensors.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||
["rejection_sampler", "typical_acceptance_sampler"])
|
||||
@torch.inference_mode()
|
||||
def test_initialize_cache():
|
||||
def test_initialize_cache(acceptance_sampler_method):
|
||||
"""Verify SpecDecodeWorker invokes initialize_cache on proposer/scorer
|
||||
workers.
|
||||
"""
|
||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||
target_worker = mock_worker()
|
||||
rejection_sampler = MagicMock(spec=RejectionSampler)
|
||||
rejection_sampler.token_id_dtype = torch.int64
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
|
||||
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
||||
metrics_collector)
|
||||
worker = SpecDecodeWorker(
|
||||
draft_worker, target_worker,
|
||||
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
|
||||
|
||||
kwargs = {"num_gpu_blocks": 1024, "num_cpu_blocks": 1023}
|
||||
worker.initialize_cache(**kwargs)
|
||||
@ -577,19 +591,20 @@ def test_initialize_cache():
|
||||
@pytest.mark.parametrize('available_cpu_blocks', [500])
|
||||
@pytest.mark.parametrize('target_cache_block_size_bytes', [2 * 2 * 4096])
|
||||
@pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096])
|
||||
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||
["rejection_sampler", "typical_acceptance_sampler"])
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_determine_num_available_blocks(available_gpu_blocks: int,
|
||||
available_cpu_blocks: int,
|
||||
target_cache_block_size_bytes: int,
|
||||
draft_kv_size_bytes: int):
|
||||
draft_kv_size_bytes: int,
|
||||
acceptance_sampler_method: str):
|
||||
"""Verify SpecDecodeWorker correctly profiles num available GPU blocks.
|
||||
Specifically, it should run profiling in the scorer worker, and then evenly
|
||||
split the blocks between proposer and scorer worker.
|
||||
"""
|
||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||
target_worker = mock_worker()
|
||||
rejection_sampler = MagicMock(spec=RejectionSampler)
|
||||
rejection_sampler.token_id_dtype = torch.int64
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
|
||||
target_worker.determine_num_available_blocks.return_value = (
|
||||
@ -598,8 +613,9 @@ def test_determine_num_available_blocks(available_gpu_blocks: int,
|
||||
target_cache_block_size_bytes)
|
||||
draft_worker.get_cache_block_size_bytes.return_value = draft_kv_size_bytes
|
||||
|
||||
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
||||
metrics_collector)
|
||||
worker = SpecDecodeWorker(
|
||||
draft_worker, target_worker,
|
||||
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
|
||||
|
||||
num_gpu_blocks, num_cpu_blocks = worker.determine_num_available_blocks()
|
||||
|
||||
|
@ -1,7 +1,11 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||
from vllm.model_executor.layers.typical_acceptance_sampler import (
|
||||
TypicalAcceptanceSampler)
|
||||
from vllm.sequence import SequenceGroupMetadata, get_all_seq_ids
|
||||
from vllm.spec_decode.util import split_batch_by_proposal_len
|
||||
|
||||
@ -109,3 +113,21 @@ def test_all_non_zero_with_zero_filter(fake_sequence_group_metadata):
|
||||
|
||||
assert filtered_groups == []
|
||||
assert indices == []
|
||||
|
||||
|
||||
def mock_spec_decode_sampler(acceptance_sampler_method):
|
||||
"""
|
||||
Returns either a RejectionSampler or TypicalAcceptanceSampler
|
||||
object depending on whether acceptance_sampler_method is
|
||||
'rejection_sampler' or 'typical_acceptance_sampler' respectively.
|
||||
"""
|
||||
if acceptance_sampler_method == "rejection_sampler":
|
||||
sampler = MagicMock(spec=RejectionSampler)
|
||||
sampler.token_id_dtype = torch.int64
|
||||
return sampler
|
||||
elif acceptance_sampler_method == "typical_acceptance_sampler":
|
||||
sampler = MagicMock(spec=TypicalAcceptanceSampler)
|
||||
sampler.token_id_dtype = torch.int64
|
||||
return sampler
|
||||
else:
|
||||
raise ValueError(f"Invalid sampler name {acceptance_sampler_method}")
|
||||
|
@ -753,7 +753,6 @@ class SchedulerConfig:
|
||||
self.chunked_prefill_enabled = enable_chunked_prefill
|
||||
self.embedding_mode = embedding_mode
|
||||
self.preemption_mode = preemption_mode
|
||||
|
||||
self._verify_args()
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
@ -834,6 +833,9 @@ class SpeculativeConfig:
|
||||
speculative_disable_by_batch_size: Optional[int],
|
||||
ngram_prompt_lookup_max: Optional[int],
|
||||
ngram_prompt_lookup_min: Optional[int],
|
||||
draft_token_acceptance_method: str,
|
||||
typical_acceptance_sampler_posterior_threshold: Optional[float],
|
||||
typical_acceptance_sampler_posterior_alpha: Optional[float],
|
||||
) -> Optional["SpeculativeConfig"]:
|
||||
"""Create a SpeculativeConfig if possible, else return None.
|
||||
|
||||
@ -870,7 +872,20 @@ class SpeculativeConfig:
|
||||
window, if provided.
|
||||
ngram_prompt_lookup_min (Optional[int]): Min size of ngram token
|
||||
window, if provided.
|
||||
|
||||
draft_token_acceptance_method (str): The method to use for
|
||||
accepting draft tokens. This can take two possible
|
||||
values 'rejection_sampler' and 'typical_acceptance_sampler'
|
||||
for RejectionSampler and TypicalAcceptanceSampler
|
||||
respectively.
|
||||
typical_acceptance_sampler_posterior_threshold (Optional[float]):
|
||||
A threshold value that sets a lower bound on the posterior
|
||||
probability of a token in the target model for it to be
|
||||
accepted. This threshold is used only when we use the
|
||||
TypicalAcceptanceSampler for token acceptance.
|
||||
typical_acceptance_sampler_posterior_alpha (Optional[float]):
|
||||
A scaling factor for the entropy-based threshold in the
|
||||
TypicalAcceptanceSampler.
|
||||
|
||||
Returns:
|
||||
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
|
||||
the necessary conditions are met, else None.
|
||||
@ -984,6 +999,11 @@ class SpeculativeConfig:
|
||||
"speculative_model unless the draft model config contains an "
|
||||
"n_predict parameter.")
|
||||
|
||||
if typical_acceptance_sampler_posterior_threshold is None:
|
||||
typical_acceptance_sampler_posterior_threshold = 0.09
|
||||
if typical_acceptance_sampler_posterior_alpha is None:
|
||||
typical_acceptance_sampler_posterior_alpha = 0.3
|
||||
|
||||
return SpeculativeConfig(
|
||||
draft_model_config,
|
||||
draft_parallel_config,
|
||||
@ -991,6 +1011,11 @@ class SpeculativeConfig:
|
||||
speculative_disable_by_batch_size,
|
||||
ngram_prompt_lookup_max,
|
||||
ngram_prompt_lookup_min,
|
||||
draft_token_acceptance_method=draft_token_acceptance_method,
|
||||
typical_acceptance_sampler_posterior_threshold=\
|
||||
typical_acceptance_sampler_posterior_threshold,
|
||||
typical_acceptance_sampler_posterior_alpha=\
|
||||
typical_acceptance_sampler_posterior_alpha,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -1072,6 +1097,9 @@ class SpeculativeConfig:
|
||||
speculative_disable_by_batch_size: Optional[int],
|
||||
ngram_prompt_lookup_max: Optional[int],
|
||||
ngram_prompt_lookup_min: Optional[int],
|
||||
draft_token_acceptance_method: str,
|
||||
typical_acceptance_sampler_posterior_threshold: float,
|
||||
typical_acceptance_sampler_posterior_alpha: float,
|
||||
):
|
||||
"""Create a SpeculativeConfig object.
|
||||
|
||||
@ -1085,6 +1113,19 @@ class SpeculativeConfig:
|
||||
enqueue requests is larger than this value.
|
||||
ngram_prompt_lookup_max: Max size of ngram token window.
|
||||
ngram_prompt_lookup_min: Min size of ngram token window.
|
||||
draft_token_acceptance_method (str): The method to use for
|
||||
accepting draft tokens. This can take two possible
|
||||
values 'rejection_sampler' and 'typical_acceptance_sampler'
|
||||
for RejectionSampler and TypicalAcceptanceSampler
|
||||
respectively.
|
||||
typical_acceptance_sampler_posterior_threshold (Optional[float]):
|
||||
A threshold value that sets a lower bound on the posterior
|
||||
probability of a token in the target model for it to be
|
||||
accepted. This threshold is used only when we use the
|
||||
TypicalAcceptanceSampler for token acceptance.
|
||||
typical_acceptance_sampler_posterior_alpha (Optional[float]):
|
||||
A scaling factor for the entropy-based threshold in the
|
||||
TypicalAcceptanceSampler.
|
||||
"""
|
||||
self.draft_model_config = draft_model_config
|
||||
self.draft_parallel_config = draft_parallel_config
|
||||
@ -1093,6 +1134,11 @@ class SpeculativeConfig:
|
||||
speculative_disable_by_batch_size
|
||||
self.ngram_prompt_lookup_max = ngram_prompt_lookup_max or 0
|
||||
self.ngram_prompt_lookup_min = ngram_prompt_lookup_min or 0
|
||||
self.draft_token_acceptance_method = draft_token_acceptance_method
|
||||
self.typical_acceptance_sampler_posterior_threshold = \
|
||||
typical_acceptance_sampler_posterior_threshold
|
||||
self.typical_acceptance_sampler_posterior_alpha = \
|
||||
typical_acceptance_sampler_posterior_alpha
|
||||
|
||||
self._verify_args()
|
||||
|
||||
@ -1104,6 +1150,31 @@ class SpeculativeConfig:
|
||||
if self.draft_model_config:
|
||||
self.draft_model_config.verify_with_parallel_config(
|
||||
self.draft_parallel_config)
|
||||
# Validate and set draft token acceptance related settings.
|
||||
|
||||
if (self.draft_token_acceptance_method is None):
|
||||
raise ValueError("draft_token_acceptance_method is not set. "
|
||||
"Expected values are rejection_sampler or "
|
||||
"typical_acceptance_sampler.")
|
||||
|
||||
if (self.draft_token_acceptance_method != 'rejection_sampler'
|
||||
and self.draft_token_acceptance_method !=
|
||||
'typical_acceptance_sampler'):
|
||||
raise ValueError(
|
||||
"Expected draft_token_acceptance_method to be either "
|
||||
"rejection_sampler or typical_acceptance_sampler. Instead it "
|
||||
f"is {self.draft_token_acceptance_method}")
|
||||
|
||||
if (self.typical_acceptance_sampler_posterior_threshold < 0
|
||||
or self.typical_acceptance_sampler_posterior_alpha < 0):
|
||||
raise ValueError(
|
||||
"Expected typical_acceptance_sampler_posterior_threshold "
|
||||
"and typical_acceptance_sampler_posterior_alpha to be > 0. "
|
||||
"Instead found "
|
||||
f"typical_acceptance_sampler_posterior_threshold = "
|
||||
f"{self.typical_acceptance_sampler_posterior_threshold} and "
|
||||
f"typical_acceptance_sampler_posterior_alpha = "
|
||||
f"{self.typical_acceptance_sampler_posterior_alpha}")
|
||||
|
||||
@property
|
||||
def num_lookahead_slots(self) -> int:
|
||||
|
@ -100,7 +100,9 @@ class EngineArgs:
|
||||
speculative_disable_by_batch_size: Optional[int] = None
|
||||
ngram_prompt_lookup_max: Optional[int] = None
|
||||
ngram_prompt_lookup_min: Optional[int] = None
|
||||
|
||||
spec_decoding_acceptance_method: str = 'rejection_sampler'
|
||||
typical_acceptance_sampler_posterior_threshold: Optional[float] = None
|
||||
typical_acceptance_sampler_posterior_alpha: Optional[float] = None
|
||||
qlora_adapter_name_or_path: Optional[str] = None
|
||||
|
||||
otlp_traces_endpoint: Optional[str] = None
|
||||
@ -577,6 +579,38 @@ class EngineArgs:
|
||||
help='Min size of window for ngram prompt lookup in speculative '
|
||||
'decoding.')
|
||||
|
||||
parser.add_argument(
|
||||
'--spec-decoding-acceptance-method',
|
||||
type=str,
|
||||
default=EngineArgs.spec_decoding_acceptance_method,
|
||||
choices=['rejection_sampler', 'typical_acceptance_sampler'],
|
||||
help='Specify the acceptance method to use during draft token '
|
||||
'verification in speculative decoding. Two types of acceptance '
|
||||
'routines are supported: '
|
||||
'1) RejectionSampler which does not allow changing the '
|
||||
'acceptance rate of draft tokens, '
|
||||
'2) TypicalAcceptanceSampler which is configurable, allowing for '
|
||||
'a higher acceptance rate at the cost of lower quality, '
|
||||
'and vice versa.')
|
||||
|
||||
parser.add_argument(
|
||||
'--typical-acceptance-sampler-posterior-threshold',
|
||||
type=float,
|
||||
default=EngineArgs.typical_acceptance_sampler_posterior_threshold,
|
||||
help='Set the lower bound threshold for the posterior '
|
||||
'probability of a token to be accepted. This threshold is '
|
||||
'used by the TypicalAcceptanceSampler to make sampling decisions '
|
||||
'during speculative decoding. Defaults to 0.09')
|
||||
|
||||
parser.add_argument(
|
||||
'--typical-acceptance-sampler-posterior-alpha',
|
||||
type=float,
|
||||
default=EngineArgs.typical_acceptance_sampler_posterior_alpha,
|
||||
help='A scaling factor for the entropy-based threshold for token '
|
||||
'acceptance in the TypicalAcceptanceSampler. Typically defaults '
|
||||
'to sqrt of --typical-acceptance-sampler-posterior-threshold '
|
||||
'i.e. 0.3')
|
||||
|
||||
parser.add_argument('--model-loader-extra-config',
|
||||
type=nullable_str,
|
||||
default=EngineArgs.model_loader_extra_config,
|
||||
@ -737,6 +771,12 @@ class EngineArgs:
|
||||
use_v2_block_manager=self.use_v2_block_manager,
|
||||
ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
|
||||
ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
|
||||
draft_token_acceptance_method=\
|
||||
self.spec_decoding_acceptance_method,
|
||||
typical_acceptance_sampler_posterior_threshold=self.
|
||||
typical_acceptance_sampler_posterior_threshold,
|
||||
typical_acceptance_sampler_posterior_alpha=self.
|
||||
typical_acceptance_sampler_posterior_alpha,
|
||||
)
|
||||
|
||||
scheduler_config = SchedulerConfig(
|
||||
|
@ -457,4 +457,4 @@ class PrometheusStatLogger(StatLoggerBase):
|
||||
|
||||
class RayPrometheusStatLogger(PrometheusStatLogger):
|
||||
"""RayPrometheusStatLogger uses Ray metrics instead."""
|
||||
_metrics_cls = RayMetrics
|
||||
_metrics_cls = RayMetrics
|
@ -3,13 +3,12 @@ from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.jit
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.model_executor.layers.spec_decode_base_sampler import (
|
||||
SpecDecodeBaseSampler)
|
||||
|
||||
|
||||
class RejectionSampler(SpecDecodeBaseSampler, nn.Module):
|
||||
class RejectionSampler(SpecDecodeBaseSampler):
|
||||
"""Apply modified rejection sampling as described in "Accelerating Large
|
||||
Language Model Decoding with Speculative Sampling"
|
||||
https://arxiv.org/pdf/2302.01318.pdf.
|
||||
@ -28,8 +27,8 @@ class RejectionSampler(SpecDecodeBaseSampler, nn.Module):
|
||||
during sampling. This catches correctness issues but adds
|
||||
nontrivial latency.
|
||||
"""
|
||||
SpecDecodeBaseSampler.__init__(self, disable_bonus_tokens, strict_mode)
|
||||
nn.Module.__init__(self)
|
||||
super().__init__(disable_bonus_tokens=disable_bonus_tokens,
|
||||
strict_mode=strict_mode)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -78,11 +77,12 @@ class RejectionSampler(SpecDecodeBaseSampler, nn.Module):
|
||||
self._raise_if_incorrect_input(target_probs, bonus_token_ids,
|
||||
draft_probs, draft_token_ids)
|
||||
|
||||
accepted, recovered_token_ids = self._batch_modified_rejection_sampling(
|
||||
target_probs,
|
||||
draft_probs,
|
||||
draft_token_ids,
|
||||
)
|
||||
accepted, recovered_token_ids = (
|
||||
self._batch_modified_rejection_sampling(
|
||||
target_probs,
|
||||
draft_probs,
|
||||
draft_token_ids,
|
||||
))
|
||||
|
||||
output_token_ids = self._create_output(
|
||||
accepted,
|
||||
|
@ -1,9 +1,12 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.jit
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class SpecDecodeBaseSampler():
|
||||
class SpecDecodeBaseSampler(nn.Module):
|
||||
"""Base class for samplers used for Speculative Decoding verification
|
||||
step.
|
||||
"""
|
||||
@ -51,6 +54,16 @@ class SpecDecodeBaseSampler():
|
||||
def token_id_dtype(self):
|
||||
return torch.int64
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
target_probs: torch.Tensor,
|
||||
bonus_token_ids: torch.Tensor,
|
||||
draft_probs: torch.Tensor,
|
||||
draft_token_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def _create_output(
|
||||
self,
|
||||
accepted: torch.Tensor, # [batch_size, k]
|
||||
|
@ -1,12 +1,11 @@
|
||||
import torch
|
||||
import torch.jit
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.model_executor.layers.spec_decode_base_sampler import (
|
||||
SpecDecodeBaseSampler)
|
||||
|
||||
|
||||
class TypicalAcceptanceSampler(SpecDecodeBaseSampler, nn.Module):
|
||||
class TypicalAcceptanceSampler(SpecDecodeBaseSampler):
|
||||
"""Apply typical acceptance sampling as described in section 3.3.1 in
|
||||
"MEDUSA: Simple LLM Inference Acceleration Framework with
|
||||
Multiple Decoding Heads"
|
||||
@ -15,10 +14,10 @@ class TypicalAcceptanceSampler(SpecDecodeBaseSampler, nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
posterior_threshold: float,
|
||||
posterior_alpha: float,
|
||||
disable_bonus_tokens: bool = False,
|
||||
strict_mode: bool = False,
|
||||
posterior_threshold: float = 0.09,
|
||||
posterior_alpha: float = 0.3,
|
||||
):
|
||||
"""Create a Typical Acceptance Sampler.
|
||||
|
||||
@ -31,23 +30,20 @@ class TypicalAcceptanceSampler(SpecDecodeBaseSampler, nn.Module):
|
||||
nontrivial latency.
|
||||
posterior_threshold : A threshold value that sets a lower bound
|
||||
on the posterior probability of a token in target model for it
|
||||
to be accepted. Default is 0.09
|
||||
to be accepted.
|
||||
posterior_alpha : A scaling factor for the entropy-based
|
||||
threshold in typical acceptance sampling. Typically defaults to
|
||||
sqrt of posterior_threshold and is set to 0.3.
|
||||
threshold in typical acceptance sampling.
|
||||
"""
|
||||
SpecDecodeBaseSampler.__init__(
|
||||
self,
|
||||
disable_bonus_tokens=disable_bonus_tokens,
|
||||
strict_mode=strict_mode)
|
||||
nn.Module.__init__(self)
|
||||
self._posterior_threshold = posterior_threshold
|
||||
self._posterior_alpha = posterior_alpha
|
||||
super().__init__(disable_bonus_tokens=disable_bonus_tokens,
|
||||
strict_mode=strict_mode)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
target_probs: torch.Tensor,
|
||||
bonus_token_ids: torch.Tensor,
|
||||
draft_probs: torch.Tensor,
|
||||
draft_token_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Sample token ids using typical acceptance sampling. This accepts
|
||||
@ -69,6 +65,8 @@ class TypicalAcceptanceSampler(SpecDecodeBaseSampler, nn.Module):
|
||||
speculative tokens in a sequence are accepted.
|
||||
shape = [batch_size, num_bonus_tokens]
|
||||
|
||||
draft_probs: This parameter is unused by the acceptance sampler.
|
||||
|
||||
draft_token_ids: The token ids that were sampled from the draft
|
||||
probabilities.
|
||||
shape = [batch_size, num_speculative_tokens]
|
||||
|
@ -4,7 +4,8 @@ from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||
from vllm.model_executor.layers.spec_decode_base_sampler import (
|
||||
SpecDecodeBaseSampler)
|
||||
from vllm.utils import is_pin_memory_available
|
||||
|
||||
|
||||
@ -46,15 +47,15 @@ Timer = Callable[[], float]
|
||||
|
||||
|
||||
class AsyncMetricsCollector:
|
||||
"""Class which copies rejection sampler metrics from the device to CPU on a
|
||||
non-default Torch stream.
|
||||
"""Class which copies rejection/typical-acceptance sampler metrics
|
||||
from the device to CPU on a non-default Torch stream.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
rejection_sampler: RejectionSampler,
|
||||
spec_decode_sampler: SpecDecodeBaseSampler,
|
||||
timer: Optional[Timer] = None,
|
||||
collect_interval_s: float = 5.0):
|
||||
self._rejection_sampler = rejection_sampler
|
||||
self.spec_decode_sampler = spec_decode_sampler
|
||||
self._timer = time.time if timer is None else timer
|
||||
|
||||
self._rank: Optional[int] = None
|
||||
@ -95,7 +96,7 @@ class AsyncMetricsCollector:
|
||||
return None
|
||||
|
||||
def _should_collect_rejsample_metrics(self, now: float) -> bool:
|
||||
"""Return whether or not this iteration should print rejection sampling
|
||||
"""Return whether or not this iteration should print sampling
|
||||
metrics.
|
||||
"""
|
||||
if self._rank != 0:
|
||||
@ -107,8 +108,8 @@ class AsyncMetricsCollector:
|
||||
return True
|
||||
|
||||
def _copy_rejsample_metrics_async(self) -> torch.cuda.Event:
|
||||
"""Copy rejection sampling metrics (number of accepted tokens, etc) to
|
||||
CPU asynchronously.
|
||||
"""Copy rejection/typical-acceptance sampling metrics
|
||||
(number of accepted tokens, etc) to CPU asynchronously.
|
||||
|
||||
Returns a CUDA event recording when the copy is complete.
|
||||
"""
|
||||
@ -117,13 +118,14 @@ class AsyncMetricsCollector:
|
||||
|
||||
with torch.cuda.stream(self._copy_stream):
|
||||
self._aggregate_num_accepted_tokens.copy_(
|
||||
self._rejection_sampler.num_accepted_tokens, non_blocking=True)
|
||||
self.spec_decode_sampler.num_accepted_tokens,
|
||||
non_blocking=True)
|
||||
self._aggregate_num_emitted_tokens.copy_(
|
||||
self._rejection_sampler.num_emitted_tokens, non_blocking=True)
|
||||
self.spec_decode_sampler.num_emitted_tokens, non_blocking=True)
|
||||
# Number of draft tokens is calculated on CPU, so no copy is
|
||||
# required.
|
||||
self._aggregate_num_draft_tokens = (
|
||||
self._rejection_sampler.num_draft_tokens)
|
||||
self.spec_decode_sampler.num_draft_tokens)
|
||||
|
||||
aggregate_metrics_ready = torch.cuda.Event()
|
||||
aggregate_metrics_ready.record(self._copy_stream)
|
||||
|
@ -7,6 +7,10 @@ from vllm.config import ParallelConfig, SpeculativeConfig
|
||||
from vllm.distributed.communication_op import broadcast_tensor_dict
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||
from vllm.model_executor.layers.spec_decode_base_sampler import (
|
||||
SpecDecodeBaseSampler)
|
||||
from vllm.model_executor.layers.typical_acceptance_sampler import (
|
||||
TypicalAcceptanceSampler)
|
||||
from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest,
|
||||
HiddenStates, SamplerOutput, SequenceGroupMetadata,
|
||||
get_all_seq_ids)
|
||||
@ -56,7 +60,12 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
|
||||
draft_worker_kwargs=draft_worker_kwargs,
|
||||
disable_by_batch_size=speculative_config.
|
||||
speculative_disable_by_batch_size,
|
||||
)
|
||||
draft_token_acceptance_method=speculative_config.
|
||||
draft_token_acceptance_method,
|
||||
typical_acceptance_sampler_posterior_threshold=speculative_config.
|
||||
typical_acceptance_sampler_posterior_threshold,
|
||||
typical_acceptance_sampler_posterior_alpha=speculative_config.
|
||||
typical_acceptance_sampler_posterior_alpha)
|
||||
|
||||
return spec_decode_worker
|
||||
|
||||
@ -78,8 +87,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
welcome!).
|
||||
* Only top-1 proposal and scoring are implemented. Tree-attention is left as
|
||||
future work.
|
||||
* Only lossless rejection sampling is supported. Contributions adding lossy
|
||||
verification routines are welcome (e.g. Medusa's typical acceptance).
|
||||
* All sequences in a batch must have the same proposal length, or zero. This
|
||||
can be improved by having per-sequence speculation in the future.
|
||||
* The scoring forward pass is done without an MQA kernel, which is
|
||||
@ -95,6 +102,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
scorer_worker: Worker,
|
||||
draft_worker_kwargs: Dict[str, Any],
|
||||
disable_by_batch_size: Optional[int],
|
||||
draft_token_acceptance_method: str,
|
||||
typical_acceptance_sampler_posterior_threshold: float,
|
||||
typical_acceptance_sampler_posterior_alpha: float,
|
||||
) -> "SpecDecodeWorker":
|
||||
|
||||
ngram_prompt_lookup_max = (
|
||||
@ -127,17 +137,30 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
logger.info("Configuring SpecDecodeWorker with proposer=%s",
|
||||
type(proposer_worker))
|
||||
|
||||
spec_decode_sampler: SpecDecodeBaseSampler = None
|
||||
if draft_token_acceptance_method == "rejection_sampler":
|
||||
spec_decode_sampler = RejectionSampler(
|
||||
disable_bonus_tokens=disable_bonus_tokens, )
|
||||
elif draft_token_acceptance_method == "typical_acceptance_sampler":
|
||||
spec_decode_sampler = TypicalAcceptanceSampler(
|
||||
disable_bonus_tokens=disable_bonus_tokens,
|
||||
posterior_threshold=\
|
||||
typical_acceptance_sampler_posterior_threshold,
|
||||
posterior_alpha=typical_acceptance_sampler_posterior_alpha,
|
||||
)
|
||||
logger.info("Configuring SpecDecodeWorker with sampler=%s",
|
||||
type(spec_decode_sampler))
|
||||
|
||||
return SpecDecodeWorker(proposer_worker,
|
||||
scorer_worker,
|
||||
disable_by_batch_size=disable_by_batch_size,
|
||||
rejection_sampler=RejectionSampler(
|
||||
disable_bonus_tokens=disable_bonus_tokens))
|
||||
spec_decode_sampler=spec_decode_sampler)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
proposer_worker: ProposerWorkerBase,
|
||||
scorer_worker: WorkerBase,
|
||||
rejection_sampler: RejectionSampler,
|
||||
spec_decode_sampler: SpecDecodeBaseSampler,
|
||||
metrics_collector: Optional[AsyncMetricsCollector] = None,
|
||||
disable_by_batch_size: Optional[int] = None,
|
||||
):
|
||||
@ -150,8 +173,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
scorer_worker: A worker that produces probabilities of speculative
|
||||
tokens according to some base model. Typically a vanilla vLLM
|
||||
Worker.
|
||||
rejection_sampler: A Torch module used to perform modified rejection
|
||||
sampling for speculative decoding.
|
||||
spec_decode_sampler: A Torch module used to perform acceptance
|
||||
sampling of the draft tokens in the verification step of
|
||||
speculative decoding. Currently we support two different
|
||||
types of sampler namely RejectionSampler and
|
||||
TypicalAcceptanceSampler. 'spec_decode_sampler' is either an
|
||||
instance of RejectionSampler or TypicalAcceptanceSampler.
|
||||
disable_by_batch_size: If the batch size is larger than this,
|
||||
disable speculative decoding for new incoming requests.
|
||||
metrics_collector: Helper class for collecting metrics; can be set
|
||||
@ -160,15 +187,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
self.proposer_worker = proposer_worker
|
||||
self.scorer_worker = scorer_worker
|
||||
self.disable_by_batch_size = disable_by_batch_size or float("inf")
|
||||
self.rejection_sampler = rejection_sampler
|
||||
|
||||
self.spec_decode_sampler = spec_decode_sampler
|
||||
self._metrics = AsyncMetricsCollector(
|
||||
rejection_sampler
|
||||
self.spec_decode_sampler
|
||||
) if metrics_collector is None else metrics_collector
|
||||
|
||||
self.probs_dtype = self.rejection_sampler.probs_dtype
|
||||
self.token_id_dtype = self.rejection_sampler.token_id_dtype
|
||||
|
||||
self.probs_dtype = self.spec_decode_sampler.probs_dtype
|
||||
self.token_id_dtype = self.spec_decode_sampler.token_id_dtype
|
||||
# Lazy initiazliation.
|
||||
self.scorer: SpeculativeScorer
|
||||
|
||||
@ -189,7 +213,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
self.proposer_worker.load_model()
|
||||
|
||||
self._metrics.init_gpu_tensors(self.rank)
|
||||
self.rejection_sampler.init_gpu_tensors(self.rank)
|
||||
self.spec_decode_sampler.init_gpu_tensors(self.rank)
|
||||
|
||||
self.scorer = BatchExpansionTop1Scorer(
|
||||
scorer_worker=self.scorer_worker,
|
||||
device=self.device,
|
||||
@ -203,7 +228,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
def _configure_model_sampler_for_spec_decode(self):
|
||||
"""Configure model sampler to emit GPU tensors. This allows spec decode
|
||||
to keep data on device without transferring to CPU and serializing,
|
||||
which significantly reduces overhead of rejection sampling.
|
||||
which significantly reduces overhead of sampling during verification.
|
||||
|
||||
NOTE(cade): This breaks abstraction boundaries pretty badly. The better
|
||||
design is to have the "move to CPU and serialize" sampling decision be
|
||||
@ -481,7 +506,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
# Get proposed tokens.
|
||||
proposal_token_ids = proposals.proposal_token_ids[spec_indices]
|
||||
|
||||
accepted_token_ids = self.rejection_sampler(
|
||||
accepted_token_ids = self.spec_decode_sampler(
|
||||
target_probs=proposal_verifier_probs,
|
||||
bonus_token_ids=bonus_token_ids,
|
||||
draft_probs=proposal_probs,
|
||||
@ -496,7 +521,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
accepted_token_ids = torch.cat(
|
||||
[accepted_token_ids, non_spec_token_ids])
|
||||
logprobs = proposal_scores.logprobs
|
||||
|
||||
# Rearrange so that results are in the order of the original seq group
|
||||
# metadata.
|
||||
accepted_token_ids[original_indices] = accepted_token_ids.clone()
|
||||
|
Loading…
x
Reference in New Issue
Block a user