vllm/tests/spec_decode/test_spec_decode_worker.py

592 lines
24 KiB
Python
Raw Normal View History

import torch
import random
import pytest
from unittest.mock import MagicMock
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker, split_num_cache_blocks_evenly
from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.model_executor.utils import set_random_seed
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from .utils import mock_worker, create_batch, ExecuteModelData, create_sampler_output_list
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics, AsyncMetricsCollector
@pytest.mark.parametrize('k', [1, 2, 6])
@pytest.mark.parametrize('batch_size', [1, 2, 32])
@torch.inference_mode()
def test_correctly_calls_draft_model(k: int, batch_size: int):
"""Verify SpecDecodeWorker calls the draft worker with correct
inputs. Everything else is mocked out.
"""
draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker()
rejection_sampler = MagicMock(spec=RejectionSampler)
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
metrics_collector)
exception_secret = 'artifical stop'
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
execute_model_data, _, _ = create_batch(batch_size, k)
with pytest.raises(ValueError, match=exception_secret):
worker.execute_model(**execute_model_data.to_dict(), num_spec_tokens=k)
call_args_list = draft_worker.get_spec_proposals.call_args_list
assert len(call_args_list) == 1
for args, _ in call_args_list:
(seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out,
blocks_to_copy, actual_k) = args
actual_execute_model_data = ExecuteModelData(seq_group_metadata_list,
blocks_to_swap_in,
blocks_to_swap_out,
blocks_to_copy)
assert actual_execute_model_data == execute_model_data
assert actual_k == k
@pytest.mark.parametrize('k', [1, 2, 6])
@pytest.mark.parametrize('batch_size', [1, 2, 32])
@torch.inference_mode()
def test_correctly_calls_target_model(k: int, batch_size: int):
"""Verify SpecDecodeWorker calls the target model with correct
inputs. Everything else is mocked out.
"""
draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker()
rejection_sampler = MagicMock(spec=RejectionSampler)
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
draft_worker.device = 'cuda'
target_worker.device = 'cuda'
set_random_seed(1)
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
metrics_collector)
worker.init_model()
vocab_size = 32_000
proposal_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64,
device='cuda')
proposal_probs = torch.rand(batch_size,
k,
vocab_size,
dtype=torch.float32,
device='cuda')
proposal_lens = torch.ones(batch_size, dtype=torch.int64,
device='cuda') * k
execute_model_data, prompts, prev_output_tokens = create_batch(
batch_size, k)
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
proposal_token_ids=proposal_token_ids,
proposal_probs=proposal_probs,
proposal_lens=proposal_lens)
exception_secret = 'artifical stop'
target_worker.execute_model.side_effect = ValueError(exception_secret)
with pytest.raises(ValueError, match=exception_secret):
worker.execute_model(**execute_model_data.to_dict(), num_spec_tokens=k)
seen_contexts = []
call_args_list = target_worker.execute_model.call_args_list
assert len(call_args_list) == 1
for args, kwargs in call_args_list:
target_execute_model_data = ExecuteModelData.from_dict(kwargs)
assert len(target_execute_model_data.seq_group_metadata_list) == (
k + 1) * batch_size
for seq_group_metadata in (
target_execute_model_data.seq_group_metadata_list):
for seq_data in seq_group_metadata.seq_data.values():
seen_contexts.append(seq_data.get_token_ids())
expected_seen_contexts = []
for prompt, prev_generated, draft_tokens in zip(
prompts, prev_output_tokens, proposal_token_ids.tolist()):
for i in range(len(draft_tokens) + 1):
expected_seen_contexts.append(prompt + prev_generated +
draft_tokens[:i])
seen_contexts.sort()
expected_seen_contexts.sort()
assert expected_seen_contexts == seen_contexts
@pytest.mark.parametrize('k', [1, 2, 6])
@pytest.mark.parametrize('batch_size', [1, 2, 32])
@torch.inference_mode()
def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
"""Verify SpecDecodeWorker calls the rejection sampler with
correct inputs. Everything else is mocked out.
"""
vocab_size = 32_000
draft_worker = mock_worker(cls=MultiStepWorker, vocab_size=vocab_size)
target_worker = mock_worker(vocab_size=vocab_size)
rejection_sampler = MagicMock(spec=RejectionSampler)
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
draft_worker.device = 'cuda'
target_worker.device = 'cuda'
set_random_seed(1)
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
metrics_collector)
worker.init_model()
proposal_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64,
device='cuda')
proposal_probs = torch.rand(batch_size,
k,
vocab_size,
dtype=torch.float32,
device='cuda')
proposal_lens = torch.ones(batch_size, dtype=torch.int64,
device='cuda') * k
execute_model_data, _, _ = create_batch(batch_size, k)
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
proposal_token_ids=proposal_token_ids,
proposal_probs=proposal_probs,
proposal_lens=proposal_lens)
target_token_ids = torch.randint(low=0,
high=vocab_size,
size=(1, batch_size * (k + 1)),
dtype=torch.int64,
device='cuda')
target_token_probs = torch.rand(1,
batch_size * (k + 1),
vocab_size,
dtype=torch.float32,
device='cuda')
target_output = create_sampler_output_list(target_token_ids,
target_token_probs)
target_worker.execute_model.return_value = target_output[0]
exception_secret = 'artifical stop'
rejection_sampler.side_effect = ValueError(exception_secret)
with pytest.raises(ValueError, match=exception_secret):
worker.execute_model(**execute_model_data.to_dict(), num_spec_tokens=k)
assert len(rejection_sampler.call_args_list) == 1
args, _ = rejection_sampler.call_args_list[0]
(actual_proposal_scores, actual_bonus_token_ids, actual_proposal_probs,
actual_proposal_token_ids) = args
assert torch.equal(actual_bonus_token_ids,
target_token_ids.reshape(batch_size, k + 1)[:, -1:])
assert torch.equal(
actual_proposal_scores,
target_token_probs.reshape(batch_size, k + 1, -1)[:, :-1])
assert torch.equal(actual_proposal_token_ids, proposal_token_ids)
assert torch.equal(actual_proposal_probs, proposal_probs)
@pytest.mark.parametrize('k', [1, 2, 6])
@pytest.mark.parametrize('batch_size', [1, 2, 32])
@torch.inference_mode()
def test_correctly_formats_output(k: int, batch_size: int):
"""Verify SpecDecodeWorker formats sampler output correctly.
Everything else is mocked out.
"""
vocab_size = 32_000
draft_worker = mock_worker(cls=MultiStepWorker, vocab_size=vocab_size)
target_worker = mock_worker(vocab_size=vocab_size)
rejection_sampler = MagicMock(spec=RejectionSampler)
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
draft_worker.device = 'cuda'
target_worker.device = 'cuda'
set_random_seed(1)
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
metrics_collector)
worker.init_model()
proposal_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64,
device='cuda')
proposal_probs = torch.rand(batch_size,
k,
vocab_size,
dtype=torch.float32,
device='cuda')
proposal_lens = torch.ones(batch_size, dtype=torch.int64,
device='cuda') * k
execute_model_data, _, _ = create_batch(batch_size, k)
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
proposal_token_ids=proposal_token_ids,
proposal_probs=proposal_probs,
proposal_lens=proposal_lens)
target_token_ids = torch.randint(low=0,
high=vocab_size,
size=(1, batch_size * (k + 1)),
dtype=torch.int64,
device='cuda')
target_token_probs = torch.rand(1,
batch_size * (k + 1),
vocab_size,
dtype=torch.float32,
device='cuda')
target_output = create_sampler_output_list(target_token_ids,
target_token_probs)
target_worker.execute_model.return_value = target_output[0]
rejection_sampler_output = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k + 1),
dtype=torch.int64,
device='cuda')
for i in range(batch_size):
minimum_accepted_tokens = 1
rejection_sampler_output[i][
-random.randint(minimum_accepted_tokens, k + 1):] = -1
rejection_sampler.return_value = rejection_sampler_output
output = worker.execute_model(**execute_model_data.to_dict(),
num_spec_tokens=k)
expected_output = create_sampler_output_list(
rejection_sampler_output.transpose(0, 1), [None for _ in range(k + 1)])
seq_ids = [
next(iter(seq_group_metadata.seq_data.keys()))
for seq_group_metadata in execute_model_data.seq_group_metadata_list
]
actual_output_by_seq = {seq_id: [] for seq_id in seq_ids}
expected_output_by_seq = {seq_id: [] for seq_id in seq_ids}
for step in output:
for seq_group in step:
for sample in seq_group.samples:
seq_id = sample.parent_seq_id
actual_output_by_seq[seq_id].append(sample)
for step in expected_output:
for seq_group in step:
for sample in seq_group.samples:
seq_id = sample.parent_seq_id
expected_output_by_seq[seq_id].append(sample)
all_seen_seq_ids = set(
list(actual_output_by_seq.keys()) +
list(expected_output_by_seq.keys()))
for seq_id in all_seen_seq_ids:
actual_by_step = actual_output_by_seq[seq_id]
expected_by_step = expected_output_by_seq[seq_id]
for i in range(k + 1):
if i >= len(actual_by_step):
assert expected_by_step[i].output_token == -1
continue
assert actual_by_step[i].output_token == expected_by_step[
i].output_token
assert actual_by_step[i].logprobs == expected_by_step[i].logprobs
@pytest.mark.parametrize('k', [1, 2])
@pytest.mark.parametrize('batch_size', [1])
@pytest.mark.parametrize('returns_metrics', [True, False])
@torch.inference_mode()
def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
"""Verify SpecDecodeWorker collects metrics.
"""
vocab_size = 32_000
draft_worker = mock_worker(cls=MultiStepWorker, vocab_size=vocab_size)
target_worker = mock_worker(vocab_size=vocab_size)
rejection_sampler = MagicMock(spec=RejectionSampler)
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
draft_worker.device = 'cuda'
target_worker.device = 'cuda'
set_random_seed(1)
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
metrics_collector)
worker.init_model()
proposal_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64,
device='cuda')
proposal_probs = torch.rand(batch_size,
k,
vocab_size,
dtype=torch.float32,
device='cuda')
proposal_lens = torch.ones(batch_size, dtype=torch.int64,
device='cuda') * k
execute_model_data, _, _ = create_batch(batch_size, k)
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
proposal_token_ids=proposal_token_ids,
proposal_probs=proposal_probs,
proposal_lens=proposal_lens)
target_token_ids = torch.randint(low=0,
high=vocab_size,
size=(1, batch_size * (k + 1)),
dtype=torch.int64,
device='cuda')
target_token_probs = torch.rand(1,
batch_size * (k + 1),
vocab_size,
dtype=torch.float32,
device='cuda')
target_output = create_sampler_output_list(target_token_ids,
target_token_probs)
target_worker.execute_model.return_value = target_output[0]
rejection_sampler_output = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k + 1),
dtype=torch.int64,
device='cuda')
for i in range(batch_size):
minimum_accepted_tokens = 1
rejection_sampler_output[i][
-random.randint(minimum_accepted_tokens, k + 1):] = -1
rejection_sampler.return_value = rejection_sampler_output
mock_rejsample_metrics = MagicMock(
spec=SpecDecodeWorkerMetrics) if returns_metrics else None
metrics_collector.maybe_collect_rejsample_metrics.return_value = mock_rejsample_metrics
output = worker.execute_model(**execute_model_data.to_dict(),
num_spec_tokens=k)
assert output[0].spec_decode_worker_metrics == mock_rejsample_metrics
call_args_list = metrics_collector.maybe_collect_rejsample_metrics.call_args_list
assert len(call_args_list) == 1
args, kwargs = call_args_list[0]
assert args[0] == k or kwargs.get('k', -1) == k
@pytest.mark.parametrize('k', [0])
@pytest.mark.parametrize('batch_size', [1, 2, 32])
@torch.inference_mode()
def test_k_equals_zero(k: int, batch_size: int):
"""Verify that the SpecDecodeWorker calls the draft and target workers
when k is zero. This happens during prefill.
"""
draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker()
rejection_sampler = MagicMock(spec=RejectionSampler)
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
draft_worker.device = 'cuda'
target_worker.device = 'cuda'
set_random_seed(1)
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
metrics_collector)
execute_model_data, prompts, prev_output_tokens = create_batch(
batch_size, k, prev_output_token_len=0)
out = worker.execute_model(**execute_model_data.to_dict(),
num_spec_tokens=k)
assert len(out) == 1, f"expected only one token output when {k=}"
assert out[0].probs is None, "expect gpu tensor references to be None"
assert out[
0].sampled_tokens is None, "expect gpu tensor references to be None"
draft_worker.execute_model.assert_called_once_with(
**execute_model_data.to_dict(), return_python_output=False)
target_worker.execute_model.assert_called_once_with(
**execute_model_data.to_dict())
@pytest.mark.parametrize('k', [0, 5])
@pytest.mark.parametrize('batch_size', [0])
@torch.inference_mode()
def test_empty_input_batch(k: int, batch_size: int):
"""Verify that the SpecDecodeWorker calls the draft and target workers
when the input batch is empty. This can happen if the engine communicates
to the workers information without scheduling a batch.
"""
draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker()
rejection_sampler = MagicMock(spec=RejectionSampler)
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
draft_worker.device = 'cuda'
target_worker.device = 'cuda'
set_random_seed(1)
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
metrics_collector)
execute_model_data, prompts, prev_output_tokens = create_batch(
batch_size, k, prev_output_token_len=0)
out = worker.execute_model(**execute_model_data.to_dict(),
num_spec_tokens=k)
assert len(out) == 1, f"expected only one token output when {k=}"
assert out[0].probs is None, "expect gpu tensor references to be None"
assert out[
0].sampled_tokens is None, "expect gpu tensor references to be None"
draft_worker.execute_model.assert_called_once_with(
**execute_model_data.to_dict(), return_python_output=False)
target_worker.execute_model.assert_called_once_with(
**execute_model_data.to_dict())
@torch.inference_mode()
def test_init_model():
"""Verify SpecDecodeWorker invokes proposer/scorer worker init_model, as
well as other GPU initialization.
"""
draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker()
rejection_sampler = MagicMock(spec=RejectionSampler)
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
metrics_collector)
worker.init_model()
draft_worker.init_model.assert_called_once()
target_worker.init_model.assert_called_once()
metrics_collector.init_gpu_tensors.assert_called_once()
rejection_sampler.init_gpu_tensors.assert_called_once()
@torch.inference_mode()
def test_init_cache_engine():
"""Verify SpecDecodeWorker invokes init_cache_engine on proposer/scorer
workers.
"""
draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker()
rejection_sampler = MagicMock(spec=RejectionSampler)
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
metrics_collector)
cache_config = MagicMock()
worker.init_cache_engine(cache_config)
draft_worker.init_cache_engine.assert_called_once_with(cache_config)
target_worker.init_cache_engine.assert_called_once_with(cache_config)
@pytest.mark.parametrize('available_gpu_blocks', [1, 1024])
@pytest.mark.parametrize('available_cpu_blocks', [500])
@pytest.mark.parametrize('target_cache_block_size_bytes', [2 * 2 * 4096])
@pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096])
@torch.inference_mode()
def test_profile_num_available_blocks(available_gpu_blocks: int,
available_cpu_blocks: int,
target_cache_block_size_bytes: int,
draft_kv_size_bytes: int):
"""Verify SpecDecodeWorker correctly profiles num available GPU blocks.
Specifically, it should run profiling in the scorer worker, and then evenly
split the blocks between proposer and scorer worker.
"""
draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker()
rejection_sampler = MagicMock(spec=RejectionSampler)
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
target_worker.profile_num_available_blocks.return_value = (
available_gpu_blocks, available_cpu_blocks)
target_worker.get_cache_block_size_bytes.return_value = target_cache_block_size_bytes
draft_worker.get_cache_block_size_bytes.return_value = draft_kv_size_bytes
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
metrics_collector)
# These values do not directly impact the adjusted block size calculation,
# so they can be fixed.
gpu_memory_utilization = 0.9
cpu_swap_space = 100
block_size = 16
num_gpu_blocks, num_cpu_blocks = worker.profile_num_available_blocks(
block_size, gpu_memory_utilization, cpu_swap_space, cache_dtype="auto")
target_worker.profile_num_available_blocks.assert_called_once_with(
block_size, gpu_memory_utilization, cpu_swap_space, "auto")
assert num_cpu_blocks == available_cpu_blocks
assert num_gpu_blocks == split_num_cache_blocks_evenly(
target_cache_block_size_bytes, draft_kv_size_bytes,
available_gpu_blocks)
@pytest.mark.parametrize('available_gpu_blocks',
list(range(20)) + [1024, 1024**2])
@pytest.mark.parametrize('target_cache_block_size_bytes',
[2 * 2 * 4096, 2 * 2 * 8192])
@pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096])
@torch.inference_mode()
def test_split_num_cache_blocks_evenly(available_gpu_blocks: int,
target_cache_block_size_bytes: int,
draft_kv_size_bytes: int):
"""Verify split_num_cache_blocks_evenly does not exceed original memory
allocation in bytes.
"""
num_blocks = split_num_cache_blocks_evenly(target_cache_block_size_bytes,
draft_kv_size_bytes,
available_gpu_blocks)
assert (num_blocks * target_cache_block_size_bytes) + (
num_blocks * draft_kv_size_bytes) <= (available_gpu_blocks *
target_cache_block_size_bytes)