2024-03-08 23:32:46 -08:00
|
|
|
import random
|
2024-07-10 16:02:47 -07:00
|
|
|
from collections import defaultdict
|
2024-04-23 01:02:36 -07:00
|
|
|
from types import SimpleNamespace
|
2024-07-10 16:02:47 -07:00
|
|
|
from typing import Dict, List, Set
|
2024-03-08 23:32:46 -08:00
|
|
|
from unittest.mock import MagicMock
|
|
|
|
|
2024-03-25 23:59:47 +09:00
|
|
|
import pytest
|
|
|
|
import torch
|
|
|
|
|
2024-08-29 22:19:08 -04:00
|
|
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
2024-03-25 23:59:47 +09:00
|
|
|
from vllm.model_executor.utils import set_random_seed
|
2024-08-29 22:19:08 -04:00
|
|
|
from vllm.sequence import ExecuteModelRequest, SequenceOutput
|
2024-11-07 17:15:14 +01:00
|
|
|
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
|
2024-03-25 23:59:47 +09:00
|
|
|
from vllm.spec_decode.interfaces import SpeculativeProposals
|
|
|
|
from vllm.spec_decode.metrics import (AsyncMetricsCollector,
|
|
|
|
SpecDecodeWorkerMetrics)
|
2024-03-08 23:32:46 -08:00
|
|
|
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
2024-03-10 19:49:14 -07:00
|
|
|
from vllm.spec_decode.spec_decode_worker import (SpecDecodeWorker,
|
|
|
|
split_num_cache_blocks_evenly)
|
2024-03-25 23:59:47 +09:00
|
|
|
|
2024-07-01 00:33:05 -07:00
|
|
|
from .test_utils import mock_spec_decode_sampler
|
2024-05-03 17:47:07 -07:00
|
|
|
from .utils import create_batch, create_sampler_output_list, mock_worker
|
2024-03-08 23:32:46 -08:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize('k', [1, 2, 6])
|
|
|
|
@pytest.mark.parametrize('batch_size', [1, 2, 32])
|
2024-07-01 00:33:05 -07:00
|
|
|
@pytest.mark.parametrize("acceptance_sampler_method",
|
|
|
|
["rejection_sampler", "typical_acceptance_sampler"])
|
2024-03-08 23:32:46 -08:00
|
|
|
@torch.inference_mode()
|
2024-07-01 00:33:05 -07:00
|
|
|
def test_correctly_calls_draft_model(k: int, batch_size: int,
|
|
|
|
acceptance_sampler_method: str):
|
2024-03-08 23:32:46 -08:00
|
|
|
"""Verify SpecDecodeWorker calls the draft worker with correct
|
|
|
|
inputs. Everything else is mocked out.
|
|
|
|
"""
|
|
|
|
draft_worker = mock_worker(cls=MultiStepWorker)
|
|
|
|
target_worker = mock_worker()
|
|
|
|
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
2024-07-01 00:33:05 -07:00
|
|
|
worker = SpecDecodeWorker(
|
2024-08-05 01:46:44 -07:00
|
|
|
draft_worker,
|
|
|
|
target_worker,
|
|
|
|
mock_spec_decode_sampler(acceptance_sampler_method),
|
|
|
|
disable_logprobs=False,
|
|
|
|
metrics_collector=metrics_collector)
|
2024-05-02 11:24:13 +09:00
|
|
|
exception_secret = 'artificial stop'
|
2024-03-08 23:32:46 -08:00
|
|
|
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
|
|
|
|
|
2024-05-03 17:47:07 -07:00
|
|
|
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
|
|
|
|
execute_model_req = ExecuteModelRequest(
|
|
|
|
seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)
|
2024-03-08 23:32:46 -08:00
|
|
|
|
|
|
|
with pytest.raises(ValueError, match=exception_secret):
|
2024-05-03 17:47:07 -07:00
|
|
|
worker.execute_model(execute_model_req=execute_model_req)
|
2024-03-08 23:32:46 -08:00
|
|
|
|
|
|
|
call_args_list = draft_worker.get_spec_proposals.call_args_list
|
|
|
|
assert len(call_args_list) == 1
|
|
|
|
|
|
|
|
for args, _ in call_args_list:
|
2024-05-03 17:47:07 -07:00
|
|
|
actual_execute_model_data = args[0]
|
|
|
|
assert actual_execute_model_data == execute_model_req
|
2024-03-08 23:32:46 -08:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize('k', [1, 2, 6])
|
|
|
|
@pytest.mark.parametrize('batch_size', [1, 2, 32])
|
2024-07-01 00:33:05 -07:00
|
|
|
@pytest.mark.parametrize("acceptance_sampler_method",
|
|
|
|
["rejection_sampler", "typical_acceptance_sampler"])
|
2024-03-08 23:32:46 -08:00
|
|
|
@torch.inference_mode()
|
2024-10-01 16:04:42 -07:00
|
|
|
def test_batch_expansion_correctly_calls_target_model(
|
|
|
|
k: int, batch_size: int, acceptance_sampler_method: str):
|
2024-03-08 23:32:46 -08:00
|
|
|
"""Verify SpecDecodeWorker calls the target model with correct
|
2024-10-01 16:04:42 -07:00
|
|
|
inputs with batch expansion. Everything else is mocked out.
|
2024-03-08 23:32:46 -08:00
|
|
|
"""
|
2024-04-23 01:02:36 -07:00
|
|
|
draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False)
|
|
|
|
target_worker = mock_worker(use_spec=False)
|
2024-03-08 23:32:46 -08:00
|
|
|
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
|
|
|
|
|
|
|
draft_worker.device = 'cuda'
|
|
|
|
target_worker.device = 'cuda'
|
|
|
|
|
|
|
|
set_random_seed(1)
|
|
|
|
|
2024-07-01 00:33:05 -07:00
|
|
|
worker = SpecDecodeWorker(
|
2024-08-05 01:46:44 -07:00
|
|
|
draft_worker,
|
|
|
|
target_worker,
|
|
|
|
mock_spec_decode_sampler(acceptance_sampler_method),
|
|
|
|
disable_logprobs=False,
|
2024-10-01 16:04:42 -07:00
|
|
|
metrics_collector=metrics_collector,
|
|
|
|
disable_mqa_scorer=True)
|
2024-03-21 18:22:17 -07:00
|
|
|
worker.init_device()
|
2024-03-08 23:32:46 -08:00
|
|
|
|
|
|
|
vocab_size = 32_000
|
|
|
|
|
|
|
|
proposal_token_ids = torch.randint(low=0,
|
|
|
|
high=vocab_size,
|
|
|
|
size=(batch_size, k),
|
|
|
|
dtype=torch.int64,
|
|
|
|
device='cuda')
|
|
|
|
proposal_probs = torch.rand(batch_size,
|
|
|
|
k,
|
|
|
|
vocab_size,
|
|
|
|
dtype=torch.float32,
|
|
|
|
device='cuda')
|
|
|
|
proposal_lens = torch.ones(batch_size, dtype=torch.int64,
|
|
|
|
device='cuda') * k
|
|
|
|
|
2024-05-03 17:47:07 -07:00
|
|
|
seq_group_metadata_list, prompts, prev_output_tokens = create_batch(
|
2024-03-08 23:32:46 -08:00
|
|
|
batch_size, k)
|
|
|
|
|
|
|
|
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
|
|
|
|
proposal_token_ids=proposal_token_ids,
|
|
|
|
proposal_probs=proposal_probs,
|
|
|
|
proposal_lens=proposal_lens)
|
|
|
|
|
2024-05-02 11:24:13 +09:00
|
|
|
exception_secret = 'artificial stop'
|
2024-03-08 23:32:46 -08:00
|
|
|
target_worker.execute_model.side_effect = ValueError(exception_secret)
|
|
|
|
|
|
|
|
with pytest.raises(ValueError, match=exception_secret):
|
2024-05-03 17:47:07 -07:00
|
|
|
worker.execute_model(execute_model_req=ExecuteModelRequest(
|
|
|
|
seq_group_metadata_list=seq_group_metadata_list,
|
|
|
|
num_lookahead_slots=k))
|
2024-03-08 23:32:46 -08:00
|
|
|
|
2024-06-15 12:45:31 +08:00
|
|
|
seen_contexts: List[List[int]] = []
|
2024-03-08 23:32:46 -08:00
|
|
|
|
|
|
|
call_args_list = target_worker.execute_model.call_args_list
|
|
|
|
assert len(call_args_list) == 1
|
2024-05-03 17:47:07 -07:00
|
|
|
for _, kwargs in call_args_list:
|
|
|
|
seq_group_metadata_list = kwargs[
|
|
|
|
"execute_model_req"].seq_group_metadata_list
|
2024-03-08 23:32:46 -08:00
|
|
|
|
2024-05-03 17:47:07 -07:00
|
|
|
assert len(seq_group_metadata_list) == (k + 1) * batch_size
|
|
|
|
for seq_group_metadata in seq_group_metadata_list:
|
2024-03-08 23:32:46 -08:00
|
|
|
for seq_data in seq_group_metadata.seq_data.values():
|
|
|
|
seen_contexts.append(seq_data.get_token_ids())
|
|
|
|
|
2024-06-15 12:45:31 +08:00
|
|
|
expected_seen_contexts: List[List[int]] = []
|
2024-03-08 23:32:46 -08:00
|
|
|
|
|
|
|
for prompt, prev_generated, draft_tokens in zip(
|
|
|
|
prompts, prev_output_tokens, proposal_token_ids.tolist()):
|
|
|
|
|
|
|
|
for i in range(len(draft_tokens) + 1):
|
|
|
|
expected_seen_contexts.append(prompt + prev_generated +
|
|
|
|
draft_tokens[:i])
|
|
|
|
|
|
|
|
seen_contexts.sort()
|
|
|
|
expected_seen_contexts.sort()
|
|
|
|
assert expected_seen_contexts == seen_contexts
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize('k', [1, 2, 6])
|
|
|
|
@pytest.mark.parametrize('batch_size', [1, 2, 32])
|
2024-07-01 00:33:05 -07:00
|
|
|
@pytest.mark.parametrize("acceptance_sampler_method",
|
|
|
|
["rejection_sampler", "typical_acceptance_sampler"])
|
2024-03-08 23:32:46 -08:00
|
|
|
@torch.inference_mode()
|
2024-07-01 00:33:05 -07:00
|
|
|
def test_correctly_calls_spec_decode_sampler(k: int, batch_size: int,
|
|
|
|
acceptance_sampler_method: str):
|
2024-03-08 23:32:46 -08:00
|
|
|
"""Verify SpecDecodeWorker calls the rejection sampler with
|
|
|
|
correct inputs. Everything else is mocked out.
|
|
|
|
"""
|
|
|
|
vocab_size = 32_000
|
|
|
|
|
2024-04-23 01:02:36 -07:00
|
|
|
draft_worker = mock_worker(cls=MultiStepWorker,
|
|
|
|
vocab_size=vocab_size,
|
|
|
|
use_spec=False)
|
|
|
|
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
|
2024-07-01 00:33:05 -07:00
|
|
|
spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
|
2024-03-08 23:32:46 -08:00
|
|
|
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
|
|
|
draft_worker.device = 'cuda'
|
|
|
|
target_worker.device = 'cuda'
|
|
|
|
|
|
|
|
set_random_seed(1)
|
|
|
|
|
2024-08-05 01:46:44 -07:00
|
|
|
worker = SpecDecodeWorker(draft_worker,
|
|
|
|
target_worker,
|
|
|
|
spec_decode_sampler,
|
|
|
|
disable_logprobs=False,
|
|
|
|
metrics_collector=metrics_collector)
|
2024-03-21 18:22:17 -07:00
|
|
|
worker.init_device()
|
2024-03-08 23:32:46 -08:00
|
|
|
|
|
|
|
proposal_token_ids = torch.randint(low=0,
|
|
|
|
high=vocab_size,
|
|
|
|
size=(batch_size, k),
|
|
|
|
dtype=torch.int64,
|
|
|
|
device='cuda')
|
|
|
|
proposal_probs = torch.rand(batch_size,
|
|
|
|
k,
|
|
|
|
vocab_size,
|
|
|
|
dtype=torch.float32,
|
|
|
|
device='cuda')
|
|
|
|
|
|
|
|
proposal_lens = torch.ones(batch_size, dtype=torch.int64,
|
|
|
|
device='cuda') * k
|
|
|
|
|
2024-05-03 17:47:07 -07:00
|
|
|
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
|
2024-03-08 23:32:46 -08:00
|
|
|
|
|
|
|
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
|
|
|
|
proposal_token_ids=proposal_token_ids,
|
|
|
|
proposal_probs=proposal_probs,
|
|
|
|
proposal_lens=proposal_lens)
|
|
|
|
|
|
|
|
target_token_ids = torch.randint(low=0,
|
|
|
|
high=vocab_size,
|
|
|
|
size=(1, batch_size * (k + 1)),
|
|
|
|
dtype=torch.int64,
|
|
|
|
device='cuda')
|
|
|
|
target_token_probs = torch.rand(1,
|
|
|
|
batch_size * (k + 1),
|
|
|
|
vocab_size,
|
|
|
|
dtype=torch.float32,
|
|
|
|
device='cuda')
|
2024-05-03 15:52:01 -07:00
|
|
|
target_token_logprobs = torch.rand(1,
|
|
|
|
batch_size * (k + 1),
|
|
|
|
vocab_size,
|
|
|
|
dtype=torch.float32,
|
|
|
|
device='cuda')
|
2024-03-08 23:32:46 -08:00
|
|
|
target_output = create_sampler_output_list(target_token_ids,
|
2024-05-03 15:52:01 -07:00
|
|
|
target_token_probs,
|
|
|
|
target_token_logprobs)
|
2024-03-08 23:32:46 -08:00
|
|
|
|
2024-04-16 13:09:21 -07:00
|
|
|
target_worker.execute_model.return_value = [target_output[0]]
|
2024-03-08 23:32:46 -08:00
|
|
|
|
2024-05-02 11:24:13 +09:00
|
|
|
exception_secret = 'artificial stop'
|
2024-07-01 00:33:05 -07:00
|
|
|
|
|
|
|
spec_decode_sampler.side_effect = ValueError(exception_secret)
|
2024-03-08 23:32:46 -08:00
|
|
|
|
|
|
|
with pytest.raises(ValueError, match=exception_secret):
|
2024-05-03 17:47:07 -07:00
|
|
|
worker.execute_model(execute_model_req=ExecuteModelRequest(
|
|
|
|
seq_group_metadata_list=seq_group_metadata_list,
|
|
|
|
num_lookahead_slots=k))
|
2024-03-08 23:32:46 -08:00
|
|
|
|
2024-07-01 00:33:05 -07:00
|
|
|
assert len(spec_decode_sampler.call_args_list) == 1
|
|
|
|
_, kwargs = spec_decode_sampler.call_args_list[0]
|
2024-04-23 01:02:36 -07:00
|
|
|
actual = SimpleNamespace(**kwargs)
|
2024-03-08 23:32:46 -08:00
|
|
|
|
2024-04-23 01:02:36 -07:00
|
|
|
assert torch.equal(actual.bonus_token_ids,
|
2024-03-08 23:32:46 -08:00
|
|
|
target_token_ids.reshape(batch_size, k + 1)[:, -1:])
|
2024-09-01 21:23:29 -07:00
|
|
|
assert torch.equal(actual.target_with_bonus_probs,
|
|
|
|
target_token_probs.reshape(batch_size, k + 1, -1))
|
2024-04-23 01:02:36 -07:00
|
|
|
assert torch.equal(actual.draft_token_ids, proposal_token_ids)
|
|
|
|
assert torch.equal(actual.draft_probs, proposal_probs)
|
2024-03-08 23:32:46 -08:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize('k', [1, 2, 6])
|
|
|
|
@pytest.mark.parametrize('batch_size', [1, 2, 32])
|
2024-07-01 00:33:05 -07:00
|
|
|
@pytest.mark.parametrize("acceptance_sampler_method",
|
|
|
|
["rejection_sampler", "typical_acceptance_sampler"])
|
2024-03-08 23:32:46 -08:00
|
|
|
@torch.inference_mode()
|
2024-07-01 00:33:05 -07:00
|
|
|
def test_correctly_formats_output(k: int, batch_size: int,
|
|
|
|
acceptance_sampler_method: str):
|
2024-03-08 23:32:46 -08:00
|
|
|
"""Verify SpecDecodeWorker formats sampler output correctly.
|
|
|
|
Everything else is mocked out.
|
|
|
|
"""
|
|
|
|
vocab_size = 32_000
|
|
|
|
|
2024-04-23 01:02:36 -07:00
|
|
|
draft_worker = mock_worker(cls=MultiStepWorker,
|
|
|
|
vocab_size=vocab_size,
|
|
|
|
use_spec=False)
|
|
|
|
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
|
2024-03-08 23:32:46 -08:00
|
|
|
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
|
|
|
draft_worker.device = 'cuda'
|
|
|
|
target_worker.device = 'cuda'
|
|
|
|
|
|
|
|
set_random_seed(1)
|
2024-07-01 00:33:05 -07:00
|
|
|
spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
|
2024-08-05 01:46:44 -07:00
|
|
|
worker = SpecDecodeWorker(draft_worker,
|
|
|
|
target_worker,
|
|
|
|
spec_decode_sampler,
|
|
|
|
disable_logprobs=False,
|
|
|
|
metrics_collector=metrics_collector)
|
2024-03-21 18:22:17 -07:00
|
|
|
worker.init_device()
|
2024-03-08 23:32:46 -08:00
|
|
|
|
|
|
|
proposal_token_ids = torch.randint(low=0,
|
|
|
|
high=vocab_size,
|
|
|
|
size=(batch_size, k),
|
|
|
|
dtype=torch.int64,
|
|
|
|
device='cuda')
|
|
|
|
proposal_probs = torch.rand(batch_size,
|
|
|
|
k,
|
|
|
|
vocab_size,
|
|
|
|
dtype=torch.float32,
|
|
|
|
device='cuda')
|
|
|
|
|
|
|
|
proposal_lens = torch.ones(batch_size, dtype=torch.int64,
|
|
|
|
device='cuda') * k
|
|
|
|
|
2024-05-03 17:47:07 -07:00
|
|
|
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
|
2024-03-08 23:32:46 -08:00
|
|
|
|
|
|
|
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
|
|
|
|
proposal_token_ids=proposal_token_ids,
|
|
|
|
proposal_probs=proposal_probs,
|
|
|
|
proposal_lens=proposal_lens)
|
|
|
|
|
|
|
|
target_token_ids = torch.randint(low=0,
|
|
|
|
high=vocab_size,
|
|
|
|
size=(1, batch_size * (k + 1)),
|
|
|
|
dtype=torch.int64,
|
|
|
|
device='cuda')
|
|
|
|
target_token_probs = torch.rand(1,
|
|
|
|
batch_size * (k + 1),
|
|
|
|
vocab_size,
|
|
|
|
dtype=torch.float32,
|
|
|
|
device='cuda')
|
2024-05-03 15:52:01 -07:00
|
|
|
target_token_logprobs = torch.rand(1,
|
|
|
|
batch_size * (k + 1),
|
|
|
|
vocab_size,
|
|
|
|
dtype=torch.float32,
|
|
|
|
device='cuda')
|
2024-03-08 23:32:46 -08:00
|
|
|
target_output = create_sampler_output_list(target_token_ids,
|
2024-05-03 15:52:01 -07:00
|
|
|
target_token_probs,
|
|
|
|
target_token_logprobs)
|
2024-03-08 23:32:46 -08:00
|
|
|
|
2024-04-16 13:09:21 -07:00
|
|
|
target_worker.execute_model.return_value = [target_output[0]]
|
2024-03-08 23:32:46 -08:00
|
|
|
|
2024-07-01 00:33:05 -07:00
|
|
|
spec_decode_sampler_output = torch.randint(low=0,
|
|
|
|
high=vocab_size,
|
|
|
|
size=(batch_size, k + 1),
|
|
|
|
dtype=torch.int64,
|
|
|
|
device='cuda')
|
2024-03-08 23:32:46 -08:00
|
|
|
for i in range(batch_size):
|
|
|
|
minimum_accepted_tokens = 1
|
2024-07-01 00:33:05 -07:00
|
|
|
spec_decode_sampler_output[i][
|
2024-03-08 23:32:46 -08:00
|
|
|
-random.randint(minimum_accepted_tokens, k + 1):] = -1
|
|
|
|
|
2024-07-01 00:33:05 -07:00
|
|
|
spec_decode_sampler.return_value = spec_decode_sampler_output
|
2024-05-03 17:47:07 -07:00
|
|
|
output = worker.execute_model(execute_model_req=ExecuteModelRequest(
|
|
|
|
seq_group_metadata_list=seq_group_metadata_list,
|
|
|
|
num_lookahead_slots=k))
|
2024-03-08 23:32:46 -08:00
|
|
|
|
|
|
|
expected_output = create_sampler_output_list(
|
2024-07-01 00:33:05 -07:00
|
|
|
token_ids=spec_decode_sampler_output.transpose(0, 1),
|
2024-05-03 15:52:01 -07:00
|
|
|
probs=[None for _ in range(k + 1)],
|
|
|
|
logprobs=[None for _ in range(k + 1)])
|
2024-03-08 23:32:46 -08:00
|
|
|
|
|
|
|
seq_ids = [
|
|
|
|
next(iter(seq_group_metadata.seq_data.keys()))
|
2024-05-03 17:47:07 -07:00
|
|
|
for seq_group_metadata in seq_group_metadata_list
|
2024-03-08 23:32:46 -08:00
|
|
|
]
|
2024-06-15 12:45:31 +08:00
|
|
|
actual_output_by_seq: Dict[int, List[SequenceOutput]] = {
|
|
|
|
seq_id: []
|
|
|
|
for seq_id in seq_ids
|
|
|
|
}
|
|
|
|
expected_output_by_seq: Dict[int, List[SequenceOutput]] = {
|
|
|
|
seq_id: []
|
|
|
|
for seq_id in seq_ids
|
|
|
|
}
|
2024-03-08 23:32:46 -08:00
|
|
|
|
|
|
|
for step in output:
|
|
|
|
for seq_group in step:
|
|
|
|
for sample in seq_group.samples:
|
|
|
|
seq_id = sample.parent_seq_id
|
|
|
|
actual_output_by_seq[seq_id].append(sample)
|
|
|
|
|
|
|
|
for step in expected_output:
|
|
|
|
for seq_group in step:
|
|
|
|
for sample in seq_group.samples:
|
|
|
|
seq_id = sample.parent_seq_id
|
|
|
|
expected_output_by_seq[seq_id].append(sample)
|
|
|
|
|
|
|
|
all_seen_seq_ids = set(
|
|
|
|
list(actual_output_by_seq.keys()) +
|
|
|
|
list(expected_output_by_seq.keys()))
|
|
|
|
for seq_id in all_seen_seq_ids:
|
|
|
|
actual_by_step = actual_output_by_seq[seq_id]
|
|
|
|
expected_by_step = expected_output_by_seq[seq_id]
|
|
|
|
|
|
|
|
for i in range(k + 1):
|
|
|
|
if i >= len(actual_by_step):
|
|
|
|
assert expected_by_step[i].output_token == -1
|
|
|
|
continue
|
|
|
|
assert actual_by_step[i].output_token == expected_by_step[
|
|
|
|
i].output_token
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize('k', [1, 2])
|
|
|
|
@pytest.mark.parametrize('batch_size', [1])
|
|
|
|
@pytest.mark.parametrize('returns_metrics', [True, False])
|
2024-07-01 00:33:05 -07:00
|
|
|
@pytest.mark.parametrize("acceptance_sampler_method",
|
|
|
|
["rejection_sampler", "typical_acceptance_sampler"])
|
2024-03-08 23:32:46 -08:00
|
|
|
@torch.inference_mode()
|
2024-07-01 00:33:05 -07:00
|
|
|
def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool,
|
|
|
|
acceptance_sampler_method: str):
|
2024-03-08 23:32:46 -08:00
|
|
|
"""Verify SpecDecodeWorker collects metrics.
|
|
|
|
"""
|
|
|
|
vocab_size = 32_000
|
|
|
|
|
2024-04-23 01:02:36 -07:00
|
|
|
draft_worker = mock_worker(cls=MultiStepWorker,
|
|
|
|
vocab_size=vocab_size,
|
|
|
|
use_spec=False)
|
|
|
|
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
|
2024-07-01 00:33:05 -07:00
|
|
|
spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
|
2024-03-08 23:32:46 -08:00
|
|
|
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
|
|
|
draft_worker.device = 'cuda'
|
|
|
|
target_worker.device = 'cuda'
|
|
|
|
|
|
|
|
set_random_seed(1)
|
|
|
|
|
2024-07-10 16:02:47 -07:00
|
|
|
worker = SpecDecodeWorker(draft_worker,
|
|
|
|
target_worker,
|
|
|
|
spec_decode_sampler,
|
2024-07-20 23:58:58 -07:00
|
|
|
disable_logprobs=False,
|
2024-07-10 16:02:47 -07:00
|
|
|
metrics_collector=metrics_collector)
|
2024-03-21 18:22:17 -07:00
|
|
|
worker.init_device()
|
2024-03-08 23:32:46 -08:00
|
|
|
|
|
|
|
proposal_token_ids = torch.randint(low=0,
|
|
|
|
high=vocab_size,
|
|
|
|
size=(batch_size, k),
|
|
|
|
dtype=torch.int64,
|
|
|
|
device='cuda')
|
|
|
|
proposal_probs = torch.rand(batch_size,
|
|
|
|
k,
|
|
|
|
vocab_size,
|
|
|
|
dtype=torch.float32,
|
|
|
|
device='cuda')
|
|
|
|
|
|
|
|
proposal_lens = torch.ones(batch_size, dtype=torch.int64,
|
|
|
|
device='cuda') * k
|
|
|
|
|
2024-05-03 17:47:07 -07:00
|
|
|
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
|
2024-03-08 23:32:46 -08:00
|
|
|
|
|
|
|
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
|
|
|
|
proposal_token_ids=proposal_token_ids,
|
|
|
|
proposal_probs=proposal_probs,
|
|
|
|
proposal_lens=proposal_lens)
|
|
|
|
|
|
|
|
target_token_ids = torch.randint(low=0,
|
|
|
|
high=vocab_size,
|
|
|
|
size=(1, batch_size * (k + 1)),
|
|
|
|
dtype=torch.int64,
|
|
|
|
device='cuda')
|
|
|
|
target_token_probs = torch.rand(1,
|
|
|
|
batch_size * (k + 1),
|
|
|
|
vocab_size,
|
|
|
|
dtype=torch.float32,
|
|
|
|
device='cuda')
|
2024-05-03 15:52:01 -07:00
|
|
|
target_token_logprobs = torch.rand(1,
|
|
|
|
batch_size * (k + 1),
|
|
|
|
vocab_size,
|
|
|
|
dtype=torch.float32,
|
|
|
|
device='cuda')
|
2024-03-08 23:32:46 -08:00
|
|
|
target_output = create_sampler_output_list(target_token_ids,
|
2024-05-03 15:52:01 -07:00
|
|
|
target_token_probs,
|
|
|
|
target_token_logprobs)
|
2024-03-08 23:32:46 -08:00
|
|
|
|
2024-04-16 13:09:21 -07:00
|
|
|
target_worker.execute_model.return_value = [target_output[0]]
|
2024-03-08 23:32:46 -08:00
|
|
|
|
2024-07-01 00:33:05 -07:00
|
|
|
spec_decode_sampler_output = torch.randint(low=0,
|
|
|
|
high=vocab_size,
|
|
|
|
size=(batch_size, k + 1),
|
|
|
|
dtype=torch.int64,
|
|
|
|
device='cuda')
|
2024-03-08 23:32:46 -08:00
|
|
|
for i in range(batch_size):
|
|
|
|
minimum_accepted_tokens = 1
|
2024-07-01 00:33:05 -07:00
|
|
|
spec_decode_sampler_output[i][
|
2024-03-08 23:32:46 -08:00
|
|
|
-random.randint(minimum_accepted_tokens, k + 1):] = -1
|
2024-07-01 00:33:05 -07:00
|
|
|
spec_decode_sampler.return_value = spec_decode_sampler_output
|
2024-03-08 23:32:46 -08:00
|
|
|
|
|
|
|
mock_rejsample_metrics = MagicMock(
|
|
|
|
spec=SpecDecodeWorkerMetrics) if returns_metrics else None
|
2024-03-10 19:49:14 -07:00
|
|
|
metrics_collector.maybe_collect_rejsample_metrics.return_value = (
|
|
|
|
mock_rejsample_metrics)
|
2024-03-08 23:32:46 -08:00
|
|
|
|
2024-05-03 17:47:07 -07:00
|
|
|
output = worker.execute_model(execute_model_req=ExecuteModelRequest(
|
|
|
|
seq_group_metadata_list=seq_group_metadata_list,
|
|
|
|
num_lookahead_slots=k))
|
2024-03-08 23:32:46 -08:00
|
|
|
assert output[0].spec_decode_worker_metrics == mock_rejsample_metrics
|
|
|
|
|
2024-03-10 19:49:14 -07:00
|
|
|
call_args_list = (
|
|
|
|
metrics_collector.maybe_collect_rejsample_metrics.call_args_list)
|
2024-03-08 23:32:46 -08:00
|
|
|
assert len(call_args_list) == 1
|
|
|
|
args, kwargs = call_args_list[0]
|
|
|
|
assert args[0] == k or kwargs.get('k', -1) == k
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize('k', [0])
|
|
|
|
@pytest.mark.parametrize('batch_size', [1, 2, 32])
|
2024-07-01 00:33:05 -07:00
|
|
|
@pytest.mark.parametrize("acceptance_sampler_method",
|
|
|
|
["rejection_sampler", "typical_acceptance_sampler"])
|
2024-03-08 23:32:46 -08:00
|
|
|
@torch.inference_mode()
|
2024-07-01 00:33:05 -07:00
|
|
|
def test_k_equals_zero(k: int, batch_size: int,
|
|
|
|
acceptance_sampler_method: str):
|
2024-03-08 23:32:46 -08:00
|
|
|
"""Verify that the SpecDecodeWorker calls the draft and target workers
|
|
|
|
when k is zero. This happens during prefill.
|
|
|
|
"""
|
|
|
|
draft_worker = mock_worker(cls=MultiStepWorker)
|
|
|
|
target_worker = mock_worker()
|
|
|
|
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
|
|
|
|
2024-06-20 20:23:12 -04:00
|
|
|
sampler_output = MagicMock(spec=SamplerOutput)
|
|
|
|
sampler_output.hidden_states = None
|
|
|
|
target_worker.execute_model.return_value = [sampler_output]
|
2024-04-16 13:09:21 -07:00
|
|
|
|
2024-03-08 23:32:46 -08:00
|
|
|
draft_worker.device = 'cuda'
|
|
|
|
target_worker.device = 'cuda'
|
|
|
|
|
|
|
|
set_random_seed(1)
|
|
|
|
|
2024-07-01 00:33:05 -07:00
|
|
|
worker = SpecDecodeWorker(
|
2024-08-05 01:46:44 -07:00
|
|
|
proposer_worker=draft_worker,
|
|
|
|
scorer_worker=target_worker,
|
|
|
|
spec_decode_sampler=mock_spec_decode_sampler(
|
|
|
|
acceptance_sampler_method),
|
|
|
|
disable_logprobs=False,
|
|
|
|
metrics_collector=metrics_collector,
|
|
|
|
)
|
2024-03-08 23:32:46 -08:00
|
|
|
|
2024-05-03 17:47:07 -07:00
|
|
|
seq_group_metadata_list, _, _ = create_batch(batch_size,
|
|
|
|
k,
|
|
|
|
prev_output_token_len=0)
|
|
|
|
execute_model_req = ExecuteModelRequest(
|
|
|
|
seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)
|
2024-03-08 23:32:46 -08:00
|
|
|
|
2024-05-03 17:47:07 -07:00
|
|
|
out = worker.execute_model(execute_model_req=execute_model_req)
|
2024-03-08 23:32:46 -08:00
|
|
|
|
|
|
|
assert len(out) == 1, f"expected only one token output when {k=}"
|
2024-07-20 23:58:58 -07:00
|
|
|
assert out[0].sampled_token_probs is None, (
|
|
|
|
"expect gpu tensor references to be None")
|
2024-03-08 23:32:46 -08:00
|
|
|
assert out[
|
2024-07-20 23:58:58 -07:00
|
|
|
0].sampled_token_ids is None, "expect gpu tensor references to be None"
|
2024-03-08 23:32:46 -08:00
|
|
|
|
2024-05-03 17:47:07 -07:00
|
|
|
draft_worker.execute_model.assert_called_once_with(execute_model_req)
|
|
|
|
target_worker.execute_model.assert_called_once_with(execute_model_req)
|
2024-03-08 23:32:46 -08:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize('k', [0, 5])
|
|
|
|
@pytest.mark.parametrize('batch_size', [0])
|
2024-07-01 00:33:05 -07:00
|
|
|
@pytest.mark.parametrize("acceptance_sampler_method",
|
|
|
|
["rejection_sampler", "typical_acceptance_sampler"])
|
2024-03-08 23:32:46 -08:00
|
|
|
@torch.inference_mode()
|
2024-07-01 00:33:05 -07:00
|
|
|
def test_empty_input_batch(k: int, batch_size: int,
|
|
|
|
acceptance_sampler_method: str):
|
2024-03-08 23:32:46 -08:00
|
|
|
"""Verify that the SpecDecodeWorker calls the draft and target workers
|
|
|
|
when the input batch is empty. This can happen if the engine communicates
|
|
|
|
to the workers information without scheduling a batch.
|
|
|
|
"""
|
|
|
|
draft_worker = mock_worker(cls=MultiStepWorker)
|
|
|
|
target_worker = mock_worker()
|
|
|
|
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
|
|
|
|
2024-06-20 20:23:12 -04:00
|
|
|
sampler_output = MagicMock(spec=SamplerOutput)
|
|
|
|
sampler_output.hidden_states = None
|
|
|
|
target_worker.execute_model.return_value = [sampler_output]
|
2024-04-16 13:09:21 -07:00
|
|
|
|
2024-03-08 23:32:46 -08:00
|
|
|
draft_worker.device = 'cuda'
|
|
|
|
target_worker.device = 'cuda'
|
|
|
|
|
|
|
|
set_random_seed(1)
|
|
|
|
|
2024-07-01 00:33:05 -07:00
|
|
|
worker = SpecDecodeWorker(
|
2024-08-05 01:46:44 -07:00
|
|
|
proposer_worker=draft_worker,
|
|
|
|
scorer_worker=target_worker,
|
|
|
|
spec_decode_sampler=mock_spec_decode_sampler(
|
|
|
|
acceptance_sampler_method),
|
|
|
|
disable_logprobs=False,
|
|
|
|
metrics_collector=metrics_collector,
|
|
|
|
)
|
2024-03-08 23:32:46 -08:00
|
|
|
|
2024-05-03 17:47:07 -07:00
|
|
|
seq_group_metadata_list, _, _ = create_batch(batch_size,
|
|
|
|
k,
|
|
|
|
prev_output_token_len=0)
|
|
|
|
execute_model_req = ExecuteModelRequest(
|
|
|
|
seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)
|
2024-03-08 23:32:46 -08:00
|
|
|
|
2024-05-03 17:47:07 -07:00
|
|
|
out = worker.execute_model(execute_model_req=execute_model_req)
|
2024-03-08 23:32:46 -08:00
|
|
|
|
|
|
|
assert len(out) == 1, f"expected only one token output when {k=}"
|
2024-07-20 23:58:58 -07:00
|
|
|
assert out[0].sampled_token_probs is None, (
|
|
|
|
"expect gpu tensor references to be None")
|
2024-03-08 23:32:46 -08:00
|
|
|
assert out[
|
2024-07-20 23:58:58 -07:00
|
|
|
0].sampled_token_ids is None, "expect gpu tensor references to be None"
|
2024-03-08 23:32:46 -08:00
|
|
|
|
2024-05-03 17:47:07 -07:00
|
|
|
draft_worker.execute_model.assert_called_once_with(execute_model_req)
|
|
|
|
target_worker.execute_model.assert_called_once_with(execute_model_req)
|
2024-03-08 23:32:46 -08:00
|
|
|
|
|
|
|
|
2024-07-01 00:33:05 -07:00
|
|
|
@pytest.mark.parametrize("acceptance_sampler_method",
|
|
|
|
["rejection_sampler", "typical_acceptance_sampler"])
|
2024-04-04 21:54:16 -07:00
|
|
|
@pytest.mark.skip_global_cleanup
|
2024-07-01 00:33:05 -07:00
|
|
|
def test_init_device(acceptance_sampler_method: str):
|
2024-03-21 18:22:17 -07:00
|
|
|
"""Verify SpecDecodeWorker invokes proposer/scorer worker init_device, as
|
2024-03-08 23:32:46 -08:00
|
|
|
well as other GPU initialization.
|
|
|
|
"""
|
2024-04-23 01:02:36 -07:00
|
|
|
draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False)
|
|
|
|
target_worker = mock_worker(use_spec=False)
|
2024-07-01 00:33:05 -07:00
|
|
|
spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
|
2024-03-08 23:32:46 -08:00
|
|
|
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
|
|
|
|
2024-08-05 01:46:44 -07:00
|
|
|
worker = SpecDecodeWorker(
|
|
|
|
proposer_worker=draft_worker,
|
|
|
|
scorer_worker=target_worker,
|
|
|
|
spec_decode_sampler=spec_decode_sampler,
|
|
|
|
disable_logprobs=False,
|
|
|
|
metrics_collector=metrics_collector,
|
|
|
|
)
|
2024-03-21 18:22:17 -07:00
|
|
|
worker.init_device()
|
2024-03-08 23:32:46 -08:00
|
|
|
|
2024-03-21 18:22:17 -07:00
|
|
|
draft_worker.init_device.assert_called_once()
|
2024-03-08 23:32:46 -08:00
|
|
|
|
2024-03-21 18:22:17 -07:00
|
|
|
target_worker.init_device.assert_called_once()
|
2024-03-08 23:32:46 -08:00
|
|
|
|
|
|
|
metrics_collector.init_gpu_tensors.assert_called_once()
|
2024-07-01 00:33:05 -07:00
|
|
|
spec_decode_sampler.init_gpu_tensors.assert_called_once()
|
2024-03-08 23:32:46 -08:00
|
|
|
|
|
|
|
|
2024-07-01 00:33:05 -07:00
|
|
|
@pytest.mark.parametrize("acceptance_sampler_method",
|
|
|
|
["rejection_sampler", "typical_acceptance_sampler"])
|
2024-03-08 23:32:46 -08:00
|
|
|
@torch.inference_mode()
|
2024-07-01 00:33:05 -07:00
|
|
|
def test_initialize_cache(acceptance_sampler_method):
|
2024-04-09 11:44:15 -07:00
|
|
|
"""Verify SpecDecodeWorker invokes initialize_cache on proposer/scorer
|
2024-03-08 23:32:46 -08:00
|
|
|
workers.
|
|
|
|
"""
|
|
|
|
draft_worker = mock_worker(cls=MultiStepWorker)
|
|
|
|
target_worker = mock_worker()
|
|
|
|
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
|
|
|
|
2024-08-05 01:46:44 -07:00
|
|
|
worker = SpecDecodeWorker(proposer_worker=draft_worker,
|
|
|
|
scorer_worker=target_worker,
|
|
|
|
spec_decode_sampler=mock_spec_decode_sampler(
|
|
|
|
acceptance_sampler_method),
|
|
|
|
metrics_collector=metrics_collector)
|
2024-03-08 23:32:46 -08:00
|
|
|
|
2024-04-09 11:44:15 -07:00
|
|
|
kwargs = {"num_gpu_blocks": 1024, "num_cpu_blocks": 1023}
|
|
|
|
worker.initialize_cache(**kwargs)
|
2024-03-08 23:32:46 -08:00
|
|
|
|
2024-04-09 11:44:15 -07:00
|
|
|
draft_worker.initialize_cache.assert_called_once_with(**kwargs)
|
|
|
|
target_worker.initialize_cache.assert_called_once_with(**kwargs)
|
2024-03-08 23:32:46 -08:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize('available_gpu_blocks', [1, 1024])
|
|
|
|
@pytest.mark.parametrize('available_cpu_blocks', [500])
|
|
|
|
@pytest.mark.parametrize('target_cache_block_size_bytes', [2 * 2 * 4096])
|
|
|
|
@pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096])
|
2024-07-01 00:33:05 -07:00
|
|
|
@pytest.mark.parametrize("acceptance_sampler_method",
|
|
|
|
["rejection_sampler", "typical_acceptance_sampler"])
|
2024-04-04 21:54:16 -07:00
|
|
|
@pytest.mark.skip_global_cleanup
|
2024-04-09 11:44:15 -07:00
|
|
|
def test_determine_num_available_blocks(available_gpu_blocks: int,
|
|
|
|
available_cpu_blocks: int,
|
|
|
|
target_cache_block_size_bytes: int,
|
2024-07-01 00:33:05 -07:00
|
|
|
draft_kv_size_bytes: int,
|
|
|
|
acceptance_sampler_method: str):
|
2024-03-08 23:32:46 -08:00
|
|
|
"""Verify SpecDecodeWorker correctly profiles num available GPU blocks.
|
|
|
|
Specifically, it should run profiling in the scorer worker, and then evenly
|
|
|
|
split the blocks between proposer and scorer worker.
|
|
|
|
"""
|
|
|
|
draft_worker = mock_worker(cls=MultiStepWorker)
|
|
|
|
target_worker = mock_worker()
|
|
|
|
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
|
|
|
|
2024-04-09 11:44:15 -07:00
|
|
|
target_worker.determine_num_available_blocks.return_value = (
|
2024-03-08 23:32:46 -08:00
|
|
|
available_gpu_blocks, available_cpu_blocks)
|
2024-03-10 19:49:14 -07:00
|
|
|
target_worker.get_cache_block_size_bytes.return_value = (
|
|
|
|
target_cache_block_size_bytes)
|
2024-03-08 23:32:46 -08:00
|
|
|
draft_worker.get_cache_block_size_bytes.return_value = draft_kv_size_bytes
|
|
|
|
|
2024-07-01 00:33:05 -07:00
|
|
|
worker = SpecDecodeWorker(
|
|
|
|
draft_worker, target_worker,
|
|
|
|
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
|
2024-03-08 23:32:46 -08:00
|
|
|
|
2024-04-09 11:44:15 -07:00
|
|
|
num_gpu_blocks, num_cpu_blocks = worker.determine_num_available_blocks()
|
2024-03-08 23:32:46 -08:00
|
|
|
|
2024-04-09 11:44:15 -07:00
|
|
|
target_worker.determine_num_available_blocks.assert_called_once()
|
2024-03-08 23:32:46 -08:00
|
|
|
assert num_cpu_blocks == available_cpu_blocks
|
|
|
|
|
|
|
|
assert num_gpu_blocks == split_num_cache_blocks_evenly(
|
|
|
|
target_cache_block_size_bytes, draft_kv_size_bytes,
|
|
|
|
available_gpu_blocks)
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize('available_gpu_blocks',
|
|
|
|
list(range(20)) + [1024, 1024**2])
|
|
|
|
@pytest.mark.parametrize('target_cache_block_size_bytes',
|
|
|
|
[2 * 2 * 4096, 2 * 2 * 8192])
|
|
|
|
@pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096])
|
2024-04-04 21:54:16 -07:00
|
|
|
@pytest.mark.skip_global_cleanup
|
2024-03-08 23:32:46 -08:00
|
|
|
def test_split_num_cache_blocks_evenly(available_gpu_blocks: int,
|
|
|
|
target_cache_block_size_bytes: int,
|
|
|
|
draft_kv_size_bytes: int):
|
|
|
|
"""Verify split_num_cache_blocks_evenly does not exceed original memory
|
|
|
|
allocation in bytes.
|
|
|
|
"""
|
|
|
|
num_blocks = split_num_cache_blocks_evenly(target_cache_block_size_bytes,
|
|
|
|
draft_kv_size_bytes,
|
|
|
|
available_gpu_blocks)
|
|
|
|
assert (num_blocks * target_cache_block_size_bytes) + (
|
|
|
|
num_blocks * draft_kv_size_bytes) <= (available_gpu_blocks *
|
|
|
|
target_cache_block_size_bytes)
|
2024-07-10 16:02:47 -07:00
|
|
|
|
|
|
|
|
|
|
|
@torch.inference_mode()
|
|
|
|
def test_populate_seq_ids_with_bonus_tokens():
|
|
|
|
"""
|
|
|
|
Verify that a call to _create_output_sampler_list correctly updates
|
|
|
|
seq_with_bonus_token_in_last_step.
|
|
|
|
|
|
|
|
seq_with_bonus_token_in_last_step is an internal data structure in
|
|
|
|
SpecDecodeWorker that tracks the sequence IDs which are assigned bonus
|
|
|
|
tokens by the target model in their last forward pass. This state is
|
|
|
|
maintained only for models relying on the KV cache, such as those using
|
|
|
|
the MultiStepWorker.
|
|
|
|
"""
|
|
|
|
batch_size = 10
|
|
|
|
k = 5
|
|
|
|
vocab_size = 10000
|
|
|
|
num_sequences_with_bonus_tokens = 5
|
|
|
|
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
|
|
|
|
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
|
|
|
target_worker.execute_model.return_value = [MagicMock(spec=SamplerOutput)]
|
|
|
|
target_worker.device = 'cuda'
|
|
|
|
|
|
|
|
set_random_seed(1)
|
|
|
|
draft_worker = mock_worker(cls=MultiStepWorker)
|
|
|
|
draft_worker.device = 'cuda'
|
|
|
|
# The sequence_ids attached to each sequence in the batch.
|
|
|
|
# The sequence at index i has seq_id assigned_seq_ids[i]
|
|
|
|
assigned_seq_ids = list(range(batch_size))
|
|
|
|
seq_group_metadata_list, _, _ = create_batch(batch_size,
|
|
|
|
k,
|
|
|
|
seq_ids=assigned_seq_ids,
|
|
|
|
prev_output_token_len=10)
|
|
|
|
target_token_logprobs = torch.rand(batch_size, (k + 1),
|
|
|
|
vocab_size,
|
|
|
|
dtype=torch.float32,
|
|
|
|
device='cuda')
|
|
|
|
accepted_token_ids = torch.randint(low=0,
|
|
|
|
high=vocab_size,
|
|
|
|
size=(batch_size, (k + 1)),
|
|
|
|
dtype=torch.int64,
|
|
|
|
device='cuda')
|
|
|
|
expected_request_id_seq_ids_mapping: Dict[str, Set[int]] = defaultdict(set)
|
|
|
|
for seq_group_metadata in seq_group_metadata_list:
|
|
|
|
for seq_id in seq_group_metadata.seq_data:
|
|
|
|
expected_request_id_seq_ids_mapping[
|
|
|
|
seq_group_metadata.request_id].add(seq_id)
|
|
|
|
# Generate a random sample of sequence indexes with bonus tokens
|
|
|
|
seq_indexes_with_bonus_tokens = random.sample(
|
|
|
|
range(batch_size), num_sequences_with_bonus_tokens)
|
|
|
|
# Create a mask that is True for indices in seq_indexes_with_bonus_tokens
|
|
|
|
mask = torch.ones(batch_size, dtype=torch.bool, device='cuda')
|
|
|
|
mask[seq_indexes_with_bonus_tokens] = False
|
|
|
|
# Set the last token ID to -1 for all indices not in
|
|
|
|
# seq_indexes_with_bonus_tokens to indicate the lack of bonus token in
|
|
|
|
# those indices.
|
|
|
|
accepted_token_ids[mask, -1:] = -1
|
|
|
|
worker = SpecDecodeWorker(draft_worker,
|
|
|
|
target_worker,
|
|
|
|
mock_spec_decode_sampler("rejection_sampler"),
|
2024-07-20 23:58:58 -07:00
|
|
|
disable_logprobs=False,
|
2024-07-10 16:02:47 -07:00
|
|
|
metrics_collector=metrics_collector)
|
|
|
|
# Initialize _seq_with_bonus_token_in_last_step with a set of sequence IDs.
|
|
|
|
# This set includes all sequence IDs in the batch as well as an additional
|
|
|
|
# `num_extra_sequence_ids` sequence IDs. Note that the sequence IDs are in
|
|
|
|
# the range [0, batch_size + num_extra_sequence_ids).
|
|
|
|
num_extra_sequence_ids = 10
|
|
|
|
worker._seq_with_bonus_token_in_last_step = set(
|
|
|
|
range(batch_size + num_extra_sequence_ids))
|
|
|
|
worker._create_output_sampler_list(
|
|
|
|
seq_group_metadata_list=seq_group_metadata_list,
|
|
|
|
accepted_token_ids=accepted_token_ids,
|
|
|
|
target_logprobs=target_token_logprobs,
|
2024-08-05 01:46:44 -07:00
|
|
|
k=k,
|
|
|
|
stage_times=(0, 0, 0))
|
2024-07-10 16:02:47 -07:00
|
|
|
# Verify that _seq_with_bonus_token_in_last_step contains the following:
|
|
|
|
# 1. Sequence IDs that were already present in
|
|
|
|
# _seq_with_bonus_token_in_last_step but were not part of the current
|
|
|
|
# batch are retained.
|
|
|
|
# 2. Of the sequence IDs present in the current batch, only those with a
|
|
|
|
# bonus token are retained in _seq_with_bonus_token_in_last_step.
|
|
|
|
# Sequence IDs that are present in the current batch but do not have
|
|
|
|
# bonus tokens are removed from _seq_with_bonus_token_in_last_step.
|
|
|
|
expected_seq_ids_with_bonus_tokens = \
|
|
|
|
set([assigned_seq_ids[i] for i in seq_indexes_with_bonus_tokens])
|
|
|
|
additional_sequence_ids = \
|
|
|
|
set(range(batch_size, batch_size + num_extra_sequence_ids))
|
|
|
|
assert worker._seq_with_bonus_token_in_last_step == \
|
|
|
|
expected_seq_ids_with_bonus_tokens.union(additional_sequence_ids)
|
|
|
|
assert worker._request_id_seq_id_mapping == \
|
|
|
|
expected_request_id_seq_ids_mapping
|
|
|
|
|
|
|
|
|
|
|
|
@torch.inference_mode()
|
|
|
|
def test_handle_finished_requests():
|
|
|
|
"""
|
|
|
|
Test to verify that finished request IDs are appropriately processed to
|
|
|
|
update the internal state of the SpecDecodeWorker.
|
|
|
|
|
|
|
|
This test initializes the SpecDecodeWorker with mock data, marks certain
|
|
|
|
requests as finished, and ensures that the corresponding sequence IDs are
|
|
|
|
correctly removed from the internal mappings.
|
|
|
|
"""
|
|
|
|
batch_size = 32
|
|
|
|
k = 3
|
|
|
|
draft_worker = mock_worker(cls=MultiStepWorker)
|
|
|
|
target_worker = mock_worker()
|
|
|
|
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
|
|
|
worker = SpecDecodeWorker(draft_worker, target_worker,
|
|
|
|
mock_spec_decode_sampler("rejection_sampler"),
|
|
|
|
metrics_collector)
|
|
|
|
# Initialize the request_id_seq_id_mapping mapping dict with a few fake
|
|
|
|
# request ids and corresponding sequence ids.
|
|
|
|
worker._request_id_seq_id_mapping = \
|
|
|
|
{'request-1': {1,2,3}, 'request-2': {4,5,6,7},
|
|
|
|
'request-3': {8,9}, 'request-4': {10,11}}
|
|
|
|
# Initialize seq_with_bonus_token_in_last_step with a few fake
|
|
|
|
# sequence ids.
|
|
|
|
worker._seq_with_bonus_token_in_last_step = {1, 4, 5, 8, 9, 10}
|
|
|
|
exception_secret = 'artificial stop'
|
|
|
|
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
|
|
|
|
|
|
|
|
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
|
|
|
|
# Mark requests with ids request-1 and request-3 as finished.
|
|
|
|
execute_model_req = ExecuteModelRequest(
|
|
|
|
seq_group_metadata_list=seq_group_metadata_list,
|
|
|
|
num_lookahead_slots=k,
|
|
|
|
finished_requests_ids=['request-1', 'request-3'])
|
|
|
|
|
|
|
|
with pytest.raises(ValueError, match=exception_secret):
|
|
|
|
worker.execute_model(execute_model_req=execute_model_req)
|
|
|
|
# Verify that request-1 and request-3 are removed from
|
|
|
|
# request_id_seq_id_mapping
|
|
|
|
assert worker._request_id_seq_id_mapping == \
|
|
|
|
{'request-2': {4,5,6,7}, 'request-4': {10,11}}
|
|
|
|
# Verify that all sequence ids corresponding to 'request-1'
|
|
|
|
# and 'request-3' are removed from seq_with_bonus_token_in_last_step.
|
|
|
|
assert worker._seq_with_bonus_token_in_last_step == \
|
|
|
|
{4,5,10}
|
2024-11-07 17:15:14 +01:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize('k', [3])
|
|
|
|
@pytest.mark.parametrize('batch_size', [2, 32])
|
|
|
|
@pytest.mark.parametrize("batch_composition",
|
|
|
|
["prefill_only", "decode_only", "mixed"])
|
|
|
|
@torch.inference_mode()
|
|
|
|
def test_chunked_prefill_flow(k: int, batch_size: int, batch_composition: str):
|
|
|
|
"""
|
|
|
|
Verify SpecDecodeWorker calls match the expected flow.
|
|
|
|
"""
|
|
|
|
vocab_size = 32_000
|
|
|
|
draft_worker = mock_worker(cls=MultiStepWorker)
|
|
|
|
target_worker = mock_worker()
|
|
|
|
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
|
|
|
worker = SpecDecodeWorker(draft_worker,
|
|
|
|
target_worker,
|
|
|
|
mock_spec_decode_sampler("rejection_sampler"),
|
|
|
|
disable_logprobs=False,
|
|
|
|
metrics_collector=metrics_collector)
|
|
|
|
exception_secret = 'artificial stop'
|
|
|
|
worker.scorer = mock_worker(BatchExpansionTop1Scorer)
|
|
|
|
worker.scorer.score_proposals.side_effect = ValueError(exception_secret)
|
|
|
|
|
|
|
|
# Create batch with combination of terminal/non-terminal prefill chunks
|
|
|
|
# and decodes (different seq_ids).
|
|
|
|
decodes, _, _ = create_batch(batch_size, k)
|
|
|
|
# Pre-chunking here, get 'batch_size' chunks.
|
|
|
|
prefill, _, _ = create_batch(batch_size,
|
|
|
|
k,
|
|
|
|
prefill_chunk_size=4,
|
|
|
|
seq_ids=list(range(batch_size,
|
|
|
|
batch_size * 2)))
|
|
|
|
|
|
|
|
if batch_composition == "prefill_only":
|
|
|
|
n_prefills = batch_size
|
|
|
|
elif batch_composition == "decode_only":
|
|
|
|
n_prefills = 0
|
|
|
|
else:
|
|
|
|
n_prefills = random.randint(1, batch_size - 1)
|
|
|
|
n_decodes = batch_size - n_prefills
|
|
|
|
|
|
|
|
prefill = random.sample(prefill, n_prefills)
|
|
|
|
decodes = random.sample(decodes, n_decodes)
|
|
|
|
target_group_metadata_list = prefill + decodes
|
|
|
|
execute_model_req = ExecuteModelRequest(
|
|
|
|
seq_group_metadata_list=target_group_metadata_list,
|
2024-11-26 09:11:16 -08:00
|
|
|
# For prefill only batches we expect num_lookahead_slots = 0.
|
|
|
|
num_lookahead_slots=k if n_decodes > 0 else 0)
|
2024-11-07 17:15:14 +01:00
|
|
|
|
|
|
|
target_token_ids = torch.randint(low=0,
|
|
|
|
high=vocab_size,
|
|
|
|
size=(1, batch_size * (k + 1)),
|
|
|
|
dtype=torch.int64,
|
|
|
|
device='cuda')
|
|
|
|
target_token_probs = torch.rand(1,
|
|
|
|
batch_size * (k + 1),
|
|
|
|
vocab_size,
|
|
|
|
dtype=torch.float32,
|
|
|
|
device='cuda')
|
|
|
|
target_token_logprobs = torch.rand(1,
|
|
|
|
batch_size * (k + 1),
|
|
|
|
vocab_size,
|
|
|
|
dtype=torch.float32,
|
|
|
|
device='cuda')
|
|
|
|
target_output = create_sampler_output_list(target_token_ids,
|
|
|
|
target_token_probs,
|
|
|
|
target_token_logprobs)
|
|
|
|
|
|
|
|
target_worker.execute_model.return_value = [target_output[0]]
|
|
|
|
|
|
|
|
if not len(decodes):
|
|
|
|
worker.execute_model(execute_model_req=execute_model_req)
|
|
|
|
# no spec run (prefill only)
|
|
|
|
draft_worker.execute_model.assert_called_once_with(execute_model_req)
|
|
|
|
target_worker.execute_model.assert_called_once_with(execute_model_req)
|
|
|
|
else:
|
|
|
|
# Decode-only run OR mixed batch, scorer call fails (it's mocked)
|
|
|
|
with pytest.raises(ValueError, match=exception_secret):
|
|
|
|
worker.execute_model(execute_model_req=execute_model_req)
|
|
|
|
# but first draft still counted
|
|
|
|
assert draft_worker.get_spec_proposals.call_count == 1
|