[Speculative decoding 3/9] Worker which speculates, scores, and applies rejection sampling (#3103)

This commit is contained in:
Cade Daniel 2024-03-08 23:32:46 -08:00 committed by GitHub
parent f48c6791b7
commit 8437bae6ef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 2786 additions and 215 deletions

View File

@ -28,7 +28,7 @@ steps:
num_gpus: 2 # only support 1 or 2 for now.
- label: Engine Test
command: pytest -v -s engine
command: pytest -v -s engine test_sequence.py
- label: Entrypoints Test
command: pytest -v -s entrypoints
@ -52,6 +52,9 @@ steps:
- label: Worker Test
command: pytest -v -s worker
- label: Speculative decoding tests
command: pytest -v -s spec_decode
- label: LoRA Test
command: pytest -v -s lora --forked

View File

@ -0,0 +1,95 @@
import torch
import pytest
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from .utils import mock_worker, create_seq_group_metadata_from_prompts
@pytest.mark.parametrize('num_target_seq_ids', [100])
def test_create_target_seq_id_iterator(num_target_seq_ids: int):
"""Verify all new sequence ids are greater than all input
seq ids.
"""
scorer = BatchExpansionTop1Scorer(mock_worker(), 'cuda:0', 32_000)
all_seq_ids = [
[1, 3, 5, 7],
list(range(100)) + [0],
[100],
]
for seq_ids in all_seq_ids:
max_seq_id = max(seq_ids)
iterator = scorer._create_target_seq_id_iterator(seq_ids) # pylint: disable=protected-access
for _ in range(num_target_seq_ids):
assert next(iterator) > max_seq_id
@pytest.mark.parametrize('k', [1, 2, 6])
def test_get_token_ids_to_score(k: int):
"""Verify correct tokens are selected for scoring.
"""
proposal_token_ids = torch.tensor(
list(range(k)),
dtype=torch.int64,
device='cuda',
)
expected_output = [
[],
]
for i in range(proposal_token_ids.shape[0]):
expected_output.append(proposal_token_ids[:i + 1].tolist())
scorer = BatchExpansionTop1Scorer(mock_worker(), 'cuda:0', 32_000)
actual_output = scorer._get_token_ids_to_score(proposal_token_ids) # pylint: disable=protected-access
actual_output = [
x.tolist() if isinstance(x, torch.Tensor) else x for x in actual_output
]
assert actual_output == expected_output
@pytest.mark.parametrize('k', [1, 2, 6])
def test_create_single_target_seq_group_metadata(k: int):
"""Verify correct creation of a batch-expanded seq group metadata.
"""
prompt_tokens = [1, 2, 3]
prev_output_tokens = [4, 5, 6]
token_ids = list(range(k))
num_tokens_processed = len(prompt_tokens) + len(prev_output_tokens) - 1
final_seq_len = len(prompt_tokens) + len(prev_output_tokens) + len(
token_ids)
block_size = 32
input_seq_group_metadata = create_seq_group_metadata_from_prompts(
[prompt_tokens], 2048 // block_size, block_size, [final_seq_len],
[prev_output_tokens], [num_tokens_processed])[0]
input_seq_id = list(input_seq_group_metadata.seq_data.keys())[0]
target_seq_id = 100
scorer = BatchExpansionTop1Scorer(mock_worker(), 'cuda:0', 32_000)
output = scorer._create_single_target_seq_group_metadata( # pylint: disable=protected-access
input_seq_group_metadata,
input_seq_id,
target_seq_id,
token_ids,
)
assert output.request_id == input_seq_group_metadata.request_id
assert len(output.seq_data) == 1
assert output.seq_data[target_seq_id].get_prompt_token_ids(
) == prompt_tokens
assert output.seq_data[target_seq_id].get_output_token_ids(
) == prev_output_tokens + token_ids
assert len(output.block_tables) == 1
assert output.block_tables[
target_seq_id] == input_seq_group_metadata.block_tables[input_seq_id]

View File

@ -0,0 +1,157 @@
import torch
import math
import pytest
from unittest.mock import MagicMock
from vllm.spec_decode.metrics import AsyncMetricsCollector
def test_initial_call_returns_none():
"""Expect first call to get metrics to return None.
"""
rej_sampler = MagicMock()
rej_sampler.num_accepted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
rej_sampler.num_emitted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
rej_sampler.num_draft_tokens = 0
collector = AsyncMetricsCollector(rej_sampler)
collector.init_gpu_tensors(rank=0)
maybe_metrics = collector.maybe_collect_rejsample_metrics(k=5)
assert maybe_metrics is None
def test_second_call_returns_metrics():
"""Expect second call to not return None.
"""
rej_sampler = MagicMock()
rej_sampler.num_accepted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
rej_sampler.num_emitted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
rej_sampler.num_draft_tokens = 0
collect_interval_s = 5.0
timer = MagicMock()
timer.side_effect = [
0.0, collect_interval_s + 0.1, collect_interval_s + 0.2
]
collector = AsyncMetricsCollector(rejection_sampler=rej_sampler,
timer=timer,
collect_interval_s=collect_interval_s)
collector.init_gpu_tensors(rank=0)
_ = collector.maybe_collect_rejsample_metrics(k=5)
metrics = collector.maybe_collect_rejsample_metrics(k=5)
assert metrics is not None
@pytest.mark.parametrize("rank", [1, 2, 3, 4])
def test_nonzero_rank_noop(rank):
"""Verify nonzero ranks don't collect metrics.
"""
rej_sampler = MagicMock()
rej_sampler.num_accepted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
rej_sampler.num_emitted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
rej_sampler.num_draft_tokens = 0
collector = AsyncMetricsCollector(rej_sampler)
collector.init_gpu_tensors(rank=rank)
_ = collector.maybe_collect_rejsample_metrics(k=5)
metrics = collector.maybe_collect_rejsample_metrics(k=5)
assert metrics is None
def test_noop_until_time():
"""Verify metrics aren't collected until enough time passes.
"""
rej_sampler = MagicMock()
rej_sampler.num_accepted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
rej_sampler.num_emitted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
rej_sampler.num_draft_tokens = 0
collect_interval_s = 5.0
timer = MagicMock()
timer.side_effect = [
0.0, collect_interval_s - 0.1, collect_interval_s - 0.1,
collect_interval_s + 0.1, collect_interval_s + 0.1
]
collector = AsyncMetricsCollector(rejection_sampler=rej_sampler,
timer=timer,
collect_interval_s=collect_interval_s)
collector.init_gpu_tensors(rank=0)
_ = collector.maybe_collect_rejsample_metrics(k=5)
metrics = collector.maybe_collect_rejsample_metrics(k=5)
assert metrics is None
_ = collector.maybe_collect_rejsample_metrics(k=5)
metrics = collector.maybe_collect_rejsample_metrics(k=5)
assert metrics is not None
@pytest.mark.parametrize("has_data", [True, False])
def test_initial_metrics_has_correct_values(has_data: bool):
"""Test correctness of metrics data.
"""
if has_data:
num_accepted_tokens = 103
num_emitted_tokens = 104
num_draft_tokens = 105
else:
num_accepted_tokens = 0
num_emitted_tokens = 0
num_draft_tokens = 0
k = 5
num_possible_tokens = AsyncMetricsCollector.get_max_num_accepted_tokens(
num_draft_tokens, k)
rej_sampler = MagicMock()
rej_sampler.num_accepted_tokens = torch.tensor(num_accepted_tokens,
dtype=torch.long,
device='cuda')
rej_sampler.num_emitted_tokens = torch.tensor(num_emitted_tokens,
dtype=torch.long,
device='cuda')
rej_sampler.num_draft_tokens = num_draft_tokens
collect_interval_s = 5.0
timer = MagicMock()
timer.side_effect = [
0.0, collect_interval_s + 0.1, collect_interval_s + 0.2
]
collector = AsyncMetricsCollector(rejection_sampler=rej_sampler,
timer=timer,
collect_interval_s=collect_interval_s)
collector.init_gpu_tensors(rank=0)
_ = collector.maybe_collect_rejsample_metrics(k)
metrics = collector.maybe_collect_rejsample_metrics(k)
assert metrics.num_spec_tokens == k
assert metrics.accepted_tokens == num_accepted_tokens
assert metrics.draft_tokens == num_draft_tokens
assert metrics.emitted_tokens == num_emitted_tokens
if has_data:
assert metrics.draft_acceptance_rate == num_accepted_tokens / num_draft_tokens
assert metrics.system_efficiency == num_emitted_tokens / num_possible_tokens
else:
assert math.isnan(metrics.draft_acceptance_rate)
assert math.isnan(metrics.system_efficiency)

View File

@ -3,14 +3,15 @@ import random
import pytest
from unittest.mock import MagicMock
from vllm.worker.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.multi_step_worker import MultiStepWorker, DraftModelTop1Proposer
from vllm.worker.worker import Worker
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplerOutput
from .utils import (create_execute_model_data, create_worker,
create_seq_group_metadata_from_prompts, zero_kv_cache,
patch_execute_model_with_seeds,
assert_logprobs_dict_allclose)
assert_logprobs_dict_allclose, create_batch)
@pytest.mark.parametrize('num_steps', list(range(1, 17)))
@ -259,3 +260,160 @@ def test_same_output_for_multi_step():
multi_step_output_logprobs, single_step_output_logprobs):
assert_logprobs_dict_allclose(multi_step_logprobs,
single_step_logprobs)
@torch.inference_mode()
def test_draft_proposals_full_speculation_len():
"""Verify DraftModelTop1Proposer correctly handles case where all sequences
can speculate.
"""
k = 10
batch_size = 32
vocab_size = 32_000
device = 'cuda:0'
draft_worker = MagicMock()
proposer = DraftModelTop1Proposer(
draft_worker=draft_worker,
device=device,
max_model_len=2048,
vocab_size=vocab_size,
)
draft_worker.execute_model_multi_step.return_value = [
SamplerOutput(
outputs=[],
sampled_token_probs=torch.rand(batch_size,
vocab_size,
device=device,
dtype=torch.float32),
sampled_token_ids=torch.randint(low=0,
high=vocab_size,
size=(batch_size, ),
device=device,
dtype=torch.long),
) for _ in range(k)
]
execute_model_data, _, _ = create_batch(batch_size, k)
proposals = proposer.get_proposals(
**execute_model_data.to_dict(),
max_proposal_len=k,
)
assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs)
assert proposals.proposal_token_ids.shape == torch.Size([batch_size, k])
assert proposals.proposal_probs.shape[:-1] == torch.Size([batch_size, k])
assert proposals.proposal_lens.shape == torch.Size([batch_size])
assert proposals.proposal_lens.tolist() == [k for _ in range(batch_size)]
@torch.inference_mode()
def test_draft_proposals_no_speculations():
"""Verify DraftModelTop1Proposer correctly handles case where no sequences
can speculate.
"""
k = 10
batch_size = 32
vocab_size = 32_000
device = 'cuda:0'
prompt_len = 10
draft_worker = MagicMock()
proposer = DraftModelTop1Proposer(
draft_worker=draft_worker,
device=device,
max_model_len=prompt_len + k - 1,
vocab_size=vocab_size,
)
execute_model_data, _, _ = create_batch(batch_size,
k,
prompt_len=prompt_len)
proposals = proposer.get_proposals(
**execute_model_data.to_dict(),
max_proposal_len=k,
)
assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs)
assert proposals.proposal_token_ids.shape == torch.Size([0, k])
assert proposals.proposal_probs.shape[:-1] == torch.Size([0, k])
assert proposals.proposal_lens.shape == torch.Size([batch_size])
assert proposals.proposal_lens.tolist() == [0 for _ in range(batch_size)]
@torch.inference_mode()
def test_draft_proposals_mixed_k():
"""Verify DraftModelTop1Proposer correctly handles case some sequences can
speculate and some can't.
"""
k = 10
batch_size = 32
vocab_size = 32_000
device = 'cuda:0'
small_prompt_len = 5
long_prompt_len = 10
prev_output_token_len = 20
expected_num_proposal_seqs = 6
expected_num_no_proposal_seqs = batch_size - expected_num_proposal_seqs
prompt_len = [
small_prompt_len for _ in range(expected_num_proposal_seqs - 1)
] + [long_prompt_len
for _ in range(expected_num_no_proposal_seqs)] + [small_prompt_len]
draft_worker = MagicMock()
proposer = DraftModelTop1Proposer(
draft_worker=draft_worker,
device=device,
max_model_len=long_prompt_len + prev_output_token_len + k - 1,
vocab_size=vocab_size,
)
draft_worker.execute_model_multi_step.return_value = [
SamplerOutput(
outputs=[],
sampled_token_probs=torch.rand(expected_num_proposal_seqs,
vocab_size,
device=device,
dtype=torch.float32),
sampled_token_ids=torch.randint(
low=0,
high=vocab_size,
size=(expected_num_proposal_seqs, ),
device=device,
dtype=torch.long),
) for _ in range(k)
]
execute_model_data, _, _ = create_batch(
batch_size,
k,
prompt_len=prompt_len,
prev_output_token_len=prev_output_token_len,
)
proposals = proposer.get_proposals(
**execute_model_data.to_dict(),
max_proposal_len=k,
)
assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs)
assert proposals.proposal_token_ids.shape == torch.Size([batch_size, k])
assert proposals.proposal_probs.shape[:-1] == torch.Size([batch_size, k])
assert proposals.proposal_lens.shape == torch.Size([batch_size])
assert proposals.proposal_lens.tolist() == [
k for _ in range(expected_num_proposal_seqs - 1)
] + [0 for _ in range(expected_num_no_proposal_seqs)] + [k]

View File

@ -0,0 +1,591 @@
import torch
import random
import pytest
from unittest.mock import MagicMock
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker, split_num_cache_blocks_evenly
from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.model_executor.utils import set_random_seed
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from .utils import mock_worker, create_batch, ExecuteModelData, create_sampler_output_list
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics, AsyncMetricsCollector
@pytest.mark.parametrize('k', [1, 2, 6])
@pytest.mark.parametrize('batch_size', [1, 2, 32])
@torch.inference_mode()
def test_correctly_calls_draft_model(k: int, batch_size: int):
"""Verify SpecDecodeWorker calls the draft worker with correct
inputs. Everything else is mocked out.
"""
draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker()
rejection_sampler = MagicMock(spec=RejectionSampler)
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
metrics_collector)
exception_secret = 'artifical stop'
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
execute_model_data, _, _ = create_batch(batch_size, k)
with pytest.raises(ValueError, match=exception_secret):
worker.execute_model(**execute_model_data.to_dict(), num_spec_tokens=k)
call_args_list = draft_worker.get_spec_proposals.call_args_list
assert len(call_args_list) == 1
for args, _ in call_args_list:
(seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out,
blocks_to_copy, actual_k) = args
actual_execute_model_data = ExecuteModelData(seq_group_metadata_list,
blocks_to_swap_in,
blocks_to_swap_out,
blocks_to_copy)
assert actual_execute_model_data == execute_model_data
assert actual_k == k
@pytest.mark.parametrize('k', [1, 2, 6])
@pytest.mark.parametrize('batch_size', [1, 2, 32])
@torch.inference_mode()
def test_correctly_calls_target_model(k: int, batch_size: int):
"""Verify SpecDecodeWorker calls the target model with correct
inputs. Everything else is mocked out.
"""
draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker()
rejection_sampler = MagicMock(spec=RejectionSampler)
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
draft_worker.device = 'cuda'
target_worker.device = 'cuda'
set_random_seed(1)
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
metrics_collector)
worker.init_model()
vocab_size = 32_000
proposal_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64,
device='cuda')
proposal_probs = torch.rand(batch_size,
k,
vocab_size,
dtype=torch.float32,
device='cuda')
proposal_lens = torch.ones(batch_size, dtype=torch.int64,
device='cuda') * k
execute_model_data, prompts, prev_output_tokens = create_batch(
batch_size, k)
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
proposal_token_ids=proposal_token_ids,
proposal_probs=proposal_probs,
proposal_lens=proposal_lens)
exception_secret = 'artifical stop'
target_worker.execute_model.side_effect = ValueError(exception_secret)
with pytest.raises(ValueError, match=exception_secret):
worker.execute_model(**execute_model_data.to_dict(), num_spec_tokens=k)
seen_contexts = []
call_args_list = target_worker.execute_model.call_args_list
assert len(call_args_list) == 1
for args, kwargs in call_args_list:
target_execute_model_data = ExecuteModelData.from_dict(kwargs)
assert len(target_execute_model_data.seq_group_metadata_list) == (
k + 1) * batch_size
for seq_group_metadata in (
target_execute_model_data.seq_group_metadata_list):
for seq_data in seq_group_metadata.seq_data.values():
seen_contexts.append(seq_data.get_token_ids())
expected_seen_contexts = []
for prompt, prev_generated, draft_tokens in zip(
prompts, prev_output_tokens, proposal_token_ids.tolist()):
for i in range(len(draft_tokens) + 1):
expected_seen_contexts.append(prompt + prev_generated +
draft_tokens[:i])
seen_contexts.sort()
expected_seen_contexts.sort()
assert expected_seen_contexts == seen_contexts
@pytest.mark.parametrize('k', [1, 2, 6])
@pytest.mark.parametrize('batch_size', [1, 2, 32])
@torch.inference_mode()
def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
"""Verify SpecDecodeWorker calls the rejection sampler with
correct inputs. Everything else is mocked out.
"""
vocab_size = 32_000
draft_worker = mock_worker(cls=MultiStepWorker, vocab_size=vocab_size)
target_worker = mock_worker(vocab_size=vocab_size)
rejection_sampler = MagicMock(spec=RejectionSampler)
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
draft_worker.device = 'cuda'
target_worker.device = 'cuda'
set_random_seed(1)
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
metrics_collector)
worker.init_model()
proposal_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64,
device='cuda')
proposal_probs = torch.rand(batch_size,
k,
vocab_size,
dtype=torch.float32,
device='cuda')
proposal_lens = torch.ones(batch_size, dtype=torch.int64,
device='cuda') * k
execute_model_data, _, _ = create_batch(batch_size, k)
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
proposal_token_ids=proposal_token_ids,
proposal_probs=proposal_probs,
proposal_lens=proposal_lens)
target_token_ids = torch.randint(low=0,
high=vocab_size,
size=(1, batch_size * (k + 1)),
dtype=torch.int64,
device='cuda')
target_token_probs = torch.rand(1,
batch_size * (k + 1),
vocab_size,
dtype=torch.float32,
device='cuda')
target_output = create_sampler_output_list(target_token_ids,
target_token_probs)
target_worker.execute_model.return_value = target_output[0]
exception_secret = 'artifical stop'
rejection_sampler.side_effect = ValueError(exception_secret)
with pytest.raises(ValueError, match=exception_secret):
worker.execute_model(**execute_model_data.to_dict(), num_spec_tokens=k)
assert len(rejection_sampler.call_args_list) == 1
args, _ = rejection_sampler.call_args_list[0]
(actual_proposal_scores, actual_bonus_token_ids, actual_proposal_probs,
actual_proposal_token_ids) = args
assert torch.equal(actual_bonus_token_ids,
target_token_ids.reshape(batch_size, k + 1)[:, -1:])
assert torch.equal(
actual_proposal_scores,
target_token_probs.reshape(batch_size, k + 1, -1)[:, :-1])
assert torch.equal(actual_proposal_token_ids, proposal_token_ids)
assert torch.equal(actual_proposal_probs, proposal_probs)
@pytest.mark.parametrize('k', [1, 2, 6])
@pytest.mark.parametrize('batch_size', [1, 2, 32])
@torch.inference_mode()
def test_correctly_formats_output(k: int, batch_size: int):
"""Verify SpecDecodeWorker formats sampler output correctly.
Everything else is mocked out.
"""
vocab_size = 32_000
draft_worker = mock_worker(cls=MultiStepWorker, vocab_size=vocab_size)
target_worker = mock_worker(vocab_size=vocab_size)
rejection_sampler = MagicMock(spec=RejectionSampler)
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
draft_worker.device = 'cuda'
target_worker.device = 'cuda'
set_random_seed(1)
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
metrics_collector)
worker.init_model()
proposal_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64,
device='cuda')
proposal_probs = torch.rand(batch_size,
k,
vocab_size,
dtype=torch.float32,
device='cuda')
proposal_lens = torch.ones(batch_size, dtype=torch.int64,
device='cuda') * k
execute_model_data, _, _ = create_batch(batch_size, k)
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
proposal_token_ids=proposal_token_ids,
proposal_probs=proposal_probs,
proposal_lens=proposal_lens)
target_token_ids = torch.randint(low=0,
high=vocab_size,
size=(1, batch_size * (k + 1)),
dtype=torch.int64,
device='cuda')
target_token_probs = torch.rand(1,
batch_size * (k + 1),
vocab_size,
dtype=torch.float32,
device='cuda')
target_output = create_sampler_output_list(target_token_ids,
target_token_probs)
target_worker.execute_model.return_value = target_output[0]
rejection_sampler_output = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k + 1),
dtype=torch.int64,
device='cuda')
for i in range(batch_size):
minimum_accepted_tokens = 1
rejection_sampler_output[i][
-random.randint(minimum_accepted_tokens, k + 1):] = -1
rejection_sampler.return_value = rejection_sampler_output
output = worker.execute_model(**execute_model_data.to_dict(),
num_spec_tokens=k)
expected_output = create_sampler_output_list(
rejection_sampler_output.transpose(0, 1), [None for _ in range(k + 1)])
seq_ids = [
next(iter(seq_group_metadata.seq_data.keys()))
for seq_group_metadata in execute_model_data.seq_group_metadata_list
]
actual_output_by_seq = {seq_id: [] for seq_id in seq_ids}
expected_output_by_seq = {seq_id: [] for seq_id in seq_ids}
for step in output:
for seq_group in step:
for sample in seq_group.samples:
seq_id = sample.parent_seq_id
actual_output_by_seq[seq_id].append(sample)
for step in expected_output:
for seq_group in step:
for sample in seq_group.samples:
seq_id = sample.parent_seq_id
expected_output_by_seq[seq_id].append(sample)
all_seen_seq_ids = set(
list(actual_output_by_seq.keys()) +
list(expected_output_by_seq.keys()))
for seq_id in all_seen_seq_ids:
actual_by_step = actual_output_by_seq[seq_id]
expected_by_step = expected_output_by_seq[seq_id]
for i in range(k + 1):
if i >= len(actual_by_step):
assert expected_by_step[i].output_token == -1
continue
assert actual_by_step[i].output_token == expected_by_step[
i].output_token
assert actual_by_step[i].logprobs == expected_by_step[i].logprobs
@pytest.mark.parametrize('k', [1, 2])
@pytest.mark.parametrize('batch_size', [1])
@pytest.mark.parametrize('returns_metrics', [True, False])
@torch.inference_mode()
def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
"""Verify SpecDecodeWorker collects metrics.
"""
vocab_size = 32_000
draft_worker = mock_worker(cls=MultiStepWorker, vocab_size=vocab_size)
target_worker = mock_worker(vocab_size=vocab_size)
rejection_sampler = MagicMock(spec=RejectionSampler)
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
draft_worker.device = 'cuda'
target_worker.device = 'cuda'
set_random_seed(1)
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
metrics_collector)
worker.init_model()
proposal_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64,
device='cuda')
proposal_probs = torch.rand(batch_size,
k,
vocab_size,
dtype=torch.float32,
device='cuda')
proposal_lens = torch.ones(batch_size, dtype=torch.int64,
device='cuda') * k
execute_model_data, _, _ = create_batch(batch_size, k)
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
proposal_token_ids=proposal_token_ids,
proposal_probs=proposal_probs,
proposal_lens=proposal_lens)
target_token_ids = torch.randint(low=0,
high=vocab_size,
size=(1, batch_size * (k + 1)),
dtype=torch.int64,
device='cuda')
target_token_probs = torch.rand(1,
batch_size * (k + 1),
vocab_size,
dtype=torch.float32,
device='cuda')
target_output = create_sampler_output_list(target_token_ids,
target_token_probs)
target_worker.execute_model.return_value = target_output[0]
rejection_sampler_output = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k + 1),
dtype=torch.int64,
device='cuda')
for i in range(batch_size):
minimum_accepted_tokens = 1
rejection_sampler_output[i][
-random.randint(minimum_accepted_tokens, k + 1):] = -1
rejection_sampler.return_value = rejection_sampler_output
mock_rejsample_metrics = MagicMock(
spec=SpecDecodeWorkerMetrics) if returns_metrics else None
metrics_collector.maybe_collect_rejsample_metrics.return_value = mock_rejsample_metrics
output = worker.execute_model(**execute_model_data.to_dict(),
num_spec_tokens=k)
assert output[0].spec_decode_worker_metrics == mock_rejsample_metrics
call_args_list = metrics_collector.maybe_collect_rejsample_metrics.call_args_list
assert len(call_args_list) == 1
args, kwargs = call_args_list[0]
assert args[0] == k or kwargs.get('k', -1) == k
@pytest.mark.parametrize('k', [0])
@pytest.mark.parametrize('batch_size', [1, 2, 32])
@torch.inference_mode()
def test_k_equals_zero(k: int, batch_size: int):
"""Verify that the SpecDecodeWorker calls the draft and target workers
when k is zero. This happens during prefill.
"""
draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker()
rejection_sampler = MagicMock(spec=RejectionSampler)
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
draft_worker.device = 'cuda'
target_worker.device = 'cuda'
set_random_seed(1)
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
metrics_collector)
execute_model_data, prompts, prev_output_tokens = create_batch(
batch_size, k, prev_output_token_len=0)
out = worker.execute_model(**execute_model_data.to_dict(),
num_spec_tokens=k)
assert len(out) == 1, f"expected only one token output when {k=}"
assert out[0].probs is None, "expect gpu tensor references to be None"
assert out[
0].sampled_tokens is None, "expect gpu tensor references to be None"
draft_worker.execute_model.assert_called_once_with(
**execute_model_data.to_dict(), return_python_output=False)
target_worker.execute_model.assert_called_once_with(
**execute_model_data.to_dict())
@pytest.mark.parametrize('k', [0, 5])
@pytest.mark.parametrize('batch_size', [0])
@torch.inference_mode()
def test_empty_input_batch(k: int, batch_size: int):
"""Verify that the SpecDecodeWorker calls the draft and target workers
when the input batch is empty. This can happen if the engine communicates
to the workers information without scheduling a batch.
"""
draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker()
rejection_sampler = MagicMock(spec=RejectionSampler)
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
draft_worker.device = 'cuda'
target_worker.device = 'cuda'
set_random_seed(1)
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
metrics_collector)
execute_model_data, prompts, prev_output_tokens = create_batch(
batch_size, k, prev_output_token_len=0)
out = worker.execute_model(**execute_model_data.to_dict(),
num_spec_tokens=k)
assert len(out) == 1, f"expected only one token output when {k=}"
assert out[0].probs is None, "expect gpu tensor references to be None"
assert out[
0].sampled_tokens is None, "expect gpu tensor references to be None"
draft_worker.execute_model.assert_called_once_with(
**execute_model_data.to_dict(), return_python_output=False)
target_worker.execute_model.assert_called_once_with(
**execute_model_data.to_dict())
@torch.inference_mode()
def test_init_model():
"""Verify SpecDecodeWorker invokes proposer/scorer worker init_model, as
well as other GPU initialization.
"""
draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker()
rejection_sampler = MagicMock(spec=RejectionSampler)
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
metrics_collector)
worker.init_model()
draft_worker.init_model.assert_called_once()
target_worker.init_model.assert_called_once()
metrics_collector.init_gpu_tensors.assert_called_once()
rejection_sampler.init_gpu_tensors.assert_called_once()
@torch.inference_mode()
def test_init_cache_engine():
"""Verify SpecDecodeWorker invokes init_cache_engine on proposer/scorer
workers.
"""
draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker()
rejection_sampler = MagicMock(spec=RejectionSampler)
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
metrics_collector)
cache_config = MagicMock()
worker.init_cache_engine(cache_config)
draft_worker.init_cache_engine.assert_called_once_with(cache_config)
target_worker.init_cache_engine.assert_called_once_with(cache_config)
@pytest.mark.parametrize('available_gpu_blocks', [1, 1024])
@pytest.mark.parametrize('available_cpu_blocks', [500])
@pytest.mark.parametrize('target_cache_block_size_bytes', [2 * 2 * 4096])
@pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096])
@torch.inference_mode()
def test_profile_num_available_blocks(available_gpu_blocks: int,
available_cpu_blocks: int,
target_cache_block_size_bytes: int,
draft_kv_size_bytes: int):
"""Verify SpecDecodeWorker correctly profiles num available GPU blocks.
Specifically, it should run profiling in the scorer worker, and then evenly
split the blocks between proposer and scorer worker.
"""
draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker()
rejection_sampler = MagicMock(spec=RejectionSampler)
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
target_worker.profile_num_available_blocks.return_value = (
available_gpu_blocks, available_cpu_blocks)
target_worker.get_cache_block_size_bytes.return_value = target_cache_block_size_bytes
draft_worker.get_cache_block_size_bytes.return_value = draft_kv_size_bytes
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
metrics_collector)
# These values do not directly impact the adjusted block size calculation,
# so they can be fixed.
gpu_memory_utilization = 0.9
cpu_swap_space = 100
block_size = 16
num_gpu_blocks, num_cpu_blocks = worker.profile_num_available_blocks(
block_size, gpu_memory_utilization, cpu_swap_space, cache_dtype="auto")
target_worker.profile_num_available_blocks.assert_called_once_with(
block_size, gpu_memory_utilization, cpu_swap_space, "auto")
assert num_cpu_blocks == available_cpu_blocks
assert num_gpu_blocks == split_num_cache_blocks_evenly(
target_cache_block_size_bytes, draft_kv_size_bytes,
available_gpu_blocks)
@pytest.mark.parametrize('available_gpu_blocks',
list(range(20)) + [1024, 1024**2])
@pytest.mark.parametrize('target_cache_block_size_bytes',
[2 * 2 * 4096, 2 * 2 * 8192])
@pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096])
@torch.inference_mode()
def test_split_num_cache_blocks_evenly(available_gpu_blocks: int,
target_cache_block_size_bytes: int,
draft_kv_size_bytes: int):
"""Verify split_num_cache_blocks_evenly does not exceed original memory
allocation in bytes.
"""
num_blocks = split_num_cache_blocks_evenly(target_cache_block_size_bytes,
draft_kv_size_bytes,
available_gpu_blocks)
assert (num_blocks * target_cache_block_size_bytes) + (
num_blocks * draft_kv_size_bytes) <= (available_gpu_blocks *
target_cache_block_size_bytes)

View File

@ -0,0 +1,111 @@
from vllm.spec_decode.util import get_all_seq_ids
from vllm.sequence import SequenceGroupMetadata
from vllm.spec_decode.util import split_batch_by_proposal_len
import pytest
from unittest.mock import MagicMock
def test_get_all_seq_ids():
"""Verify get_all_seq_ids extracts all seq ids.
"""
expected_seq_ids = list(range(10)) + list(range(100, 110))
seq_group_metadata_list = [
SequenceGroupMetadata(
request_id=str(seq_id),
is_prompt=True,
seq_data={
seq_id: MagicMock(),
},
sampling_params=MagicMock(),
block_tables={
seq_id: MagicMock(),
},
lora_request=None,
) for seq_id in expected_seq_ids
]
actual_seq_ids = get_all_seq_ids(seq_group_metadata_list)
assert actual_seq_ids == expected_seq_ids
@pytest.fixture
def fake_sequence_group_metadata():
seq_ids = list(range(3))
return [
SequenceGroupMetadata(
request_id=str(i),
is_prompt=True,
seq_data={
i: MagicMock(),
},
sampling_params=MagicMock(),
block_tables={
i: MagicMock(),
},
lora_request=None,
) for i in seq_ids
]
def test_filter_zero_length_proposals(fake_sequence_group_metadata):
proposal_lens = [0, 1, 0]
filtered_groups, indices = split_batch_by_proposal_len(
fake_sequence_group_metadata,
proposal_lens,
select_proposal_len_zero=True)
expected_groups = [
fake_sequence_group_metadata[0], fake_sequence_group_metadata[2]
]
expected_indices = [0, 2]
assert filtered_groups == expected_groups
assert indices == expected_indices
def test_filter_non_zero_length_proposals(fake_sequence_group_metadata):
proposal_lens = [0, 1, 2]
filtered_groups, indices = split_batch_by_proposal_len(
fake_sequence_group_metadata,
proposal_lens,
select_proposal_len_zero=False)
expected_groups = [
fake_sequence_group_metadata[1], fake_sequence_group_metadata[2]
]
expected_indices = [1, 2]
assert filtered_groups == expected_groups
assert indices == expected_indices
def test_empty_inputs():
filtered_groups, indices = split_batch_by_proposal_len(
[], [], select_proposal_len_zero=True)
assert filtered_groups == []
assert indices == []
def test_all_zero_with_non_zero_filter(fake_sequence_group_metadata):
proposal_lens = [0, 0, 0]
filtered_groups, indices = split_batch_by_proposal_len(
fake_sequence_group_metadata,
proposal_lens,
select_proposal_len_zero=False)
assert filtered_groups == []
assert indices == []
def test_all_non_zero_with_zero_filter(fake_sequence_group_metadata):
proposal_lens = [1, 1, 1]
filtered_groups, indices = split_batch_by_proposal_len(
fake_sequence_group_metadata,
proposal_lens,
select_proposal_len_zero=True)
assert filtered_groups == []
assert indices == []

View File

@ -1,13 +1,16 @@
import torch
from typing import List, Optional, Dict
from typing import List, Optional, Dict, Iterable, Union
from unittest.mock import MagicMock
from vllm.worker.worker import Worker
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.engine.arg_utils import EngineArgs
from vllm.sequence import Logprob, SequenceGroupMetadata, SequenceData
from vllm.sequence import (Logprob, SequenceGroupMetadata, SequenceData,
SamplerOutput, SequenceGroupOutput, SequenceOutput)
from vllm.sampling_params import SamplingParams
from vllm.worker.cache_engine import CacheEngine
from vllm.model_executor.utils import set_random_seed
from itertools import count
from dataclasses import dataclass, fields
@ -24,6 +27,11 @@ class ExecuteModelData:
return dict(
(field.name, getattr(self, field.name)) for field in fields(self))
@classmethod
def from_dict(cls, d):
cleaned = dict((field.name, d[field.name]) for field in fields(cls))
return cls(**cleaned)
def round_up_to_next_block(seq_len: int, block_size: int) -> int:
return (seq_len + block_size - 1) // block_size
@ -50,6 +58,21 @@ def create_execute_model_data(
)
def mock_worker(cls=None,
vocab_size: int = 30_000,
max_model_len: int = 2048,
rank: int = 0) -> MagicMock:
if cls is None:
cls = Worker
worker = MagicMock(spec=cls)
worker.vocab_size = vocab_size
worker.max_model_len = max_model_len
worker.rank = rank
worker.device = 'cuda:0'
return worker
def patch_execute_model_with_seeds(worker: Worker, rand_seeds: List[int]):
seed_iter = iter(rand_seeds)
original_execute_model = worker.execute_model
@ -117,25 +140,12 @@ def create_seq_group_metadata_from_prompts(
block_size: int,
final_seq_lens: List[int],
continuations: Optional[List[List[int]]] = None,
num_tokens_processed: Optional[List[int]] = None,
seq_ids: Optional[List[int]] = None,
) -> List[SequenceGroupMetadata]:
if continuations is None:
continuations = [[] for _ in prompts]
if num_tokens_processed is None:
# Default to 1 token missing from kv cache for generation sequences.
num_tokens_processed = []
for continuation, prompt in zip(continuations, prompts):
# If prefill, then default to zero tokens processed.
if not continuation:
num_tokens_processed.append(0)
else:
# If generation, then default to all but one tokens processed.
num_tokens_processed.append(
len(continuation) + len(prompt) - 1)
if seq_ids is None:
seq_ids = list(i for i, _ in enumerate(prompts))
@ -155,13 +165,15 @@ def create_seq_group_metadata_from_prompts(
is_prompt=len(cont_token_ids) == 0,
seq_data={
i:
SequenceData(prompt_token_ids=prompt_token_ids[:] +
cont_token_ids[:])
SequenceData(
prompt_token_ids=prompt_token_ids[:],
output_token_ids=cont_token_ids[:],
),
},
sampling_params=SamplingParams(temperature=0.0, ),
block_tables={i: block_allocations[i][:]},
) for i, (prompt_token_ids, cont_token_ids, num_tokens_saved) in
enumerate(zip(prompts, continuations, num_tokens_processed))
) for i, (prompt_token_ids,
cont_token_ids) in enumerate(zip(prompts, continuations))
]
@ -178,3 +190,68 @@ def assert_logprobs_dict_allclose(
expected = torch.tensor(
single_step_expected_logprobs[token_id].logprob)
assert torch.allclose(actual, expected)
def create_sampler_output_list(
token_ids: torch.Tensor,
probs: Iterable[Optional[torch.Tensor]],
seq_ids: Optional[List[int]] = None) -> List[SamplerOutput]:
num_steps, batch_size = token_ids.shape
token_ids_by_step = token_ids.tolist()
if seq_ids is None:
seq_ids = list(range(batch_size))
return [
SamplerOutput(outputs=[
SequenceGroupOutput(
samples=[
SequenceOutput(
output_token=token_id,
parent_seq_id=seq_ids[seq_index],
logprobs={token_id: 0},
)
],
prompt_logprobs=None,
) for seq_index, token_id in enumerate(token_ids_by_step[step])
],
sampled_token_probs=probs[step],
sampled_token_ids=token_ids[step])
for step in range(num_steps)
]
def create_batch(batch_size,
k,
prompt_len: Union[int, List[int]] = 10,
prev_output_token_len: int = 10,
seq_ids: Optional[List[int]] = None,
num_gpu_blocks: Optional[int] = None,
block_size: Optional[int] = None):
if block_size is None:
block_size = 8
if num_gpu_blocks is None:
num_gpu_blocks = 2048 // block_size
iterator = count()
if isinstance(prompt_len, int):
prompt_lens = [prompt_len for _ in range(batch_size)]
else:
prompt_lens = prompt_len
prompts = [[next(iterator) for _ in range(p_len)] for p_len in prompt_lens]
prev_output_tokens = [[
next(iterator) for _ in range(prev_output_token_len)
] for _ in range(batch_size)]
final_seq_lens = [
len(prompt) + len(prev_output_token) + k + 1
for prompt, prev_output_token in zip(prompts, prev_output_tokens)
]
execute_model_data = create_execute_model_data(
create_seq_group_metadata_from_prompts(prompts, num_gpu_blocks,
block_size, final_seq_lens,
prev_output_tokens, seq_ids), )
return execute_model_data, prompts, prev_output_tokens

50
tests/test_sequence.py Normal file
View File

@ -0,0 +1,50 @@
import pytest
from vllm.sequence import SequenceGroupOutput, SamplerOutput, SequenceOutput
@pytest.fixture
def sample_outputs():
return [
SequenceGroupOutput(samples=[
SequenceOutput(parent_seq_id=0, output_token=i, logprobs={})
],
prompt_logprobs=None) for i in range(5)
]
@pytest.fixture
def sampler_output(sample_outputs):
return SamplerOutput(outputs=sample_outputs)
def test_sampler_output_initialization(sampler_output, sample_outputs):
assert len(sampler_output) == len(sample_outputs)
assert sampler_output.sampled_token_probs is None
assert sampler_output.sampled_token_ids is None
assert sampler_output.spec_decode_worker_metrics is None
def test_sampler_output_getitem(sampler_output, sample_outputs):
assert sampler_output[2] == sample_outputs[2]
def test_sampler_output_setitem(sampler_output):
new_output = SequenceGroupOutput(samples=[
SequenceOutput(parent_seq_id=0, output_token=99, logprobs={})
],
prompt_logprobs=None)
sampler_output[2] = new_output
assert sampler_output[2] == new_output
def test_sampler_output_len(sampler_output, sample_outputs):
assert len(sampler_output) == len(sample_outputs)
def test_sampler_output_eq(sample_outputs):
sampler_output1 = SamplerOutput(outputs=sample_outputs)
sampler_output2 = SamplerOutput(outputs=sample_outputs.copy())
sampler_output3 = SamplerOutput(outputs=sample_outputs[:-1])
assert sampler_output1 == sampler_output2
assert sampler_output1 != sampler_output3

View File

@ -21,8 +21,6 @@ class RejectionSampler(nn.Module):
nontrivial latency.
"""
super().__init__()
self.probs_dtype = torch.float32
self.token_id_dtype = torch.int64
self._strict_mode = strict_mode
# NOTE: A "bonus token" is accepted iff all proposal tokens are
@ -44,6 +42,14 @@ class RejectionSampler(nn.Module):
dtype=torch.long,
device=device)
@property
def probs_dtype(self):
return torch.float32
@property
def token_id_dtype(self):
return torch.int64
def forward(
self,
target_probs: torch.Tensor,

View File

@ -587,4 +587,4 @@ def _build_sampler_output(
SequenceOutput(seq_ids[parent_id], next_token_id, logprobs))
sampler_output.append(
SequenceGroupOutput(seq_outputs, group_prompt_logprobs))
return sampler_output
return SamplerOutput(outputs=sampler_output)

View File

@ -2,12 +2,16 @@
import copy
import enum
from dataclasses import dataclass
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Union, TYPE_CHECKING
from vllm.block import LogicalTokenBlock
from vllm.sampling_params import SamplingParams
from vllm.lora.request import LoRARequest
if TYPE_CHECKING:
import torch
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
@dataclass
class Logprob:
@ -81,6 +85,8 @@ class SequenceData:
Args:
prompt_token_ids: The token IDs of the prompt.
output_token_ids: The token IDs of the output. Set to an empty list if
None.
Attributes:
prompt_token_ids: The token IDs of the prompt.
@ -91,9 +97,13 @@ class SequenceData:
def __init__(
self,
prompt_token_ids: List[int],
output_token_ids: Optional[List[int]] = None,
) -> None:
if output_token_ids is None:
output_token_ids = []
self.prompt_token_ids = prompt_token_ids
self.output_token_ids: List[int] = []
self.output_token_ids = output_token_ids
self.cumulative_logprob = 0.0
def append_token_id(self, token_id: int, logprob: float) -> None:
@ -117,6 +127,12 @@ class SequenceData:
return self.prompt_token_ids[-1]
return self.output_token_ids[-1]
def get_prompt_token_ids(self) -> int:
return self.prompt_token_ids
def get_output_token_ids(self) -> int:
return self.output_token_ids
def __repr__(self) -> str:
return (f"SequenceData("
f"prompt_token_ids={self.prompt_token_ids}, "
@ -506,6 +522,35 @@ class SequenceGroupOutput:
and self.prompt_logprobs == other.prompt_logprobs)
# For each sequence group, we generate a list of SequenceOutput object,
# each of which contains one possible candidate for the next token.
SamplerOutput = List[SequenceGroupOutput]
@dataclass
class SamplerOutput:
"""For each sequence group, we generate a list of SequenceOutput object,
each of which contains one possible candidate for the next token.
This datastructure implements methods so it can be used like a list, but
also has optional fields for device tensors.
"""
outputs: List[SequenceGroupOutput]
# On-device tensor containing probabilities of each token.
sampled_token_probs: Optional["torch.Tensor"] = None
# On-device tensor containing the sampled token ids.
sampled_token_ids: Optional["torch.Tensor"] = None
# Spec decode metrics populated by workers.
spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None
def __getitem__(self, idx: int):
return self.outputs[idx]
def __setitem__(self, idx: int, value):
self.outputs[idx] = value
def __len__(self):
return len(self.outputs)
def __eq__(self, other: object):
return isinstance(other,
self.__class__) and self.outputs == other.outputs

View File

@ -0,0 +1,351 @@
from typing import Iterator, List, Tuple, Optional, Dict
from itertools import chain, count
import torch
from vllm.sequence import (SamplerOutput, SequenceGroupMetadata, SequenceData)
from vllm.worker.worker import Worker
from vllm.spec_decode.util import nvtx_range, sampler_output_to_torch, get_all_seq_ids, split_batch_by_proposal_len
from vllm.spec_decode.interfaces import SpeculativeScorer, SpeculativeProposals, SpeculativeScores
SeqId = int
TargetSeqId = int
TokenId = int
class BatchExpansionTop1Scorer(SpeculativeScorer):
"""Implements a speculative scorer that uses batch expansion to get
probabilities of speculative tokens according to the scoring model.
Batch expansion converts a list of sequences and multiple query positions
to a new batch of sequences, each with a single query position. This allows
for MQA-like scoring in speculative decoding without requiring an MQA
kernel.
It is strictly less efficient than MQA scoring.
It only supports scoring the top1 proposal tokens of the proposer, instead
of topk/tree.
"""
def __init__(self, scorer_worker: Worker, device: str, vocab_size: int):
self._scorer_worker = scorer_worker
self._device = device
self._vocab_size = vocab_size
@nvtx_range("BatchExpansionTop1Scorer.score_proposals")
def score_proposals(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Optional[Dict[int, int]],
blocks_to_swap_out: Optional[Dict[int, int]],
blocks_to_copy: Optional[Dict[int, List[int]]],
k: int,
proposals: SpeculativeProposals,
) -> SpeculativeScores:
"""Score the proposed tokens via the scorer model.
This converts each input sequence to a set of k+1 target sequences. The
target sequences have the unique continuations to be scored and a
unique sequence ID that is different from all input sequence ids.
If a speculative sequence length would exceed the max model length, then
no speculation is produced for that sequence.
Args:
seq_group_metadata_list: The input sequence group metadata.
blocks_to_swap_in: This is passed to the worker during scoring.
blocks_to_swap_out: This is passed to the worker during scoring.
blocks_to_copy: This is passed to the worker during scoring.
k: The fixed proposal length.
proposals: The speculative proposals to score.
Returns:
SpeculativeScores: The scores of each speculative token, along with
which sequences were ignored during scoring.
"""
# TODO(cade) perform this on GPU to remove blocking call.
proposal_lens_list = proposals.proposal_lens.tolist()
proposal_token_ids_list = proposals.proposal_token_ids.tolist()
spec_indices, non_spec_indices, target_seq_group_metadata_list, num_scoring_tokens = self._expand_batch(
seq_group_metadata_list=seq_group_metadata_list,
proposal_token_ids_list=proposal_token_ids_list,
proposal_lens_list=proposal_lens_list,
)
target_sampler_output = self._scorer_worker.execute_model(
seq_group_metadata_list=target_seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
return_python_output=False)
all_tokens, all_probs = self._contract_batch(
original_bs=len(seq_group_metadata_list),
target_sampler_output=target_sampler_output,
proposals=proposals,
num_scoring_tokens=num_scoring_tokens,
non_spec_indices=non_spec_indices,
spec_indices=spec_indices,
k=k,
)
return SpeculativeScores(
probs=all_probs,
token_ids=all_tokens,
)
def _expand_batch(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
proposal_token_ids_list: List[TokenId],
proposal_lens_list: List[int],
) -> Tuple[List[int], List[int], List[SequenceGroupMetadata], int]:
"""Given the input sequences and potentially multiple corresponding
proposal tokens, create a new batch where each sequence has a single
query token.
"""
# vLLM currently only supports proposal lens equal to zero or the batch
# proposal len. This adds some complexity (splitting the batch into spec
# and non spec sequences) and should be removed in the future. It can be
# done by supporting per-sequence proposal lens.
spec_seqs, spec_indices = split_batch_by_proposal_len(
seq_group_metadata_list,
proposal_lens_list,
select_proposal_len_zero=False)
non_spec_seqs, non_spec_indices = split_batch_by_proposal_len(
seq_group_metadata_list,
proposal_lens_list,
select_proposal_len_zero=True)
target_seq_group_metadata_list = self._create_scoring_model_input(
spec_seqs, proposal_token_ids_list)
num_scoring_tokens = len(target_seq_group_metadata_list)
target_seq_group_metadata_list.extend(non_spec_seqs)
return spec_indices, non_spec_indices, target_seq_group_metadata_list, num_scoring_tokens
def _contract_batch(self, original_bs: int,
target_sampler_output: List[SamplerOutput],
proposals: SpeculativeProposals,
num_scoring_tokens: int, non_spec_indices: List[int],
spec_indices: List[int],
k: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""Contract the expanded batch back into its original size.
This maps the scores of speculative tokens back to their original
sequences.
"""
(target_token_ids, target_probs, non_spec_target_token_ids,
non_spec_target_probs) = self._split_scoring_output(
target_sampler_output, num_scoring_tokens)
# Map distinct sequences used to score each token
# of shape [batch_size * k + 1] back to [batch_size, k + 1].
batch_size, k = proposals.proposal_token_ids.shape
target_token_ids = target_token_ids.squeeze().reshape(
batch_size, k + 1)
target_probs = target_probs.squeeze().reshape(batch_size, k + 1,
self._vocab_size)
all_tokens = torch.full(size=(original_bs, k + 1),
fill_value=-1,
device=self._device,
dtype=torch.long)
all_probs = torch.zeros(original_bs,
k + 1,
self._vocab_size,
device=self._device,
dtype=torch.float32)
if non_spec_indices:
all_tokens[non_spec_indices, 0] = non_spec_target_token_ids
all_probs[non_spec_indices, :1, :] = non_spec_target_probs
if spec_indices:
all_tokens[spec_indices] = target_token_ids
all_probs[spec_indices] = target_probs
return all_tokens, all_probs
def _create_scoring_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k]
) -> List[SequenceGroupMetadata]:
"""Given the original input sequences and proposed tokens from the draft
model, create a list of target sequences that can be used for scoring.
"""
if not seq_group_metadata_list:
return []
target_seq_ids_iter = self._create_target_seq_id_iterator(
get_all_seq_ids(seq_group_metadata_list))
target_seq_group_metadata = list(
chain.from_iterable(
self._create_target_seq_group_metadata(
seq_group_metadata,
proposal_token_ids,
i,
target_seq_ids_iter,
) for i, seq_group_metadata in enumerate(
seq_group_metadata_list)))
return target_seq_group_metadata
def _create_target_seq_group_metadata(
self,
input_seq_group_metadata: SequenceGroupMetadata,
proposal_token_ids: List[TokenId], # shape: [batch_size, k]
batch_index: int,
target_seq_ids_iter: Iterator[TargetSeqId],
) -> List[SequenceGroupMetadata]:
"""Given an input sequence group metadata and a list of draft tokens,
create a list of target SequenceGroupMetadata, one for each
token id that needs to be scored.
Naive speculative decoding requires K target model scores, one for each
draft model token. However one can add a bonus token such that if each
token is accepted, then a final token may be sampled from the model.
This function creates K+1 target SequenceGroupMetadata to take
advantage of the bonus token.
"""
assert not input_seq_group_metadata.is_prompt, (
"Speculating on "
"prompts not yet supported")
assert len(input_seq_group_metadata.seq_data) == 1, (
"Beam search "
"not supported in speculative decoding")
input_seq_id = next(iter(input_seq_group_metadata.seq_data.keys()))
token_ids_to_score = self._get_token_ids_to_score(
proposal_token_ids[batch_index])
target_seq_group_metadata_list: List[SequenceGroupMetadata] = []
for token_ids in token_ids_to_score:
target_seq_group_metadata_list.append(
self._create_single_target_seq_group_metadata(
input_seq_group_metadata,
input_seq_id,
next(target_seq_ids_iter),
token_ids,
))
return target_seq_group_metadata_list
def _create_single_target_seq_group_metadata(
self,
seq_group_metadata: SequenceGroupMetadata,
seq_id: SeqId,
target_seq_id: TargetSeqId,
token_ids: List[TokenId],
) -> SequenceGroupMetadata:
"""Create a single target SequenceGroupMetadata.
Args:
seq_group_metadata: The metadata for the input sequence.
seq_id: The input sequence ID.
target_seq_id: The corresponding target sequence ID.
token_ids: The list of token ids that are to be appended to the
input sequence.
"""
seq_data = seq_group_metadata.seq_data[seq_id]
prompt_token_ids = seq_data.get_prompt_token_ids()
new_output_token_ids = [*seq_data.get_output_token_ids(), *token_ids]
return SequenceGroupMetadata(
request_id=seq_group_metadata.request_id,
is_prompt=seq_group_metadata.is_prompt,
seq_data={
target_seq_id:
SequenceData(
prompt_token_ids=prompt_token_ids,
output_token_ids=new_output_token_ids,
),
},
sampling_params=seq_group_metadata.sampling_params,
block_tables={
target_seq_id: seq_group_metadata.block_tables[seq_id],
},
lora_request=None,
)
def _split_scoring_output(
self, sampler_output: SamplerOutput, num_scoring_tokens: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Split the target model output into speculative and non-speculative
output.
"""
# vLLM currently only supports proposal lens equal to zero or the batch
# proposal len. This adds some complexity (splitting the batch into spec
# and non spec sequences) and should be removed in the future. It can be
# done by supporting per-sequence proposal lens.
#
# First samples are from speculative scoring, latter samples are non-
# speculative samples.
split_sizes = [
num_scoring_tokens,
sampler_output.sampled_token_ids.numel() - num_scoring_tokens
]
(spec_probs, non_spec_probs
) = sampler_output.sampled_token_probs.split(split_sizes)
(spec_sampled_tokens, non_spec_sampled_tokens
) = sampler_output.sampled_token_ids.flatten().split(split_sizes)
# Convert scores to tensors.
sampler_output.sampled_token_probs = spec_probs
sampler_output.sampled_token_ids = spec_sampled_tokens
target_token_ids, target_probs = sampler_output_to_torch(
[sampler_output])
# Convert non-speculative output tokens to tensors.
sampler_output.sampled_token_probs = non_spec_probs
sampler_output.sampled_token_ids = non_spec_sampled_tokens
non_spec_target_token_ids, non_spec_target_probs = sampler_output_to_torch(
[sampler_output])
return target_token_ids, target_probs, non_spec_target_token_ids, non_spec_target_probs
def _create_target_seq_id_iterator(
self, seq_ids: List[SeqId]) -> Iterator[TargetSeqId]:
"""Create an iterator for creating target sequence ids.
Target sequence ids are distinct from sequence ids because we create a
distinct target sequence id for each proposal token to be scored.
This implementation increments a counter starting at 1 + max of all
provided input sequence ids.
"""
return count(start=max(seq_ids) + 1)
def _get_token_ids_to_score(
self,
full_spec_token_ids: List[TokenId] # shape: [k]
) -> List[List[TokenId]]:
"""Given an int tensor of proposal token ids, return a list of
token ids that should be scored.
Returns k+1 output lists. The additional one is used for generating the
bonus token.
Example:
Input: [0, 1, 2, 3] (k=4)
Output: (k+1 lists)
[]
[0]
[0, 1]
[0, 1, 2]
[0, 1, 2, 3]
"""
empty_token_ids = []
token_ids_to_score = [empty_token_ids]
token_ids_to_score.extend([
full_spec_token_ids[:i + 1]
for i in range(len(full_spec_token_ids))
])
return token_ids_to_score

View File

@ -0,0 +1,77 @@
from typing import List, Tuple, Optional, Dict
from dataclasses import dataclass
from abc import ABC, abstractmethod
import torch
from vllm.sequence import SequenceGroupMetadata
@dataclass
class SpeculativeProposals:
"""Datastructure used to represent proposal tokens from some proposer. It
also tracks how many speculative tokens each sequence has.
"""
# Speculative proposal tokens.
proposal_token_ids: torch.Tensor
# Probabilities of the proposal tokens according to the proposer.
proposal_probs: torch.Tensor
# The valid length of each proposal; can be zero.
proposal_lens: torch.Tensor
def __repr__(self):
return (f"SpeculativeProposals("
f"proposal_token_ids={self.proposal_token_ids.shape}, "
f"proposal_probs={self.proposal_probs.shape}, "
f"proposal_lens={self.proposal_lens.shape})")
@dataclass
class SpeculativeScores:
"""Datastructure used to represent the scores of speculative tokens
according to the scoring model.
"""
# Probabilities of the speculative tokens according to the scoring model.
probs: torch.Tensor
# Token ids sampled from the scoring model. Used for speculative bonus
# tokens and also non-speculative normal decoding.
token_ids: torch.Tensor
def __repr__(self):
return (f"SpeculativeScores("
f"probs={self.probs.shape}, "
f"token_ids={self.token_ids.shape})")
class SpeculativeProposer(ABC):
@abstractmethod
def get_proposals(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
max_proposal_len: int,
) -> SpeculativeProposals:
raise NotImplementedError
class SpeculativeScorer(ABC):
@abstractmethod
def score_proposals(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Optional[Dict[int, int]],
blocks_to_swap_out: Optional[Dict[int, int]],
blocks_to_copy: Optional[Dict[int, List[int]]],
k: int,
proposals: SpeculativeProposals,
) -> Tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError

174
vllm/spec_decode/metrics.py Normal file
View File

@ -0,0 +1,174 @@
import torch
from dataclasses import dataclass
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from typing import Optional
from vllm.utils import in_wsl
import time
from typing import Callable
@dataclass
class SpecDecodeWorkerMetrics:
"""Dataclass holding metrics emitted from the spec decode worker.
"""
# The empirical acceptance rate of the proposal method on a per-token basis.
# This is useful for evaluating how well the proposal method aligns with the
# scoring method.
draft_acceptance_rate: float
# The empirical efficiency, measured as the number of tokens emitted by the
# system divided by the number of tokens that could be emitted by the system
# if the proposal method were perfect.
system_efficiency: float
# The number of speculative tokens produced by the proposal method.
draft_tokens: int
# The number of tokens emitted by the entire system.
emitted_tokens: int
# The number of tokens accepted by the scoring model and verification
# routine, e.g. Llama2-70B and lossless rejection sampling.
#
# NOTE: Any token accepted by the verification routine is considered
# accepted (regardless of if the speculative prefix is also accepted). The
# user will usually see less accepted tokens. This metric is helpful when
# evaluating alignment of the proposal method with the scoring model.
accepted_tokens: int
# The number of speculative tokens per sequence.
num_spec_tokens: int
Timer = Callable[[], float]
class AsyncMetricsCollector:
"""Class which copies rejection sampler metrics from the device to CPU on a
non-default Torch stream.
"""
def __init__(self,
rejection_sampler: RejectionSampler,
timer: Optional[Timer] = None,
collect_interval_s: float = 5.0):
self._rejection_sampler = rejection_sampler
self._timer = time.time if timer is None else timer
self._rank: Optional[int] = None
# We don't have a device set yet.
self._copy_stream: Optional[torch.cuda.Stream] = None
self._in_flight_copy: Optional[torch.cuda.Event] = None
pin_memory = not in_wsl()
self._aggregate_num_accepted_tokens = torch.tensor(
0, dtype=torch.long, device="cpu", pin_memory=pin_memory)
self._aggregate_num_emitted_tokens = torch.tensor(
0, dtype=torch.long, device="cpu", pin_memory=pin_memory)
self._aggregate_num_draft_tokens = 0
self._rejsample_metrics_collect_interval_s = collect_interval_s
self._last_metrics_collect_time = self._timer()
def init_gpu_tensors(self, rank: int) -> None:
self._rank = rank
self._copy_stream = torch.cuda.Stream()
def maybe_collect_rejsample_metrics(
self, k: int) -> Optional[SpecDecodeWorkerMetrics]:
# If a copy was initiated in the previous call, collect and return.
if self._in_flight_copy is not None:
ready_event = self._in_flight_copy
self._in_flight_copy = None
return self._collect_rejsample_metrics(k, ready_event)
# Otherwise, check if we should start a new copy.
if self._should_collect_rejsample_metrics(self._timer()):
assert self._in_flight_copy is None
self._in_flight_copy = self._copy_rejsample_metrics_async()
return None
def _should_collect_rejsample_metrics(self, now: float) -> bool:
"""Return whether or not this iteration should print rejection sampling
metrics.
"""
if self._rank != 0:
return False
if (now - self._last_metrics_collect_time <
self._rejsample_metrics_collect_interval_s):
return False
return True
def _copy_rejsample_metrics_async(self) -> torch.cuda.Event:
"""Copy rejection sampling metrics (number of accepted tokens, etc) to
CPU asynchronously.
Returns a CUDA event recording when the copy is complete.
"""
self._copy_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._copy_stream):
self._aggregate_num_accepted_tokens.copy_(
self._rejection_sampler.num_accepted_tokens, non_blocking=True)
self._aggregate_num_emitted_tokens.copy_(
self._rejection_sampler.num_emitted_tokens, non_blocking=True)
# Number of draft tokens is calculated on CPU, so no copy is
# required.
self._aggregate_num_draft_tokens = (
self._rejection_sampler.num_draft_tokens)
aggregate_metrics_ready = torch.cuda.Event()
aggregate_metrics_ready.record(self._copy_stream)
return aggregate_metrics_ready
def _collect_rejsample_metrics(
self, k: int,
ready_event: torch.cuda.Event) -> SpecDecodeWorkerMetrics:
"""Create metrics object from statistics copied asynchronously.
Args:
k: int. The number of speculative tokens; used to determine system
efficiency.
ready_event: torch.cuda.Event. The CUDA event recording when the
async GPU->CPU copy is complete.
"""
ready_event.synchronize()
accepted_tokens = self._aggregate_num_accepted_tokens.item()
emitted_tokens = self._aggregate_num_emitted_tokens.item()
draft_tokens = self._aggregate_num_draft_tokens
num_possible_tokens = self.get_max_num_accepted_tokens(draft_tokens, k)
if draft_tokens > 0:
draft_acceptance_rate = accepted_tokens / draft_tokens
else:
draft_acceptance_rate = float("nan")
if num_possible_tokens > 0:
system_efficiency = emitted_tokens / num_possible_tokens
else:
system_efficiency = float("nan")
return SpecDecodeWorkerMetrics(
num_spec_tokens=k,
draft_acceptance_rate=draft_acceptance_rate,
system_efficiency=system_efficiency,
accepted_tokens=accepted_tokens,
draft_tokens=draft_tokens,
emitted_tokens=emitted_tokens,
)
@staticmethod
def get_max_num_accepted_tokens(draft_tokens: int, k: int) -> int:
# Divide by k since batch size can be variable.
total_num_spec_seqs = draft_tokens / k
num_accepted_per_seq_if_all_accepted = k + 1
return int(total_num_spec_seqs / num_accepted_per_seq_if_all_accepted)

View File

@ -0,0 +1,366 @@
from typing import List, Dict, Optional, Tuple
import copy
import torch
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.worker.worker import Worker
from vllm.spec_decode.interfaces import SpeculativeProposals, SpeculativeProposer
from vllm.spec_decode.util import sampler_output_to_torch
class MultiStepWorker(Worker):
"""The MultiStepWorker is equivalent to a Worker except that it allows
multiple forward passes in a single call, assuming the scheduler has
allocated enough space to store the additional KV. This reduces overhead
by invoking the scheduler less.
The MultiStepWorker does not support cache swap operations, or beam search.
Cache swap operations do not require large modifications. On the other hand,
beam search requires memory allocations during sequence forks and thus
requires more thought for MultiStepWorker support.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._proposer: Optional[DraftModelTop1Proposer] = None
def init_model(self):
super().init_model()
self._proposer = DraftModelTop1Proposer(
self,
self.device,
self.max_model_len,
self.vocab_size,
)
@torch.inference_mode()
def execute_model_multi_step(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
num_steps: int,
) -> List[SamplerOutput]:
"""Run the model forward pass num_steps times. Returns the list of
sampler output, one per model forward pass.
"""
self._raise_if_unsupported(seq_group_metadata_list, blocks_to_swap_in,
blocks_to_swap_out, blocks_to_copy)
# Shallow copy input data so modifications (such as appending tokens)
# do not cause side-effects.
copied_seq_group_metadata_list = self._shallow_copy_inputs(
seq_group_metadata_list)
# Assert enough KV space for num_steps tokens per sequence.
self._assert_enough_kv_space(seq_group_metadata_list, num_steps)
# Run model num_steps times.
model_outputs = []
for _ in range(num_steps):
model_output = super().execute_model(
seq_group_metadata_list=copied_seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
)
self._append_new_tokens(model_output,
copied_seq_group_metadata_list)
model_outputs.append(model_output)
return model_outputs
def get_spec_proposals(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
max_proposal_len: int,
) -> SpeculativeProposals:
"""Produce speculations given an input batch of sequences. The number of
speculative tokens per sequence is determined by max_proposal_len.
"""
return self._proposer.get_proposals(
seq_group_metadata_list,
blocks_to_swap_in,
blocks_to_swap_out,
blocks_to_copy,
max_proposal_len,
)
def _append_new_tokens(
self, model_output: SamplerOutput,
seq_group_metadata_list: SequenceGroupMetadata) -> None:
"""Given model output from a single run, append the tokens to the
sequences. This is normally done outside of the worker, but it is
required if the worker is to perform multiple forward passes.
"""
for seq_group_metadata, sequence_group_outputs in zip(
seq_group_metadata_list, model_output):
seq_group_metadata.is_prompt = False
for seq_output in sequence_group_outputs.samples:
# NOTE: Beam search is not supported, so we can assume that
# parent_seq_id == seq_id.
seq = seq_group_metadata.seq_data[seq_output.parent_seq_id]
token_id = seq_output.output_token
token_logprob = seq_output.logprobs[token_id]
seq.append_token_id(token_id, token_logprob.logprob)
def _shallow_copy_inputs(
self, seq_group_metadata_list: List[SequenceGroupMetadata]
) -> List[SequenceGroupMetadata]:
"""Copy input data structures to remove side-effects when input data
structures are shared with other modules.
Helpful when the vLLM scheduler runs in the same process as the worker.
The alternative is deep-copying (or other form of deep copy); this has
performance downsides.
"""
# Shallow-copy the list of SequenceGroupMetadata. This allows us to
# append tokens and change is_prompt without external side-effects.
new_seq_group_metadata_list = []
for old_seq_group_metadata in seq_group_metadata_list:
# We must shallow-copy seq_group_metadata as is_prompt could change.
seq_group_metadata = copy.copy(old_seq_group_metadata)
new_seq_group_metadata_list.append(seq_group_metadata)
# We must shallow-copy seq_data as we will append token ids
new_seq_data = {}
for seq_id, old_seq_data in seq_group_metadata.seq_data.items():
new_seq_data[seq_id] = copy.copy(old_seq_data)
new_seq_data[
seq_id].output_token_ids = old_seq_data.output_token_ids[:]
seq_group_metadata.seq_data = new_seq_data
return new_seq_group_metadata_list
def _assert_enough_kv_space(
self, seq_group_metadata_list: List[SequenceGroupMetadata],
num_steps: int) -> None:
"""Assert there are enough physical blocks per sequence to store the
current KV plus additional KV from num_steps tokens.
"""
assert self.model_runner.block_size is not None
for seq_group_metadata in seq_group_metadata_list:
# Only one seq_id is guaranteed because there is no beam search.
seq_id = list(seq_group_metadata.seq_data.keys())[0]
seq = seq_group_metadata.seq_data[seq_id]
# After num_steps, the seq len will be the current seq len
# plus one token per step.
final_seq_len = seq.get_len() + num_steps
# We will have final_seq_len - 1 KV because vLLM saves KV for a
# token in the iteration after the token was generated.
required_num_kv_slots = final_seq_len - 1
# The allocated number of kv slots is the number of allocated blocks
# times the number of slots of block.
number_physical_blocks = len(
seq_group_metadata.block_tables[seq_id])
allocated_kv_slots = (number_physical_blocks *
self.model_runner.block_size)
if required_num_kv_slots > allocated_kv_slots:
request_id = seq_group_metadata.request_id
raise ValueError(
"The worker attempted to run "
f"{num_steps} times but found insufficient KV space for "
f"{request_id=} {seq_id=}. ({allocated_kv_slots=} "
f"{required_num_kv_slots=}).")
def _raise_if_unsupported(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
) -> None:
"""MultiStepWorker does not yet implement support for cache swap
operations or beam search.
"""
if any([blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy]):
raise NotImplementedError(
"MultiStepWorker does not support cache operations")
if any(
len(seq_group_metadata.seq_data.keys()) != 1
for seq_group_metadata in seq_group_metadata_list):
raise NotImplementedError(
"MultiStepWorker does not support beam search.")
class DraftModelTop1Proposer(SpeculativeProposer):
"""Helper class which separates out sequences which would exceed the max
model length when speculated upon.
This allows combinations of models such as JackFram/llama-68m draft with
meta-llama/Llama2-13b-chat-hf, as llama-68m has max_position_embeddings of
2048 while Llama2-13b has max_position_embeddings of 4096.
We treat the sequences which exceed the proposal draft model length as
"non-spec sequences". Essentially they skip the draft model and go through
normal decoding in the target model.
Currently, only proposal_lens of 0 and k are supported, where k is a global
batch proposal length. In the future vLLM should support per-sequence
proposal lengths.
"""
def __init__(
self,
draft_worker: MultiStepWorker,
device: str,
max_model_len: int,
vocab_size: int,
):
self._draft_worker = draft_worker
self._device = device
self._max_model_len = max_model_len
self._vocab_size = vocab_size
def get_proposals(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
max_proposal_len: int,
) -> SpeculativeProposals:
"""Get speculative proposals given the input batch.
Sequences which would exceed the max model length are skipped during
speculation.
"""
# Split speculative- and non-speculative- sequences.
proposal_lens, nonzero_proposal_len_seqs, nonzero_proposal_len_indices = self._split_by_max_model_len(
seq_group_metadata_list, max_proposal_len)
if nonzero_proposal_len_seqs:
# Speculate tokens using the draft worker for the speculative
# sequences.
maybe_sampler_output = self._draft_worker.execute_model_multi_step(
seq_group_metadata_list=nonzero_proposal_len_seqs,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
num_steps=max_proposal_len,
)
else:
# If no sequences can be speculated, set sampler output to None.
maybe_sampler_output = None
# Combine speculative- and non-speculative sequences into the same
# representation.
proposal_tokens, proposal_probs, proposal_lens = self._merge_outputs(
batch_size=len(seq_group_metadata_list),
max_proposal_len=max_proposal_len,
maybe_sampler_output=maybe_sampler_output,
proposal_lens=proposal_lens,
nonzero_proposal_len_indices=nonzero_proposal_len_indices,
)
proposals = SpeculativeProposals(
proposal_token_ids=proposal_tokens,
proposal_probs=proposal_probs,
proposal_lens=proposal_lens,
)
return proposals
def _split_by_max_model_len(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
max_proposal_len: int,
) -> Tuple[List[int], List[SequenceGroupMetadata], List[int]]:
"""Determine which sequences would exceed the max model length.
"""
proposal_lens: List[int] = []
nonzero_proposal_len_seqs: List[SequenceGroupMetadata] = []
nonzero_proposal_len_indices: List[int] = []
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
seq_data = next(iter(seq_group_metadata.seq_data.values()))
seq_len = seq_data.get_len()
# Currently only proposal lens of 0 or the global batch proposal len
# are supported.
if seq_len + max_proposal_len < self._max_model_len:
proposal_lens.append(max_proposal_len)
nonzero_proposal_len_seqs.append(seq_group_metadata)
nonzero_proposal_len_indices.append(i)
else:
proposal_lens.append(0)
return proposal_lens, nonzero_proposal_len_seqs, nonzero_proposal_len_indices
def _merge_outputs(
self,
batch_size: int,
max_proposal_len: int,
maybe_sampler_output: Optional[SamplerOutput],
proposal_lens: List[int],
nonzero_proposal_len_indices: List[int],
) -> Tuple[torch.Tensor, torch.tensor, torch.Tensor]:
"""After speculations are produced, merge the speculation results with
the skipped sequences.
"""
if maybe_sampler_output is None:
# If no speculative tokens, the sampler output will be None.
# In this case we return empty tensors.
proposal_tokens = torch.zeros(0,
max_proposal_len,
dtype=torch.long,
device=self._device)
proposal_probs = torch.zeros(0,
max_proposal_len,
self._vocab_size,
dtype=torch.float32,
device=self._device)
proposal_lens = torch.zeros(len(proposal_lens),
dtype=torch.long,
device=self._device)
return proposal_tokens, proposal_probs, proposal_lens
sampler_output = maybe_sampler_output
proposal_tokens, proposal_probs = sampler_output_to_torch(
sampler_output)
# Now, reformat the output GPU tensors such that each sequence has
# a proposal. the proposal can be empty, e.g. [-1, -1, -1]
entire_proposal_tokens = torch.full(size=(batch_size,
*proposal_tokens.shape[1:]),
fill_value=-1,
dtype=torch.long,
device=self._device)
entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens
entire_proposal_probs = torch.zeros(batch_size,
*proposal_probs.shape[1:],
dtype=torch.float32,
device=self._device)
entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs
proposal_tokens, proposal_probs = entire_proposal_tokens, entire_proposal_probs
proposal_lens = torch.zeros(batch_size,
dtype=torch.long,
device=self._device)
proposal_lens[nonzero_proposal_len_indices] = max_proposal_len
return proposal_tokens, proposal_probs, proposal_lens

View File

@ -0,0 +1,372 @@
from typing import List, Tuple, Optional, Dict
from functools import cached_property
import torch
from vllm.spec_decode.metrics import AsyncMetricsCollector
from vllm.sequence import (SamplerOutput, SequenceGroupMetadata,
SequenceGroupOutput, SequenceOutput)
from vllm.worker.worker import Worker
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.config import CacheConfig
from vllm.spec_decode.util import nvtx_range, get_all_seq_ids, split_batch_by_proposal_len
from vllm.spec_decode.interfaces import SpeculativeProposals, SpeculativeScores
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.spec_decode.interfaces import SpeculativeScorer
class SpecDecodeWorker:
"""Worker which implements speculative decoding.
Speculative decoding reduces decoding per-token latency by using a proposal
method, such as a small draft model, to speculate ahead of a larger LLM. The
probabilities of the speculative tokens are then determined by the larger
LLM, after which some verification routine determines which (if any) of the
speculative tokens are accepted by the larger LLM.
See https://github.com/vllm-project/vllm/pull/2188 and
https://github.com/vllm-project/vllm/pull/3103 for more info.
The current implementation has the following limitations:
* Only draft-model proposal is implemented (contributions for more forms are
welcome!).
* Only top-1 proposal and scoring are implemented. Tree-attention is left as
future work.
* Only lossless rejection sampling is supported. Contributions adding lossy
verification routines are welcome (e.g. Medusa's typical acceptance).
* All sequences in a batch must have the same proposal length, or zero. This
can be improved by having per-sequence speculation in the future.
* The scoring forward pass is done without an MQA kernel, which is
suboptimal especially as the batch size, proposal length, and sequence
lengths grow. Contributions to add a MQA scoring are welcome once
correctness tests pass.
More info here https://docs.google.com/document/d/1T-JaS2T1NRfdP51qzqpyakoCXxSXTtORppiwaj5asxA/edit.
"""
def __init__(
self,
proposer_worker: MultiStepWorker,
scorer_worker: Worker,
rejection_sampler: RejectionSampler,
metrics_collector: Optional[AsyncMetricsCollector] = None,
):
"""
Create a SpecDecodeWorker.
Args:
proposer_worker: A worker that can produce speculative tokens for
sequences.
scorer_worker: A worker that produces probabilities of speculative
tokens according to some base model. Typically a vanilla vLLM
Worker.
rejection_sampler: A Torch module used to perform modified rejection
sampling for speculative decoding.
metrics_collector: Helper class for collecting metrics; can be set
for testing purposes.
"""
self.proposer_worker = proposer_worker
self.scorer_worker = scorer_worker
self.rejection_sampler = rejection_sampler
self._metrics = AsyncMetricsCollector(
rejection_sampler
) if metrics_collector is None else metrics_collector
self.probs_dtype = self.rejection_sampler.probs_dtype
self.token_id_dtype = self.rejection_sampler.token_id_dtype
self.scorer: SpeculativeScorer = None
def init_model(self) -> None:
"""Initialize both scorer and proposer models.
"""
# The scorer worker model is initialized first in case the proposer
# model has a smaller TP degree than the target worker.
self.scorer_worker.init_model()
self.proposer_worker.init_model()
self._metrics.init_gpu_tensors(self.rank)
self.rejection_sampler.init_gpu_tensors(self.rank)
self.scorer = BatchExpansionTop1Scorer(
scorer_worker=self.scorer_worker,
device=self.device,
vocab_size=self._vocab_size)
def profile_num_available_blocks(self, block_size: int,
gpu_memory_utilization: float,
cpu_swap_space: int,
cache_dtype: str) -> Tuple[int, int]:
"""Determine the number of cache blocks to use.
This is done by profiling the scorer model (which is typically the
larger of the two). Then the total memory which would be used by the
scorer cache is divided evenly between the proposer and scorer model KV,
such that the number of blocks is equal in both KV caches.
"""
num_gpu_blocks, num_cpu_blocks = (
self.scorer_worker.profile_num_available_blocks(
block_size, gpu_memory_utilization, cpu_swap_space,
cache_dtype))
scorer_cache_block_size_bytes = self.scorer_worker.get_cache_block_size_bytes(
block_size, cache_dtype)
proposer_cache_block_size_bytes = self.proposer_worker.get_cache_block_size_bytes(
block_size, cache_dtype)
new_num_gpu_blocks = split_num_cache_blocks_evenly(
scorer_cache_block_size_bytes, proposer_cache_block_size_bytes,
num_gpu_blocks)
return new_num_gpu_blocks, num_cpu_blocks
def init_cache_engine(self, cache_config: CacheConfig):
"""Initialize the cache engine of the scorer and proposer workers.
"""
self.scorer_worker.init_cache_engine(cache_config)
self.proposer_worker.init_cache_engine(cache_config)
@torch.inference_mode()
def execute_model(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Optional[Dict[int, int]],
blocks_to_swap_out: Optional[Dict[int, int]],
blocks_to_copy: Optional[Dict[int, List[int]]],
num_spec_tokens: int,
) -> List[SamplerOutput]:
"""Perform speculative decoding on the input batch.
"""
assert seq_group_metadata_list is not None, (
"speculative decoding "
"requires non-None seq_group_metadata_list")
# If no spec tokens, call the proposer and scorer workers normally.
# Used for prefill.
if num_spec_tokens == 0 or len(seq_group_metadata_list) == 0:
return self._run_no_spec(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
)
return self._run_speculative_decoding_step(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
k=num_spec_tokens,
)
@nvtx_range("spec_decode_worker._run_no_spec")
def _run_no_spec(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Optional[Dict[int, int]],
blocks_to_swap_out: Optional[Dict[int, int]],
blocks_to_copy: Optional[Dict[int, List[int]]],
) -> List[SamplerOutput]:
"""Run a prefill step, without any speculation. The input is sent to the
proposer and scorer model so that the KV cache is consistent between the
two.
"""
self.proposer_worker.execute_model(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
return_python_output=False)
sampler_output = self.scorer_worker.execute_model(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
)
# Clear device tensors from sampler output. This reduces communication
# overhead when the engine runs in a different process than the workers.
sampler_output.probs = None
sampler_output.sampled_tokens = None
return [sampler_output]
@nvtx_range("spec_decode_worker._run_speculative_decoding_step")
def _run_speculative_decoding_step(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Optional[Dict[int, int]],
blocks_to_swap_out: Optional[Dict[int, int]],
blocks_to_copy: Optional[Dict[int, List[int]]],
k: int,
) -> List[SamplerOutput]:
"""Execute a single step of speculative decoding.
This invokes the proposer worker to get k speculative tokens for each
sequence, then scores each speculative token using the scoring worker.
Returns a list of SamplerOutput, each containing a single token per
sequence.
"""
# Generate proposals using draft worker.
proposals = self.proposer_worker.get_spec_proposals(
seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out,
blocks_to_copy, k)
proposal_scores = self.scorer.score_proposals(
seq_group_metadata_list,
blocks_to_swap_in,
blocks_to_swap_out,
blocks_to_copy,
k,
proposals,
)
accepted_token_ids = self._verify_tokens(seq_group_metadata_list,
proposal_scores, proposals, k)
return self._create_output_sampler_list(seq_group_metadata_list,
accepted_token_ids, k)
@nvtx_range("spec_decode_worker._verify_tokens")
def _verify_tokens(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
proposal_scores: SpeculativeScores,
proposals: SpeculativeProposals,
max_proposal_len: int,
) -> torch.Tensor:
"""Determine which speculative tokens are accepted using the
probabilities of each token according to the proposer and scorer models.
"""
proposal_lens_list = proposals.proposal_lens.tolist()
# vLLM currently only supports proposal lens equal to zero or the batch
# proposal len. This adds some complexity (splitting the batch into spec
# and non spec sequences) and should be removed in the future. It can be
# done by supporting per-sequence proposal lens.
_, spec_indices = split_batch_by_proposal_len(
seq_group_metadata_list,
proposal_lens_list,
select_proposal_len_zero=False)
_, non_spec_indices = split_batch_by_proposal_len(
seq_group_metadata_list,
proposal_lens_list,
select_proposal_len_zero=True)
original_indices = spec_indices + non_spec_indices
proposal_probs = proposal_scores.probs[spec_indices, :-1]
bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:]
non_spec_token_ids = proposal_scores.token_ids[non_spec_indices]
accepted_token_ids = self.rejection_sampler(
proposal_probs,
bonus_token_ids,
proposals.proposal_probs,
proposals.proposal_token_ids,
)
# Append output tokens from non-speculative sequences to
# the accepted token ids tensor.
non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len +
1).clone()
non_spec_token_ids[:, 1:] = -1
accepted_token_ids = torch.cat(
[accepted_token_ids, non_spec_token_ids])
# Rearrange so that results are in the order of the original seq group
# metadata.
accepted_token_ids[original_indices] = accepted_token_ids.clone()
return accepted_token_ids
def _create_output_sampler_list(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
accepted_token_ids: torch.Tensor, # shape: [batch_size, k+1]
k: int,
) -> List[SamplerOutput]:
"""Given the accepted token ids, create a list of SamplerOutput.
The output is padded with -1 tokens such that each sequence has
the same number of outputs.
"""
seq_ids = get_all_seq_ids(seq_group_metadata_list)
# shape: [k+1, batch_size]
accepted_token_ids_by_step = accepted_token_ids.transpose(0,
1).tolist()
sampler_output_list = []
for token_ids_by_step in accepted_token_ids_by_step:
if all(token_id == -1 for token_id in token_ids_by_step):
break
step_output_token_ids = []
for token_id, seq_id in zip(token_ids_by_step, seq_ids):
step_output_token_ids.append(
SequenceGroupOutput(
samples=[
SequenceOutput(
parent_seq_id=seq_id,
output_token=token_id,
# TODO Add verifier logprobs.
logprobs={token_id: 0.0},
)
],
prompt_logprobs=None,
))
sampler_output_list.append(
SamplerOutput(outputs=step_output_token_ids))
maybe_rejsample_metrics = self._metrics.maybe_collect_rejsample_metrics(
k)
if maybe_rejsample_metrics is not None:
sampler_output_list[
0].spec_decode_worker_metrics = maybe_rejsample_metrics
return sampler_output_list
@cached_property
def _vocab_size(self) -> int:
"""Get the vocab size of the model and make sure it's consistent between
draft and target workers.
"""
vocab_sizes = [
worker.vocab_size
for worker in [self.proposer_worker, self.scorer_worker]
]
assert all(vocab_sizes[0] == vocab_size for vocab_size in vocab_sizes)
return vocab_sizes[0]
@property
def rank(self):
return self.scorer_worker.rank
@property
def device(self):
return self.scorer_worker.device
def split_num_cache_blocks_evenly(scorer_cache_block_size_bytes: int,
proposer_cache_block_size_bytes: int,
total_num_gpu_blocks: int) -> int:
"""Given total_num_gpu_blocks, the number of GPU blocks that could be
allocate to the target model, this function calculates how many blocks
should be given to the draft and target model.
Note that usually the block size, in bytes, of each model is different,
as it's a function of number of KV/layer, number of heads, and hidden
dimension size.
Since the target and draft models allocate the same number of blocks, we
simply calculate the number of blocks where if allocated by both models,
the total memory usage from KV cache is no larger than the number of
blocks allocatable by the target model alone.
"""
new_num_gpu_blocks = int(
total_num_gpu_blocks * scorer_cache_block_size_bytes /
(proposer_cache_block_size_bytes + scorer_cache_block_size_bytes))
return new_num_gpu_blocks

99
vllm/spec_decode/util.py Normal file
View File

@ -0,0 +1,99 @@
import torch
from typing import List, Tuple
from vllm.sequence import SequenceGroupMetadata, SamplerOutput
from contextlib import contextmanager
from itertools import chain
SeqId = int
def get_all_seq_ids(
seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[SeqId]:
"""Given a list of SequenceGroupMetadata, create a list of all
sequence ids.
"""
return list(
chain.from_iterable([
seq_group_metadata.seq_data.keys()
for seq_group_metadata in seq_group_metadata_list
]))
def split_batch_by_proposal_len(
seq_group_metadata_list: List[SequenceGroupMetadata],
proposal_lens: List[int], select_proposal_len_zero: bool
) -> Tuple[List[SequenceGroupMetadata], List[int]]:
"""Utility function that splits a batch based on whether the proposal len is
zero or not. We should remove this once vLLM supports per-sequence proposal
lens in a batch.
"""
if select_proposal_len_zero:
predicate = lambda proposal_len: proposal_len == 0
else:
predicate = lambda proposal_len: proposal_len != 0
indices = [
i for i, (_, proposal_len
) in enumerate(zip(seq_group_metadata_list, proposal_lens))
if predicate(proposal_len)
]
seq_groups = [
seq_group for seq_group, proposal_len in zip(
seq_group_metadata_list, proposal_lens) if predicate(proposal_len)
]
return seq_groups, indices
def sampler_output_to_torch(
sampler_output_list: List[SamplerOutput],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Utility function which converts a list of SamplerOutput to tensors.
Returns:
sampled_token_ids: torch.Tensor
shape: [batch_size, len(sampler_output_list)]
sampled_token_probs: torch.Tensor
shape: [batch_size, len(sampler_output_list), vocab_size]
"""
# shape: [batch_size, num_sampler_output, vocab_size]
sampled_token_probs = torch.stack(
[
sampler_output.sampled_token_probs
for sampler_output in sampler_output_list
],
dim=0,
).transpose(0, 1)
# shape: [batch_size, num_sampler_output]
sampled_token_ids = torch.stack(
[
sampler_output.sampled_token_ids.flatten()
for sampler_output in sampler_output_list
],
dim=0,
).transpose(0, 1)
return sampled_token_ids, sampled_token_probs
@contextmanager
def nvtx_range(msg, *args, **kwargs):
"""
Context manager / decorator that pushes an NVTX range at the beginning
of its scope, and pops it at the end. If extra arguments are given,
they are passed as arguments to msg.format().
If running with cuda graphs, you must enable nsys cuda graph profiling.
Arguments:
msg (string): message to associate with the range
"""
torch.cuda.nvtx.range_push(msg.format(*args, **kwargs))
try:
yield
finally:
torch.cuda.nvtx.range_pop()

View File

@ -97,8 +97,6 @@ class ModelRunner:
f"Loading model weights took {self.model_memory_usage / float(2**30):.4f} GB"
)
vocab_size = self.model.config.vocab_size
if self.lora_config:
assert hasattr(
self.model, "supported_lora_modules"
@ -111,7 +109,7 @@ class ModelRunner:
self.lora_manager = LRUCacheWorkerLoRAManager(
self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens +
self.scheduler_config.max_paddings, vocab_size,
self.scheduler_config.max_paddings, self.vocab_size,
self.lora_config, self.device, self.model.embedding_modules,
self.model.embedding_padding_modules)
self.model = self.lora_manager.create_lora_manager(self.model)
@ -607,8 +605,7 @@ class ModelRunner:
@torch.inference_mode()
def profile_run(self) -> None:
# Enable top-k sampling to reflect the accurate memory usage.
vocab_size = self.model_config.get_vocab_size()
sampling_params = SamplingParams(top_p=0.99, top_k=vocab_size - 1)
sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
max_num_seqs = self.scheduler_config.max_num_seqs
@ -774,6 +771,10 @@ class ModelRunner:
self.graph_runners.clear()
self.cupy_nccl_backend = None
@property
def vocab_size(self) -> int:
return self.model_config.get_vocab_size()
class CUDAGraphRunner:

View File

@ -1,178 +0,0 @@
from typing import List, Dict
import copy
import torch
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.worker.worker import Worker
class MultiStepWorker(Worker):
"""The MultiStepWorker is equivalent to a Worker except that it allows
multiple forward passes in a single call, assuming the scheduler has
allocated enough space to store the additional KV. This reduces overhead
by invoking the scheduler less.
The MultiStepWorker does not support cache swap operations, or beam search.
Cache swap operations do not require large modifications. On the other hand,
beam search requires memory allocations during sequence forks and thus
requires more thought for MultiStepWorker support.
"""
@torch.inference_mode()
def execute_model_multi_step(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
num_steps: int,
) -> List[SamplerOutput]:
"""Run the model forward pass num_steps times. Returns the list of
sampler output, one per model forward pass.
"""
self._raise_if_unsupported(seq_group_metadata_list, blocks_to_swap_in,
blocks_to_swap_out, blocks_to_copy)
# Shallow copy input data so modifications (such as appending tokens)
# do not cause side-effects.
copied_seq_group_metadata_list = self._shallow_copy_inputs(
seq_group_metadata_list)
# Assert enough KV space for num_steps tokens per sequence.
self._assert_enough_kv_space(seq_group_metadata_list, num_steps)
# Run model num_steps times.
model_outputs = []
for _ in range(num_steps):
model_output = super().execute_model(
seq_group_metadata_list=copied_seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
)
self._append_new_tokens(model_output,
copied_seq_group_metadata_list)
model_outputs.append(model_output)
return model_outputs
def _append_new_tokens(
self, model_output: SamplerOutput,
seq_group_metadata_list: SequenceGroupMetadata) -> None:
"""Given model output from a single run, append the tokens to the
sequences. This is normally done outside of the worker, but it is
required if the worker is to perform multiple forward passes.
"""
for seq_group_metadata, sequence_group_outputs in zip(
seq_group_metadata_list, model_output):
seq_group_metadata.is_prompt = False
for seq_output in sequence_group_outputs.samples:
# NOTE: Beam search is not supported, so we can assume that
# parent_seq_id == seq_id.
seq = seq_group_metadata.seq_data[seq_output.parent_seq_id]
token_id = seq_output.output_token
token_logprob = seq_output.logprobs[token_id]
seq.append_token_id(token_id, token_logprob.logprob)
def _shallow_copy_inputs(
self, seq_group_metadata_list: List[SequenceGroupMetadata]
) -> List[SequenceGroupMetadata]:
"""Copy input data structures to remove side-effects when input data
structures are shared with other modules.
The multi-step worker must be able to append tokens to sequences after
a forward pass. This necessitates modification of the data structures
used by the worker. Since these data structures are shared with other
parts of vLLM, like the scheduler, we must take care not to introduce
unexpected side-effects.
When Ray is used to orchestrate worker processes (such as when the
tensor-parallel degree is >1), this is not a problem because the input
datastructures will be serialized and created anew in the worker
process.
However, when Ray is not used to orchestrate the worker processes (such
as when the tensor-parallel degree is 1), this is a problem. We avoid
the problem by shallow-copying the input datastructures (specifically,
the parts that will change in multiple steps).
"""
# Shallow-copy the list of SequenceGroupMetadata. This allows us to
# append tokens and change is_prompt without external side-effects.
new_seq_group_metadata_list = []
for old_seq_group_metadata in seq_group_metadata_list:
# We must shallow-copy seq_group_metadata as is_prompt could change.
seq_group_metadata = copy.copy(old_seq_group_metadata)
new_seq_group_metadata_list.append(seq_group_metadata)
# We must shallow-copy seq_data as we will append token ids
new_seq_data = {}
for seq_id, old_seq_data in seq_group_metadata.seq_data.items():
new_seq_data[seq_id] = copy.copy(old_seq_data)
new_seq_data[
seq_id].output_token_ids = old_seq_data.output_token_ids[:]
seq_group_metadata.seq_data = new_seq_data
return new_seq_group_metadata_list
def _assert_enough_kv_space(
self, seq_group_metadata_list: List[SequenceGroupMetadata],
num_steps: int) -> None:
"""Assert there are enough physical blocks per sequence to store the
current KV plus additional KV from num_steps tokens.
"""
assert self.model_runner.block_size is not None
for seq_group_metadata in seq_group_metadata_list:
# Only one seq_id is guaranteed because there is no beam search.
seq_id = list(seq_group_metadata.seq_data.keys())[0]
seq = seq_group_metadata.seq_data[seq_id]
# After num_steps, the seq len will be the current seq len
# plus one token per step.
final_seq_len = seq.get_len() + num_steps
# We will have final_seq_len - 1 KV because vLLM saves KV for a
# token in the iteration after the token was generated.
required_num_kv_slots = final_seq_len - 1
# The allocated number of kv slots is the number of allocated blocks
# times the number of slots of block.
number_physical_blocks = len(
seq_group_metadata.block_tables[seq_id])
allocated_kv_slots = (number_physical_blocks *
self.model_runner.block_size)
if required_num_kv_slots > allocated_kv_slots:
request_id = seq_group_metadata.request_id
raise ValueError(
"The worker attempted to run "
f"{num_steps} times but found insufficient KV space for "
f"{request_id=} {seq_id=}. ({allocated_kv_slots=} "
f"{required_num_kv_slots=}).")
def _raise_if_unsupported(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
) -> None:
"""MultiStepWorker does not yet implement support for cache swap
operations or beam search.
"""
if any([blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy]):
raise NotImplementedError(
"MultiStepWorker does not support cache operations")
if any(
len(seq_group_metadata.seq_data.keys()) != 1
for seq_group_metadata in seq_group_metadata_list):
raise NotImplementedError(
"MultiStepWorker does not support beam search.")

View File

@ -130,8 +130,8 @@ class Worker:
# GPU did not change their memory usage during the profiling.
peak_memory = self.init_gpu_memory - free_gpu_memory
cache_block_size = CacheEngine.get_cache_block_size(
block_size, cache_dtype, self.model_config, self.parallel_config)
cache_block_size = self.get_cache_block_size_bytes(
block_size, cache_dtype)
num_gpu_blocks = int(
(total_gpu_memory * gpu_memory_utilization - peak_memory) //
cache_block_size)
@ -232,6 +232,22 @@ class Worker:
def list_loras(self) -> Set[int]:
return self.model_runner.list_loras()
@property
def max_model_len(self) -> int:
return self.model_config.max_model_len
@property
def vocab_size(self) -> int:
return self.model_runner.vocab_size
def get_cache_block_size_bytes(self, block_size: int,
cache_dtype: str) -> int:
"""Get the size of the KV cache block size in bytes.
"""
return CacheEngine.get_cache_block_size(block_size, cache_dtype,
self.model_config,
self.parallel_config)
def init_distributed_environment(
parallel_config: ParallelConfig,