
- **Add SPDX license headers to python source files** - **Check for SPDX headers using pre-commit** commit 9d7ef44c3cfb72ca4c32e1c677d99259d10d4745 Author: Russell Bryant <rbryant@redhat.com> Date: Fri Jan 31 14:18:24 2025 -0500 Add SPDX license headers to python source files This commit adds SPDX license headers to python source files as recommended to the project by the Linux Foundation. These headers provide a concise way that is both human and machine readable for communicating license information for each source file. It helps avoid any ambiguity about the license of the code and can also be easily used by tools to help manage license compliance. The Linux Foundation runs license scans against the codebase to help ensure we are in compliance with the licenses of the code we use, including dependencies. Having these headers in place helps that tool do its job. More information can be found on the SPDX site: - https://spdx.dev/learn/handling-license-info/ Signed-off-by: Russell Bryant <rbryant@redhat.com> commit 5a1cf1cb3b80759131c73f6a9dddebccac039dea Author: Russell Bryant <rbryant@redhat.com> Date: Fri Jan 31 14:36:32 2025 -0500 Check for SPDX headers using pre-commit Signed-off-by: Russell Bryant <rbryant@redhat.com> --------- Signed-off-by: Russell Bryant <rbryant@redhat.com>
908 lines
39 KiB
Python
908 lines
39 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import random
|
|
from collections import defaultdict
|
|
from types import SimpleNamespace
|
|
from typing import Dict, List, Set
|
|
from unittest.mock import MagicMock
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
|
from vllm.model_executor.utils import set_random_seed
|
|
from vllm.sequence import ExecuteModelRequest, SequenceOutput
|
|
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
|
|
from vllm.spec_decode.interfaces import SpeculativeProposals
|
|
from vllm.spec_decode.metrics import (AsyncMetricsCollector,
|
|
SpecDecodeWorkerMetrics)
|
|
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
|
from vllm.spec_decode.spec_decode_worker import (SpecDecodeWorker,
|
|
split_num_cache_blocks_evenly)
|
|
|
|
from .test_utils import mock_spec_decode_sampler
|
|
from .utils import create_batch, create_sampler_output_list, mock_worker
|
|
|
|
|
|
@pytest.mark.parametrize('k', [1, 2, 6])
|
|
@pytest.mark.parametrize('batch_size', [1, 2, 32])
|
|
@pytest.mark.parametrize("acceptance_sampler_method",
|
|
["rejection_sampler", "typical_acceptance_sampler"])
|
|
@torch.inference_mode()
|
|
def test_correctly_calls_draft_model(k: int, batch_size: int,
|
|
acceptance_sampler_method: str):
|
|
"""Verify SpecDecodeWorker calls the draft worker with correct
|
|
inputs. Everything else is mocked out.
|
|
"""
|
|
draft_worker = mock_worker(cls=MultiStepWorker)
|
|
target_worker = mock_worker()
|
|
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
|
worker = SpecDecodeWorker(
|
|
draft_worker,
|
|
target_worker,
|
|
mock_spec_decode_sampler(acceptance_sampler_method),
|
|
disable_logprobs=False,
|
|
metrics_collector=metrics_collector)
|
|
exception_secret = 'artificial stop'
|
|
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
|
|
|
|
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
|
|
execute_model_req = ExecuteModelRequest(
|
|
seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)
|
|
|
|
with pytest.raises(ValueError, match=exception_secret):
|
|
worker.execute_model(execute_model_req=execute_model_req)
|
|
|
|
call_args_list = draft_worker.get_spec_proposals.call_args_list
|
|
assert len(call_args_list) == 1
|
|
|
|
for args, _ in call_args_list:
|
|
actual_execute_model_data = args[0]
|
|
assert actual_execute_model_data == execute_model_req
|
|
|
|
|
|
@pytest.mark.parametrize('k', [1, 2, 6])
|
|
@pytest.mark.parametrize('batch_size', [1, 2, 32])
|
|
@pytest.mark.parametrize("acceptance_sampler_method",
|
|
["rejection_sampler", "typical_acceptance_sampler"])
|
|
@torch.inference_mode()
|
|
def test_batch_expansion_correctly_calls_target_model(
|
|
k: int, batch_size: int, acceptance_sampler_method: str):
|
|
"""Verify SpecDecodeWorker calls the target model with correct
|
|
inputs with batch expansion. Everything else is mocked out.
|
|
"""
|
|
draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False)
|
|
target_worker = mock_worker(use_spec=False)
|
|
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
|
|
|
draft_worker.device = 'cuda'
|
|
target_worker.device = 'cuda'
|
|
|
|
set_random_seed(1)
|
|
|
|
worker = SpecDecodeWorker(
|
|
draft_worker,
|
|
target_worker,
|
|
mock_spec_decode_sampler(acceptance_sampler_method),
|
|
disable_logprobs=False,
|
|
metrics_collector=metrics_collector,
|
|
disable_mqa_scorer=True)
|
|
worker.init_device()
|
|
|
|
vocab_size = 32_000
|
|
|
|
proposal_token_ids = torch.randint(low=0,
|
|
high=vocab_size,
|
|
size=(batch_size, k),
|
|
dtype=torch.int64,
|
|
device='cuda')
|
|
proposal_probs = torch.rand(batch_size,
|
|
k,
|
|
vocab_size,
|
|
dtype=torch.float32,
|
|
device='cuda')
|
|
proposal_lens = torch.ones(batch_size, dtype=torch.int64,
|
|
device='cuda') * k
|
|
|
|
seq_group_metadata_list, prompts, prev_output_tokens = create_batch(
|
|
batch_size, k)
|
|
|
|
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
|
|
proposal_token_ids=proposal_token_ids,
|
|
proposal_probs=proposal_probs,
|
|
proposal_lens=proposal_lens)
|
|
|
|
exception_secret = 'artificial stop'
|
|
target_worker.execute_model.side_effect = ValueError(exception_secret)
|
|
|
|
with pytest.raises(ValueError, match=exception_secret):
|
|
worker.execute_model(execute_model_req=ExecuteModelRequest(
|
|
seq_group_metadata_list=seq_group_metadata_list,
|
|
num_lookahead_slots=k))
|
|
|
|
seen_contexts: List[List[int]] = []
|
|
|
|
call_args_list = target_worker.execute_model.call_args_list
|
|
assert len(call_args_list) == 1
|
|
for _, kwargs in call_args_list:
|
|
seq_group_metadata_list = kwargs[
|
|
"execute_model_req"].seq_group_metadata_list
|
|
|
|
assert len(seq_group_metadata_list) == (k + 1) * batch_size
|
|
for seq_group_metadata in seq_group_metadata_list:
|
|
for seq_data in seq_group_metadata.seq_data.values():
|
|
seen_contexts.append(seq_data.get_token_ids())
|
|
|
|
expected_seen_contexts: List[List[int]] = []
|
|
|
|
for prompt, prev_generated, draft_tokens in zip(
|
|
prompts, prev_output_tokens, proposal_token_ids.tolist()):
|
|
|
|
for i in range(len(draft_tokens) + 1):
|
|
expected_seen_contexts.append(prompt + prev_generated +
|
|
draft_tokens[:i])
|
|
|
|
seen_contexts.sort()
|
|
expected_seen_contexts.sort()
|
|
assert expected_seen_contexts == seen_contexts
|
|
|
|
|
|
@pytest.mark.parametrize('k', [1, 2, 6])
|
|
@pytest.mark.parametrize('batch_size', [1, 2, 32])
|
|
@pytest.mark.parametrize("acceptance_sampler_method",
|
|
["rejection_sampler", "typical_acceptance_sampler"])
|
|
@torch.inference_mode()
|
|
def test_correctly_calls_spec_decode_sampler(k: int, batch_size: int,
|
|
acceptance_sampler_method: str):
|
|
"""Verify SpecDecodeWorker calls the rejection sampler with
|
|
correct inputs. Everything else is mocked out.
|
|
"""
|
|
vocab_size = 32_000
|
|
|
|
draft_worker = mock_worker(cls=MultiStepWorker,
|
|
vocab_size=vocab_size,
|
|
use_spec=False)
|
|
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
|
|
spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
|
|
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
|
draft_worker.device = 'cuda'
|
|
target_worker.device = 'cuda'
|
|
|
|
set_random_seed(1)
|
|
|
|
worker = SpecDecodeWorker(draft_worker,
|
|
target_worker,
|
|
spec_decode_sampler,
|
|
disable_logprobs=False,
|
|
metrics_collector=metrics_collector)
|
|
worker.init_device()
|
|
|
|
proposal_token_ids = torch.randint(low=0,
|
|
high=vocab_size,
|
|
size=(batch_size, k),
|
|
dtype=torch.int64,
|
|
device='cuda')
|
|
proposal_probs = torch.rand(batch_size,
|
|
k,
|
|
vocab_size,
|
|
dtype=torch.float32,
|
|
device='cuda')
|
|
|
|
proposal_lens = torch.ones(batch_size, dtype=torch.int64,
|
|
device='cuda') * k
|
|
|
|
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
|
|
|
|
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
|
|
proposal_token_ids=proposal_token_ids,
|
|
proposal_probs=proposal_probs,
|
|
proposal_lens=proposal_lens)
|
|
|
|
target_token_ids = torch.randint(low=0,
|
|
high=vocab_size,
|
|
size=(1, batch_size * (k + 1)),
|
|
dtype=torch.int64,
|
|
device='cuda')
|
|
target_token_probs = torch.rand(1,
|
|
batch_size * (k + 1),
|
|
vocab_size,
|
|
dtype=torch.float32,
|
|
device='cuda')
|
|
target_token_logprobs = torch.rand(1,
|
|
batch_size * (k + 1),
|
|
vocab_size,
|
|
dtype=torch.float32,
|
|
device='cuda')
|
|
target_output = create_sampler_output_list(target_token_ids,
|
|
target_token_probs,
|
|
target_token_logprobs)
|
|
|
|
target_worker.execute_model.return_value = [target_output[0]]
|
|
|
|
exception_secret = 'artificial stop'
|
|
|
|
spec_decode_sampler.side_effect = ValueError(exception_secret)
|
|
|
|
with pytest.raises(ValueError, match=exception_secret):
|
|
worker.execute_model(execute_model_req=ExecuteModelRequest(
|
|
seq_group_metadata_list=seq_group_metadata_list,
|
|
num_lookahead_slots=k))
|
|
|
|
assert len(spec_decode_sampler.call_args_list) == 1
|
|
_, kwargs = spec_decode_sampler.call_args_list[0]
|
|
actual = SimpleNamespace(**kwargs)
|
|
|
|
assert torch.equal(actual.bonus_token_ids,
|
|
target_token_ids.reshape(batch_size, k + 1)[:, -1:])
|
|
assert torch.equal(actual.target_with_bonus_probs,
|
|
target_token_probs.reshape(batch_size, k + 1, -1))
|
|
assert torch.equal(actual.draft_token_ids, proposal_token_ids)
|
|
assert torch.equal(actual.draft_probs, proposal_probs)
|
|
|
|
|
|
@pytest.mark.parametrize('k', [1, 2, 6])
|
|
@pytest.mark.parametrize('batch_size', [1, 2, 32])
|
|
@pytest.mark.parametrize("acceptance_sampler_method",
|
|
["rejection_sampler", "typical_acceptance_sampler"])
|
|
@torch.inference_mode()
|
|
def test_correctly_formats_output(k: int, batch_size: int,
|
|
acceptance_sampler_method: str):
|
|
"""Verify SpecDecodeWorker formats sampler output correctly.
|
|
Everything else is mocked out.
|
|
"""
|
|
vocab_size = 32_000
|
|
|
|
draft_worker = mock_worker(cls=MultiStepWorker,
|
|
vocab_size=vocab_size,
|
|
use_spec=False)
|
|
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
|
|
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
|
draft_worker.device = 'cuda'
|
|
target_worker.device = 'cuda'
|
|
|
|
set_random_seed(1)
|
|
spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
|
|
worker = SpecDecodeWorker(draft_worker,
|
|
target_worker,
|
|
spec_decode_sampler,
|
|
disable_logprobs=False,
|
|
metrics_collector=metrics_collector)
|
|
worker.init_device()
|
|
|
|
proposal_token_ids = torch.randint(low=0,
|
|
high=vocab_size,
|
|
size=(batch_size, k),
|
|
dtype=torch.int64,
|
|
device='cuda')
|
|
proposal_probs = torch.rand(batch_size,
|
|
k,
|
|
vocab_size,
|
|
dtype=torch.float32,
|
|
device='cuda')
|
|
|
|
proposal_lens = torch.ones(batch_size, dtype=torch.int64,
|
|
device='cuda') * k
|
|
|
|
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
|
|
|
|
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
|
|
proposal_token_ids=proposal_token_ids,
|
|
proposal_probs=proposal_probs,
|
|
proposal_lens=proposal_lens)
|
|
|
|
target_token_ids = torch.randint(low=0,
|
|
high=vocab_size,
|
|
size=(1, batch_size * (k + 1)),
|
|
dtype=torch.int64,
|
|
device='cuda')
|
|
target_token_probs = torch.rand(1,
|
|
batch_size * (k + 1),
|
|
vocab_size,
|
|
dtype=torch.float32,
|
|
device='cuda')
|
|
target_token_logprobs = torch.rand(1,
|
|
batch_size * (k + 1),
|
|
vocab_size,
|
|
dtype=torch.float32,
|
|
device='cuda')
|
|
target_output = create_sampler_output_list(target_token_ids,
|
|
target_token_probs,
|
|
target_token_logprobs)
|
|
|
|
target_worker.execute_model.return_value = [target_output[0]]
|
|
|
|
spec_decode_sampler_output = torch.randint(low=0,
|
|
high=vocab_size,
|
|
size=(batch_size, k + 1),
|
|
dtype=torch.int64,
|
|
device='cuda')
|
|
for i in range(batch_size):
|
|
minimum_accepted_tokens = 1
|
|
spec_decode_sampler_output[i][
|
|
-random.randint(minimum_accepted_tokens, k + 1):] = -1
|
|
|
|
spec_decode_sampler.return_value = spec_decode_sampler_output
|
|
output = worker.execute_model(execute_model_req=ExecuteModelRequest(
|
|
seq_group_metadata_list=seq_group_metadata_list,
|
|
num_lookahead_slots=k))
|
|
|
|
expected_output = create_sampler_output_list(
|
|
token_ids=spec_decode_sampler_output.transpose(0, 1),
|
|
probs=[None for _ in range(k + 1)],
|
|
logprobs=[None for _ in range(k + 1)])
|
|
|
|
seq_ids = [
|
|
next(iter(seq_group_metadata.seq_data.keys()))
|
|
for seq_group_metadata in seq_group_metadata_list
|
|
]
|
|
actual_output_by_seq: Dict[int, List[SequenceOutput]] = {
|
|
seq_id: []
|
|
for seq_id in seq_ids
|
|
}
|
|
expected_output_by_seq: Dict[int, List[SequenceOutput]] = {
|
|
seq_id: []
|
|
for seq_id in seq_ids
|
|
}
|
|
|
|
for step in output:
|
|
for seq_group in step:
|
|
for sample in seq_group.samples:
|
|
seq_id = sample.parent_seq_id
|
|
actual_output_by_seq[seq_id].append(sample)
|
|
|
|
for step in expected_output:
|
|
for seq_group in step:
|
|
for sample in seq_group.samples:
|
|
seq_id = sample.parent_seq_id
|
|
expected_output_by_seq[seq_id].append(sample)
|
|
|
|
all_seen_seq_ids = set(
|
|
list(actual_output_by_seq.keys()) +
|
|
list(expected_output_by_seq.keys()))
|
|
for seq_id in all_seen_seq_ids:
|
|
actual_by_step = actual_output_by_seq[seq_id]
|
|
expected_by_step = expected_output_by_seq[seq_id]
|
|
|
|
for i in range(k + 1):
|
|
if i >= len(actual_by_step):
|
|
assert expected_by_step[i].output_token == -1
|
|
continue
|
|
assert actual_by_step[i].output_token == expected_by_step[
|
|
i].output_token
|
|
|
|
|
|
@pytest.mark.parametrize('k', [1, 2])
|
|
@pytest.mark.parametrize('batch_size', [1])
|
|
@pytest.mark.parametrize('returns_metrics', [True, False])
|
|
@pytest.mark.parametrize("acceptance_sampler_method",
|
|
["rejection_sampler", "typical_acceptance_sampler"])
|
|
@torch.inference_mode()
|
|
def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool,
|
|
acceptance_sampler_method: str):
|
|
"""Verify SpecDecodeWorker collects metrics.
|
|
"""
|
|
vocab_size = 32_000
|
|
|
|
draft_worker = mock_worker(cls=MultiStepWorker,
|
|
vocab_size=vocab_size,
|
|
use_spec=False)
|
|
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
|
|
spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
|
|
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
|
draft_worker.device = 'cuda'
|
|
target_worker.device = 'cuda'
|
|
|
|
set_random_seed(1)
|
|
|
|
worker = SpecDecodeWorker(draft_worker,
|
|
target_worker,
|
|
spec_decode_sampler,
|
|
disable_logprobs=False,
|
|
metrics_collector=metrics_collector)
|
|
worker.init_device()
|
|
|
|
proposal_token_ids = torch.randint(low=0,
|
|
high=vocab_size,
|
|
size=(batch_size, k),
|
|
dtype=torch.int64,
|
|
device='cuda')
|
|
proposal_probs = torch.rand(batch_size,
|
|
k,
|
|
vocab_size,
|
|
dtype=torch.float32,
|
|
device='cuda')
|
|
|
|
proposal_lens = torch.ones(batch_size, dtype=torch.int64,
|
|
device='cuda') * k
|
|
|
|
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
|
|
|
|
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
|
|
proposal_token_ids=proposal_token_ids,
|
|
proposal_probs=proposal_probs,
|
|
proposal_lens=proposal_lens)
|
|
|
|
target_token_ids = torch.randint(low=0,
|
|
high=vocab_size,
|
|
size=(1, batch_size * (k + 1)),
|
|
dtype=torch.int64,
|
|
device='cuda')
|
|
target_token_probs = torch.rand(1,
|
|
batch_size * (k + 1),
|
|
vocab_size,
|
|
dtype=torch.float32,
|
|
device='cuda')
|
|
target_token_logprobs = torch.rand(1,
|
|
batch_size * (k + 1),
|
|
vocab_size,
|
|
dtype=torch.float32,
|
|
device='cuda')
|
|
target_output = create_sampler_output_list(target_token_ids,
|
|
target_token_probs,
|
|
target_token_logprobs)
|
|
|
|
target_worker.execute_model.return_value = [target_output[0]]
|
|
|
|
spec_decode_sampler_output = torch.randint(low=0,
|
|
high=vocab_size,
|
|
size=(batch_size, k + 1),
|
|
dtype=torch.int64,
|
|
device='cuda')
|
|
for i in range(batch_size):
|
|
minimum_accepted_tokens = 1
|
|
spec_decode_sampler_output[i][
|
|
-random.randint(minimum_accepted_tokens, k + 1):] = -1
|
|
spec_decode_sampler.return_value = spec_decode_sampler_output
|
|
|
|
mock_rejsample_metrics = MagicMock(
|
|
spec=SpecDecodeWorkerMetrics) if returns_metrics else None
|
|
metrics_collector.maybe_collect_rejsample_metrics.return_value = (
|
|
mock_rejsample_metrics)
|
|
|
|
output = worker.execute_model(execute_model_req=ExecuteModelRequest(
|
|
seq_group_metadata_list=seq_group_metadata_list,
|
|
num_lookahead_slots=k))
|
|
assert output[0].spec_decode_worker_metrics == mock_rejsample_metrics
|
|
|
|
call_args_list = (
|
|
metrics_collector.maybe_collect_rejsample_metrics.call_args_list)
|
|
assert len(call_args_list) == 1
|
|
args, kwargs = call_args_list[0]
|
|
assert args[0] == k or kwargs.get('k', -1) == k
|
|
|
|
|
|
@pytest.mark.parametrize('k', [0])
|
|
@pytest.mark.parametrize('batch_size', [1, 2, 32])
|
|
@pytest.mark.parametrize("acceptance_sampler_method",
|
|
["rejection_sampler", "typical_acceptance_sampler"])
|
|
@torch.inference_mode()
|
|
def test_k_equals_zero(k: int, batch_size: int,
|
|
acceptance_sampler_method: str):
|
|
"""Verify that the SpecDecodeWorker calls the draft and target workers
|
|
when k is zero. This happens during prefill.
|
|
"""
|
|
draft_worker = mock_worker(cls=MultiStepWorker)
|
|
target_worker = mock_worker()
|
|
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
|
|
|
sampler_output = MagicMock(spec=SamplerOutput)
|
|
sampler_output.hidden_states = None
|
|
target_worker.execute_model.return_value = [sampler_output]
|
|
|
|
draft_worker.device = 'cuda'
|
|
target_worker.device = 'cuda'
|
|
|
|
set_random_seed(1)
|
|
|
|
worker = SpecDecodeWorker(
|
|
proposer_worker=draft_worker,
|
|
scorer_worker=target_worker,
|
|
spec_decode_sampler=mock_spec_decode_sampler(
|
|
acceptance_sampler_method),
|
|
disable_logprobs=False,
|
|
metrics_collector=metrics_collector,
|
|
)
|
|
|
|
seq_group_metadata_list, _, _ = create_batch(batch_size,
|
|
k,
|
|
prev_output_token_len=0)
|
|
execute_model_req = ExecuteModelRequest(
|
|
seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)
|
|
|
|
out = worker.execute_model(execute_model_req=execute_model_req)
|
|
|
|
assert len(out) == 1, f"expected only one token output when {k=}"
|
|
assert out[0].sampled_token_probs is None, (
|
|
"expect gpu tensor references to be None")
|
|
assert out[
|
|
0].sampled_token_ids is None, "expect gpu tensor references to be None"
|
|
|
|
draft_worker.execute_model.assert_called_once_with(execute_model_req)
|
|
target_worker.execute_model.assert_called_once_with(execute_model_req)
|
|
|
|
|
|
@pytest.mark.parametrize('k', [0, 5])
|
|
@pytest.mark.parametrize('batch_size', [0])
|
|
@pytest.mark.parametrize("acceptance_sampler_method",
|
|
["rejection_sampler", "typical_acceptance_sampler"])
|
|
@torch.inference_mode()
|
|
def test_empty_input_batch(k: int, batch_size: int,
|
|
acceptance_sampler_method: str):
|
|
"""Verify that the SpecDecodeWorker calls the draft and target workers
|
|
when the input batch is empty. This can happen if the engine communicates
|
|
to the workers information without scheduling a batch.
|
|
"""
|
|
draft_worker = mock_worker(cls=MultiStepWorker)
|
|
target_worker = mock_worker()
|
|
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
|
|
|
sampler_output = MagicMock(spec=SamplerOutput)
|
|
sampler_output.hidden_states = None
|
|
target_worker.execute_model.return_value = [sampler_output]
|
|
|
|
draft_worker.device = 'cuda'
|
|
target_worker.device = 'cuda'
|
|
|
|
set_random_seed(1)
|
|
|
|
worker = SpecDecodeWorker(
|
|
proposer_worker=draft_worker,
|
|
scorer_worker=target_worker,
|
|
spec_decode_sampler=mock_spec_decode_sampler(
|
|
acceptance_sampler_method),
|
|
disable_logprobs=False,
|
|
metrics_collector=metrics_collector,
|
|
)
|
|
|
|
seq_group_metadata_list, _, _ = create_batch(batch_size,
|
|
k,
|
|
prev_output_token_len=0)
|
|
execute_model_req = ExecuteModelRequest(
|
|
seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)
|
|
|
|
out = worker.execute_model(execute_model_req=execute_model_req)
|
|
|
|
assert len(out) == 1, f"expected only one token output when {k=}"
|
|
assert out[0].sampled_token_probs is None, (
|
|
"expect gpu tensor references to be None")
|
|
assert out[
|
|
0].sampled_token_ids is None, "expect gpu tensor references to be None"
|
|
|
|
draft_worker.execute_model.assert_called_once_with(execute_model_req)
|
|
target_worker.execute_model.assert_called_once_with(execute_model_req)
|
|
|
|
|
|
@pytest.mark.parametrize("acceptance_sampler_method",
|
|
["rejection_sampler", "typical_acceptance_sampler"])
|
|
@pytest.mark.skip_global_cleanup
|
|
def test_init_device(acceptance_sampler_method: str):
|
|
"""Verify SpecDecodeWorker invokes proposer/scorer worker init_device, as
|
|
well as other GPU initialization.
|
|
"""
|
|
draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False)
|
|
target_worker = mock_worker(use_spec=False)
|
|
spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
|
|
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
|
|
|
worker = SpecDecodeWorker(
|
|
proposer_worker=draft_worker,
|
|
scorer_worker=target_worker,
|
|
spec_decode_sampler=spec_decode_sampler,
|
|
disable_logprobs=False,
|
|
metrics_collector=metrics_collector,
|
|
)
|
|
worker.init_device()
|
|
|
|
draft_worker.init_device.assert_called_once()
|
|
|
|
target_worker.init_device.assert_called_once()
|
|
|
|
metrics_collector.init_tensors.assert_called_once()
|
|
spec_decode_sampler.init_tensors.assert_called_once()
|
|
|
|
|
|
@pytest.mark.parametrize("acceptance_sampler_method",
|
|
["rejection_sampler", "typical_acceptance_sampler"])
|
|
@torch.inference_mode()
|
|
def test_initialize_cache(acceptance_sampler_method):
|
|
"""Verify SpecDecodeWorker invokes initialize_cache on proposer/scorer
|
|
workers.
|
|
"""
|
|
draft_worker = mock_worker(cls=MultiStepWorker)
|
|
target_worker = mock_worker()
|
|
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
|
|
|
worker = SpecDecodeWorker(proposer_worker=draft_worker,
|
|
scorer_worker=target_worker,
|
|
spec_decode_sampler=mock_spec_decode_sampler(
|
|
acceptance_sampler_method),
|
|
metrics_collector=metrics_collector)
|
|
|
|
kwargs = {"num_gpu_blocks": 1024, "num_cpu_blocks": 1023}
|
|
worker.initialize_cache(**kwargs)
|
|
|
|
draft_worker.initialize_cache.assert_called_once_with(**kwargs)
|
|
target_worker.initialize_cache.assert_called_once_with(**kwargs)
|
|
|
|
|
|
@pytest.mark.parametrize('available_gpu_blocks', [1, 1024])
|
|
@pytest.mark.parametrize('available_cpu_blocks', [500])
|
|
@pytest.mark.parametrize('target_cache_block_size_bytes', [2 * 2 * 4096])
|
|
@pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096])
|
|
@pytest.mark.parametrize("acceptance_sampler_method",
|
|
["rejection_sampler", "typical_acceptance_sampler"])
|
|
@pytest.mark.skip_global_cleanup
|
|
def test_determine_num_available_blocks(available_gpu_blocks: int,
|
|
available_cpu_blocks: int,
|
|
target_cache_block_size_bytes: int,
|
|
draft_kv_size_bytes: int,
|
|
acceptance_sampler_method: str):
|
|
"""Verify SpecDecodeWorker correctly profiles num available GPU blocks.
|
|
Specifically, it should run profiling in the scorer worker, and then evenly
|
|
split the blocks between proposer and scorer worker.
|
|
"""
|
|
draft_worker = mock_worker(cls=MultiStepWorker)
|
|
target_worker = mock_worker()
|
|
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
|
|
|
target_worker.determine_num_available_blocks.return_value = (
|
|
available_gpu_blocks, available_cpu_blocks)
|
|
target_worker.get_cache_block_size_bytes.return_value = (
|
|
target_cache_block_size_bytes)
|
|
draft_worker.get_cache_block_size_bytes.return_value = draft_kv_size_bytes
|
|
|
|
worker = SpecDecodeWorker(
|
|
draft_worker, target_worker,
|
|
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
|
|
|
|
num_gpu_blocks, num_cpu_blocks = worker.determine_num_available_blocks()
|
|
|
|
target_worker.determine_num_available_blocks.assert_called_once()
|
|
assert num_cpu_blocks == available_cpu_blocks
|
|
|
|
assert num_gpu_blocks == split_num_cache_blocks_evenly(
|
|
target_cache_block_size_bytes, draft_kv_size_bytes,
|
|
available_gpu_blocks)
|
|
|
|
|
|
@pytest.mark.parametrize('available_gpu_blocks',
|
|
list(range(20)) + [1024, 1024**2])
|
|
@pytest.mark.parametrize('target_cache_block_size_bytes',
|
|
[2 * 2 * 4096, 2 * 2 * 8192])
|
|
@pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096])
|
|
@pytest.mark.skip_global_cleanup
|
|
def test_split_num_cache_blocks_evenly(available_gpu_blocks: int,
|
|
target_cache_block_size_bytes: int,
|
|
draft_kv_size_bytes: int):
|
|
"""Verify split_num_cache_blocks_evenly does not exceed original memory
|
|
allocation in bytes.
|
|
"""
|
|
num_blocks = split_num_cache_blocks_evenly(target_cache_block_size_bytes,
|
|
draft_kv_size_bytes,
|
|
available_gpu_blocks)
|
|
assert (num_blocks * target_cache_block_size_bytes) + (
|
|
num_blocks * draft_kv_size_bytes) <= (available_gpu_blocks *
|
|
target_cache_block_size_bytes)
|
|
|
|
|
|
@torch.inference_mode()
|
|
def test_populate_seq_ids_with_bonus_tokens():
|
|
"""
|
|
Verify that a call to _create_output_sampler_list correctly updates
|
|
seq_with_bonus_token_in_last_step.
|
|
|
|
seq_with_bonus_token_in_last_step is an internal data structure in
|
|
SpecDecodeWorker that tracks the sequence IDs which are assigned bonus
|
|
tokens by the target model in their last forward pass. This state is
|
|
maintained only for models relying on the KV cache, such as those using
|
|
the MultiStepWorker.
|
|
"""
|
|
batch_size = 10
|
|
k = 5
|
|
vocab_size = 10000
|
|
num_sequences_with_bonus_tokens = 5
|
|
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
|
|
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
|
target_worker.execute_model.return_value = [MagicMock(spec=SamplerOutput)]
|
|
target_worker.device = 'cuda'
|
|
|
|
set_random_seed(1)
|
|
draft_worker = mock_worker(cls=MultiStepWorker)
|
|
draft_worker.device = 'cuda'
|
|
# The sequence_ids attached to each sequence in the batch.
|
|
# The sequence at index i has seq_id assigned_seq_ids[i]
|
|
assigned_seq_ids = list(range(batch_size))
|
|
seq_group_metadata_list, _, _ = create_batch(batch_size,
|
|
k,
|
|
seq_ids=assigned_seq_ids,
|
|
prev_output_token_len=10)
|
|
target_token_logprobs = torch.rand(batch_size, (k + 1),
|
|
vocab_size,
|
|
dtype=torch.float32,
|
|
device='cuda')
|
|
accepted_token_ids = torch.randint(low=0,
|
|
high=vocab_size,
|
|
size=(batch_size, (k + 1)),
|
|
dtype=torch.int64,
|
|
device='cuda')
|
|
expected_request_id_seq_ids_mapping: Dict[str, Set[int]] = defaultdict(set)
|
|
for seq_group_metadata in seq_group_metadata_list:
|
|
for seq_id in seq_group_metadata.seq_data:
|
|
expected_request_id_seq_ids_mapping[
|
|
seq_group_metadata.request_id].add(seq_id)
|
|
# Generate a random sample of sequence indexes with bonus tokens
|
|
seq_indexes_with_bonus_tokens = random.sample(
|
|
range(batch_size), num_sequences_with_bonus_tokens)
|
|
# Create a mask that is True for indices in seq_indexes_with_bonus_tokens
|
|
mask = torch.ones(batch_size, dtype=torch.bool, device='cuda')
|
|
mask[seq_indexes_with_bonus_tokens] = False
|
|
# Set the last token ID to -1 for all indices not in
|
|
# seq_indexes_with_bonus_tokens to indicate the lack of bonus token in
|
|
# those indices.
|
|
accepted_token_ids[mask, -1:] = -1
|
|
worker = SpecDecodeWorker(draft_worker,
|
|
target_worker,
|
|
mock_spec_decode_sampler("rejection_sampler"),
|
|
disable_logprobs=False,
|
|
metrics_collector=metrics_collector)
|
|
# Initialize _seq_with_bonus_token_in_last_step with a set of sequence IDs.
|
|
# This set includes all sequence IDs in the batch as well as an additional
|
|
# `num_extra_sequence_ids` sequence IDs. Note that the sequence IDs are in
|
|
# the range [0, batch_size + num_extra_sequence_ids).
|
|
num_extra_sequence_ids = 10
|
|
worker._seq_with_bonus_token_in_last_step = set(
|
|
range(batch_size + num_extra_sequence_ids))
|
|
worker._create_output_sampler_list(
|
|
seq_group_metadata_list=seq_group_metadata_list,
|
|
accepted_token_ids=accepted_token_ids,
|
|
target_logprobs=target_token_logprobs,
|
|
prompt_logprobs=None,
|
|
k=k,
|
|
stage_times=(0, 0, 0))
|
|
# Verify that _seq_with_bonus_token_in_last_step contains the following:
|
|
# 1. Sequence IDs that were already present in
|
|
# _seq_with_bonus_token_in_last_step but were not part of the current
|
|
# batch are retained.
|
|
# 2. Of the sequence IDs present in the current batch, only those with a
|
|
# bonus token are retained in _seq_with_bonus_token_in_last_step.
|
|
# Sequence IDs that are present in the current batch but do not have
|
|
# bonus tokens are removed from _seq_with_bonus_token_in_last_step.
|
|
expected_seq_ids_with_bonus_tokens = \
|
|
set([assigned_seq_ids[i] for i in seq_indexes_with_bonus_tokens])
|
|
additional_sequence_ids = \
|
|
set(range(batch_size, batch_size + num_extra_sequence_ids))
|
|
assert worker._seq_with_bonus_token_in_last_step == \
|
|
expected_seq_ids_with_bonus_tokens.union(additional_sequence_ids)
|
|
assert worker._request_id_seq_id_mapping == \
|
|
expected_request_id_seq_ids_mapping
|
|
|
|
|
|
@torch.inference_mode()
|
|
def test_handle_finished_requests():
|
|
"""
|
|
Test to verify that finished request IDs are appropriately processed to
|
|
update the internal state of the SpecDecodeWorker.
|
|
|
|
This test initializes the SpecDecodeWorker with mock data, marks certain
|
|
requests as finished, and ensures that the corresponding sequence IDs are
|
|
correctly removed from the internal mappings.
|
|
"""
|
|
batch_size = 32
|
|
k = 3
|
|
draft_worker = mock_worker(cls=MultiStepWorker)
|
|
target_worker = mock_worker()
|
|
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
|
worker = SpecDecodeWorker(draft_worker, target_worker,
|
|
mock_spec_decode_sampler("rejection_sampler"),
|
|
metrics_collector)
|
|
# Initialize the request_id_seq_id_mapping mapping dict with a few fake
|
|
# request ids and corresponding sequence ids.
|
|
worker._request_id_seq_id_mapping = \
|
|
{'request-1': {1,2,3}, 'request-2': {4,5,6,7},
|
|
'request-3': {8,9}, 'request-4': {10,11}}
|
|
# Initialize seq_with_bonus_token_in_last_step with a few fake
|
|
# sequence ids.
|
|
worker._seq_with_bonus_token_in_last_step = {1, 4, 5, 8, 9, 10}
|
|
exception_secret = 'artificial stop'
|
|
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
|
|
|
|
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
|
|
# Mark requests with ids request-1 and request-3 as finished.
|
|
execute_model_req = ExecuteModelRequest(
|
|
seq_group_metadata_list=seq_group_metadata_list,
|
|
num_lookahead_slots=k,
|
|
finished_requests_ids=['request-1', 'request-3'])
|
|
|
|
with pytest.raises(ValueError, match=exception_secret):
|
|
worker.execute_model(execute_model_req=execute_model_req)
|
|
# Verify that request-1 and request-3 are removed from
|
|
# request_id_seq_id_mapping
|
|
assert worker._request_id_seq_id_mapping == \
|
|
{'request-2': {4,5,6,7}, 'request-4': {10,11}}
|
|
# Verify that all sequence ids corresponding to 'request-1'
|
|
# and 'request-3' are removed from seq_with_bonus_token_in_last_step.
|
|
assert worker._seq_with_bonus_token_in_last_step == \
|
|
{4,5,10}
|
|
|
|
|
|
@pytest.mark.parametrize('k', [3])
|
|
@pytest.mark.parametrize('batch_size', [2, 32])
|
|
@pytest.mark.parametrize("batch_composition",
|
|
["prefill_only", "decode_only", "mixed"])
|
|
@torch.inference_mode()
|
|
def test_chunked_prefill_flow(k: int, batch_size: int, batch_composition: str):
|
|
"""
|
|
Verify SpecDecodeWorker calls match the expected flow.
|
|
"""
|
|
vocab_size = 32_000
|
|
draft_worker = mock_worker(cls=MultiStepWorker)
|
|
target_worker = mock_worker()
|
|
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
|
worker = SpecDecodeWorker(draft_worker,
|
|
target_worker,
|
|
mock_spec_decode_sampler("rejection_sampler"),
|
|
disable_logprobs=False,
|
|
metrics_collector=metrics_collector)
|
|
exception_secret = 'artificial stop'
|
|
worker.scorer = mock_worker(BatchExpansionTop1Scorer)
|
|
worker.scorer.score_proposals.side_effect = ValueError(exception_secret)
|
|
|
|
# Create batch with combination of terminal/non-terminal prefill chunks
|
|
# and decodes (different seq_ids).
|
|
decodes, _, _ = create_batch(batch_size, k)
|
|
# Pre-chunking here, get 'batch_size' chunks.
|
|
prefill, _, _ = create_batch(batch_size,
|
|
k,
|
|
prefill_chunk_size=4,
|
|
seq_ids=list(range(batch_size,
|
|
batch_size * 2)))
|
|
|
|
if batch_composition == "prefill_only":
|
|
n_prefills = batch_size
|
|
elif batch_composition == "decode_only":
|
|
n_prefills = 0
|
|
else:
|
|
n_prefills = random.randint(1, batch_size - 1)
|
|
n_decodes = batch_size - n_prefills
|
|
|
|
prefill = random.sample(prefill, n_prefills)
|
|
decodes = random.sample(decodes, n_decodes)
|
|
target_group_metadata_list = prefill + decodes
|
|
execute_model_req = ExecuteModelRequest(
|
|
seq_group_metadata_list=target_group_metadata_list,
|
|
# For prefill only batches we expect num_lookahead_slots = 0.
|
|
num_lookahead_slots=k if n_decodes > 0 else 0)
|
|
|
|
target_token_ids = torch.randint(low=0,
|
|
high=vocab_size,
|
|
size=(1, batch_size * (k + 1)),
|
|
dtype=torch.int64,
|
|
device='cuda')
|
|
target_token_probs = torch.rand(1,
|
|
batch_size * (k + 1),
|
|
vocab_size,
|
|
dtype=torch.float32,
|
|
device='cuda')
|
|
target_token_logprobs = torch.rand(1,
|
|
batch_size * (k + 1),
|
|
vocab_size,
|
|
dtype=torch.float32,
|
|
device='cuda')
|
|
target_output = create_sampler_output_list(target_token_ids,
|
|
target_token_probs,
|
|
target_token_logprobs)
|
|
|
|
target_worker.execute_model.return_value = [target_output[0]]
|
|
|
|
if not len(decodes):
|
|
worker.execute_model(execute_model_req=execute_model_req)
|
|
# no spec run (prefill only)
|
|
draft_worker.execute_model.assert_called_once_with(execute_model_req)
|
|
target_worker.execute_model.assert_called_once_with(execute_model_req)
|
|
else:
|
|
# Decode-only run OR mixed batch, scorer call fails (it's mocked)
|
|
with pytest.raises(ValueError, match=exception_secret):
|
|
worker.execute_model(execute_model_req=execute_model_req)
|
|
# but first draft still counted
|
|
assert draft_worker.get_spec_proposals.call_count == 1
|