[Speculative Decoding 2/2 ] Integrate typical acceptance sampler into Spec Decode Worker (#5348)
This commit is contained in:
parent
614aa51203
commit
80ca1e6a3a
@ -52,6 +52,19 @@ def get_draft_token_ids(batch_size: int, k: int, vocab_size: int,
|
|||||||
return draft_token_ids
|
return draft_token_ids
|
||||||
|
|
||||||
|
|
||||||
|
def get_acceptance_sampler(
|
||||||
|
posterior_threshold: float = 0.03,
|
||||||
|
posterior_alpha: float = 0.9,
|
||||||
|
disable_bonus_tokens: bool = False,
|
||||||
|
strict_mode: bool = False,
|
||||||
|
) -> TypicalAcceptanceSampler:
|
||||||
|
"""
|
||||||
|
Initializes and returns a TypicalAcceptanceSampler.
|
||||||
|
"""
|
||||||
|
return TypicalAcceptanceSampler(posterior_threshold, posterior_alpha,
|
||||||
|
disable_bonus_tokens, strict_mode)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("k", list(range(1, 6)))
|
@pytest.mark.parametrize("k", list(range(1, 6)))
|
||||||
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
|
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
|
||||||
@pytest.mark.parametrize("batch_size", list(range(1, 32)))
|
@pytest.mark.parametrize("batch_size", list(range(1, 32)))
|
||||||
@ -64,7 +77,7 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
|
|||||||
different combinations of k, vocab_size, batch_size and num devices.
|
different combinations of k, vocab_size, batch_size and num devices.
|
||||||
"""
|
"""
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
typical_acceptance_sampler = TypicalAcceptanceSampler()
|
typical_acceptance_sampler = get_acceptance_sampler()
|
||||||
typical_acceptance_sampler.init_gpu_tensors(rank=0)
|
typical_acceptance_sampler.init_gpu_tensors(rank=0)
|
||||||
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||||
bonus_token_ids = torch.randint(low=0,
|
bonus_token_ids = torch.randint(low=0,
|
||||||
@ -76,7 +89,10 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
|
|||||||
size=(batch_size, k),
|
size=(batch_size, k),
|
||||||
dtype=torch.int64)
|
dtype=torch.int64)
|
||||||
# Verify that sampling succeeds for all cases.
|
# Verify that sampling succeeds for all cases.
|
||||||
typical_acceptance_sampler(target_probs, bonus_token_ids, draft_token_ids)
|
typical_acceptance_sampler(target_probs,
|
||||||
|
bonus_token_ids,
|
||||||
|
draft_probs=None,
|
||||||
|
draft_token_ids=draft_token_ids)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("above_or_below_vocab_range", ["above", "below"])
|
@pytest.mark.parametrize("above_or_below_vocab_range", ["above", "below"])
|
||||||
@ -94,7 +110,7 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
|
|||||||
batch_size = 5
|
batch_size = 5
|
||||||
vocab_size = 30_000
|
vocab_size = 30_000
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
typical_acceptance_sampler = TypicalAcceptanceSampler(strict_mode=True)
|
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
|
||||||
typical_acceptance_sampler.init_gpu_tensors(rank=0)
|
typical_acceptance_sampler.init_gpu_tensors(rank=0)
|
||||||
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||||
bonus_token_ids = torch.randint(low=0,
|
bonus_token_ids = torch.randint(low=0,
|
||||||
@ -125,8 +141,10 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
|
|||||||
oob_token_ids[0][0] = rogue_token_id
|
oob_token_ids[0][0] = rogue_token_id
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
typical_acceptance_sampler(target_probs, bonus_token_ids,
|
typical_acceptance_sampler(target_probs,
|
||||||
draft_token_ids)
|
bonus_token_ids,
|
||||||
|
draft_probs=None,
|
||||||
|
draft_token_ids=draft_token_ids)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("seed", list(range(10)))
|
@pytest.mark.parametrize("seed", list(range(10)))
|
||||||
@ -151,7 +169,7 @@ def test_uniform_target_distribution_accepts_all_tokens(
|
|||||||
batch_size = 5
|
batch_size = 5
|
||||||
vocab_size = 30_000
|
vocab_size = 30_000
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
typical_acceptance_sampler = TypicalAcceptanceSampler(
|
typical_acceptance_sampler = get_acceptance_sampler(
|
||||||
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
|
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
|
||||||
typical_acceptance_sampler.init_gpu_tensors(rank=0)
|
typical_acceptance_sampler.init_gpu_tensors(rank=0)
|
||||||
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||||
@ -163,9 +181,11 @@ def test_uniform_target_distribution_accepts_all_tokens(
|
|||||||
high=vocab_size,
|
high=vocab_size,
|
||||||
size=(batch_size, 1),
|
size=(batch_size, 1),
|
||||||
dtype=torch.int64)
|
dtype=torch.int64)
|
||||||
output_token_ids = typical_acceptance_sampler(target_probs,
|
output_token_ids = typical_acceptance_sampler(
|
||||||
|
target_probs,
|
||||||
bonus_token_ids,
|
bonus_token_ids,
|
||||||
draft_token_ids)
|
draft_probs=None,
|
||||||
|
draft_token_ids=draft_token_ids)
|
||||||
# We are using a uniform target probability distribution.
|
# We are using a uniform target probability distribution.
|
||||||
# For a uniform distribution the entropy is very high and it
|
# For a uniform distribution the entropy is very high and it
|
||||||
# should lead to all draft tokens being accepted. Verify that.
|
# should lead to all draft tokens being accepted. Verify that.
|
||||||
@ -203,7 +223,7 @@ def test_temperature_zero_target_distribution(seed: int,
|
|||||||
vocab_size = 30_000
|
vocab_size = 30_000
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
|
|
||||||
typical_acceptance_sampler = TypicalAcceptanceSampler(
|
typical_acceptance_sampler = get_acceptance_sampler(
|
||||||
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
|
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
|
||||||
typical_acceptance_sampler.init_gpu_tensors(rank=0)
|
typical_acceptance_sampler.init_gpu_tensors(rank=0)
|
||||||
# Simulate temperature 0 probability distribution for target probabilities
|
# Simulate temperature 0 probability distribution for target probabilities
|
||||||
@ -224,9 +244,11 @@ def test_temperature_zero_target_distribution(seed: int,
|
|||||||
# 1.0 tokens in the target distribution we will reject all of them and
|
# 1.0 tokens in the target distribution we will reject all of them and
|
||||||
# fallback to the greedy sampling for selecting 1 token for each sequence.
|
# fallback to the greedy sampling for selecting 1 token for each sequence.
|
||||||
# Verify the same.
|
# Verify the same.
|
||||||
output_token_ids = typical_acceptance_sampler(target_probs,
|
output_token_ids = typical_acceptance_sampler(
|
||||||
|
target_probs,
|
||||||
bonus_token_ids,
|
bonus_token_ids,
|
||||||
draft_token_ids)
|
draft_probs=None,
|
||||||
|
draft_token_ids=draft_token_ids)
|
||||||
assert output_token_ids.shape[0] == batch_size
|
assert output_token_ids.shape[0] == batch_size
|
||||||
assert output_token_ids.shape[1] == (k + 1)
|
assert output_token_ids.shape[1] == (k + 1)
|
||||||
assert torch.all(output_token_ids[:, -1] == -1)
|
assert torch.all(output_token_ids[:, -1] == -1)
|
||||||
@ -261,7 +283,7 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
|
|||||||
batch_size = 4
|
batch_size = 4
|
||||||
vocab_size = 30_000
|
vocab_size = 30_000
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
typical_acceptance_sampler = TypicalAcceptanceSampler(
|
typical_acceptance_sampler = get_acceptance_sampler(
|
||||||
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
|
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
|
||||||
typical_acceptance_sampler.init_gpu_tensors(rank=0)
|
typical_acceptance_sampler.init_gpu_tensors(rank=0)
|
||||||
# For sequences 0 and 2 set the distribution to a temperature
|
# For sequences 0 and 2 set the distribution to a temperature
|
||||||
@ -277,9 +299,11 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
|
|||||||
high=vocab_size,
|
high=vocab_size,
|
||||||
size=(batch_size, 1),
|
size=(batch_size, 1),
|
||||||
dtype=torch.int64)
|
dtype=torch.int64)
|
||||||
output_token_ids = typical_acceptance_sampler(target_probs,
|
output_token_ids = typical_acceptance_sampler(
|
||||||
|
target_probs,
|
||||||
bonus_token_ids,
|
bonus_token_ids,
|
||||||
draft_token_ids)
|
draft_probs=None,
|
||||||
|
draft_token_ids=draft_token_ids)
|
||||||
# verify the shape of output_token_ids
|
# verify the shape of output_token_ids
|
||||||
assert output_token_ids.shape[0] == batch_size
|
assert output_token_ids.shape[0] == batch_size
|
||||||
assert output_token_ids.shape[1] == (k + 1)
|
assert output_token_ids.shape[1] == (k + 1)
|
||||||
@ -326,7 +350,7 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
|
|||||||
batch_size = 1
|
batch_size = 1
|
||||||
vocab_size = 30_000
|
vocab_size = 30_000
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
typical_acceptance_sampler = TypicalAcceptanceSampler(
|
typical_acceptance_sampler = get_acceptance_sampler(
|
||||||
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
|
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
|
||||||
typical_acceptance_sampler.init_gpu_tensors(rank=0)
|
typical_acceptance_sampler.init_gpu_tensors(rank=0)
|
||||||
# Create a temperature zero target probability distribution and ensure
|
# Create a temperature zero target probability distribution and ensure
|
||||||
@ -339,9 +363,11 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
|
|||||||
high=vocab_size,
|
high=vocab_size,
|
||||||
size=(batch_size, 1),
|
size=(batch_size, 1),
|
||||||
dtype=torch.int64)
|
dtype=torch.int64)
|
||||||
output_token_ids = typical_acceptance_sampler(target_probs,
|
output_token_ids = typical_acceptance_sampler(
|
||||||
|
target_probs,
|
||||||
bonus_token_ids,
|
bonus_token_ids,
|
||||||
draft_token_ids)
|
draft_probs=None,
|
||||||
|
draft_token_ids=draft_token_ids)
|
||||||
assert output_token_ids.shape[0] == batch_size
|
assert output_token_ids.shape[0] == batch_size
|
||||||
assert output_token_ids.shape[1] == (k + 1)
|
assert output_token_ids.shape[1] == (k + 1)
|
||||||
assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids)
|
assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids)
|
||||||
@ -357,9 +383,11 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
|
|||||||
batch_size, k, vocab_size, zero_temperature_token_ids)
|
batch_size, k, vocab_size, zero_temperature_token_ids)
|
||||||
draft_token_ids = torch.cat(
|
draft_token_ids = torch.cat(
|
||||||
(draft_token_ids[:, :2], draft_token_ids_to_replace[:, -3:]), dim=1)
|
(draft_token_ids[:, :2], draft_token_ids_to_replace[:, -3:]), dim=1)
|
||||||
output_token_ids = typical_acceptance_sampler(target_probs,
|
output_token_ids = typical_acceptance_sampler(
|
||||||
|
target_probs,
|
||||||
bonus_token_ids,
|
bonus_token_ids,
|
||||||
draft_token_ids)
|
draft_probs=None,
|
||||||
|
draft_token_ids=draft_token_ids)
|
||||||
assert output_token_ids.shape[0] == batch_size
|
assert output_token_ids.shape[0] == batch_size
|
||||||
assert output_token_ids.shape[1] == (k + 1)
|
assert output_token_ids.shape[1] == (k + 1)
|
||||||
assert torch.all(output_token_ids[:, :2] == draft_token_ids[:, :2])
|
assert torch.all(output_token_ids[:, :2] == draft_token_ids[:, :2])
|
||||||
@ -384,7 +412,7 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
|
|||||||
batch_size = 1
|
batch_size = 1
|
||||||
vocab_size = 30_000
|
vocab_size = 30_000
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
typical_acceptance_sampler = TypicalAcceptanceSampler(
|
typical_acceptance_sampler = get_acceptance_sampler(
|
||||||
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
|
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
|
||||||
typical_acceptance_sampler.init_gpu_tensors(rank=0)
|
typical_acceptance_sampler.init_gpu_tensors(rank=0)
|
||||||
# Simulate temperature 0 probability distribution for target
|
# Simulate temperature 0 probability distribution for target
|
||||||
@ -402,9 +430,11 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
|
|||||||
high=vocab_size,
|
high=vocab_size,
|
||||||
size=(batch_size, 1),
|
size=(batch_size, 1),
|
||||||
dtype=torch.int64)
|
dtype=torch.int64)
|
||||||
output_token_ids = typical_acceptance_sampler(target_probs,
|
output_token_ids = typical_acceptance_sampler(
|
||||||
|
target_probs,
|
||||||
bonus_token_ids,
|
bonus_token_ids,
|
||||||
draft_token_ids)
|
draft_probs=None,
|
||||||
|
draft_token_ids=draft_token_ids)
|
||||||
assert output_token_ids.shape[0] == batch_size
|
assert output_token_ids.shape[0] == batch_size
|
||||||
assert output_token_ids.shape[1] == (k + 1)
|
assert output_token_ids.shape[1] == (k + 1)
|
||||||
assert torch.all(output_token_ids[:, 1:-1] == -1)
|
assert torch.all(output_token_ids[:, 1:-1] == -1)
|
||||||
@ -418,9 +448,11 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
|
|||||||
posterior_threshold=0.0,
|
posterior_threshold=0.0,
|
||||||
posterior_alpha=0.0)
|
posterior_alpha=0.0)
|
||||||
typical_acceptance_sampler.init_gpu_tensors(rank=0)
|
typical_acceptance_sampler.init_gpu_tensors(rank=0)
|
||||||
output_token_ids = typical_acceptance_sampler(target_probs,
|
output_token_ids = typical_acceptance_sampler(
|
||||||
|
target_probs,
|
||||||
bonus_token_ids,
|
bonus_token_ids,
|
||||||
draft_token_ids)
|
draft_probs=None,
|
||||||
|
draft_token_ids=draft_token_ids)
|
||||||
assert output_token_ids.shape[0] == batch_size
|
assert output_token_ids.shape[0] == batch_size
|
||||||
assert output_token_ids.shape[1] == (k + 1)
|
assert output_token_ids.shape[1] == (k + 1)
|
||||||
assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids)
|
assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids)
|
||||||
@ -451,7 +483,7 @@ def test_replacement_token_ids(seed: int, disable_bonus_tokens: bool,
|
|||||||
batch_size = 5
|
batch_size = 5
|
||||||
vocab_size = 30_000
|
vocab_size = 30_000
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
typical_acceptance_sampler = TypicalAcceptanceSampler(
|
typical_acceptance_sampler = get_acceptance_sampler(
|
||||||
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
|
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
|
||||||
typical_acceptance_sampler.init_gpu_tensors(rank=0)
|
typical_acceptance_sampler.init_gpu_tensors(rank=0)
|
||||||
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||||
|
@ -11,9 +11,15 @@ distribution matches the target model's output distribution (up to hardware
|
|||||||
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
|
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
|
||||||
equality. This gives us good coverage of temp=0.
|
equality. This gives us good coverage of temp=0.
|
||||||
|
|
||||||
|
At temp=0, the TypicalAcceptanceSampler ensures that only the tokens with the
|
||||||
|
highest probability in the target distribution are accepted. Therefore, we can
|
||||||
|
expect greedy equality for the TypicalAcceptanceSampler at temp=0.
|
||||||
|
|
||||||
For temp>0, we rely on unit tests on the rejection sampler to verify that the
|
For temp>0, we rely on unit tests on the rejection sampler to verify that the
|
||||||
output distribution is the same with spec decode vs. no spec decode (this would
|
output distribution is the same with spec decode vs. no spec decode (this would
|
||||||
be prohibitively expensive to run with a real model).
|
be prohibitively expensive to run with a real model). Similarly, for the
|
||||||
|
TypicalAcceptance sampler also, we rely on unit tests to validate temp>0
|
||||||
|
test cases.
|
||||||
|
|
||||||
NOTE: Speculative decoding's distribution equality requires that the measured
|
NOTE: Speculative decoding's distribution equality requires that the measured
|
||||||
distributions of the target model and proposal model be deterministic given the
|
distributions of the target model and proposal model be deterministic given the
|
||||||
@ -611,3 +617,49 @@ def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int,
|
|||||||
batch_size,
|
batch_size,
|
||||||
max_output_len=output_len,
|
max_output_len=output_len,
|
||||||
force_output_len=True)
|
force_output_len=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"common_llm_kwargs",
|
||||||
|
[{
|
||||||
|
"model": "JackFram/llama-160m",
|
||||||
|
|
||||||
|
# Skip cuda graph recording for fast test.
|
||||||
|
"enforce_eager": True,
|
||||||
|
|
||||||
|
# Required for spec decode.
|
||||||
|
"use_v2_block_manager": True
|
||||||
|
}])
|
||||||
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"test_llm_kwargs",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"speculative_model": "JackFram/llama-68m",
|
||||||
|
"num_speculative_tokens": k,
|
||||||
|
"spec_decoding_acceptance_method": "typical_acceptance_sampler"
|
||||||
|
}
|
||||||
|
# Try a range of common k.
|
||||||
|
for k in [1, 2, 3]
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"output_len",
|
||||||
|
[
|
||||||
|
# Use smaller output len for fast test.
|
||||||
|
32,
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
def test_typical_acceptance_sampling(baseline_llm_generator,
|
||||||
|
test_llm_generator, batch_size: int,
|
||||||
|
output_len: int):
|
||||||
|
"""Verify that speculative decoding produces exact equality to without spec
|
||||||
|
decode with TypicalAcceptanceSampler as the draft token acceptance
|
||||||
|
sampling method.
|
||||||
|
"""
|
||||||
|
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||||
|
test_llm_generator,
|
||||||
|
batch_size,
|
||||||
|
max_output_len=output_len,
|
||||||
|
force_output_len=True)
|
||||||
|
@ -3,33 +3,35 @@ from unittest.mock import MagicMock, patch
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
|
||||||
from vllm.sequence import ExecuteModelRequest
|
from vllm.sequence import ExecuteModelRequest
|
||||||
from vllm.spec_decode.metrics import AsyncMetricsCollector
|
from vllm.spec_decode.metrics import AsyncMetricsCollector
|
||||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||||
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker
|
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker
|
||||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||||
|
|
||||||
|
from .test_utils import mock_spec_decode_sampler
|
||||||
from .utils import create_batch, mock_worker
|
from .utils import create_batch, mock_worker
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('queue_size', [4])
|
@pytest.mark.parametrize('queue_size', [4])
|
||||||
@pytest.mark.parametrize('batch_size', [1])
|
@pytest.mark.parametrize('batch_size', [1])
|
||||||
@pytest.mark.parametrize('k', [1])
|
@pytest.mark.parametrize('k', [1])
|
||||||
|
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||||
|
["rejection_sampler", "typical_acceptance_sampler"])
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int):
|
def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int,
|
||||||
|
acceptance_sampler_method: str):
|
||||||
"""Verify that speculative tokens are disabled when the batch size
|
"""Verify that speculative tokens are disabled when the batch size
|
||||||
exceeds the threshold.
|
exceeds the threshold.
|
||||||
"""
|
"""
|
||||||
disable_by_batch_size = 3
|
disable_by_batch_size = 3
|
||||||
|
|
||||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||||
target_worker = mock_worker()
|
target_worker = mock_worker()
|
||||||
rejection_sampler = MagicMock(spec=RejectionSampler)
|
|
||||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||||
worker = SpecDecodeWorker(proposer_worker=draft_worker,
|
worker = SpecDecodeWorker(proposer_worker=draft_worker,
|
||||||
scorer_worker=target_worker,
|
scorer_worker=target_worker,
|
||||||
rejection_sampler=rejection_sampler,
|
spec_decode_sampler=mock_spec_decode_sampler(
|
||||||
|
acceptance_sampler_method),
|
||||||
metrics_collector=metrics_collector,
|
metrics_collector=metrics_collector,
|
||||||
disable_by_batch_size=disable_by_batch_size)
|
disable_by_batch_size=disable_by_batch_size)
|
||||||
|
|
||||||
|
@ -10,16 +10,16 @@ from vllm.spec_decode.metrics import AsyncMetricsCollector
|
|||||||
def test_initial_call_returns_none():
|
def test_initial_call_returns_none():
|
||||||
"""Expect first call to get metrics to return None.
|
"""Expect first call to get metrics to return None.
|
||||||
"""
|
"""
|
||||||
rej_sampler = MagicMock()
|
spec_decode_sampler = MagicMock()
|
||||||
rej_sampler.num_accepted_tokens = torch.tensor(0,
|
spec_decode_sampler.num_accepted_tokens = torch.tensor(0,
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
device='cuda')
|
device='cuda')
|
||||||
rej_sampler.num_emitted_tokens = torch.tensor(0,
|
spec_decode_sampler.num_emitted_tokens = torch.tensor(0,
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
device='cuda')
|
device='cuda')
|
||||||
rej_sampler.num_draft_tokens = 0
|
spec_decode_sampler.num_draft_tokens = 0
|
||||||
|
|
||||||
collector = AsyncMetricsCollector(rej_sampler)
|
collector = AsyncMetricsCollector(spec_decode_sampler)
|
||||||
collector.init_gpu_tensors(rank=0)
|
collector.init_gpu_tensors(rank=0)
|
||||||
maybe_metrics = collector.maybe_collect_rejsample_metrics(k=5)
|
maybe_metrics = collector.maybe_collect_rejsample_metrics(k=5)
|
||||||
assert maybe_metrics is None
|
assert maybe_metrics is None
|
||||||
@ -28,14 +28,14 @@ def test_initial_call_returns_none():
|
|||||||
def test_second_call_returns_metrics():
|
def test_second_call_returns_metrics():
|
||||||
"""Expect second call to not return None.
|
"""Expect second call to not return None.
|
||||||
"""
|
"""
|
||||||
rej_sampler = MagicMock()
|
spec_decode_sampler = MagicMock()
|
||||||
rej_sampler.num_accepted_tokens = torch.tensor(0,
|
spec_decode_sampler.num_accepted_tokens = torch.tensor(0,
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
device='cuda')
|
device='cuda')
|
||||||
rej_sampler.num_emitted_tokens = torch.tensor(0,
|
spec_decode_sampler.num_emitted_tokens = torch.tensor(0,
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
device='cuda')
|
device='cuda')
|
||||||
rej_sampler.num_draft_tokens = 0
|
spec_decode_sampler.num_draft_tokens = 0
|
||||||
|
|
||||||
collect_interval_s = 5.0
|
collect_interval_s = 5.0
|
||||||
timer = MagicMock()
|
timer = MagicMock()
|
||||||
@ -43,7 +43,7 @@ def test_second_call_returns_metrics():
|
|||||||
0.0, collect_interval_s + 0.1, collect_interval_s + 0.2
|
0.0, collect_interval_s + 0.1, collect_interval_s + 0.2
|
||||||
]
|
]
|
||||||
|
|
||||||
collector = AsyncMetricsCollector(rejection_sampler=rej_sampler,
|
collector = AsyncMetricsCollector(spec_decode_sampler=spec_decode_sampler,
|
||||||
timer=timer,
|
timer=timer,
|
||||||
collect_interval_s=collect_interval_s)
|
collect_interval_s=collect_interval_s)
|
||||||
collector.init_gpu_tensors(rank=0)
|
collector.init_gpu_tensors(rank=0)
|
||||||
@ -56,16 +56,16 @@ def test_second_call_returns_metrics():
|
|||||||
def test_nonzero_rank_noop(rank):
|
def test_nonzero_rank_noop(rank):
|
||||||
"""Verify nonzero ranks don't collect metrics.
|
"""Verify nonzero ranks don't collect metrics.
|
||||||
"""
|
"""
|
||||||
rej_sampler = MagicMock()
|
spec_decode_sampler = MagicMock()
|
||||||
rej_sampler.num_accepted_tokens = torch.tensor(0,
|
spec_decode_sampler.num_accepted_tokens = torch.tensor(0,
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
device='cuda')
|
device='cuda')
|
||||||
rej_sampler.num_emitted_tokens = torch.tensor(0,
|
spec_decode_sampler.num_emitted_tokens = torch.tensor(0,
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
device='cuda')
|
device='cuda')
|
||||||
rej_sampler.num_draft_tokens = 0
|
spec_decode_sampler.num_draft_tokens = 0
|
||||||
|
|
||||||
collector = AsyncMetricsCollector(rej_sampler)
|
collector = AsyncMetricsCollector(spec_decode_sampler)
|
||||||
collector.init_gpu_tensors(rank=rank)
|
collector.init_gpu_tensors(rank=rank)
|
||||||
_ = collector.maybe_collect_rejsample_metrics(k=5)
|
_ = collector.maybe_collect_rejsample_metrics(k=5)
|
||||||
metrics = collector.maybe_collect_rejsample_metrics(k=5)
|
metrics = collector.maybe_collect_rejsample_metrics(k=5)
|
||||||
@ -75,14 +75,14 @@ def test_nonzero_rank_noop(rank):
|
|||||||
def test_noop_until_time():
|
def test_noop_until_time():
|
||||||
"""Verify metrics aren't collected until enough time passes.
|
"""Verify metrics aren't collected until enough time passes.
|
||||||
"""
|
"""
|
||||||
rej_sampler = MagicMock()
|
spec_decode_sampler = MagicMock()
|
||||||
rej_sampler.num_accepted_tokens = torch.tensor(0,
|
spec_decode_sampler.num_accepted_tokens = torch.tensor(0,
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
device='cuda')
|
device='cuda')
|
||||||
rej_sampler.num_emitted_tokens = torch.tensor(0,
|
spec_decode_sampler.num_emitted_tokens = torch.tensor(0,
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
device='cuda')
|
device='cuda')
|
||||||
rej_sampler.num_draft_tokens = 0
|
spec_decode_sampler.num_draft_tokens = 0
|
||||||
|
|
||||||
collect_interval_s = 5.0
|
collect_interval_s = 5.0
|
||||||
timer = MagicMock()
|
timer = MagicMock()
|
||||||
@ -91,7 +91,7 @@ def test_noop_until_time():
|
|||||||
collect_interval_s + 0.1, collect_interval_s + 0.1
|
collect_interval_s + 0.1, collect_interval_s + 0.1
|
||||||
]
|
]
|
||||||
|
|
||||||
collector = AsyncMetricsCollector(rejection_sampler=rej_sampler,
|
collector = AsyncMetricsCollector(spec_decode_sampler=spec_decode_sampler,
|
||||||
timer=timer,
|
timer=timer,
|
||||||
collect_interval_s=collect_interval_s)
|
collect_interval_s=collect_interval_s)
|
||||||
collector.init_gpu_tensors(rank=0)
|
collector.init_gpu_tensors(rank=0)
|
||||||
@ -122,14 +122,14 @@ def test_initial_metrics_has_correct_values(has_data: bool):
|
|||||||
max_num_emitted_tokens = AsyncMetricsCollector.get_max_num_emitted_tokens(
|
max_num_emitted_tokens = AsyncMetricsCollector.get_max_num_emitted_tokens(
|
||||||
num_draft_tokens, k)
|
num_draft_tokens, k)
|
||||||
|
|
||||||
rej_sampler = MagicMock()
|
spec_decode_sampler = MagicMock()
|
||||||
rej_sampler.num_accepted_tokens = torch.tensor(num_accepted_tokens,
|
spec_decode_sampler.num_accepted_tokens = torch.tensor(num_accepted_tokens,
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
device='cuda')
|
device='cuda')
|
||||||
rej_sampler.num_emitted_tokens = torch.tensor(num_emitted_tokens,
|
spec_decode_sampler.num_emitted_tokens = torch.tensor(num_emitted_tokens,
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
device='cuda')
|
device='cuda')
|
||||||
rej_sampler.num_draft_tokens = num_draft_tokens
|
spec_decode_sampler.num_draft_tokens = num_draft_tokens
|
||||||
|
|
||||||
collect_interval_s = 5.0
|
collect_interval_s = 5.0
|
||||||
timer = MagicMock()
|
timer = MagicMock()
|
||||||
@ -137,7 +137,7 @@ def test_initial_metrics_has_correct_values(has_data: bool):
|
|||||||
0.0, collect_interval_s + 0.1, collect_interval_s + 0.2
|
0.0, collect_interval_s + 0.1, collect_interval_s + 0.2
|
||||||
]
|
]
|
||||||
|
|
||||||
collector = AsyncMetricsCollector(rejection_sampler=rej_sampler,
|
collector = AsyncMetricsCollector(spec_decode_sampler=spec_decode_sampler,
|
||||||
timer=timer,
|
timer=timer,
|
||||||
collect_interval_s=collect_interval_s)
|
collect_interval_s=collect_interval_s)
|
||||||
collector.init_gpu_tensors(rank=0)
|
collector.init_gpu_tensors(rank=0)
|
||||||
|
@ -6,7 +6,6 @@ from unittest.mock import MagicMock
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
|
||||||
from vllm.model_executor.utils import set_random_seed
|
from vllm.model_executor.utils import set_random_seed
|
||||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput, SequenceOutput
|
from vllm.sequence import ExecuteModelRequest, SamplerOutput, SequenceOutput
|
||||||
from vllm.spec_decode.interfaces import SpeculativeProposals
|
from vllm.spec_decode.interfaces import SpeculativeProposals
|
||||||
@ -16,23 +15,26 @@ from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
|||||||
from vllm.spec_decode.spec_decode_worker import (SpecDecodeWorker,
|
from vllm.spec_decode.spec_decode_worker import (SpecDecodeWorker,
|
||||||
split_num_cache_blocks_evenly)
|
split_num_cache_blocks_evenly)
|
||||||
|
|
||||||
|
from .test_utils import mock_spec_decode_sampler
|
||||||
from .utils import create_batch, create_sampler_output_list, mock_worker
|
from .utils import create_batch, create_sampler_output_list, mock_worker
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('k', [1, 2, 6])
|
@pytest.mark.parametrize('k', [1, 2, 6])
|
||||||
@pytest.mark.parametrize('batch_size', [1, 2, 32])
|
@pytest.mark.parametrize('batch_size', [1, 2, 32])
|
||||||
|
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||||
|
["rejection_sampler", "typical_acceptance_sampler"])
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_correctly_calls_draft_model(k: int, batch_size: int):
|
def test_correctly_calls_draft_model(k: int, batch_size: int,
|
||||||
|
acceptance_sampler_method: str):
|
||||||
"""Verify SpecDecodeWorker calls the draft worker with correct
|
"""Verify SpecDecodeWorker calls the draft worker with correct
|
||||||
inputs. Everything else is mocked out.
|
inputs. Everything else is mocked out.
|
||||||
"""
|
"""
|
||||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||||
target_worker = mock_worker()
|
target_worker = mock_worker()
|
||||||
rejection_sampler = MagicMock(spec=RejectionSampler)
|
|
||||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||||
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
worker = SpecDecodeWorker(
|
||||||
metrics_collector)
|
draft_worker, target_worker,
|
||||||
|
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
|
||||||
exception_secret = 'artificial stop'
|
exception_secret = 'artificial stop'
|
||||||
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
|
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
|
||||||
|
|
||||||
@ -53,15 +55,16 @@ def test_correctly_calls_draft_model(k: int, batch_size: int):
|
|||||||
|
|
||||||
@pytest.mark.parametrize('k', [1, 2, 6])
|
@pytest.mark.parametrize('k', [1, 2, 6])
|
||||||
@pytest.mark.parametrize('batch_size', [1, 2, 32])
|
@pytest.mark.parametrize('batch_size', [1, 2, 32])
|
||||||
|
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||||
|
["rejection_sampler", "typical_acceptance_sampler"])
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_correctly_calls_target_model(k: int, batch_size: int):
|
def test_correctly_calls_target_model(k: int, batch_size: int,
|
||||||
|
acceptance_sampler_method: str):
|
||||||
"""Verify SpecDecodeWorker calls the target model with correct
|
"""Verify SpecDecodeWorker calls the target model with correct
|
||||||
inputs. Everything else is mocked out.
|
inputs. Everything else is mocked out.
|
||||||
"""
|
"""
|
||||||
draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False)
|
draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False)
|
||||||
target_worker = mock_worker(use_spec=False)
|
target_worker = mock_worker(use_spec=False)
|
||||||
rejection_sampler = MagicMock(spec=RejectionSampler)
|
|
||||||
rejection_sampler.token_id_dtype = torch.int64
|
|
||||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||||
|
|
||||||
draft_worker.device = 'cuda'
|
draft_worker.device = 'cuda'
|
||||||
@ -69,8 +72,9 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
|
|||||||
|
|
||||||
set_random_seed(1)
|
set_random_seed(1)
|
||||||
|
|
||||||
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
worker = SpecDecodeWorker(
|
||||||
metrics_collector)
|
draft_worker, target_worker,
|
||||||
|
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
|
||||||
worker.init_device()
|
worker.init_device()
|
||||||
|
|
||||||
vocab_size = 32_000
|
vocab_size = 32_000
|
||||||
@ -133,8 +137,11 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
|
|||||||
|
|
||||||
@pytest.mark.parametrize('k', [1, 2, 6])
|
@pytest.mark.parametrize('k', [1, 2, 6])
|
||||||
@pytest.mark.parametrize('batch_size', [1, 2, 32])
|
@pytest.mark.parametrize('batch_size', [1, 2, 32])
|
||||||
|
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||||
|
["rejection_sampler", "typical_acceptance_sampler"])
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
|
def test_correctly_calls_spec_decode_sampler(k: int, batch_size: int,
|
||||||
|
acceptance_sampler_method: str):
|
||||||
"""Verify SpecDecodeWorker calls the rejection sampler with
|
"""Verify SpecDecodeWorker calls the rejection sampler with
|
||||||
correct inputs. Everything else is mocked out.
|
correct inputs. Everything else is mocked out.
|
||||||
"""
|
"""
|
||||||
@ -144,15 +151,14 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
|
|||||||
vocab_size=vocab_size,
|
vocab_size=vocab_size,
|
||||||
use_spec=False)
|
use_spec=False)
|
||||||
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
|
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
|
||||||
rejection_sampler = MagicMock(spec=RejectionSampler)
|
spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
|
||||||
rejection_sampler.token_id_dtype = torch.int64
|
|
||||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||||
draft_worker.device = 'cuda'
|
draft_worker.device = 'cuda'
|
||||||
target_worker.device = 'cuda'
|
target_worker.device = 'cuda'
|
||||||
|
|
||||||
set_random_seed(1)
|
set_random_seed(1)
|
||||||
|
|
||||||
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler,
|
||||||
metrics_collector)
|
metrics_collector)
|
||||||
worker.init_device()
|
worker.init_device()
|
||||||
|
|
||||||
@ -199,15 +205,16 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
|
|||||||
target_worker.execute_model.return_value = [target_output[0]]
|
target_worker.execute_model.return_value = [target_output[0]]
|
||||||
|
|
||||||
exception_secret = 'artificial stop'
|
exception_secret = 'artificial stop'
|
||||||
rejection_sampler.side_effect = ValueError(exception_secret)
|
|
||||||
|
spec_decode_sampler.side_effect = ValueError(exception_secret)
|
||||||
|
|
||||||
with pytest.raises(ValueError, match=exception_secret):
|
with pytest.raises(ValueError, match=exception_secret):
|
||||||
worker.execute_model(execute_model_req=ExecuteModelRequest(
|
worker.execute_model(execute_model_req=ExecuteModelRequest(
|
||||||
seq_group_metadata_list=seq_group_metadata_list,
|
seq_group_metadata_list=seq_group_metadata_list,
|
||||||
num_lookahead_slots=k))
|
num_lookahead_slots=k))
|
||||||
|
|
||||||
assert len(rejection_sampler.call_args_list) == 1
|
assert len(spec_decode_sampler.call_args_list) == 1
|
||||||
_, kwargs = rejection_sampler.call_args_list[0]
|
_, kwargs = spec_decode_sampler.call_args_list[0]
|
||||||
actual = SimpleNamespace(**kwargs)
|
actual = SimpleNamespace(**kwargs)
|
||||||
|
|
||||||
assert torch.equal(actual.bonus_token_ids,
|
assert torch.equal(actual.bonus_token_ids,
|
||||||
@ -221,8 +228,11 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
|
|||||||
|
|
||||||
@pytest.mark.parametrize('k', [1, 2, 6])
|
@pytest.mark.parametrize('k', [1, 2, 6])
|
||||||
@pytest.mark.parametrize('batch_size', [1, 2, 32])
|
@pytest.mark.parametrize('batch_size', [1, 2, 32])
|
||||||
|
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||||
|
["rejection_sampler", "typical_acceptance_sampler"])
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_correctly_formats_output(k: int, batch_size: int):
|
def test_correctly_formats_output(k: int, batch_size: int,
|
||||||
|
acceptance_sampler_method: str):
|
||||||
"""Verify SpecDecodeWorker formats sampler output correctly.
|
"""Verify SpecDecodeWorker formats sampler output correctly.
|
||||||
Everything else is mocked out.
|
Everything else is mocked out.
|
||||||
"""
|
"""
|
||||||
@ -232,15 +242,13 @@ def test_correctly_formats_output(k: int, batch_size: int):
|
|||||||
vocab_size=vocab_size,
|
vocab_size=vocab_size,
|
||||||
use_spec=False)
|
use_spec=False)
|
||||||
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
|
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
|
||||||
rejection_sampler = MagicMock(spec=RejectionSampler)
|
|
||||||
rejection_sampler.token_id_dtype = torch.int64
|
|
||||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||||
draft_worker.device = 'cuda'
|
draft_worker.device = 'cuda'
|
||||||
target_worker.device = 'cuda'
|
target_worker.device = 'cuda'
|
||||||
|
|
||||||
set_random_seed(1)
|
set_random_seed(1)
|
||||||
|
spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
|
||||||
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler,
|
||||||
metrics_collector)
|
metrics_collector)
|
||||||
worker.init_device()
|
worker.init_device()
|
||||||
|
|
||||||
@ -286,24 +294,23 @@ def test_correctly_formats_output(k: int, batch_size: int):
|
|||||||
|
|
||||||
target_worker.execute_model.return_value = [target_output[0]]
|
target_worker.execute_model.return_value = [target_output[0]]
|
||||||
|
|
||||||
rejection_sampler_output = torch.randint(low=0,
|
spec_decode_sampler_output = torch.randint(low=0,
|
||||||
high=vocab_size,
|
high=vocab_size,
|
||||||
size=(batch_size, k + 1),
|
size=(batch_size, k + 1),
|
||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
device='cuda')
|
device='cuda')
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
minimum_accepted_tokens = 1
|
minimum_accepted_tokens = 1
|
||||||
rejection_sampler_output[i][
|
spec_decode_sampler_output[i][
|
||||||
-random.randint(minimum_accepted_tokens, k + 1):] = -1
|
-random.randint(minimum_accepted_tokens, k + 1):] = -1
|
||||||
|
|
||||||
rejection_sampler.return_value = rejection_sampler_output
|
spec_decode_sampler.return_value = spec_decode_sampler_output
|
||||||
|
|
||||||
output = worker.execute_model(execute_model_req=ExecuteModelRequest(
|
output = worker.execute_model(execute_model_req=ExecuteModelRequest(
|
||||||
seq_group_metadata_list=seq_group_metadata_list,
|
seq_group_metadata_list=seq_group_metadata_list,
|
||||||
num_lookahead_slots=k))
|
num_lookahead_slots=k))
|
||||||
|
|
||||||
expected_output = create_sampler_output_list(
|
expected_output = create_sampler_output_list(
|
||||||
token_ids=rejection_sampler_output.transpose(0, 1),
|
token_ids=spec_decode_sampler_output.transpose(0, 1),
|
||||||
probs=[None for _ in range(k + 1)],
|
probs=[None for _ in range(k + 1)],
|
||||||
logprobs=[None for _ in range(k + 1)])
|
logprobs=[None for _ in range(k + 1)])
|
||||||
|
|
||||||
@ -350,8 +357,11 @@ def test_correctly_formats_output(k: int, batch_size: int):
|
|||||||
@pytest.mark.parametrize('k', [1, 2])
|
@pytest.mark.parametrize('k', [1, 2])
|
||||||
@pytest.mark.parametrize('batch_size', [1])
|
@pytest.mark.parametrize('batch_size', [1])
|
||||||
@pytest.mark.parametrize('returns_metrics', [True, False])
|
@pytest.mark.parametrize('returns_metrics', [True, False])
|
||||||
|
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||||
|
["rejection_sampler", "typical_acceptance_sampler"])
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
|
def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool,
|
||||||
|
acceptance_sampler_method: str):
|
||||||
"""Verify SpecDecodeWorker collects metrics.
|
"""Verify SpecDecodeWorker collects metrics.
|
||||||
"""
|
"""
|
||||||
vocab_size = 32_000
|
vocab_size = 32_000
|
||||||
@ -360,15 +370,14 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
|
|||||||
vocab_size=vocab_size,
|
vocab_size=vocab_size,
|
||||||
use_spec=False)
|
use_spec=False)
|
||||||
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
|
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
|
||||||
rejection_sampler = MagicMock(spec=RejectionSampler)
|
spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
|
||||||
rejection_sampler.token_id_dtype = torch.int64
|
|
||||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||||
draft_worker.device = 'cuda'
|
draft_worker.device = 'cuda'
|
||||||
target_worker.device = 'cuda'
|
target_worker.device = 'cuda'
|
||||||
|
|
||||||
set_random_seed(1)
|
set_random_seed(1)
|
||||||
|
|
||||||
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler,
|
||||||
metrics_collector)
|
metrics_collector)
|
||||||
worker.init_device()
|
worker.init_device()
|
||||||
|
|
||||||
@ -414,17 +423,16 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
|
|||||||
|
|
||||||
target_worker.execute_model.return_value = [target_output[0]]
|
target_worker.execute_model.return_value = [target_output[0]]
|
||||||
|
|
||||||
rejection_sampler_output = torch.randint(low=0,
|
spec_decode_sampler_output = torch.randint(low=0,
|
||||||
high=vocab_size,
|
high=vocab_size,
|
||||||
size=(batch_size, k + 1),
|
size=(batch_size, k + 1),
|
||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
device='cuda')
|
device='cuda')
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
minimum_accepted_tokens = 1
|
minimum_accepted_tokens = 1
|
||||||
rejection_sampler_output[i][
|
spec_decode_sampler_output[i][
|
||||||
-random.randint(minimum_accepted_tokens, k + 1):] = -1
|
-random.randint(minimum_accepted_tokens, k + 1):] = -1
|
||||||
|
spec_decode_sampler.return_value = spec_decode_sampler_output
|
||||||
rejection_sampler.return_value = rejection_sampler_output
|
|
||||||
|
|
||||||
mock_rejsample_metrics = MagicMock(
|
mock_rejsample_metrics = MagicMock(
|
||||||
spec=SpecDecodeWorkerMetrics) if returns_metrics else None
|
spec=SpecDecodeWorkerMetrics) if returns_metrics else None
|
||||||
@ -445,15 +453,16 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
|
|||||||
|
|
||||||
@pytest.mark.parametrize('k', [0])
|
@pytest.mark.parametrize('k', [0])
|
||||||
@pytest.mark.parametrize('batch_size', [1, 2, 32])
|
@pytest.mark.parametrize('batch_size', [1, 2, 32])
|
||||||
|
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||||
|
["rejection_sampler", "typical_acceptance_sampler"])
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_k_equals_zero(k: int, batch_size: int):
|
def test_k_equals_zero(k: int, batch_size: int,
|
||||||
|
acceptance_sampler_method: str):
|
||||||
"""Verify that the SpecDecodeWorker calls the draft and target workers
|
"""Verify that the SpecDecodeWorker calls the draft and target workers
|
||||||
when k is zero. This happens during prefill.
|
when k is zero. This happens during prefill.
|
||||||
"""
|
"""
|
||||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||||
target_worker = mock_worker()
|
target_worker = mock_worker()
|
||||||
rejection_sampler = MagicMock(spec=RejectionSampler)
|
|
||||||
rejection_sampler.token_id_dtype = torch.int64
|
|
||||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||||
|
|
||||||
sampler_output = MagicMock(spec=SamplerOutput)
|
sampler_output = MagicMock(spec=SamplerOutput)
|
||||||
@ -465,8 +474,9 @@ def test_k_equals_zero(k: int, batch_size: int):
|
|||||||
|
|
||||||
set_random_seed(1)
|
set_random_seed(1)
|
||||||
|
|
||||||
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
worker = SpecDecodeWorker(
|
||||||
metrics_collector)
|
draft_worker, target_worker,
|
||||||
|
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
|
||||||
|
|
||||||
seq_group_metadata_list, _, _ = create_batch(batch_size,
|
seq_group_metadata_list, _, _ = create_batch(batch_size,
|
||||||
k,
|
k,
|
||||||
@ -487,16 +497,17 @@ def test_k_equals_zero(k: int, batch_size: int):
|
|||||||
|
|
||||||
@pytest.mark.parametrize('k', [0, 5])
|
@pytest.mark.parametrize('k', [0, 5])
|
||||||
@pytest.mark.parametrize('batch_size', [0])
|
@pytest.mark.parametrize('batch_size', [0])
|
||||||
|
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||||
|
["rejection_sampler", "typical_acceptance_sampler"])
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_empty_input_batch(k: int, batch_size: int):
|
def test_empty_input_batch(k: int, batch_size: int,
|
||||||
|
acceptance_sampler_method: str):
|
||||||
"""Verify that the SpecDecodeWorker calls the draft and target workers
|
"""Verify that the SpecDecodeWorker calls the draft and target workers
|
||||||
when the input batch is empty. This can happen if the engine communicates
|
when the input batch is empty. This can happen if the engine communicates
|
||||||
to the workers information without scheduling a batch.
|
to the workers information without scheduling a batch.
|
||||||
"""
|
"""
|
||||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||||
target_worker = mock_worker()
|
target_worker = mock_worker()
|
||||||
rejection_sampler = MagicMock(spec=RejectionSampler)
|
|
||||||
rejection_sampler.token_id_dtype = torch.int64
|
|
||||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||||
|
|
||||||
sampler_output = MagicMock(spec=SamplerOutput)
|
sampler_output = MagicMock(spec=SamplerOutput)
|
||||||
@ -508,8 +519,9 @@ def test_empty_input_batch(k: int, batch_size: int):
|
|||||||
|
|
||||||
set_random_seed(1)
|
set_random_seed(1)
|
||||||
|
|
||||||
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
worker = SpecDecodeWorker(
|
||||||
metrics_collector)
|
draft_worker, target_worker,
|
||||||
|
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
|
||||||
|
|
||||||
seq_group_metadata_list, _, _ = create_batch(batch_size,
|
seq_group_metadata_list, _, _ = create_batch(batch_size,
|
||||||
k,
|
k,
|
||||||
@ -528,18 +540,19 @@ def test_empty_input_batch(k: int, batch_size: int):
|
|||||||
target_worker.execute_model.assert_called_once_with(execute_model_req)
|
target_worker.execute_model.assert_called_once_with(execute_model_req)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||||
|
["rejection_sampler", "typical_acceptance_sampler"])
|
||||||
@pytest.mark.skip_global_cleanup
|
@pytest.mark.skip_global_cleanup
|
||||||
def test_init_device():
|
def test_init_device(acceptance_sampler_method: str):
|
||||||
"""Verify SpecDecodeWorker invokes proposer/scorer worker init_device, as
|
"""Verify SpecDecodeWorker invokes proposer/scorer worker init_device, as
|
||||||
well as other GPU initialization.
|
well as other GPU initialization.
|
||||||
"""
|
"""
|
||||||
draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False)
|
draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False)
|
||||||
target_worker = mock_worker(use_spec=False)
|
target_worker = mock_worker(use_spec=False)
|
||||||
rejection_sampler = MagicMock(spec=RejectionSampler)
|
spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
|
||||||
rejection_sampler.token_id_dtype = torch.int64
|
|
||||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||||
|
|
||||||
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler,
|
||||||
metrics_collector)
|
metrics_collector)
|
||||||
|
|
||||||
worker.init_device()
|
worker.init_device()
|
||||||
@ -549,22 +562,23 @@ def test_init_device():
|
|||||||
target_worker.init_device.assert_called_once()
|
target_worker.init_device.assert_called_once()
|
||||||
|
|
||||||
metrics_collector.init_gpu_tensors.assert_called_once()
|
metrics_collector.init_gpu_tensors.assert_called_once()
|
||||||
rejection_sampler.init_gpu_tensors.assert_called_once()
|
spec_decode_sampler.init_gpu_tensors.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||||
|
["rejection_sampler", "typical_acceptance_sampler"])
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_initialize_cache():
|
def test_initialize_cache(acceptance_sampler_method):
|
||||||
"""Verify SpecDecodeWorker invokes initialize_cache on proposer/scorer
|
"""Verify SpecDecodeWorker invokes initialize_cache on proposer/scorer
|
||||||
workers.
|
workers.
|
||||||
"""
|
"""
|
||||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||||
target_worker = mock_worker()
|
target_worker = mock_worker()
|
||||||
rejection_sampler = MagicMock(spec=RejectionSampler)
|
|
||||||
rejection_sampler.token_id_dtype = torch.int64
|
|
||||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||||
|
|
||||||
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
worker = SpecDecodeWorker(
|
||||||
metrics_collector)
|
draft_worker, target_worker,
|
||||||
|
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
|
||||||
|
|
||||||
kwargs = {"num_gpu_blocks": 1024, "num_cpu_blocks": 1023}
|
kwargs = {"num_gpu_blocks": 1024, "num_cpu_blocks": 1023}
|
||||||
worker.initialize_cache(**kwargs)
|
worker.initialize_cache(**kwargs)
|
||||||
@ -577,19 +591,20 @@ def test_initialize_cache():
|
|||||||
@pytest.mark.parametrize('available_cpu_blocks', [500])
|
@pytest.mark.parametrize('available_cpu_blocks', [500])
|
||||||
@pytest.mark.parametrize('target_cache_block_size_bytes', [2 * 2 * 4096])
|
@pytest.mark.parametrize('target_cache_block_size_bytes', [2 * 2 * 4096])
|
||||||
@pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096])
|
@pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096])
|
||||||
|
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||||
|
["rejection_sampler", "typical_acceptance_sampler"])
|
||||||
@pytest.mark.skip_global_cleanup
|
@pytest.mark.skip_global_cleanup
|
||||||
def test_determine_num_available_blocks(available_gpu_blocks: int,
|
def test_determine_num_available_blocks(available_gpu_blocks: int,
|
||||||
available_cpu_blocks: int,
|
available_cpu_blocks: int,
|
||||||
target_cache_block_size_bytes: int,
|
target_cache_block_size_bytes: int,
|
||||||
draft_kv_size_bytes: int):
|
draft_kv_size_bytes: int,
|
||||||
|
acceptance_sampler_method: str):
|
||||||
"""Verify SpecDecodeWorker correctly profiles num available GPU blocks.
|
"""Verify SpecDecodeWorker correctly profiles num available GPU blocks.
|
||||||
Specifically, it should run profiling in the scorer worker, and then evenly
|
Specifically, it should run profiling in the scorer worker, and then evenly
|
||||||
split the blocks between proposer and scorer worker.
|
split the blocks between proposer and scorer worker.
|
||||||
"""
|
"""
|
||||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||||
target_worker = mock_worker()
|
target_worker = mock_worker()
|
||||||
rejection_sampler = MagicMock(spec=RejectionSampler)
|
|
||||||
rejection_sampler.token_id_dtype = torch.int64
|
|
||||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||||
|
|
||||||
target_worker.determine_num_available_blocks.return_value = (
|
target_worker.determine_num_available_blocks.return_value = (
|
||||||
@ -598,8 +613,9 @@ def test_determine_num_available_blocks(available_gpu_blocks: int,
|
|||||||
target_cache_block_size_bytes)
|
target_cache_block_size_bytes)
|
||||||
draft_worker.get_cache_block_size_bytes.return_value = draft_kv_size_bytes
|
draft_worker.get_cache_block_size_bytes.return_value = draft_kv_size_bytes
|
||||||
|
|
||||||
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
worker = SpecDecodeWorker(
|
||||||
metrics_collector)
|
draft_worker, target_worker,
|
||||||
|
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
|
||||||
|
|
||||||
num_gpu_blocks, num_cpu_blocks = worker.determine_num_available_blocks()
|
num_gpu_blocks, num_cpu_blocks = worker.determine_num_available_blocks()
|
||||||
|
|
||||||
|
@ -1,7 +1,11 @@
|
|||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||||
|
from vllm.model_executor.layers.typical_acceptance_sampler import (
|
||||||
|
TypicalAcceptanceSampler)
|
||||||
from vllm.sequence import SequenceGroupMetadata, get_all_seq_ids
|
from vllm.sequence import SequenceGroupMetadata, get_all_seq_ids
|
||||||
from vllm.spec_decode.util import split_batch_by_proposal_len
|
from vllm.spec_decode.util import split_batch_by_proposal_len
|
||||||
|
|
||||||
@ -109,3 +113,21 @@ def test_all_non_zero_with_zero_filter(fake_sequence_group_metadata):
|
|||||||
|
|
||||||
assert filtered_groups == []
|
assert filtered_groups == []
|
||||||
assert indices == []
|
assert indices == []
|
||||||
|
|
||||||
|
|
||||||
|
def mock_spec_decode_sampler(acceptance_sampler_method):
|
||||||
|
"""
|
||||||
|
Returns either a RejectionSampler or TypicalAcceptanceSampler
|
||||||
|
object depending on whether acceptance_sampler_method is
|
||||||
|
'rejection_sampler' or 'typical_acceptance_sampler' respectively.
|
||||||
|
"""
|
||||||
|
if acceptance_sampler_method == "rejection_sampler":
|
||||||
|
sampler = MagicMock(spec=RejectionSampler)
|
||||||
|
sampler.token_id_dtype = torch.int64
|
||||||
|
return sampler
|
||||||
|
elif acceptance_sampler_method == "typical_acceptance_sampler":
|
||||||
|
sampler = MagicMock(spec=TypicalAcceptanceSampler)
|
||||||
|
sampler.token_id_dtype = torch.int64
|
||||||
|
return sampler
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid sampler name {acceptance_sampler_method}")
|
||||||
|
@ -753,7 +753,6 @@ class SchedulerConfig:
|
|||||||
self.chunked_prefill_enabled = enable_chunked_prefill
|
self.chunked_prefill_enabled = enable_chunked_prefill
|
||||||
self.embedding_mode = embedding_mode
|
self.embedding_mode = embedding_mode
|
||||||
self.preemption_mode = preemption_mode
|
self.preemption_mode = preemption_mode
|
||||||
|
|
||||||
self._verify_args()
|
self._verify_args()
|
||||||
|
|
||||||
def _verify_args(self) -> None:
|
def _verify_args(self) -> None:
|
||||||
@ -834,6 +833,9 @@ class SpeculativeConfig:
|
|||||||
speculative_disable_by_batch_size: Optional[int],
|
speculative_disable_by_batch_size: Optional[int],
|
||||||
ngram_prompt_lookup_max: Optional[int],
|
ngram_prompt_lookup_max: Optional[int],
|
||||||
ngram_prompt_lookup_min: Optional[int],
|
ngram_prompt_lookup_min: Optional[int],
|
||||||
|
draft_token_acceptance_method: str,
|
||||||
|
typical_acceptance_sampler_posterior_threshold: Optional[float],
|
||||||
|
typical_acceptance_sampler_posterior_alpha: Optional[float],
|
||||||
) -> Optional["SpeculativeConfig"]:
|
) -> Optional["SpeculativeConfig"]:
|
||||||
"""Create a SpeculativeConfig if possible, else return None.
|
"""Create a SpeculativeConfig if possible, else return None.
|
||||||
|
|
||||||
@ -870,6 +872,19 @@ class SpeculativeConfig:
|
|||||||
window, if provided.
|
window, if provided.
|
||||||
ngram_prompt_lookup_min (Optional[int]): Min size of ngram token
|
ngram_prompt_lookup_min (Optional[int]): Min size of ngram token
|
||||||
window, if provided.
|
window, if provided.
|
||||||
|
draft_token_acceptance_method (str): The method to use for
|
||||||
|
accepting draft tokens. This can take two possible
|
||||||
|
values 'rejection_sampler' and 'typical_acceptance_sampler'
|
||||||
|
for RejectionSampler and TypicalAcceptanceSampler
|
||||||
|
respectively.
|
||||||
|
typical_acceptance_sampler_posterior_threshold (Optional[float]):
|
||||||
|
A threshold value that sets a lower bound on the posterior
|
||||||
|
probability of a token in the target model for it to be
|
||||||
|
accepted. This threshold is used only when we use the
|
||||||
|
TypicalAcceptanceSampler for token acceptance.
|
||||||
|
typical_acceptance_sampler_posterior_alpha (Optional[float]):
|
||||||
|
A scaling factor for the entropy-based threshold in the
|
||||||
|
TypicalAcceptanceSampler.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
|
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
|
||||||
@ -984,6 +999,11 @@ class SpeculativeConfig:
|
|||||||
"speculative_model unless the draft model config contains an "
|
"speculative_model unless the draft model config contains an "
|
||||||
"n_predict parameter.")
|
"n_predict parameter.")
|
||||||
|
|
||||||
|
if typical_acceptance_sampler_posterior_threshold is None:
|
||||||
|
typical_acceptance_sampler_posterior_threshold = 0.09
|
||||||
|
if typical_acceptance_sampler_posterior_alpha is None:
|
||||||
|
typical_acceptance_sampler_posterior_alpha = 0.3
|
||||||
|
|
||||||
return SpeculativeConfig(
|
return SpeculativeConfig(
|
||||||
draft_model_config,
|
draft_model_config,
|
||||||
draft_parallel_config,
|
draft_parallel_config,
|
||||||
@ -991,6 +1011,11 @@ class SpeculativeConfig:
|
|||||||
speculative_disable_by_batch_size,
|
speculative_disable_by_batch_size,
|
||||||
ngram_prompt_lookup_max,
|
ngram_prompt_lookup_max,
|
||||||
ngram_prompt_lookup_min,
|
ngram_prompt_lookup_min,
|
||||||
|
draft_token_acceptance_method=draft_token_acceptance_method,
|
||||||
|
typical_acceptance_sampler_posterior_threshold=\
|
||||||
|
typical_acceptance_sampler_posterior_threshold,
|
||||||
|
typical_acceptance_sampler_posterior_alpha=\
|
||||||
|
typical_acceptance_sampler_posterior_alpha,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -1072,6 +1097,9 @@ class SpeculativeConfig:
|
|||||||
speculative_disable_by_batch_size: Optional[int],
|
speculative_disable_by_batch_size: Optional[int],
|
||||||
ngram_prompt_lookup_max: Optional[int],
|
ngram_prompt_lookup_max: Optional[int],
|
||||||
ngram_prompt_lookup_min: Optional[int],
|
ngram_prompt_lookup_min: Optional[int],
|
||||||
|
draft_token_acceptance_method: str,
|
||||||
|
typical_acceptance_sampler_posterior_threshold: float,
|
||||||
|
typical_acceptance_sampler_posterior_alpha: float,
|
||||||
):
|
):
|
||||||
"""Create a SpeculativeConfig object.
|
"""Create a SpeculativeConfig object.
|
||||||
|
|
||||||
@ -1085,6 +1113,19 @@ class SpeculativeConfig:
|
|||||||
enqueue requests is larger than this value.
|
enqueue requests is larger than this value.
|
||||||
ngram_prompt_lookup_max: Max size of ngram token window.
|
ngram_prompt_lookup_max: Max size of ngram token window.
|
||||||
ngram_prompt_lookup_min: Min size of ngram token window.
|
ngram_prompt_lookup_min: Min size of ngram token window.
|
||||||
|
draft_token_acceptance_method (str): The method to use for
|
||||||
|
accepting draft tokens. This can take two possible
|
||||||
|
values 'rejection_sampler' and 'typical_acceptance_sampler'
|
||||||
|
for RejectionSampler and TypicalAcceptanceSampler
|
||||||
|
respectively.
|
||||||
|
typical_acceptance_sampler_posterior_threshold (Optional[float]):
|
||||||
|
A threshold value that sets a lower bound on the posterior
|
||||||
|
probability of a token in the target model for it to be
|
||||||
|
accepted. This threshold is used only when we use the
|
||||||
|
TypicalAcceptanceSampler for token acceptance.
|
||||||
|
typical_acceptance_sampler_posterior_alpha (Optional[float]):
|
||||||
|
A scaling factor for the entropy-based threshold in the
|
||||||
|
TypicalAcceptanceSampler.
|
||||||
"""
|
"""
|
||||||
self.draft_model_config = draft_model_config
|
self.draft_model_config = draft_model_config
|
||||||
self.draft_parallel_config = draft_parallel_config
|
self.draft_parallel_config = draft_parallel_config
|
||||||
@ -1093,6 +1134,11 @@ class SpeculativeConfig:
|
|||||||
speculative_disable_by_batch_size
|
speculative_disable_by_batch_size
|
||||||
self.ngram_prompt_lookup_max = ngram_prompt_lookup_max or 0
|
self.ngram_prompt_lookup_max = ngram_prompt_lookup_max or 0
|
||||||
self.ngram_prompt_lookup_min = ngram_prompt_lookup_min or 0
|
self.ngram_prompt_lookup_min = ngram_prompt_lookup_min or 0
|
||||||
|
self.draft_token_acceptance_method = draft_token_acceptance_method
|
||||||
|
self.typical_acceptance_sampler_posterior_threshold = \
|
||||||
|
typical_acceptance_sampler_posterior_threshold
|
||||||
|
self.typical_acceptance_sampler_posterior_alpha = \
|
||||||
|
typical_acceptance_sampler_posterior_alpha
|
||||||
|
|
||||||
self._verify_args()
|
self._verify_args()
|
||||||
|
|
||||||
@ -1104,6 +1150,31 @@ class SpeculativeConfig:
|
|||||||
if self.draft_model_config:
|
if self.draft_model_config:
|
||||||
self.draft_model_config.verify_with_parallel_config(
|
self.draft_model_config.verify_with_parallel_config(
|
||||||
self.draft_parallel_config)
|
self.draft_parallel_config)
|
||||||
|
# Validate and set draft token acceptance related settings.
|
||||||
|
|
||||||
|
if (self.draft_token_acceptance_method is None):
|
||||||
|
raise ValueError("draft_token_acceptance_method is not set. "
|
||||||
|
"Expected values are rejection_sampler or "
|
||||||
|
"typical_acceptance_sampler.")
|
||||||
|
|
||||||
|
if (self.draft_token_acceptance_method != 'rejection_sampler'
|
||||||
|
and self.draft_token_acceptance_method !=
|
||||||
|
'typical_acceptance_sampler'):
|
||||||
|
raise ValueError(
|
||||||
|
"Expected draft_token_acceptance_method to be either "
|
||||||
|
"rejection_sampler or typical_acceptance_sampler. Instead it "
|
||||||
|
f"is {self.draft_token_acceptance_method}")
|
||||||
|
|
||||||
|
if (self.typical_acceptance_sampler_posterior_threshold < 0
|
||||||
|
or self.typical_acceptance_sampler_posterior_alpha < 0):
|
||||||
|
raise ValueError(
|
||||||
|
"Expected typical_acceptance_sampler_posterior_threshold "
|
||||||
|
"and typical_acceptance_sampler_posterior_alpha to be > 0. "
|
||||||
|
"Instead found "
|
||||||
|
f"typical_acceptance_sampler_posterior_threshold = "
|
||||||
|
f"{self.typical_acceptance_sampler_posterior_threshold} and "
|
||||||
|
f"typical_acceptance_sampler_posterior_alpha = "
|
||||||
|
f"{self.typical_acceptance_sampler_posterior_alpha}")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_lookahead_slots(self) -> int:
|
def num_lookahead_slots(self) -> int:
|
||||||
|
@ -100,7 +100,9 @@ class EngineArgs:
|
|||||||
speculative_disable_by_batch_size: Optional[int] = None
|
speculative_disable_by_batch_size: Optional[int] = None
|
||||||
ngram_prompt_lookup_max: Optional[int] = None
|
ngram_prompt_lookup_max: Optional[int] = None
|
||||||
ngram_prompt_lookup_min: Optional[int] = None
|
ngram_prompt_lookup_min: Optional[int] = None
|
||||||
|
spec_decoding_acceptance_method: str = 'rejection_sampler'
|
||||||
|
typical_acceptance_sampler_posterior_threshold: Optional[float] = None
|
||||||
|
typical_acceptance_sampler_posterior_alpha: Optional[float] = None
|
||||||
qlora_adapter_name_or_path: Optional[str] = None
|
qlora_adapter_name_or_path: Optional[str] = None
|
||||||
|
|
||||||
otlp_traces_endpoint: Optional[str] = None
|
otlp_traces_endpoint: Optional[str] = None
|
||||||
@ -577,6 +579,38 @@ class EngineArgs:
|
|||||||
help='Min size of window for ngram prompt lookup in speculative '
|
help='Min size of window for ngram prompt lookup in speculative '
|
||||||
'decoding.')
|
'decoding.')
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--spec-decoding-acceptance-method',
|
||||||
|
type=str,
|
||||||
|
default=EngineArgs.spec_decoding_acceptance_method,
|
||||||
|
choices=['rejection_sampler', 'typical_acceptance_sampler'],
|
||||||
|
help='Specify the acceptance method to use during draft token '
|
||||||
|
'verification in speculative decoding. Two types of acceptance '
|
||||||
|
'routines are supported: '
|
||||||
|
'1) RejectionSampler which does not allow changing the '
|
||||||
|
'acceptance rate of draft tokens, '
|
||||||
|
'2) TypicalAcceptanceSampler which is configurable, allowing for '
|
||||||
|
'a higher acceptance rate at the cost of lower quality, '
|
||||||
|
'and vice versa.')
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--typical-acceptance-sampler-posterior-threshold',
|
||||||
|
type=float,
|
||||||
|
default=EngineArgs.typical_acceptance_sampler_posterior_threshold,
|
||||||
|
help='Set the lower bound threshold for the posterior '
|
||||||
|
'probability of a token to be accepted. This threshold is '
|
||||||
|
'used by the TypicalAcceptanceSampler to make sampling decisions '
|
||||||
|
'during speculative decoding. Defaults to 0.09')
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--typical-acceptance-sampler-posterior-alpha',
|
||||||
|
type=float,
|
||||||
|
default=EngineArgs.typical_acceptance_sampler_posterior_alpha,
|
||||||
|
help='A scaling factor for the entropy-based threshold for token '
|
||||||
|
'acceptance in the TypicalAcceptanceSampler. Typically defaults '
|
||||||
|
'to sqrt of --typical-acceptance-sampler-posterior-threshold '
|
||||||
|
'i.e. 0.3')
|
||||||
|
|
||||||
parser.add_argument('--model-loader-extra-config',
|
parser.add_argument('--model-loader-extra-config',
|
||||||
type=nullable_str,
|
type=nullable_str,
|
||||||
default=EngineArgs.model_loader_extra_config,
|
default=EngineArgs.model_loader_extra_config,
|
||||||
@ -737,6 +771,12 @@ class EngineArgs:
|
|||||||
use_v2_block_manager=self.use_v2_block_manager,
|
use_v2_block_manager=self.use_v2_block_manager,
|
||||||
ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
|
ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
|
||||||
ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
|
ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
|
||||||
|
draft_token_acceptance_method=\
|
||||||
|
self.spec_decoding_acceptance_method,
|
||||||
|
typical_acceptance_sampler_posterior_threshold=self.
|
||||||
|
typical_acceptance_sampler_posterior_threshold,
|
||||||
|
typical_acceptance_sampler_posterior_alpha=self.
|
||||||
|
typical_acceptance_sampler_posterior_alpha,
|
||||||
)
|
)
|
||||||
|
|
||||||
scheduler_config = SchedulerConfig(
|
scheduler_config = SchedulerConfig(
|
||||||
|
@ -3,13 +3,12 @@ from typing import Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.jit
|
import torch.jit
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
from vllm.model_executor.layers.spec_decode_base_sampler import (
|
from vllm.model_executor.layers.spec_decode_base_sampler import (
|
||||||
SpecDecodeBaseSampler)
|
SpecDecodeBaseSampler)
|
||||||
|
|
||||||
|
|
||||||
class RejectionSampler(SpecDecodeBaseSampler, nn.Module):
|
class RejectionSampler(SpecDecodeBaseSampler):
|
||||||
"""Apply modified rejection sampling as described in "Accelerating Large
|
"""Apply modified rejection sampling as described in "Accelerating Large
|
||||||
Language Model Decoding with Speculative Sampling"
|
Language Model Decoding with Speculative Sampling"
|
||||||
https://arxiv.org/pdf/2302.01318.pdf.
|
https://arxiv.org/pdf/2302.01318.pdf.
|
||||||
@ -28,8 +27,8 @@ class RejectionSampler(SpecDecodeBaseSampler, nn.Module):
|
|||||||
during sampling. This catches correctness issues but adds
|
during sampling. This catches correctness issues but adds
|
||||||
nontrivial latency.
|
nontrivial latency.
|
||||||
"""
|
"""
|
||||||
SpecDecodeBaseSampler.__init__(self, disable_bonus_tokens, strict_mode)
|
super().__init__(disable_bonus_tokens=disable_bonus_tokens,
|
||||||
nn.Module.__init__(self)
|
strict_mode=strict_mode)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -78,11 +77,12 @@ class RejectionSampler(SpecDecodeBaseSampler, nn.Module):
|
|||||||
self._raise_if_incorrect_input(target_probs, bonus_token_ids,
|
self._raise_if_incorrect_input(target_probs, bonus_token_ids,
|
||||||
draft_probs, draft_token_ids)
|
draft_probs, draft_token_ids)
|
||||||
|
|
||||||
accepted, recovered_token_ids = self._batch_modified_rejection_sampling(
|
accepted, recovered_token_ids = (
|
||||||
|
self._batch_modified_rejection_sampling(
|
||||||
target_probs,
|
target_probs,
|
||||||
draft_probs,
|
draft_probs,
|
||||||
draft_token_ids,
|
draft_token_ids,
|
||||||
)
|
))
|
||||||
|
|
||||||
output_token_ids = self._create_output(
|
output_token_ids = self._create_output(
|
||||||
accepted,
|
accepted,
|
||||||
|
@ -1,9 +1,12 @@
|
|||||||
|
from abc import abstractmethod
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.jit
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
class SpecDecodeBaseSampler():
|
class SpecDecodeBaseSampler(nn.Module):
|
||||||
"""Base class for samplers used for Speculative Decoding verification
|
"""Base class for samplers used for Speculative Decoding verification
|
||||||
step.
|
step.
|
||||||
"""
|
"""
|
||||||
@ -51,6 +54,16 @@ class SpecDecodeBaseSampler():
|
|||||||
def token_id_dtype(self):
|
def token_id_dtype(self):
|
||||||
return torch.int64
|
return torch.int64
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
target_probs: torch.Tensor,
|
||||||
|
bonus_token_ids: torch.Tensor,
|
||||||
|
draft_probs: torch.Tensor,
|
||||||
|
draft_token_ids: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def _create_output(
|
def _create_output(
|
||||||
self,
|
self,
|
||||||
accepted: torch.Tensor, # [batch_size, k]
|
accepted: torch.Tensor, # [batch_size, k]
|
||||||
|
@ -1,12 +1,11 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.jit
|
import torch.jit
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
from vllm.model_executor.layers.spec_decode_base_sampler import (
|
from vllm.model_executor.layers.spec_decode_base_sampler import (
|
||||||
SpecDecodeBaseSampler)
|
SpecDecodeBaseSampler)
|
||||||
|
|
||||||
|
|
||||||
class TypicalAcceptanceSampler(SpecDecodeBaseSampler, nn.Module):
|
class TypicalAcceptanceSampler(SpecDecodeBaseSampler):
|
||||||
"""Apply typical acceptance sampling as described in section 3.3.1 in
|
"""Apply typical acceptance sampling as described in section 3.3.1 in
|
||||||
"MEDUSA: Simple LLM Inference Acceleration Framework with
|
"MEDUSA: Simple LLM Inference Acceleration Framework with
|
||||||
Multiple Decoding Heads"
|
Multiple Decoding Heads"
|
||||||
@ -15,10 +14,10 @@ class TypicalAcceptanceSampler(SpecDecodeBaseSampler, nn.Module):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
posterior_threshold: float,
|
||||||
|
posterior_alpha: float,
|
||||||
disable_bonus_tokens: bool = False,
|
disable_bonus_tokens: bool = False,
|
||||||
strict_mode: bool = False,
|
strict_mode: bool = False,
|
||||||
posterior_threshold: float = 0.09,
|
|
||||||
posterior_alpha: float = 0.3,
|
|
||||||
):
|
):
|
||||||
"""Create a Typical Acceptance Sampler.
|
"""Create a Typical Acceptance Sampler.
|
||||||
|
|
||||||
@ -31,23 +30,20 @@ class TypicalAcceptanceSampler(SpecDecodeBaseSampler, nn.Module):
|
|||||||
nontrivial latency.
|
nontrivial latency.
|
||||||
posterior_threshold : A threshold value that sets a lower bound
|
posterior_threshold : A threshold value that sets a lower bound
|
||||||
on the posterior probability of a token in target model for it
|
on the posterior probability of a token in target model for it
|
||||||
to be accepted. Default is 0.09
|
to be accepted.
|
||||||
posterior_alpha : A scaling factor for the entropy-based
|
posterior_alpha : A scaling factor for the entropy-based
|
||||||
threshold in typical acceptance sampling. Typically defaults to
|
threshold in typical acceptance sampling.
|
||||||
sqrt of posterior_threshold and is set to 0.3.
|
|
||||||
"""
|
"""
|
||||||
SpecDecodeBaseSampler.__init__(
|
|
||||||
self,
|
|
||||||
disable_bonus_tokens=disable_bonus_tokens,
|
|
||||||
strict_mode=strict_mode)
|
|
||||||
nn.Module.__init__(self)
|
|
||||||
self._posterior_threshold = posterior_threshold
|
self._posterior_threshold = posterior_threshold
|
||||||
self._posterior_alpha = posterior_alpha
|
self._posterior_alpha = posterior_alpha
|
||||||
|
super().__init__(disable_bonus_tokens=disable_bonus_tokens,
|
||||||
|
strict_mode=strict_mode)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
target_probs: torch.Tensor,
|
target_probs: torch.Tensor,
|
||||||
bonus_token_ids: torch.Tensor,
|
bonus_token_ids: torch.Tensor,
|
||||||
|
draft_probs: torch.Tensor,
|
||||||
draft_token_ids: torch.Tensor,
|
draft_token_ids: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Sample token ids using typical acceptance sampling. This accepts
|
"""Sample token ids using typical acceptance sampling. This accepts
|
||||||
@ -69,6 +65,8 @@ class TypicalAcceptanceSampler(SpecDecodeBaseSampler, nn.Module):
|
|||||||
speculative tokens in a sequence are accepted.
|
speculative tokens in a sequence are accepted.
|
||||||
shape = [batch_size, num_bonus_tokens]
|
shape = [batch_size, num_bonus_tokens]
|
||||||
|
|
||||||
|
draft_probs: This parameter is unused by the acceptance sampler.
|
||||||
|
|
||||||
draft_token_ids: The token ids that were sampled from the draft
|
draft_token_ids: The token ids that were sampled from the draft
|
||||||
probabilities.
|
probabilities.
|
||||||
shape = [batch_size, num_speculative_tokens]
|
shape = [batch_size, num_speculative_tokens]
|
||||||
|
@ -4,7 +4,8 @@ from typing import Callable, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
from vllm.model_executor.layers.spec_decode_base_sampler import (
|
||||||
|
SpecDecodeBaseSampler)
|
||||||
from vllm.utils import is_pin_memory_available
|
from vllm.utils import is_pin_memory_available
|
||||||
|
|
||||||
|
|
||||||
@ -46,15 +47,15 @@ Timer = Callable[[], float]
|
|||||||
|
|
||||||
|
|
||||||
class AsyncMetricsCollector:
|
class AsyncMetricsCollector:
|
||||||
"""Class which copies rejection sampler metrics from the device to CPU on a
|
"""Class which copies rejection/typical-acceptance sampler metrics
|
||||||
non-default Torch stream.
|
from the device to CPU on a non-default Torch stream.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
rejection_sampler: RejectionSampler,
|
spec_decode_sampler: SpecDecodeBaseSampler,
|
||||||
timer: Optional[Timer] = None,
|
timer: Optional[Timer] = None,
|
||||||
collect_interval_s: float = 5.0):
|
collect_interval_s: float = 5.0):
|
||||||
self._rejection_sampler = rejection_sampler
|
self.spec_decode_sampler = spec_decode_sampler
|
||||||
self._timer = time.time if timer is None else timer
|
self._timer = time.time if timer is None else timer
|
||||||
|
|
||||||
self._rank: Optional[int] = None
|
self._rank: Optional[int] = None
|
||||||
@ -95,7 +96,7 @@ class AsyncMetricsCollector:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def _should_collect_rejsample_metrics(self, now: float) -> bool:
|
def _should_collect_rejsample_metrics(self, now: float) -> bool:
|
||||||
"""Return whether or not this iteration should print rejection sampling
|
"""Return whether or not this iteration should print sampling
|
||||||
metrics.
|
metrics.
|
||||||
"""
|
"""
|
||||||
if self._rank != 0:
|
if self._rank != 0:
|
||||||
@ -107,8 +108,8 @@ class AsyncMetricsCollector:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def _copy_rejsample_metrics_async(self) -> torch.cuda.Event:
|
def _copy_rejsample_metrics_async(self) -> torch.cuda.Event:
|
||||||
"""Copy rejection sampling metrics (number of accepted tokens, etc) to
|
"""Copy rejection/typical-acceptance sampling metrics
|
||||||
CPU asynchronously.
|
(number of accepted tokens, etc) to CPU asynchronously.
|
||||||
|
|
||||||
Returns a CUDA event recording when the copy is complete.
|
Returns a CUDA event recording when the copy is complete.
|
||||||
"""
|
"""
|
||||||
@ -117,13 +118,14 @@ class AsyncMetricsCollector:
|
|||||||
|
|
||||||
with torch.cuda.stream(self._copy_stream):
|
with torch.cuda.stream(self._copy_stream):
|
||||||
self._aggregate_num_accepted_tokens.copy_(
|
self._aggregate_num_accepted_tokens.copy_(
|
||||||
self._rejection_sampler.num_accepted_tokens, non_blocking=True)
|
self.spec_decode_sampler.num_accepted_tokens,
|
||||||
|
non_blocking=True)
|
||||||
self._aggregate_num_emitted_tokens.copy_(
|
self._aggregate_num_emitted_tokens.copy_(
|
||||||
self._rejection_sampler.num_emitted_tokens, non_blocking=True)
|
self.spec_decode_sampler.num_emitted_tokens, non_blocking=True)
|
||||||
# Number of draft tokens is calculated on CPU, so no copy is
|
# Number of draft tokens is calculated on CPU, so no copy is
|
||||||
# required.
|
# required.
|
||||||
self._aggregate_num_draft_tokens = (
|
self._aggregate_num_draft_tokens = (
|
||||||
self._rejection_sampler.num_draft_tokens)
|
self.spec_decode_sampler.num_draft_tokens)
|
||||||
|
|
||||||
aggregate_metrics_ready = torch.cuda.Event()
|
aggregate_metrics_ready = torch.cuda.Event()
|
||||||
aggregate_metrics_ready.record(self._copy_stream)
|
aggregate_metrics_ready.record(self._copy_stream)
|
||||||
|
@ -7,6 +7,10 @@ from vllm.config import ParallelConfig, SpeculativeConfig
|
|||||||
from vllm.distributed.communication_op import broadcast_tensor_dict
|
from vllm.distributed.communication_op import broadcast_tensor_dict
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||||
|
from vllm.model_executor.layers.spec_decode_base_sampler import (
|
||||||
|
SpecDecodeBaseSampler)
|
||||||
|
from vllm.model_executor.layers.typical_acceptance_sampler import (
|
||||||
|
TypicalAcceptanceSampler)
|
||||||
from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest,
|
from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest,
|
||||||
HiddenStates, SamplerOutput, SequenceGroupMetadata,
|
HiddenStates, SamplerOutput, SequenceGroupMetadata,
|
||||||
get_all_seq_ids)
|
get_all_seq_ids)
|
||||||
@ -56,7 +60,12 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
|
|||||||
draft_worker_kwargs=draft_worker_kwargs,
|
draft_worker_kwargs=draft_worker_kwargs,
|
||||||
disable_by_batch_size=speculative_config.
|
disable_by_batch_size=speculative_config.
|
||||||
speculative_disable_by_batch_size,
|
speculative_disable_by_batch_size,
|
||||||
)
|
draft_token_acceptance_method=speculative_config.
|
||||||
|
draft_token_acceptance_method,
|
||||||
|
typical_acceptance_sampler_posterior_threshold=speculative_config.
|
||||||
|
typical_acceptance_sampler_posterior_threshold,
|
||||||
|
typical_acceptance_sampler_posterior_alpha=speculative_config.
|
||||||
|
typical_acceptance_sampler_posterior_alpha)
|
||||||
|
|
||||||
return spec_decode_worker
|
return spec_decode_worker
|
||||||
|
|
||||||
@ -78,8 +87,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
welcome!).
|
welcome!).
|
||||||
* Only top-1 proposal and scoring are implemented. Tree-attention is left as
|
* Only top-1 proposal and scoring are implemented. Tree-attention is left as
|
||||||
future work.
|
future work.
|
||||||
* Only lossless rejection sampling is supported. Contributions adding lossy
|
|
||||||
verification routines are welcome (e.g. Medusa's typical acceptance).
|
|
||||||
* All sequences in a batch must have the same proposal length, or zero. This
|
* All sequences in a batch must have the same proposal length, or zero. This
|
||||||
can be improved by having per-sequence speculation in the future.
|
can be improved by having per-sequence speculation in the future.
|
||||||
* The scoring forward pass is done without an MQA kernel, which is
|
* The scoring forward pass is done without an MQA kernel, which is
|
||||||
@ -95,6 +102,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
scorer_worker: Worker,
|
scorer_worker: Worker,
|
||||||
draft_worker_kwargs: Dict[str, Any],
|
draft_worker_kwargs: Dict[str, Any],
|
||||||
disable_by_batch_size: Optional[int],
|
disable_by_batch_size: Optional[int],
|
||||||
|
draft_token_acceptance_method: str,
|
||||||
|
typical_acceptance_sampler_posterior_threshold: float,
|
||||||
|
typical_acceptance_sampler_posterior_alpha: float,
|
||||||
) -> "SpecDecodeWorker":
|
) -> "SpecDecodeWorker":
|
||||||
|
|
||||||
ngram_prompt_lookup_max = (
|
ngram_prompt_lookup_max = (
|
||||||
@ -127,17 +137,30 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
logger.info("Configuring SpecDecodeWorker with proposer=%s",
|
logger.info("Configuring SpecDecodeWorker with proposer=%s",
|
||||||
type(proposer_worker))
|
type(proposer_worker))
|
||||||
|
|
||||||
|
spec_decode_sampler: SpecDecodeBaseSampler = None
|
||||||
|
if draft_token_acceptance_method == "rejection_sampler":
|
||||||
|
spec_decode_sampler = RejectionSampler(
|
||||||
|
disable_bonus_tokens=disable_bonus_tokens, )
|
||||||
|
elif draft_token_acceptance_method == "typical_acceptance_sampler":
|
||||||
|
spec_decode_sampler = TypicalAcceptanceSampler(
|
||||||
|
disable_bonus_tokens=disable_bonus_tokens,
|
||||||
|
posterior_threshold=\
|
||||||
|
typical_acceptance_sampler_posterior_threshold,
|
||||||
|
posterior_alpha=typical_acceptance_sampler_posterior_alpha,
|
||||||
|
)
|
||||||
|
logger.info("Configuring SpecDecodeWorker with sampler=%s",
|
||||||
|
type(spec_decode_sampler))
|
||||||
|
|
||||||
return SpecDecodeWorker(proposer_worker,
|
return SpecDecodeWorker(proposer_worker,
|
||||||
scorer_worker,
|
scorer_worker,
|
||||||
disable_by_batch_size=disable_by_batch_size,
|
disable_by_batch_size=disable_by_batch_size,
|
||||||
rejection_sampler=RejectionSampler(
|
spec_decode_sampler=spec_decode_sampler)
|
||||||
disable_bonus_tokens=disable_bonus_tokens))
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
proposer_worker: ProposerWorkerBase,
|
proposer_worker: ProposerWorkerBase,
|
||||||
scorer_worker: WorkerBase,
|
scorer_worker: WorkerBase,
|
||||||
rejection_sampler: RejectionSampler,
|
spec_decode_sampler: SpecDecodeBaseSampler,
|
||||||
metrics_collector: Optional[AsyncMetricsCollector] = None,
|
metrics_collector: Optional[AsyncMetricsCollector] = None,
|
||||||
disable_by_batch_size: Optional[int] = None,
|
disable_by_batch_size: Optional[int] = None,
|
||||||
):
|
):
|
||||||
@ -150,8 +173,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
scorer_worker: A worker that produces probabilities of speculative
|
scorer_worker: A worker that produces probabilities of speculative
|
||||||
tokens according to some base model. Typically a vanilla vLLM
|
tokens according to some base model. Typically a vanilla vLLM
|
||||||
Worker.
|
Worker.
|
||||||
rejection_sampler: A Torch module used to perform modified rejection
|
spec_decode_sampler: A Torch module used to perform acceptance
|
||||||
sampling for speculative decoding.
|
sampling of the draft tokens in the verification step of
|
||||||
|
speculative decoding. Currently we support two different
|
||||||
|
types of sampler namely RejectionSampler and
|
||||||
|
TypicalAcceptanceSampler. 'spec_decode_sampler' is either an
|
||||||
|
instance of RejectionSampler or TypicalAcceptanceSampler.
|
||||||
disable_by_batch_size: If the batch size is larger than this,
|
disable_by_batch_size: If the batch size is larger than this,
|
||||||
disable speculative decoding for new incoming requests.
|
disable speculative decoding for new incoming requests.
|
||||||
metrics_collector: Helper class for collecting metrics; can be set
|
metrics_collector: Helper class for collecting metrics; can be set
|
||||||
@ -160,15 +187,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
self.proposer_worker = proposer_worker
|
self.proposer_worker = proposer_worker
|
||||||
self.scorer_worker = scorer_worker
|
self.scorer_worker = scorer_worker
|
||||||
self.disable_by_batch_size = disable_by_batch_size or float("inf")
|
self.disable_by_batch_size = disable_by_batch_size or float("inf")
|
||||||
self.rejection_sampler = rejection_sampler
|
self.spec_decode_sampler = spec_decode_sampler
|
||||||
|
|
||||||
self._metrics = AsyncMetricsCollector(
|
self._metrics = AsyncMetricsCollector(
|
||||||
rejection_sampler
|
self.spec_decode_sampler
|
||||||
) if metrics_collector is None else metrics_collector
|
) if metrics_collector is None else metrics_collector
|
||||||
|
self.probs_dtype = self.spec_decode_sampler.probs_dtype
|
||||||
self.probs_dtype = self.rejection_sampler.probs_dtype
|
self.token_id_dtype = self.spec_decode_sampler.token_id_dtype
|
||||||
self.token_id_dtype = self.rejection_sampler.token_id_dtype
|
|
||||||
|
|
||||||
# Lazy initiazliation.
|
# Lazy initiazliation.
|
||||||
self.scorer: SpeculativeScorer
|
self.scorer: SpeculativeScorer
|
||||||
|
|
||||||
@ -189,7 +213,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
self.proposer_worker.load_model()
|
self.proposer_worker.load_model()
|
||||||
|
|
||||||
self._metrics.init_gpu_tensors(self.rank)
|
self._metrics.init_gpu_tensors(self.rank)
|
||||||
self.rejection_sampler.init_gpu_tensors(self.rank)
|
self.spec_decode_sampler.init_gpu_tensors(self.rank)
|
||||||
|
|
||||||
self.scorer = BatchExpansionTop1Scorer(
|
self.scorer = BatchExpansionTop1Scorer(
|
||||||
scorer_worker=self.scorer_worker,
|
scorer_worker=self.scorer_worker,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
@ -203,7 +228,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
def _configure_model_sampler_for_spec_decode(self):
|
def _configure_model_sampler_for_spec_decode(self):
|
||||||
"""Configure model sampler to emit GPU tensors. This allows spec decode
|
"""Configure model sampler to emit GPU tensors. This allows spec decode
|
||||||
to keep data on device without transferring to CPU and serializing,
|
to keep data on device without transferring to CPU and serializing,
|
||||||
which significantly reduces overhead of rejection sampling.
|
which significantly reduces overhead of sampling during verification.
|
||||||
|
|
||||||
NOTE(cade): This breaks abstraction boundaries pretty badly. The better
|
NOTE(cade): This breaks abstraction boundaries pretty badly. The better
|
||||||
design is to have the "move to CPU and serialize" sampling decision be
|
design is to have the "move to CPU and serialize" sampling decision be
|
||||||
@ -481,7 +506,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
# Get proposed tokens.
|
# Get proposed tokens.
|
||||||
proposal_token_ids = proposals.proposal_token_ids[spec_indices]
|
proposal_token_ids = proposals.proposal_token_ids[spec_indices]
|
||||||
|
|
||||||
accepted_token_ids = self.rejection_sampler(
|
accepted_token_ids = self.spec_decode_sampler(
|
||||||
target_probs=proposal_verifier_probs,
|
target_probs=proposal_verifier_probs,
|
||||||
bonus_token_ids=bonus_token_ids,
|
bonus_token_ids=bonus_token_ids,
|
||||||
draft_probs=proposal_probs,
|
draft_probs=proposal_probs,
|
||||||
@ -496,7 +521,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
accepted_token_ids = torch.cat(
|
accepted_token_ids = torch.cat(
|
||||||
[accepted_token_ids, non_spec_token_ids])
|
[accepted_token_ids, non_spec_token_ids])
|
||||||
logprobs = proposal_scores.logprobs
|
logprobs = proposal_scores.logprobs
|
||||||
|
|
||||||
# Rearrange so that results are in the order of the original seq group
|
# Rearrange so that results are in the order of the original seq group
|
||||||
# metadata.
|
# metadata.
|
||||||
accepted_token_ids[original_indices] = accepted_token_ids.clone()
|
accepted_token_ids[original_indices] = accepted_token_ids.clone()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user