[Speculative decoding 3/9] Worker which speculates, scores, and applies rejection sampling (#3103)
This commit is contained in:
parent
f48c6791b7
commit
8437bae6ef
@ -28,7 +28,7 @@ steps:
|
||||
num_gpus: 2 # only support 1 or 2 for now.
|
||||
|
||||
- label: Engine Test
|
||||
command: pytest -v -s engine
|
||||
command: pytest -v -s engine test_sequence.py
|
||||
|
||||
- label: Entrypoints Test
|
||||
command: pytest -v -s entrypoints
|
||||
@ -52,6 +52,9 @@ steps:
|
||||
- label: Worker Test
|
||||
command: pytest -v -s worker
|
||||
|
||||
- label: Speculative decoding tests
|
||||
command: pytest -v -s spec_decode
|
||||
|
||||
- label: LoRA Test
|
||||
command: pytest -v -s lora --forked
|
||||
|
||||
|
95
tests/spec_decode/test_batch_expansion.py
Normal file
95
tests/spec_decode/test_batch_expansion.py
Normal file
@ -0,0 +1,95 @@
|
||||
import torch
|
||||
import pytest
|
||||
|
||||
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
|
||||
|
||||
from .utils import mock_worker, create_seq_group_metadata_from_prompts
|
||||
|
||||
|
||||
@pytest.mark.parametrize('num_target_seq_ids', [100])
|
||||
def test_create_target_seq_id_iterator(num_target_seq_ids: int):
|
||||
"""Verify all new sequence ids are greater than all input
|
||||
seq ids.
|
||||
"""
|
||||
scorer = BatchExpansionTop1Scorer(mock_worker(), 'cuda:0', 32_000)
|
||||
|
||||
all_seq_ids = [
|
||||
[1, 3, 5, 7],
|
||||
list(range(100)) + [0],
|
||||
[100],
|
||||
]
|
||||
|
||||
for seq_ids in all_seq_ids:
|
||||
max_seq_id = max(seq_ids)
|
||||
iterator = scorer._create_target_seq_id_iterator(seq_ids) # pylint: disable=protected-access
|
||||
for _ in range(num_target_seq_ids):
|
||||
assert next(iterator) > max_seq_id
|
||||
|
||||
|
||||
@pytest.mark.parametrize('k', [1, 2, 6])
|
||||
def test_get_token_ids_to_score(k: int):
|
||||
"""Verify correct tokens are selected for scoring.
|
||||
"""
|
||||
proposal_token_ids = torch.tensor(
|
||||
list(range(k)),
|
||||
dtype=torch.int64,
|
||||
device='cuda',
|
||||
)
|
||||
|
||||
expected_output = [
|
||||
[],
|
||||
]
|
||||
for i in range(proposal_token_ids.shape[0]):
|
||||
expected_output.append(proposal_token_ids[:i + 1].tolist())
|
||||
|
||||
scorer = BatchExpansionTop1Scorer(mock_worker(), 'cuda:0', 32_000)
|
||||
actual_output = scorer._get_token_ids_to_score(proposal_token_ids) # pylint: disable=protected-access
|
||||
|
||||
actual_output = [
|
||||
x.tolist() if isinstance(x, torch.Tensor) else x for x in actual_output
|
||||
]
|
||||
|
||||
assert actual_output == expected_output
|
||||
|
||||
|
||||
@pytest.mark.parametrize('k', [1, 2, 6])
|
||||
def test_create_single_target_seq_group_metadata(k: int):
|
||||
"""Verify correct creation of a batch-expanded seq group metadata.
|
||||
"""
|
||||
|
||||
prompt_tokens = [1, 2, 3]
|
||||
prev_output_tokens = [4, 5, 6]
|
||||
|
||||
token_ids = list(range(k))
|
||||
|
||||
num_tokens_processed = len(prompt_tokens) + len(prev_output_tokens) - 1
|
||||
|
||||
final_seq_len = len(prompt_tokens) + len(prev_output_tokens) + len(
|
||||
token_ids)
|
||||
|
||||
block_size = 32
|
||||
input_seq_group_metadata = create_seq_group_metadata_from_prompts(
|
||||
[prompt_tokens], 2048 // block_size, block_size, [final_seq_len],
|
||||
[prev_output_tokens], [num_tokens_processed])[0]
|
||||
|
||||
input_seq_id = list(input_seq_group_metadata.seq_data.keys())[0]
|
||||
target_seq_id = 100
|
||||
|
||||
scorer = BatchExpansionTop1Scorer(mock_worker(), 'cuda:0', 32_000)
|
||||
output = scorer._create_single_target_seq_group_metadata( # pylint: disable=protected-access
|
||||
input_seq_group_metadata,
|
||||
input_seq_id,
|
||||
target_seq_id,
|
||||
token_ids,
|
||||
)
|
||||
|
||||
assert output.request_id == input_seq_group_metadata.request_id
|
||||
assert len(output.seq_data) == 1
|
||||
assert output.seq_data[target_seq_id].get_prompt_token_ids(
|
||||
) == prompt_tokens
|
||||
assert output.seq_data[target_seq_id].get_output_token_ids(
|
||||
) == prev_output_tokens + token_ids
|
||||
|
||||
assert len(output.block_tables) == 1
|
||||
assert output.block_tables[
|
||||
target_seq_id] == input_seq_group_metadata.block_tables[input_seq_id]
|
157
tests/spec_decode/test_metrics.py
Normal file
157
tests/spec_decode/test_metrics.py
Normal file
@ -0,0 +1,157 @@
|
||||
import torch
|
||||
import math
|
||||
import pytest
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from vllm.spec_decode.metrics import AsyncMetricsCollector
|
||||
|
||||
|
||||
def test_initial_call_returns_none():
|
||||
"""Expect first call to get metrics to return None.
|
||||
"""
|
||||
rej_sampler = MagicMock()
|
||||
rej_sampler.num_accepted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
rej_sampler.num_emitted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
rej_sampler.num_draft_tokens = 0
|
||||
|
||||
collector = AsyncMetricsCollector(rej_sampler)
|
||||
collector.init_gpu_tensors(rank=0)
|
||||
maybe_metrics = collector.maybe_collect_rejsample_metrics(k=5)
|
||||
assert maybe_metrics is None
|
||||
|
||||
|
||||
def test_second_call_returns_metrics():
|
||||
"""Expect second call to not return None.
|
||||
"""
|
||||
rej_sampler = MagicMock()
|
||||
rej_sampler.num_accepted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
rej_sampler.num_emitted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
rej_sampler.num_draft_tokens = 0
|
||||
|
||||
collect_interval_s = 5.0
|
||||
timer = MagicMock()
|
||||
timer.side_effect = [
|
||||
0.0, collect_interval_s + 0.1, collect_interval_s + 0.2
|
||||
]
|
||||
|
||||
collector = AsyncMetricsCollector(rejection_sampler=rej_sampler,
|
||||
timer=timer,
|
||||
collect_interval_s=collect_interval_s)
|
||||
collector.init_gpu_tensors(rank=0)
|
||||
_ = collector.maybe_collect_rejsample_metrics(k=5)
|
||||
metrics = collector.maybe_collect_rejsample_metrics(k=5)
|
||||
assert metrics is not None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("rank", [1, 2, 3, 4])
|
||||
def test_nonzero_rank_noop(rank):
|
||||
"""Verify nonzero ranks don't collect metrics.
|
||||
"""
|
||||
rej_sampler = MagicMock()
|
||||
rej_sampler.num_accepted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
rej_sampler.num_emitted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
rej_sampler.num_draft_tokens = 0
|
||||
|
||||
collector = AsyncMetricsCollector(rej_sampler)
|
||||
collector.init_gpu_tensors(rank=rank)
|
||||
_ = collector.maybe_collect_rejsample_metrics(k=5)
|
||||
metrics = collector.maybe_collect_rejsample_metrics(k=5)
|
||||
assert metrics is None
|
||||
|
||||
|
||||
def test_noop_until_time():
|
||||
"""Verify metrics aren't collected until enough time passes.
|
||||
"""
|
||||
rej_sampler = MagicMock()
|
||||
rej_sampler.num_accepted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
rej_sampler.num_emitted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
rej_sampler.num_draft_tokens = 0
|
||||
|
||||
collect_interval_s = 5.0
|
||||
timer = MagicMock()
|
||||
timer.side_effect = [
|
||||
0.0, collect_interval_s - 0.1, collect_interval_s - 0.1,
|
||||
collect_interval_s + 0.1, collect_interval_s + 0.1
|
||||
]
|
||||
|
||||
collector = AsyncMetricsCollector(rejection_sampler=rej_sampler,
|
||||
timer=timer,
|
||||
collect_interval_s=collect_interval_s)
|
||||
collector.init_gpu_tensors(rank=0)
|
||||
|
||||
_ = collector.maybe_collect_rejsample_metrics(k=5)
|
||||
metrics = collector.maybe_collect_rejsample_metrics(k=5)
|
||||
assert metrics is None
|
||||
|
||||
_ = collector.maybe_collect_rejsample_metrics(k=5)
|
||||
metrics = collector.maybe_collect_rejsample_metrics(k=5)
|
||||
assert metrics is not None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("has_data", [True, False])
|
||||
def test_initial_metrics_has_correct_values(has_data: bool):
|
||||
"""Test correctness of metrics data.
|
||||
"""
|
||||
if has_data:
|
||||
num_accepted_tokens = 103
|
||||
num_emitted_tokens = 104
|
||||
num_draft_tokens = 105
|
||||
else:
|
||||
num_accepted_tokens = 0
|
||||
num_emitted_tokens = 0
|
||||
num_draft_tokens = 0
|
||||
k = 5
|
||||
|
||||
num_possible_tokens = AsyncMetricsCollector.get_max_num_accepted_tokens(
|
||||
num_draft_tokens, k)
|
||||
|
||||
rej_sampler = MagicMock()
|
||||
rej_sampler.num_accepted_tokens = torch.tensor(num_accepted_tokens,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
rej_sampler.num_emitted_tokens = torch.tensor(num_emitted_tokens,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
rej_sampler.num_draft_tokens = num_draft_tokens
|
||||
|
||||
collect_interval_s = 5.0
|
||||
timer = MagicMock()
|
||||
timer.side_effect = [
|
||||
0.0, collect_interval_s + 0.1, collect_interval_s + 0.2
|
||||
]
|
||||
|
||||
collector = AsyncMetricsCollector(rejection_sampler=rej_sampler,
|
||||
timer=timer,
|
||||
collect_interval_s=collect_interval_s)
|
||||
collector.init_gpu_tensors(rank=0)
|
||||
_ = collector.maybe_collect_rejsample_metrics(k)
|
||||
metrics = collector.maybe_collect_rejsample_metrics(k)
|
||||
|
||||
assert metrics.num_spec_tokens == k
|
||||
assert metrics.accepted_tokens == num_accepted_tokens
|
||||
assert metrics.draft_tokens == num_draft_tokens
|
||||
assert metrics.emitted_tokens == num_emitted_tokens
|
||||
|
||||
if has_data:
|
||||
assert metrics.draft_acceptance_rate == num_accepted_tokens / num_draft_tokens
|
||||
assert metrics.system_efficiency == num_emitted_tokens / num_possible_tokens
|
||||
else:
|
||||
assert math.isnan(metrics.draft_acceptance_rate)
|
||||
assert math.isnan(metrics.system_efficiency)
|
@ -3,14 +3,15 @@ import random
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from vllm.worker.spec_decode.multi_step_worker import MultiStepWorker
|
||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker, DraftModelTop1Proposer
|
||||
from vllm.worker.worker import Worker
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
from .utils import (create_execute_model_data, create_worker,
|
||||
create_seq_group_metadata_from_prompts, zero_kv_cache,
|
||||
patch_execute_model_with_seeds,
|
||||
assert_logprobs_dict_allclose)
|
||||
assert_logprobs_dict_allclose, create_batch)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('num_steps', list(range(1, 17)))
|
||||
@ -259,3 +260,160 @@ def test_same_output_for_multi_step():
|
||||
multi_step_output_logprobs, single_step_output_logprobs):
|
||||
assert_logprobs_dict_allclose(multi_step_logprobs,
|
||||
single_step_logprobs)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_draft_proposals_full_speculation_len():
|
||||
"""Verify DraftModelTop1Proposer correctly handles case where all sequences
|
||||
can speculate.
|
||||
"""
|
||||
k = 10
|
||||
batch_size = 32
|
||||
vocab_size = 32_000
|
||||
device = 'cuda:0'
|
||||
|
||||
draft_worker = MagicMock()
|
||||
proposer = DraftModelTop1Proposer(
|
||||
draft_worker=draft_worker,
|
||||
device=device,
|
||||
max_model_len=2048,
|
||||
vocab_size=vocab_size,
|
||||
)
|
||||
draft_worker.execute_model_multi_step.return_value = [
|
||||
SamplerOutput(
|
||||
outputs=[],
|
||||
sampled_token_probs=torch.rand(batch_size,
|
||||
vocab_size,
|
||||
device=device,
|
||||
dtype=torch.float32),
|
||||
sampled_token_ids=torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, ),
|
||||
device=device,
|
||||
dtype=torch.long),
|
||||
) for _ in range(k)
|
||||
]
|
||||
|
||||
execute_model_data, _, _ = create_batch(batch_size, k)
|
||||
|
||||
proposals = proposer.get_proposals(
|
||||
**execute_model_data.to_dict(),
|
||||
max_proposal_len=k,
|
||||
)
|
||||
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
assert torch.is_tensor(proposals.proposal_probs)
|
||||
|
||||
assert proposals.proposal_token_ids.shape == torch.Size([batch_size, k])
|
||||
assert proposals.proposal_probs.shape[:-1] == torch.Size([batch_size, k])
|
||||
|
||||
assert proposals.proposal_lens.shape == torch.Size([batch_size])
|
||||
assert proposals.proposal_lens.tolist() == [k for _ in range(batch_size)]
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_draft_proposals_no_speculations():
|
||||
"""Verify DraftModelTop1Proposer correctly handles case where no sequences
|
||||
can speculate.
|
||||
"""
|
||||
k = 10
|
||||
batch_size = 32
|
||||
vocab_size = 32_000
|
||||
device = 'cuda:0'
|
||||
prompt_len = 10
|
||||
|
||||
draft_worker = MagicMock()
|
||||
proposer = DraftModelTop1Proposer(
|
||||
draft_worker=draft_worker,
|
||||
device=device,
|
||||
max_model_len=prompt_len + k - 1,
|
||||
vocab_size=vocab_size,
|
||||
)
|
||||
|
||||
execute_model_data, _, _ = create_batch(batch_size,
|
||||
k,
|
||||
prompt_len=prompt_len)
|
||||
|
||||
proposals = proposer.get_proposals(
|
||||
**execute_model_data.to_dict(),
|
||||
max_proposal_len=k,
|
||||
)
|
||||
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
assert torch.is_tensor(proposals.proposal_probs)
|
||||
|
||||
assert proposals.proposal_token_ids.shape == torch.Size([0, k])
|
||||
assert proposals.proposal_probs.shape[:-1] == torch.Size([0, k])
|
||||
|
||||
assert proposals.proposal_lens.shape == torch.Size([batch_size])
|
||||
assert proposals.proposal_lens.tolist() == [0 for _ in range(batch_size)]
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_draft_proposals_mixed_k():
|
||||
"""Verify DraftModelTop1Proposer correctly handles case some sequences can
|
||||
speculate and some can't.
|
||||
"""
|
||||
k = 10
|
||||
batch_size = 32
|
||||
vocab_size = 32_000
|
||||
device = 'cuda:0'
|
||||
|
||||
small_prompt_len = 5
|
||||
long_prompt_len = 10
|
||||
prev_output_token_len = 20
|
||||
|
||||
expected_num_proposal_seqs = 6
|
||||
expected_num_no_proposal_seqs = batch_size - expected_num_proposal_seqs
|
||||
|
||||
prompt_len = [
|
||||
small_prompt_len for _ in range(expected_num_proposal_seqs - 1)
|
||||
] + [long_prompt_len
|
||||
for _ in range(expected_num_no_proposal_seqs)] + [small_prompt_len]
|
||||
|
||||
draft_worker = MagicMock()
|
||||
proposer = DraftModelTop1Proposer(
|
||||
draft_worker=draft_worker,
|
||||
device=device,
|
||||
max_model_len=long_prompt_len + prev_output_token_len + k - 1,
|
||||
vocab_size=vocab_size,
|
||||
)
|
||||
|
||||
draft_worker.execute_model_multi_step.return_value = [
|
||||
SamplerOutput(
|
||||
outputs=[],
|
||||
sampled_token_probs=torch.rand(expected_num_proposal_seqs,
|
||||
vocab_size,
|
||||
device=device,
|
||||
dtype=torch.float32),
|
||||
sampled_token_ids=torch.randint(
|
||||
low=0,
|
||||
high=vocab_size,
|
||||
size=(expected_num_proposal_seqs, ),
|
||||
device=device,
|
||||
dtype=torch.long),
|
||||
) for _ in range(k)
|
||||
]
|
||||
|
||||
execute_model_data, _, _ = create_batch(
|
||||
batch_size,
|
||||
k,
|
||||
prompt_len=prompt_len,
|
||||
prev_output_token_len=prev_output_token_len,
|
||||
)
|
||||
|
||||
proposals = proposer.get_proposals(
|
||||
**execute_model_data.to_dict(),
|
||||
max_proposal_len=k,
|
||||
)
|
||||
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
assert torch.is_tensor(proposals.proposal_probs)
|
||||
|
||||
assert proposals.proposal_token_ids.shape == torch.Size([batch_size, k])
|
||||
assert proposals.proposal_probs.shape[:-1] == torch.Size([batch_size, k])
|
||||
|
||||
assert proposals.proposal_lens.shape == torch.Size([batch_size])
|
||||
assert proposals.proposal_lens.tolist() == [
|
||||
k for _ in range(expected_num_proposal_seqs - 1)
|
||||
] + [0 for _ in range(expected_num_no_proposal_seqs)] + [k]
|
591
tests/spec_decode/test_spec_decode_worker.py
Normal file
591
tests/spec_decode/test_spec_decode_worker.py
Normal file
@ -0,0 +1,591 @@
|
||||
import torch
|
||||
import random
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker, split_num_cache_blocks_evenly
|
||||
from vllm.spec_decode.interfaces import SpeculativeProposals
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||
from .utils import mock_worker, create_batch, ExecuteModelData, create_sampler_output_list
|
||||
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics, AsyncMetricsCollector
|
||||
|
||||
|
||||
@pytest.mark.parametrize('k', [1, 2, 6])
|
||||
@pytest.mark.parametrize('batch_size', [1, 2, 32])
|
||||
@torch.inference_mode()
|
||||
def test_correctly_calls_draft_model(k: int, batch_size: int):
|
||||
"""Verify SpecDecodeWorker calls the draft worker with correct
|
||||
inputs. Everything else is mocked out.
|
||||
"""
|
||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||
target_worker = mock_worker()
|
||||
rejection_sampler = MagicMock(spec=RejectionSampler)
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
||||
metrics_collector)
|
||||
|
||||
exception_secret = 'artifical stop'
|
||||
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
|
||||
|
||||
execute_model_data, _, _ = create_batch(batch_size, k)
|
||||
|
||||
with pytest.raises(ValueError, match=exception_secret):
|
||||
worker.execute_model(**execute_model_data.to_dict(), num_spec_tokens=k)
|
||||
|
||||
call_args_list = draft_worker.get_spec_proposals.call_args_list
|
||||
assert len(call_args_list) == 1
|
||||
|
||||
for args, _ in call_args_list:
|
||||
(seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out,
|
||||
blocks_to_copy, actual_k) = args
|
||||
actual_execute_model_data = ExecuteModelData(seq_group_metadata_list,
|
||||
blocks_to_swap_in,
|
||||
blocks_to_swap_out,
|
||||
blocks_to_copy)
|
||||
assert actual_execute_model_data == execute_model_data
|
||||
assert actual_k == k
|
||||
|
||||
|
||||
@pytest.mark.parametrize('k', [1, 2, 6])
|
||||
@pytest.mark.parametrize('batch_size', [1, 2, 32])
|
||||
@torch.inference_mode()
|
||||
def test_correctly_calls_target_model(k: int, batch_size: int):
|
||||
"""Verify SpecDecodeWorker calls the target model with correct
|
||||
inputs. Everything else is mocked out.
|
||||
"""
|
||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||
target_worker = mock_worker()
|
||||
rejection_sampler = MagicMock(spec=RejectionSampler)
|
||||
rejection_sampler.token_id_dtype = torch.int64
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
|
||||
draft_worker.device = 'cuda'
|
||||
target_worker.device = 'cuda'
|
||||
|
||||
set_random_seed(1)
|
||||
|
||||
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
||||
metrics_collector)
|
||||
worker.init_model()
|
||||
|
||||
vocab_size = 32_000
|
||||
|
||||
proposal_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k),
|
||||
dtype=torch.int64,
|
||||
device='cuda')
|
||||
proposal_probs = torch.rand(batch_size,
|
||||
k,
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='cuda')
|
||||
proposal_lens = torch.ones(batch_size, dtype=torch.int64,
|
||||
device='cuda') * k
|
||||
|
||||
execute_model_data, prompts, prev_output_tokens = create_batch(
|
||||
batch_size, k)
|
||||
|
||||
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
|
||||
proposal_token_ids=proposal_token_ids,
|
||||
proposal_probs=proposal_probs,
|
||||
proposal_lens=proposal_lens)
|
||||
|
||||
exception_secret = 'artifical stop'
|
||||
target_worker.execute_model.side_effect = ValueError(exception_secret)
|
||||
|
||||
with pytest.raises(ValueError, match=exception_secret):
|
||||
worker.execute_model(**execute_model_data.to_dict(), num_spec_tokens=k)
|
||||
|
||||
seen_contexts = []
|
||||
|
||||
call_args_list = target_worker.execute_model.call_args_list
|
||||
assert len(call_args_list) == 1
|
||||
for args, kwargs in call_args_list:
|
||||
target_execute_model_data = ExecuteModelData.from_dict(kwargs)
|
||||
|
||||
assert len(target_execute_model_data.seq_group_metadata_list) == (
|
||||
k + 1) * batch_size
|
||||
for seq_group_metadata in (
|
||||
target_execute_model_data.seq_group_metadata_list):
|
||||
for seq_data in seq_group_metadata.seq_data.values():
|
||||
seen_contexts.append(seq_data.get_token_ids())
|
||||
|
||||
expected_seen_contexts = []
|
||||
|
||||
for prompt, prev_generated, draft_tokens in zip(
|
||||
prompts, prev_output_tokens, proposal_token_ids.tolist()):
|
||||
|
||||
for i in range(len(draft_tokens) + 1):
|
||||
expected_seen_contexts.append(prompt + prev_generated +
|
||||
draft_tokens[:i])
|
||||
|
||||
seen_contexts.sort()
|
||||
expected_seen_contexts.sort()
|
||||
assert expected_seen_contexts == seen_contexts
|
||||
|
||||
|
||||
@pytest.mark.parametrize('k', [1, 2, 6])
|
||||
@pytest.mark.parametrize('batch_size', [1, 2, 32])
|
||||
@torch.inference_mode()
|
||||
def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
|
||||
"""Verify SpecDecodeWorker calls the rejection sampler with
|
||||
correct inputs. Everything else is mocked out.
|
||||
"""
|
||||
vocab_size = 32_000
|
||||
|
||||
draft_worker = mock_worker(cls=MultiStepWorker, vocab_size=vocab_size)
|
||||
target_worker = mock_worker(vocab_size=vocab_size)
|
||||
rejection_sampler = MagicMock(spec=RejectionSampler)
|
||||
rejection_sampler.token_id_dtype = torch.int64
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
draft_worker.device = 'cuda'
|
||||
target_worker.device = 'cuda'
|
||||
|
||||
set_random_seed(1)
|
||||
|
||||
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
||||
metrics_collector)
|
||||
worker.init_model()
|
||||
|
||||
proposal_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k),
|
||||
dtype=torch.int64,
|
||||
device='cuda')
|
||||
proposal_probs = torch.rand(batch_size,
|
||||
k,
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='cuda')
|
||||
|
||||
proposal_lens = torch.ones(batch_size, dtype=torch.int64,
|
||||
device='cuda') * k
|
||||
|
||||
execute_model_data, _, _ = create_batch(batch_size, k)
|
||||
|
||||
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
|
||||
proposal_token_ids=proposal_token_ids,
|
||||
proposal_probs=proposal_probs,
|
||||
proposal_lens=proposal_lens)
|
||||
|
||||
target_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(1, batch_size * (k + 1)),
|
||||
dtype=torch.int64,
|
||||
device='cuda')
|
||||
target_token_probs = torch.rand(1,
|
||||
batch_size * (k + 1),
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='cuda')
|
||||
target_output = create_sampler_output_list(target_token_ids,
|
||||
target_token_probs)
|
||||
|
||||
target_worker.execute_model.return_value = target_output[0]
|
||||
|
||||
exception_secret = 'artifical stop'
|
||||
rejection_sampler.side_effect = ValueError(exception_secret)
|
||||
|
||||
with pytest.raises(ValueError, match=exception_secret):
|
||||
worker.execute_model(**execute_model_data.to_dict(), num_spec_tokens=k)
|
||||
|
||||
assert len(rejection_sampler.call_args_list) == 1
|
||||
args, _ = rejection_sampler.call_args_list[0]
|
||||
(actual_proposal_scores, actual_bonus_token_ids, actual_proposal_probs,
|
||||
actual_proposal_token_ids) = args
|
||||
|
||||
assert torch.equal(actual_bonus_token_ids,
|
||||
target_token_ids.reshape(batch_size, k + 1)[:, -1:])
|
||||
assert torch.equal(
|
||||
actual_proposal_scores,
|
||||
target_token_probs.reshape(batch_size, k + 1, -1)[:, :-1])
|
||||
assert torch.equal(actual_proposal_token_ids, proposal_token_ids)
|
||||
assert torch.equal(actual_proposal_probs, proposal_probs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('k', [1, 2, 6])
|
||||
@pytest.mark.parametrize('batch_size', [1, 2, 32])
|
||||
@torch.inference_mode()
|
||||
def test_correctly_formats_output(k: int, batch_size: int):
|
||||
"""Verify SpecDecodeWorker formats sampler output correctly.
|
||||
Everything else is mocked out.
|
||||
"""
|
||||
vocab_size = 32_000
|
||||
|
||||
draft_worker = mock_worker(cls=MultiStepWorker, vocab_size=vocab_size)
|
||||
target_worker = mock_worker(vocab_size=vocab_size)
|
||||
rejection_sampler = MagicMock(spec=RejectionSampler)
|
||||
rejection_sampler.token_id_dtype = torch.int64
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
draft_worker.device = 'cuda'
|
||||
target_worker.device = 'cuda'
|
||||
|
||||
set_random_seed(1)
|
||||
|
||||
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
||||
metrics_collector)
|
||||
worker.init_model()
|
||||
|
||||
proposal_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k),
|
||||
dtype=torch.int64,
|
||||
device='cuda')
|
||||
proposal_probs = torch.rand(batch_size,
|
||||
k,
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='cuda')
|
||||
|
||||
proposal_lens = torch.ones(batch_size, dtype=torch.int64,
|
||||
device='cuda') * k
|
||||
|
||||
execute_model_data, _, _ = create_batch(batch_size, k)
|
||||
|
||||
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
|
||||
proposal_token_ids=proposal_token_ids,
|
||||
proposal_probs=proposal_probs,
|
||||
proposal_lens=proposal_lens)
|
||||
|
||||
target_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(1, batch_size * (k + 1)),
|
||||
dtype=torch.int64,
|
||||
device='cuda')
|
||||
target_token_probs = torch.rand(1,
|
||||
batch_size * (k + 1),
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='cuda')
|
||||
target_output = create_sampler_output_list(target_token_ids,
|
||||
target_token_probs)
|
||||
|
||||
target_worker.execute_model.return_value = target_output[0]
|
||||
|
||||
rejection_sampler_output = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k + 1),
|
||||
dtype=torch.int64,
|
||||
device='cuda')
|
||||
for i in range(batch_size):
|
||||
minimum_accepted_tokens = 1
|
||||
rejection_sampler_output[i][
|
||||
-random.randint(minimum_accepted_tokens, k + 1):] = -1
|
||||
|
||||
rejection_sampler.return_value = rejection_sampler_output
|
||||
|
||||
output = worker.execute_model(**execute_model_data.to_dict(),
|
||||
num_spec_tokens=k)
|
||||
|
||||
expected_output = create_sampler_output_list(
|
||||
rejection_sampler_output.transpose(0, 1), [None for _ in range(k + 1)])
|
||||
|
||||
seq_ids = [
|
||||
next(iter(seq_group_metadata.seq_data.keys()))
|
||||
for seq_group_metadata in execute_model_data.seq_group_metadata_list
|
||||
]
|
||||
actual_output_by_seq = {seq_id: [] for seq_id in seq_ids}
|
||||
expected_output_by_seq = {seq_id: [] for seq_id in seq_ids}
|
||||
|
||||
for step in output:
|
||||
for seq_group in step:
|
||||
for sample in seq_group.samples:
|
||||
seq_id = sample.parent_seq_id
|
||||
actual_output_by_seq[seq_id].append(sample)
|
||||
|
||||
for step in expected_output:
|
||||
for seq_group in step:
|
||||
for sample in seq_group.samples:
|
||||
seq_id = sample.parent_seq_id
|
||||
expected_output_by_seq[seq_id].append(sample)
|
||||
|
||||
all_seen_seq_ids = set(
|
||||
list(actual_output_by_seq.keys()) +
|
||||
list(expected_output_by_seq.keys()))
|
||||
for seq_id in all_seen_seq_ids:
|
||||
actual_by_step = actual_output_by_seq[seq_id]
|
||||
expected_by_step = expected_output_by_seq[seq_id]
|
||||
|
||||
for i in range(k + 1):
|
||||
if i >= len(actual_by_step):
|
||||
assert expected_by_step[i].output_token == -1
|
||||
continue
|
||||
assert actual_by_step[i].output_token == expected_by_step[
|
||||
i].output_token
|
||||
assert actual_by_step[i].logprobs == expected_by_step[i].logprobs
|
||||
|
||||
|
||||
@pytest.mark.parametrize('k', [1, 2])
|
||||
@pytest.mark.parametrize('batch_size', [1])
|
||||
@pytest.mark.parametrize('returns_metrics', [True, False])
|
||||
@torch.inference_mode()
|
||||
def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
|
||||
"""Verify SpecDecodeWorker collects metrics.
|
||||
"""
|
||||
vocab_size = 32_000
|
||||
|
||||
draft_worker = mock_worker(cls=MultiStepWorker, vocab_size=vocab_size)
|
||||
target_worker = mock_worker(vocab_size=vocab_size)
|
||||
rejection_sampler = MagicMock(spec=RejectionSampler)
|
||||
rejection_sampler.token_id_dtype = torch.int64
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
draft_worker.device = 'cuda'
|
||||
target_worker.device = 'cuda'
|
||||
|
||||
set_random_seed(1)
|
||||
|
||||
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
||||
metrics_collector)
|
||||
worker.init_model()
|
||||
|
||||
proposal_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k),
|
||||
dtype=torch.int64,
|
||||
device='cuda')
|
||||
proposal_probs = torch.rand(batch_size,
|
||||
k,
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='cuda')
|
||||
|
||||
proposal_lens = torch.ones(batch_size, dtype=torch.int64,
|
||||
device='cuda') * k
|
||||
|
||||
execute_model_data, _, _ = create_batch(batch_size, k)
|
||||
|
||||
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
|
||||
proposal_token_ids=proposal_token_ids,
|
||||
proposal_probs=proposal_probs,
|
||||
proposal_lens=proposal_lens)
|
||||
|
||||
target_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(1, batch_size * (k + 1)),
|
||||
dtype=torch.int64,
|
||||
device='cuda')
|
||||
target_token_probs = torch.rand(1,
|
||||
batch_size * (k + 1),
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='cuda')
|
||||
target_output = create_sampler_output_list(target_token_ids,
|
||||
target_token_probs)
|
||||
|
||||
target_worker.execute_model.return_value = target_output[0]
|
||||
|
||||
rejection_sampler_output = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k + 1),
|
||||
dtype=torch.int64,
|
||||
device='cuda')
|
||||
for i in range(batch_size):
|
||||
minimum_accepted_tokens = 1
|
||||
rejection_sampler_output[i][
|
||||
-random.randint(minimum_accepted_tokens, k + 1):] = -1
|
||||
|
||||
rejection_sampler.return_value = rejection_sampler_output
|
||||
|
||||
mock_rejsample_metrics = MagicMock(
|
||||
spec=SpecDecodeWorkerMetrics) if returns_metrics else None
|
||||
metrics_collector.maybe_collect_rejsample_metrics.return_value = mock_rejsample_metrics
|
||||
|
||||
output = worker.execute_model(**execute_model_data.to_dict(),
|
||||
num_spec_tokens=k)
|
||||
assert output[0].spec_decode_worker_metrics == mock_rejsample_metrics
|
||||
|
||||
call_args_list = metrics_collector.maybe_collect_rejsample_metrics.call_args_list
|
||||
assert len(call_args_list) == 1
|
||||
args, kwargs = call_args_list[0]
|
||||
assert args[0] == k or kwargs.get('k', -1) == k
|
||||
|
||||
|
||||
@pytest.mark.parametrize('k', [0])
|
||||
@pytest.mark.parametrize('batch_size', [1, 2, 32])
|
||||
@torch.inference_mode()
|
||||
def test_k_equals_zero(k: int, batch_size: int):
|
||||
"""Verify that the SpecDecodeWorker calls the draft and target workers
|
||||
when k is zero. This happens during prefill.
|
||||
"""
|
||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||
target_worker = mock_worker()
|
||||
rejection_sampler = MagicMock(spec=RejectionSampler)
|
||||
rejection_sampler.token_id_dtype = torch.int64
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
|
||||
draft_worker.device = 'cuda'
|
||||
target_worker.device = 'cuda'
|
||||
|
||||
set_random_seed(1)
|
||||
|
||||
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
||||
metrics_collector)
|
||||
|
||||
execute_model_data, prompts, prev_output_tokens = create_batch(
|
||||
batch_size, k, prev_output_token_len=0)
|
||||
|
||||
out = worker.execute_model(**execute_model_data.to_dict(),
|
||||
num_spec_tokens=k)
|
||||
|
||||
assert len(out) == 1, f"expected only one token output when {k=}"
|
||||
assert out[0].probs is None, "expect gpu tensor references to be None"
|
||||
assert out[
|
||||
0].sampled_tokens is None, "expect gpu tensor references to be None"
|
||||
|
||||
draft_worker.execute_model.assert_called_once_with(
|
||||
**execute_model_data.to_dict(), return_python_output=False)
|
||||
target_worker.execute_model.assert_called_once_with(
|
||||
**execute_model_data.to_dict())
|
||||
|
||||
|
||||
@pytest.mark.parametrize('k', [0, 5])
|
||||
@pytest.mark.parametrize('batch_size', [0])
|
||||
@torch.inference_mode()
|
||||
def test_empty_input_batch(k: int, batch_size: int):
|
||||
"""Verify that the SpecDecodeWorker calls the draft and target workers
|
||||
when the input batch is empty. This can happen if the engine communicates
|
||||
to the workers information without scheduling a batch.
|
||||
"""
|
||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||
target_worker = mock_worker()
|
||||
rejection_sampler = MagicMock(spec=RejectionSampler)
|
||||
rejection_sampler.token_id_dtype = torch.int64
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
|
||||
draft_worker.device = 'cuda'
|
||||
target_worker.device = 'cuda'
|
||||
|
||||
set_random_seed(1)
|
||||
|
||||
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
||||
metrics_collector)
|
||||
|
||||
execute_model_data, prompts, prev_output_tokens = create_batch(
|
||||
batch_size, k, prev_output_token_len=0)
|
||||
|
||||
out = worker.execute_model(**execute_model_data.to_dict(),
|
||||
num_spec_tokens=k)
|
||||
|
||||
assert len(out) == 1, f"expected only one token output when {k=}"
|
||||
assert out[0].probs is None, "expect gpu tensor references to be None"
|
||||
assert out[
|
||||
0].sampled_tokens is None, "expect gpu tensor references to be None"
|
||||
|
||||
draft_worker.execute_model.assert_called_once_with(
|
||||
**execute_model_data.to_dict(), return_python_output=False)
|
||||
target_worker.execute_model.assert_called_once_with(
|
||||
**execute_model_data.to_dict())
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_init_model():
|
||||
"""Verify SpecDecodeWorker invokes proposer/scorer worker init_model, as
|
||||
well as other GPU initialization.
|
||||
"""
|
||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||
target_worker = mock_worker()
|
||||
rejection_sampler = MagicMock(spec=RejectionSampler)
|
||||
rejection_sampler.token_id_dtype = torch.int64
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
|
||||
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
||||
metrics_collector)
|
||||
|
||||
worker.init_model()
|
||||
|
||||
draft_worker.init_model.assert_called_once()
|
||||
|
||||
target_worker.init_model.assert_called_once()
|
||||
|
||||
metrics_collector.init_gpu_tensors.assert_called_once()
|
||||
rejection_sampler.init_gpu_tensors.assert_called_once()
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_init_cache_engine():
|
||||
"""Verify SpecDecodeWorker invokes init_cache_engine on proposer/scorer
|
||||
workers.
|
||||
"""
|
||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||
target_worker = mock_worker()
|
||||
rejection_sampler = MagicMock(spec=RejectionSampler)
|
||||
rejection_sampler.token_id_dtype = torch.int64
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
|
||||
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
||||
metrics_collector)
|
||||
|
||||
cache_config = MagicMock()
|
||||
|
||||
worker.init_cache_engine(cache_config)
|
||||
|
||||
draft_worker.init_cache_engine.assert_called_once_with(cache_config)
|
||||
target_worker.init_cache_engine.assert_called_once_with(cache_config)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('available_gpu_blocks', [1, 1024])
|
||||
@pytest.mark.parametrize('available_cpu_blocks', [500])
|
||||
@pytest.mark.parametrize('target_cache_block_size_bytes', [2 * 2 * 4096])
|
||||
@pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096])
|
||||
@torch.inference_mode()
|
||||
def test_profile_num_available_blocks(available_gpu_blocks: int,
|
||||
available_cpu_blocks: int,
|
||||
target_cache_block_size_bytes: int,
|
||||
draft_kv_size_bytes: int):
|
||||
"""Verify SpecDecodeWorker correctly profiles num available GPU blocks.
|
||||
Specifically, it should run profiling in the scorer worker, and then evenly
|
||||
split the blocks between proposer and scorer worker.
|
||||
"""
|
||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||
target_worker = mock_worker()
|
||||
rejection_sampler = MagicMock(spec=RejectionSampler)
|
||||
rejection_sampler.token_id_dtype = torch.int64
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
|
||||
target_worker.profile_num_available_blocks.return_value = (
|
||||
available_gpu_blocks, available_cpu_blocks)
|
||||
target_worker.get_cache_block_size_bytes.return_value = target_cache_block_size_bytes
|
||||
draft_worker.get_cache_block_size_bytes.return_value = draft_kv_size_bytes
|
||||
|
||||
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
||||
metrics_collector)
|
||||
|
||||
# These values do not directly impact the adjusted block size calculation,
|
||||
# so they can be fixed.
|
||||
gpu_memory_utilization = 0.9
|
||||
cpu_swap_space = 100
|
||||
block_size = 16
|
||||
|
||||
num_gpu_blocks, num_cpu_blocks = worker.profile_num_available_blocks(
|
||||
block_size, gpu_memory_utilization, cpu_swap_space, cache_dtype="auto")
|
||||
|
||||
target_worker.profile_num_available_blocks.assert_called_once_with(
|
||||
block_size, gpu_memory_utilization, cpu_swap_space, "auto")
|
||||
assert num_cpu_blocks == available_cpu_blocks
|
||||
|
||||
assert num_gpu_blocks == split_num_cache_blocks_evenly(
|
||||
target_cache_block_size_bytes, draft_kv_size_bytes,
|
||||
available_gpu_blocks)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('available_gpu_blocks',
|
||||
list(range(20)) + [1024, 1024**2])
|
||||
@pytest.mark.parametrize('target_cache_block_size_bytes',
|
||||
[2 * 2 * 4096, 2 * 2 * 8192])
|
||||
@pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096])
|
||||
@torch.inference_mode()
|
||||
def test_split_num_cache_blocks_evenly(available_gpu_blocks: int,
|
||||
target_cache_block_size_bytes: int,
|
||||
draft_kv_size_bytes: int):
|
||||
"""Verify split_num_cache_blocks_evenly does not exceed original memory
|
||||
allocation in bytes.
|
||||
"""
|
||||
num_blocks = split_num_cache_blocks_evenly(target_cache_block_size_bytes,
|
||||
draft_kv_size_bytes,
|
||||
available_gpu_blocks)
|
||||
assert (num_blocks * target_cache_block_size_bytes) + (
|
||||
num_blocks * draft_kv_size_bytes) <= (available_gpu_blocks *
|
||||
target_cache_block_size_bytes)
|
111
tests/spec_decode/test_utils.py
Normal file
111
tests/spec_decode/test_utils.py
Normal file
@ -0,0 +1,111 @@
|
||||
from vllm.spec_decode.util import get_all_seq_ids
|
||||
from vllm.sequence import SequenceGroupMetadata
|
||||
from vllm.spec_decode.util import split_batch_by_proposal_len
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
|
||||
def test_get_all_seq_ids():
|
||||
"""Verify get_all_seq_ids extracts all seq ids.
|
||||
"""
|
||||
expected_seq_ids = list(range(10)) + list(range(100, 110))
|
||||
|
||||
seq_group_metadata_list = [
|
||||
SequenceGroupMetadata(
|
||||
request_id=str(seq_id),
|
||||
is_prompt=True,
|
||||
seq_data={
|
||||
seq_id: MagicMock(),
|
||||
},
|
||||
sampling_params=MagicMock(),
|
||||
block_tables={
|
||||
seq_id: MagicMock(),
|
||||
},
|
||||
lora_request=None,
|
||||
) for seq_id in expected_seq_ids
|
||||
]
|
||||
|
||||
actual_seq_ids = get_all_seq_ids(seq_group_metadata_list)
|
||||
assert actual_seq_ids == expected_seq_ids
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_sequence_group_metadata():
|
||||
seq_ids = list(range(3))
|
||||
return [
|
||||
SequenceGroupMetadata(
|
||||
request_id=str(i),
|
||||
is_prompt=True,
|
||||
seq_data={
|
||||
i: MagicMock(),
|
||||
},
|
||||
sampling_params=MagicMock(),
|
||||
block_tables={
|
||||
i: MagicMock(),
|
||||
},
|
||||
lora_request=None,
|
||||
) for i in seq_ids
|
||||
]
|
||||
|
||||
|
||||
def test_filter_zero_length_proposals(fake_sequence_group_metadata):
|
||||
proposal_lens = [0, 1, 0]
|
||||
filtered_groups, indices = split_batch_by_proposal_len(
|
||||
fake_sequence_group_metadata,
|
||||
proposal_lens,
|
||||
select_proposal_len_zero=True)
|
||||
|
||||
expected_groups = [
|
||||
fake_sequence_group_metadata[0], fake_sequence_group_metadata[2]
|
||||
]
|
||||
expected_indices = [0, 2]
|
||||
|
||||
assert filtered_groups == expected_groups
|
||||
assert indices == expected_indices
|
||||
|
||||
|
||||
def test_filter_non_zero_length_proposals(fake_sequence_group_metadata):
|
||||
proposal_lens = [0, 1, 2]
|
||||
filtered_groups, indices = split_batch_by_proposal_len(
|
||||
fake_sequence_group_metadata,
|
||||
proposal_lens,
|
||||
select_proposal_len_zero=False)
|
||||
|
||||
expected_groups = [
|
||||
fake_sequence_group_metadata[1], fake_sequence_group_metadata[2]
|
||||
]
|
||||
expected_indices = [1, 2]
|
||||
|
||||
assert filtered_groups == expected_groups
|
||||
assert indices == expected_indices
|
||||
|
||||
|
||||
def test_empty_inputs():
|
||||
filtered_groups, indices = split_batch_by_proposal_len(
|
||||
[], [], select_proposal_len_zero=True)
|
||||
|
||||
assert filtered_groups == []
|
||||
assert indices == []
|
||||
|
||||
|
||||
def test_all_zero_with_non_zero_filter(fake_sequence_group_metadata):
|
||||
proposal_lens = [0, 0, 0]
|
||||
filtered_groups, indices = split_batch_by_proposal_len(
|
||||
fake_sequence_group_metadata,
|
||||
proposal_lens,
|
||||
select_proposal_len_zero=False)
|
||||
|
||||
assert filtered_groups == []
|
||||
assert indices == []
|
||||
|
||||
|
||||
def test_all_non_zero_with_zero_filter(fake_sequence_group_metadata):
|
||||
proposal_lens = [1, 1, 1]
|
||||
filtered_groups, indices = split_batch_by_proposal_len(
|
||||
fake_sequence_group_metadata,
|
||||
proposal_lens,
|
||||
select_proposal_len_zero=True)
|
||||
|
||||
assert filtered_groups == []
|
||||
assert indices == []
|
@ -1,13 +1,16 @@
|
||||
import torch
|
||||
from typing import List, Optional, Dict
|
||||
from typing import List, Optional, Dict, Iterable, Union
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from vllm.worker.worker import Worker
|
||||
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.sequence import Logprob, SequenceGroupMetadata, SequenceData
|
||||
from vllm.sequence import (Logprob, SequenceGroupMetadata, SequenceData,
|
||||
SamplerOutput, SequenceGroupOutput, SequenceOutput)
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.worker.cache_engine import CacheEngine
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from itertools import count
|
||||
from dataclasses import dataclass, fields
|
||||
|
||||
|
||||
@ -24,6 +27,11 @@ class ExecuteModelData:
|
||||
return dict(
|
||||
(field.name, getattr(self, field.name)) for field in fields(self))
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d):
|
||||
cleaned = dict((field.name, d[field.name]) for field in fields(cls))
|
||||
return cls(**cleaned)
|
||||
|
||||
|
||||
def round_up_to_next_block(seq_len: int, block_size: int) -> int:
|
||||
return (seq_len + block_size - 1) // block_size
|
||||
@ -50,6 +58,21 @@ def create_execute_model_data(
|
||||
)
|
||||
|
||||
|
||||
def mock_worker(cls=None,
|
||||
vocab_size: int = 30_000,
|
||||
max_model_len: int = 2048,
|
||||
rank: int = 0) -> MagicMock:
|
||||
if cls is None:
|
||||
cls = Worker
|
||||
|
||||
worker = MagicMock(spec=cls)
|
||||
worker.vocab_size = vocab_size
|
||||
worker.max_model_len = max_model_len
|
||||
worker.rank = rank
|
||||
worker.device = 'cuda:0'
|
||||
return worker
|
||||
|
||||
|
||||
def patch_execute_model_with_seeds(worker: Worker, rand_seeds: List[int]):
|
||||
seed_iter = iter(rand_seeds)
|
||||
original_execute_model = worker.execute_model
|
||||
@ -117,25 +140,12 @@ def create_seq_group_metadata_from_prompts(
|
||||
block_size: int,
|
||||
final_seq_lens: List[int],
|
||||
continuations: Optional[List[List[int]]] = None,
|
||||
num_tokens_processed: Optional[List[int]] = None,
|
||||
seq_ids: Optional[List[int]] = None,
|
||||
) -> List[SequenceGroupMetadata]:
|
||||
|
||||
if continuations is None:
|
||||
continuations = [[] for _ in prompts]
|
||||
|
||||
if num_tokens_processed is None:
|
||||
# Default to 1 token missing from kv cache for generation sequences.
|
||||
num_tokens_processed = []
|
||||
for continuation, prompt in zip(continuations, prompts):
|
||||
# If prefill, then default to zero tokens processed.
|
||||
if not continuation:
|
||||
num_tokens_processed.append(0)
|
||||
else:
|
||||
# If generation, then default to all but one tokens processed.
|
||||
num_tokens_processed.append(
|
||||
len(continuation) + len(prompt) - 1)
|
||||
|
||||
if seq_ids is None:
|
||||
seq_ids = list(i for i, _ in enumerate(prompts))
|
||||
|
||||
@ -155,13 +165,15 @@ def create_seq_group_metadata_from_prompts(
|
||||
is_prompt=len(cont_token_ids) == 0,
|
||||
seq_data={
|
||||
i:
|
||||
SequenceData(prompt_token_ids=prompt_token_ids[:] +
|
||||
cont_token_ids[:])
|
||||
SequenceData(
|
||||
prompt_token_ids=prompt_token_ids[:],
|
||||
output_token_ids=cont_token_ids[:],
|
||||
),
|
||||
},
|
||||
sampling_params=SamplingParams(temperature=0.0, ),
|
||||
block_tables={i: block_allocations[i][:]},
|
||||
) for i, (prompt_token_ids, cont_token_ids, num_tokens_saved) in
|
||||
enumerate(zip(prompts, continuations, num_tokens_processed))
|
||||
) for i, (prompt_token_ids,
|
||||
cont_token_ids) in enumerate(zip(prompts, continuations))
|
||||
]
|
||||
|
||||
|
||||
@ -178,3 +190,68 @@ def assert_logprobs_dict_allclose(
|
||||
expected = torch.tensor(
|
||||
single_step_expected_logprobs[token_id].logprob)
|
||||
assert torch.allclose(actual, expected)
|
||||
|
||||
|
||||
def create_sampler_output_list(
|
||||
token_ids: torch.Tensor,
|
||||
probs: Iterable[Optional[torch.Tensor]],
|
||||
seq_ids: Optional[List[int]] = None) -> List[SamplerOutput]:
|
||||
num_steps, batch_size = token_ids.shape
|
||||
token_ids_by_step = token_ids.tolist()
|
||||
|
||||
if seq_ids is None:
|
||||
seq_ids = list(range(batch_size))
|
||||
|
||||
return [
|
||||
SamplerOutput(outputs=[
|
||||
SequenceGroupOutput(
|
||||
samples=[
|
||||
SequenceOutput(
|
||||
output_token=token_id,
|
||||
parent_seq_id=seq_ids[seq_index],
|
||||
logprobs={token_id: 0},
|
||||
)
|
||||
],
|
||||
prompt_logprobs=None,
|
||||
) for seq_index, token_id in enumerate(token_ids_by_step[step])
|
||||
],
|
||||
sampled_token_probs=probs[step],
|
||||
sampled_token_ids=token_ids[step])
|
||||
for step in range(num_steps)
|
||||
]
|
||||
|
||||
|
||||
def create_batch(batch_size,
|
||||
k,
|
||||
prompt_len: Union[int, List[int]] = 10,
|
||||
prev_output_token_len: int = 10,
|
||||
seq_ids: Optional[List[int]] = None,
|
||||
num_gpu_blocks: Optional[int] = None,
|
||||
block_size: Optional[int] = None):
|
||||
if block_size is None:
|
||||
block_size = 8
|
||||
|
||||
if num_gpu_blocks is None:
|
||||
num_gpu_blocks = 2048 // block_size
|
||||
|
||||
iterator = count()
|
||||
|
||||
if isinstance(prompt_len, int):
|
||||
prompt_lens = [prompt_len for _ in range(batch_size)]
|
||||
else:
|
||||
prompt_lens = prompt_len
|
||||
|
||||
prompts = [[next(iterator) for _ in range(p_len)] for p_len in prompt_lens]
|
||||
prev_output_tokens = [[
|
||||
next(iterator) for _ in range(prev_output_token_len)
|
||||
] for _ in range(batch_size)]
|
||||
final_seq_lens = [
|
||||
len(prompt) + len(prev_output_token) + k + 1
|
||||
for prompt, prev_output_token in zip(prompts, prev_output_tokens)
|
||||
]
|
||||
|
||||
execute_model_data = create_execute_model_data(
|
||||
create_seq_group_metadata_from_prompts(prompts, num_gpu_blocks,
|
||||
block_size, final_seq_lens,
|
||||
prev_output_tokens, seq_ids), )
|
||||
return execute_model_data, prompts, prev_output_tokens
|
50
tests/test_sequence.py
Normal file
50
tests/test_sequence.py
Normal file
@ -0,0 +1,50 @@
|
||||
import pytest
|
||||
|
||||
from vllm.sequence import SequenceGroupOutput, SamplerOutput, SequenceOutput
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_outputs():
|
||||
return [
|
||||
SequenceGroupOutput(samples=[
|
||||
SequenceOutput(parent_seq_id=0, output_token=i, logprobs={})
|
||||
],
|
||||
prompt_logprobs=None) for i in range(5)
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sampler_output(sample_outputs):
|
||||
return SamplerOutput(outputs=sample_outputs)
|
||||
|
||||
|
||||
def test_sampler_output_initialization(sampler_output, sample_outputs):
|
||||
assert len(sampler_output) == len(sample_outputs)
|
||||
assert sampler_output.sampled_token_probs is None
|
||||
assert sampler_output.sampled_token_ids is None
|
||||
assert sampler_output.spec_decode_worker_metrics is None
|
||||
|
||||
|
||||
def test_sampler_output_getitem(sampler_output, sample_outputs):
|
||||
assert sampler_output[2] == sample_outputs[2]
|
||||
|
||||
|
||||
def test_sampler_output_setitem(sampler_output):
|
||||
new_output = SequenceGroupOutput(samples=[
|
||||
SequenceOutput(parent_seq_id=0, output_token=99, logprobs={})
|
||||
],
|
||||
prompt_logprobs=None)
|
||||
sampler_output[2] = new_output
|
||||
assert sampler_output[2] == new_output
|
||||
|
||||
|
||||
def test_sampler_output_len(sampler_output, sample_outputs):
|
||||
assert len(sampler_output) == len(sample_outputs)
|
||||
|
||||
|
||||
def test_sampler_output_eq(sample_outputs):
|
||||
sampler_output1 = SamplerOutput(outputs=sample_outputs)
|
||||
sampler_output2 = SamplerOutput(outputs=sample_outputs.copy())
|
||||
sampler_output3 = SamplerOutput(outputs=sample_outputs[:-1])
|
||||
assert sampler_output1 == sampler_output2
|
||||
assert sampler_output1 != sampler_output3
|
@ -21,8 +21,6 @@ class RejectionSampler(nn.Module):
|
||||
nontrivial latency.
|
||||
"""
|
||||
super().__init__()
|
||||
self.probs_dtype = torch.float32
|
||||
self.token_id_dtype = torch.int64
|
||||
self._strict_mode = strict_mode
|
||||
|
||||
# NOTE: A "bonus token" is accepted iff all proposal tokens are
|
||||
@ -44,6 +42,14 @@ class RejectionSampler(nn.Module):
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
|
||||
@property
|
||||
def probs_dtype(self):
|
||||
return torch.float32
|
||||
|
||||
@property
|
||||
def token_id_dtype(self):
|
||||
return torch.int64
|
||||
|
||||
def forward(
|
||||
self,
|
||||
target_probs: torch.Tensor,
|
||||
|
@ -587,4 +587,4 @@ def _build_sampler_output(
|
||||
SequenceOutput(seq_ids[parent_id], next_token_id, logprobs))
|
||||
sampler_output.append(
|
||||
SequenceGroupOutput(seq_outputs, group_prompt_logprobs))
|
||||
return sampler_output
|
||||
return SamplerOutput(outputs=sampler_output)
|
||||
|
@ -2,12 +2,16 @@
|
||||
import copy
|
||||
import enum
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Dict, List, Optional, Union, TYPE_CHECKING
|
||||
|
||||
from vllm.block import LogicalTokenBlock
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
|
||||
|
||||
|
||||
@dataclass
|
||||
class Logprob:
|
||||
@ -81,6 +85,8 @@ class SequenceData:
|
||||
|
||||
Args:
|
||||
prompt_token_ids: The token IDs of the prompt.
|
||||
output_token_ids: The token IDs of the output. Set to an empty list if
|
||||
None.
|
||||
|
||||
Attributes:
|
||||
prompt_token_ids: The token IDs of the prompt.
|
||||
@ -91,9 +97,13 @@ class SequenceData:
|
||||
def __init__(
|
||||
self,
|
||||
prompt_token_ids: List[int],
|
||||
output_token_ids: Optional[List[int]] = None,
|
||||
) -> None:
|
||||
if output_token_ids is None:
|
||||
output_token_ids = []
|
||||
|
||||
self.prompt_token_ids = prompt_token_ids
|
||||
self.output_token_ids: List[int] = []
|
||||
self.output_token_ids = output_token_ids
|
||||
self.cumulative_logprob = 0.0
|
||||
|
||||
def append_token_id(self, token_id: int, logprob: float) -> None:
|
||||
@ -117,6 +127,12 @@ class SequenceData:
|
||||
return self.prompt_token_ids[-1]
|
||||
return self.output_token_ids[-1]
|
||||
|
||||
def get_prompt_token_ids(self) -> int:
|
||||
return self.prompt_token_ids
|
||||
|
||||
def get_output_token_ids(self) -> int:
|
||||
return self.output_token_ids
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"SequenceData("
|
||||
f"prompt_token_ids={self.prompt_token_ids}, "
|
||||
@ -506,6 +522,35 @@ class SequenceGroupOutput:
|
||||
and self.prompt_logprobs == other.prompt_logprobs)
|
||||
|
||||
|
||||
# For each sequence group, we generate a list of SequenceOutput object,
|
||||
# each of which contains one possible candidate for the next token.
|
||||
SamplerOutput = List[SequenceGroupOutput]
|
||||
@dataclass
|
||||
class SamplerOutput:
|
||||
"""For each sequence group, we generate a list of SequenceOutput object,
|
||||
each of which contains one possible candidate for the next token.
|
||||
|
||||
This datastructure implements methods so it can be used like a list, but
|
||||
also has optional fields for device tensors.
|
||||
"""
|
||||
|
||||
outputs: List[SequenceGroupOutput]
|
||||
|
||||
# On-device tensor containing probabilities of each token.
|
||||
sampled_token_probs: Optional["torch.Tensor"] = None
|
||||
|
||||
# On-device tensor containing the sampled token ids.
|
||||
sampled_token_ids: Optional["torch.Tensor"] = None
|
||||
|
||||
# Spec decode metrics populated by workers.
|
||||
spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
return self.outputs[idx]
|
||||
|
||||
def __setitem__(self, idx: int, value):
|
||||
self.outputs[idx] = value
|
||||
|
||||
def __len__(self):
|
||||
return len(self.outputs)
|
||||
|
||||
def __eq__(self, other: object):
|
||||
return isinstance(other,
|
||||
self.__class__) and self.outputs == other.outputs
|
||||
|
351
vllm/spec_decode/batch_expansion.py
Normal file
351
vllm/spec_decode/batch_expansion.py
Normal file
@ -0,0 +1,351 @@
|
||||
from typing import Iterator, List, Tuple, Optional, Dict
|
||||
from itertools import chain, count
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.sequence import (SamplerOutput, SequenceGroupMetadata, SequenceData)
|
||||
from vllm.worker.worker import Worker
|
||||
from vllm.spec_decode.util import nvtx_range, sampler_output_to_torch, get_all_seq_ids, split_batch_by_proposal_len
|
||||
from vllm.spec_decode.interfaces import SpeculativeScorer, SpeculativeProposals, SpeculativeScores
|
||||
|
||||
SeqId = int
|
||||
TargetSeqId = int
|
||||
TokenId = int
|
||||
|
||||
|
||||
class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
"""Implements a speculative scorer that uses batch expansion to get
|
||||
probabilities of speculative tokens according to the scoring model.
|
||||
|
||||
Batch expansion converts a list of sequences and multiple query positions
|
||||
to a new batch of sequences, each with a single query position. This allows
|
||||
for MQA-like scoring in speculative decoding without requiring an MQA
|
||||
kernel.
|
||||
|
||||
It is strictly less efficient than MQA scoring.
|
||||
|
||||
It only supports scoring the top1 proposal tokens of the proposer, instead
|
||||
of topk/tree.
|
||||
"""
|
||||
|
||||
def __init__(self, scorer_worker: Worker, device: str, vocab_size: int):
|
||||
self._scorer_worker = scorer_worker
|
||||
self._device = device
|
||||
self._vocab_size = vocab_size
|
||||
|
||||
@nvtx_range("BatchExpansionTop1Scorer.score_proposals")
|
||||
def score_proposals(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Optional[Dict[int, int]],
|
||||
blocks_to_swap_out: Optional[Dict[int, int]],
|
||||
blocks_to_copy: Optional[Dict[int, List[int]]],
|
||||
k: int,
|
||||
proposals: SpeculativeProposals,
|
||||
) -> SpeculativeScores:
|
||||
"""Score the proposed tokens via the scorer model.
|
||||
|
||||
This converts each input sequence to a set of k+1 target sequences. The
|
||||
target sequences have the unique continuations to be scored and a
|
||||
unique sequence ID that is different from all input sequence ids.
|
||||
|
||||
If a speculative sequence length would exceed the max model length, then
|
||||
no speculation is produced for that sequence.
|
||||
|
||||
Args:
|
||||
seq_group_metadata_list: The input sequence group metadata.
|
||||
blocks_to_swap_in: This is passed to the worker during scoring.
|
||||
blocks_to_swap_out: This is passed to the worker during scoring.
|
||||
blocks_to_copy: This is passed to the worker during scoring.
|
||||
k: The fixed proposal length.
|
||||
proposals: The speculative proposals to score.
|
||||
Returns:
|
||||
SpeculativeScores: The scores of each speculative token, along with
|
||||
which sequences were ignored during scoring.
|
||||
"""
|
||||
|
||||
# TODO(cade) perform this on GPU to remove blocking call.
|
||||
proposal_lens_list = proposals.proposal_lens.tolist()
|
||||
proposal_token_ids_list = proposals.proposal_token_ids.tolist()
|
||||
|
||||
spec_indices, non_spec_indices, target_seq_group_metadata_list, num_scoring_tokens = self._expand_batch(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
proposal_token_ids_list=proposal_token_ids_list,
|
||||
proposal_lens_list=proposal_lens_list,
|
||||
)
|
||||
|
||||
target_sampler_output = self._scorer_worker.execute_model(
|
||||
seq_group_metadata_list=target_seq_group_metadata_list,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
return_python_output=False)
|
||||
|
||||
all_tokens, all_probs = self._contract_batch(
|
||||
original_bs=len(seq_group_metadata_list),
|
||||
target_sampler_output=target_sampler_output,
|
||||
proposals=proposals,
|
||||
num_scoring_tokens=num_scoring_tokens,
|
||||
non_spec_indices=non_spec_indices,
|
||||
spec_indices=spec_indices,
|
||||
k=k,
|
||||
)
|
||||
|
||||
return SpeculativeScores(
|
||||
probs=all_probs,
|
||||
token_ids=all_tokens,
|
||||
)
|
||||
|
||||
def _expand_batch(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
proposal_token_ids_list: List[TokenId],
|
||||
proposal_lens_list: List[int],
|
||||
) -> Tuple[List[int], List[int], List[SequenceGroupMetadata], int]:
|
||||
"""Given the input sequences and potentially multiple corresponding
|
||||
proposal tokens, create a new batch where each sequence has a single
|
||||
query token.
|
||||
"""
|
||||
|
||||
# vLLM currently only supports proposal lens equal to zero or the batch
|
||||
# proposal len. This adds some complexity (splitting the batch into spec
|
||||
# and non spec sequences) and should be removed in the future. It can be
|
||||
# done by supporting per-sequence proposal lens.
|
||||
spec_seqs, spec_indices = split_batch_by_proposal_len(
|
||||
seq_group_metadata_list,
|
||||
proposal_lens_list,
|
||||
select_proposal_len_zero=False)
|
||||
non_spec_seqs, non_spec_indices = split_batch_by_proposal_len(
|
||||
seq_group_metadata_list,
|
||||
proposal_lens_list,
|
||||
select_proposal_len_zero=True)
|
||||
|
||||
target_seq_group_metadata_list = self._create_scoring_model_input(
|
||||
spec_seqs, proposal_token_ids_list)
|
||||
num_scoring_tokens = len(target_seq_group_metadata_list)
|
||||
target_seq_group_metadata_list.extend(non_spec_seqs)
|
||||
|
||||
return spec_indices, non_spec_indices, target_seq_group_metadata_list, num_scoring_tokens
|
||||
|
||||
def _contract_batch(self, original_bs: int,
|
||||
target_sampler_output: List[SamplerOutput],
|
||||
proposals: SpeculativeProposals,
|
||||
num_scoring_tokens: int, non_spec_indices: List[int],
|
||||
spec_indices: List[int],
|
||||
k: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Contract the expanded batch back into its original size.
|
||||
This maps the scores of speculative tokens back to their original
|
||||
sequences.
|
||||
"""
|
||||
(target_token_ids, target_probs, non_spec_target_token_ids,
|
||||
non_spec_target_probs) = self._split_scoring_output(
|
||||
target_sampler_output, num_scoring_tokens)
|
||||
|
||||
# Map distinct sequences used to score each token
|
||||
# of shape [batch_size * k + 1] back to [batch_size, k + 1].
|
||||
batch_size, k = proposals.proposal_token_ids.shape
|
||||
|
||||
target_token_ids = target_token_ids.squeeze().reshape(
|
||||
batch_size, k + 1)
|
||||
target_probs = target_probs.squeeze().reshape(batch_size, k + 1,
|
||||
self._vocab_size)
|
||||
|
||||
all_tokens = torch.full(size=(original_bs, k + 1),
|
||||
fill_value=-1,
|
||||
device=self._device,
|
||||
dtype=torch.long)
|
||||
all_probs = torch.zeros(original_bs,
|
||||
k + 1,
|
||||
self._vocab_size,
|
||||
device=self._device,
|
||||
dtype=torch.float32)
|
||||
|
||||
if non_spec_indices:
|
||||
all_tokens[non_spec_indices, 0] = non_spec_target_token_ids
|
||||
all_probs[non_spec_indices, :1, :] = non_spec_target_probs
|
||||
|
||||
if spec_indices:
|
||||
all_tokens[spec_indices] = target_token_ids
|
||||
all_probs[spec_indices] = target_probs
|
||||
|
||||
return all_tokens, all_probs
|
||||
|
||||
def _create_scoring_model_input(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k]
|
||||
) -> List[SequenceGroupMetadata]:
|
||||
"""Given the original input sequences and proposed tokens from the draft
|
||||
model, create a list of target sequences that can be used for scoring.
|
||||
"""
|
||||
|
||||
if not seq_group_metadata_list:
|
||||
return []
|
||||
|
||||
target_seq_ids_iter = self._create_target_seq_id_iterator(
|
||||
get_all_seq_ids(seq_group_metadata_list))
|
||||
|
||||
target_seq_group_metadata = list(
|
||||
chain.from_iterable(
|
||||
self._create_target_seq_group_metadata(
|
||||
seq_group_metadata,
|
||||
proposal_token_ids,
|
||||
i,
|
||||
target_seq_ids_iter,
|
||||
) for i, seq_group_metadata in enumerate(
|
||||
seq_group_metadata_list)))
|
||||
|
||||
return target_seq_group_metadata
|
||||
|
||||
def _create_target_seq_group_metadata(
|
||||
self,
|
||||
input_seq_group_metadata: SequenceGroupMetadata,
|
||||
proposal_token_ids: List[TokenId], # shape: [batch_size, k]
|
||||
batch_index: int,
|
||||
target_seq_ids_iter: Iterator[TargetSeqId],
|
||||
) -> List[SequenceGroupMetadata]:
|
||||
"""Given an input sequence group metadata and a list of draft tokens,
|
||||
create a list of target SequenceGroupMetadata, one for each
|
||||
token id that needs to be scored.
|
||||
|
||||
Naive speculative decoding requires K target model scores, one for each
|
||||
draft model token. However one can add a bonus token such that if each
|
||||
token is accepted, then a final token may be sampled from the model.
|
||||
This function creates K+1 target SequenceGroupMetadata to take
|
||||
advantage of the bonus token.
|
||||
"""
|
||||
assert not input_seq_group_metadata.is_prompt, (
|
||||
"Speculating on "
|
||||
"prompts not yet supported")
|
||||
assert len(input_seq_group_metadata.seq_data) == 1, (
|
||||
"Beam search "
|
||||
"not supported in speculative decoding")
|
||||
input_seq_id = next(iter(input_seq_group_metadata.seq_data.keys()))
|
||||
|
||||
token_ids_to_score = self._get_token_ids_to_score(
|
||||
proposal_token_ids[batch_index])
|
||||
|
||||
target_seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||
for token_ids in token_ids_to_score:
|
||||
target_seq_group_metadata_list.append(
|
||||
self._create_single_target_seq_group_metadata(
|
||||
input_seq_group_metadata,
|
||||
input_seq_id,
|
||||
next(target_seq_ids_iter),
|
||||
token_ids,
|
||||
))
|
||||
|
||||
return target_seq_group_metadata_list
|
||||
|
||||
def _create_single_target_seq_group_metadata(
|
||||
self,
|
||||
seq_group_metadata: SequenceGroupMetadata,
|
||||
seq_id: SeqId,
|
||||
target_seq_id: TargetSeqId,
|
||||
token_ids: List[TokenId],
|
||||
) -> SequenceGroupMetadata:
|
||||
"""Create a single target SequenceGroupMetadata.
|
||||
|
||||
Args:
|
||||
seq_group_metadata: The metadata for the input sequence.
|
||||
seq_id: The input sequence ID.
|
||||
target_seq_id: The corresponding target sequence ID.
|
||||
token_ids: The list of token ids that are to be appended to the
|
||||
input sequence.
|
||||
"""
|
||||
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||
prompt_token_ids = seq_data.get_prompt_token_ids()
|
||||
new_output_token_ids = [*seq_data.get_output_token_ids(), *token_ids]
|
||||
|
||||
return SequenceGroupMetadata(
|
||||
request_id=seq_group_metadata.request_id,
|
||||
is_prompt=seq_group_metadata.is_prompt,
|
||||
seq_data={
|
||||
target_seq_id:
|
||||
SequenceData(
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
output_token_ids=new_output_token_ids,
|
||||
),
|
||||
},
|
||||
sampling_params=seq_group_metadata.sampling_params,
|
||||
block_tables={
|
||||
target_seq_id: seq_group_metadata.block_tables[seq_id],
|
||||
},
|
||||
lora_request=None,
|
||||
)
|
||||
|
||||
def _split_scoring_output(
|
||||
self, sampler_output: SamplerOutput, num_scoring_tokens: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Split the target model output into speculative and non-speculative
|
||||
output.
|
||||
"""
|
||||
|
||||
# vLLM currently only supports proposal lens equal to zero or the batch
|
||||
# proposal len. This adds some complexity (splitting the batch into spec
|
||||
# and non spec sequences) and should be removed in the future. It can be
|
||||
# done by supporting per-sequence proposal lens.
|
||||
#
|
||||
# First samples are from speculative scoring, latter samples are non-
|
||||
# speculative samples.
|
||||
split_sizes = [
|
||||
num_scoring_tokens,
|
||||
sampler_output.sampled_token_ids.numel() - num_scoring_tokens
|
||||
]
|
||||
(spec_probs, non_spec_probs
|
||||
) = sampler_output.sampled_token_probs.split(split_sizes)
|
||||
(spec_sampled_tokens, non_spec_sampled_tokens
|
||||
) = sampler_output.sampled_token_ids.flatten().split(split_sizes)
|
||||
|
||||
# Convert scores to tensors.
|
||||
sampler_output.sampled_token_probs = spec_probs
|
||||
sampler_output.sampled_token_ids = spec_sampled_tokens
|
||||
target_token_ids, target_probs = sampler_output_to_torch(
|
||||
[sampler_output])
|
||||
|
||||
# Convert non-speculative output tokens to tensors.
|
||||
sampler_output.sampled_token_probs = non_spec_probs
|
||||
sampler_output.sampled_token_ids = non_spec_sampled_tokens
|
||||
non_spec_target_token_ids, non_spec_target_probs = sampler_output_to_torch(
|
||||
[sampler_output])
|
||||
|
||||
return target_token_ids, target_probs, non_spec_target_token_ids, non_spec_target_probs
|
||||
|
||||
def _create_target_seq_id_iterator(
|
||||
self, seq_ids: List[SeqId]) -> Iterator[TargetSeqId]:
|
||||
"""Create an iterator for creating target sequence ids.
|
||||
Target sequence ids are distinct from sequence ids because we create a
|
||||
distinct target sequence id for each proposal token to be scored.
|
||||
|
||||
This implementation increments a counter starting at 1 + max of all
|
||||
provided input sequence ids.
|
||||
"""
|
||||
return count(start=max(seq_ids) + 1)
|
||||
|
||||
def _get_token_ids_to_score(
|
||||
self,
|
||||
full_spec_token_ids: List[TokenId] # shape: [k]
|
||||
) -> List[List[TokenId]]:
|
||||
"""Given an int tensor of proposal token ids, return a list of
|
||||
token ids that should be scored.
|
||||
|
||||
Returns k+1 output lists. The additional one is used for generating the
|
||||
bonus token.
|
||||
|
||||
Example:
|
||||
Input: [0, 1, 2, 3] (k=4)
|
||||
Output: (k+1 lists)
|
||||
[]
|
||||
[0]
|
||||
[0, 1]
|
||||
[0, 1, 2]
|
||||
[0, 1, 2, 3]
|
||||
"""
|
||||
empty_token_ids = []
|
||||
|
||||
token_ids_to_score = [empty_token_ids]
|
||||
token_ids_to_score.extend([
|
||||
full_spec_token_ids[:i + 1]
|
||||
for i in range(len(full_spec_token_ids))
|
||||
])
|
||||
return token_ids_to_score
|
77
vllm/spec_decode/interfaces.py
Normal file
77
vllm/spec_decode/interfaces.py
Normal file
@ -0,0 +1,77 @@
|
||||
from typing import List, Tuple, Optional, Dict
|
||||
from dataclasses import dataclass
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.sequence import SequenceGroupMetadata
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpeculativeProposals:
|
||||
"""Datastructure used to represent proposal tokens from some proposer. It
|
||||
also tracks how many speculative tokens each sequence has.
|
||||
"""
|
||||
|
||||
# Speculative proposal tokens.
|
||||
proposal_token_ids: torch.Tensor
|
||||
|
||||
# Probabilities of the proposal tokens according to the proposer.
|
||||
proposal_probs: torch.Tensor
|
||||
|
||||
# The valid length of each proposal; can be zero.
|
||||
proposal_lens: torch.Tensor
|
||||
|
||||
def __repr__(self):
|
||||
return (f"SpeculativeProposals("
|
||||
f"proposal_token_ids={self.proposal_token_ids.shape}, "
|
||||
f"proposal_probs={self.proposal_probs.shape}, "
|
||||
f"proposal_lens={self.proposal_lens.shape})")
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpeculativeScores:
|
||||
"""Datastructure used to represent the scores of speculative tokens
|
||||
according to the scoring model.
|
||||
"""
|
||||
|
||||
# Probabilities of the speculative tokens according to the scoring model.
|
||||
probs: torch.Tensor
|
||||
|
||||
# Token ids sampled from the scoring model. Used for speculative bonus
|
||||
# tokens and also non-speculative normal decoding.
|
||||
token_ids: torch.Tensor
|
||||
|
||||
def __repr__(self):
|
||||
return (f"SpeculativeScores("
|
||||
f"probs={self.probs.shape}, "
|
||||
f"token_ids={self.token_ids.shape})")
|
||||
|
||||
|
||||
class SpeculativeProposer(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def get_proposals(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
max_proposal_len: int,
|
||||
) -> SpeculativeProposals:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SpeculativeScorer(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def score_proposals(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Optional[Dict[int, int]],
|
||||
blocks_to_swap_out: Optional[Dict[int, int]],
|
||||
blocks_to_copy: Optional[Dict[int, List[int]]],
|
||||
k: int,
|
||||
proposals: SpeculativeProposals,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError
|
174
vllm/spec_decode/metrics.py
Normal file
174
vllm/spec_decode/metrics.py
Normal file
@ -0,0 +1,174 @@
|
||||
import torch
|
||||
from dataclasses import dataclass
|
||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||
from typing import Optional
|
||||
from vllm.utils import in_wsl
|
||||
import time
|
||||
from typing import Callable
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpecDecodeWorkerMetrics:
|
||||
"""Dataclass holding metrics emitted from the spec decode worker.
|
||||
"""
|
||||
|
||||
# The empirical acceptance rate of the proposal method on a per-token basis.
|
||||
# This is useful for evaluating how well the proposal method aligns with the
|
||||
# scoring method.
|
||||
draft_acceptance_rate: float
|
||||
|
||||
# The empirical efficiency, measured as the number of tokens emitted by the
|
||||
# system divided by the number of tokens that could be emitted by the system
|
||||
# if the proposal method were perfect.
|
||||
system_efficiency: float
|
||||
|
||||
# The number of speculative tokens produced by the proposal method.
|
||||
draft_tokens: int
|
||||
|
||||
# The number of tokens emitted by the entire system.
|
||||
emitted_tokens: int
|
||||
|
||||
# The number of tokens accepted by the scoring model and verification
|
||||
# routine, e.g. Llama2-70B and lossless rejection sampling.
|
||||
#
|
||||
# NOTE: Any token accepted by the verification routine is considered
|
||||
# accepted (regardless of if the speculative prefix is also accepted). The
|
||||
# user will usually see less accepted tokens. This metric is helpful when
|
||||
# evaluating alignment of the proposal method with the scoring model.
|
||||
accepted_tokens: int
|
||||
|
||||
# The number of speculative tokens per sequence.
|
||||
num_spec_tokens: int
|
||||
|
||||
|
||||
Timer = Callable[[], float]
|
||||
|
||||
|
||||
class AsyncMetricsCollector:
|
||||
"""Class which copies rejection sampler metrics from the device to CPU on a
|
||||
non-default Torch stream.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
rejection_sampler: RejectionSampler,
|
||||
timer: Optional[Timer] = None,
|
||||
collect_interval_s: float = 5.0):
|
||||
self._rejection_sampler = rejection_sampler
|
||||
self._timer = time.time if timer is None else timer
|
||||
|
||||
self._rank: Optional[int] = None
|
||||
|
||||
# We don't have a device set yet.
|
||||
self._copy_stream: Optional[torch.cuda.Stream] = None
|
||||
|
||||
self._in_flight_copy: Optional[torch.cuda.Event] = None
|
||||
|
||||
pin_memory = not in_wsl()
|
||||
self._aggregate_num_accepted_tokens = torch.tensor(
|
||||
0, dtype=torch.long, device="cpu", pin_memory=pin_memory)
|
||||
self._aggregate_num_emitted_tokens = torch.tensor(
|
||||
0, dtype=torch.long, device="cpu", pin_memory=pin_memory)
|
||||
self._aggregate_num_draft_tokens = 0
|
||||
|
||||
self._rejsample_metrics_collect_interval_s = collect_interval_s
|
||||
self._last_metrics_collect_time = self._timer()
|
||||
|
||||
def init_gpu_tensors(self, rank: int) -> None:
|
||||
self._rank = rank
|
||||
self._copy_stream = torch.cuda.Stream()
|
||||
|
||||
def maybe_collect_rejsample_metrics(
|
||||
self, k: int) -> Optional[SpecDecodeWorkerMetrics]:
|
||||
|
||||
# If a copy was initiated in the previous call, collect and return.
|
||||
if self._in_flight_copy is not None:
|
||||
ready_event = self._in_flight_copy
|
||||
self._in_flight_copy = None
|
||||
return self._collect_rejsample_metrics(k, ready_event)
|
||||
|
||||
# Otherwise, check if we should start a new copy.
|
||||
if self._should_collect_rejsample_metrics(self._timer()):
|
||||
assert self._in_flight_copy is None
|
||||
self._in_flight_copy = self._copy_rejsample_metrics_async()
|
||||
|
||||
return None
|
||||
|
||||
def _should_collect_rejsample_metrics(self, now: float) -> bool:
|
||||
"""Return whether or not this iteration should print rejection sampling
|
||||
metrics.
|
||||
"""
|
||||
if self._rank != 0:
|
||||
return False
|
||||
|
||||
if (now - self._last_metrics_collect_time <
|
||||
self._rejsample_metrics_collect_interval_s):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _copy_rejsample_metrics_async(self) -> torch.cuda.Event:
|
||||
"""Copy rejection sampling metrics (number of accepted tokens, etc) to
|
||||
CPU asynchronously.
|
||||
|
||||
Returns a CUDA event recording when the copy is complete.
|
||||
"""
|
||||
self._copy_stream.wait_stream(torch.cuda.current_stream())
|
||||
|
||||
with torch.cuda.stream(self._copy_stream):
|
||||
self._aggregate_num_accepted_tokens.copy_(
|
||||
self._rejection_sampler.num_accepted_tokens, non_blocking=True)
|
||||
self._aggregate_num_emitted_tokens.copy_(
|
||||
self._rejection_sampler.num_emitted_tokens, non_blocking=True)
|
||||
# Number of draft tokens is calculated on CPU, so no copy is
|
||||
# required.
|
||||
self._aggregate_num_draft_tokens = (
|
||||
self._rejection_sampler.num_draft_tokens)
|
||||
|
||||
aggregate_metrics_ready = torch.cuda.Event()
|
||||
aggregate_metrics_ready.record(self._copy_stream)
|
||||
|
||||
return aggregate_metrics_ready
|
||||
|
||||
def _collect_rejsample_metrics(
|
||||
self, k: int,
|
||||
ready_event: torch.cuda.Event) -> SpecDecodeWorkerMetrics:
|
||||
"""Create metrics object from statistics copied asynchronously.
|
||||
|
||||
Args:
|
||||
k: int. The number of speculative tokens; used to determine system
|
||||
efficiency.
|
||||
ready_event: torch.cuda.Event. The CUDA event recording when the
|
||||
async GPU->CPU copy is complete.
|
||||
"""
|
||||
|
||||
ready_event.synchronize()
|
||||
accepted_tokens = self._aggregate_num_accepted_tokens.item()
|
||||
emitted_tokens = self._aggregate_num_emitted_tokens.item()
|
||||
draft_tokens = self._aggregate_num_draft_tokens
|
||||
|
||||
num_possible_tokens = self.get_max_num_accepted_tokens(draft_tokens, k)
|
||||
|
||||
if draft_tokens > 0:
|
||||
draft_acceptance_rate = accepted_tokens / draft_tokens
|
||||
else:
|
||||
draft_acceptance_rate = float("nan")
|
||||
|
||||
if num_possible_tokens > 0:
|
||||
system_efficiency = emitted_tokens / num_possible_tokens
|
||||
else:
|
||||
system_efficiency = float("nan")
|
||||
|
||||
return SpecDecodeWorkerMetrics(
|
||||
num_spec_tokens=k,
|
||||
draft_acceptance_rate=draft_acceptance_rate,
|
||||
system_efficiency=system_efficiency,
|
||||
accepted_tokens=accepted_tokens,
|
||||
draft_tokens=draft_tokens,
|
||||
emitted_tokens=emitted_tokens,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_max_num_accepted_tokens(draft_tokens: int, k: int) -> int:
|
||||
# Divide by k since batch size can be variable.
|
||||
total_num_spec_seqs = draft_tokens / k
|
||||
num_accepted_per_seq_if_all_accepted = k + 1
|
||||
return int(total_num_spec_seqs / num_accepted_per_seq_if_all_accepted)
|
366
vllm/spec_decode/multi_step_worker.py
Normal file
366
vllm/spec_decode/multi_step_worker.py
Normal file
@ -0,0 +1,366 @@
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
import copy
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.worker.worker import Worker
|
||||
from vllm.spec_decode.interfaces import SpeculativeProposals, SpeculativeProposer
|
||||
from vllm.spec_decode.util import sampler_output_to_torch
|
||||
|
||||
|
||||
class MultiStepWorker(Worker):
|
||||
"""The MultiStepWorker is equivalent to a Worker except that it allows
|
||||
multiple forward passes in a single call, assuming the scheduler has
|
||||
allocated enough space to store the additional KV. This reduces overhead
|
||||
by invoking the scheduler less.
|
||||
|
||||
The MultiStepWorker does not support cache swap operations, or beam search.
|
||||
Cache swap operations do not require large modifications. On the other hand,
|
||||
beam search requires memory allocations during sequence forks and thus
|
||||
requires more thought for MultiStepWorker support.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self._proposer: Optional[DraftModelTop1Proposer] = None
|
||||
|
||||
def init_model(self):
|
||||
super().init_model()
|
||||
|
||||
self._proposer = DraftModelTop1Proposer(
|
||||
self,
|
||||
self.device,
|
||||
self.max_model_len,
|
||||
self.vocab_size,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model_multi_step(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
num_steps: int,
|
||||
) -> List[SamplerOutput]:
|
||||
"""Run the model forward pass num_steps times. Returns the list of
|
||||
sampler output, one per model forward pass.
|
||||
"""
|
||||
self._raise_if_unsupported(seq_group_metadata_list, blocks_to_swap_in,
|
||||
blocks_to_swap_out, blocks_to_copy)
|
||||
|
||||
# Shallow copy input data so modifications (such as appending tokens)
|
||||
# do not cause side-effects.
|
||||
copied_seq_group_metadata_list = self._shallow_copy_inputs(
|
||||
seq_group_metadata_list)
|
||||
|
||||
# Assert enough KV space for num_steps tokens per sequence.
|
||||
self._assert_enough_kv_space(seq_group_metadata_list, num_steps)
|
||||
|
||||
# Run model num_steps times.
|
||||
model_outputs = []
|
||||
for _ in range(num_steps):
|
||||
model_output = super().execute_model(
|
||||
seq_group_metadata_list=copied_seq_group_metadata_list,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
)
|
||||
|
||||
self._append_new_tokens(model_output,
|
||||
copied_seq_group_metadata_list)
|
||||
model_outputs.append(model_output)
|
||||
|
||||
return model_outputs
|
||||
|
||||
def get_spec_proposals(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
max_proposal_len: int,
|
||||
) -> SpeculativeProposals:
|
||||
"""Produce speculations given an input batch of sequences. The number of
|
||||
speculative tokens per sequence is determined by max_proposal_len.
|
||||
"""
|
||||
|
||||
return self._proposer.get_proposals(
|
||||
seq_group_metadata_list,
|
||||
blocks_to_swap_in,
|
||||
blocks_to_swap_out,
|
||||
blocks_to_copy,
|
||||
max_proposal_len,
|
||||
)
|
||||
|
||||
def _append_new_tokens(
|
||||
self, model_output: SamplerOutput,
|
||||
seq_group_metadata_list: SequenceGroupMetadata) -> None:
|
||||
"""Given model output from a single run, append the tokens to the
|
||||
sequences. This is normally done outside of the worker, but it is
|
||||
required if the worker is to perform multiple forward passes.
|
||||
"""
|
||||
for seq_group_metadata, sequence_group_outputs in zip(
|
||||
seq_group_metadata_list, model_output):
|
||||
seq_group_metadata.is_prompt = False
|
||||
|
||||
for seq_output in sequence_group_outputs.samples:
|
||||
# NOTE: Beam search is not supported, so we can assume that
|
||||
# parent_seq_id == seq_id.
|
||||
seq = seq_group_metadata.seq_data[seq_output.parent_seq_id]
|
||||
|
||||
token_id = seq_output.output_token
|
||||
token_logprob = seq_output.logprobs[token_id]
|
||||
|
||||
seq.append_token_id(token_id, token_logprob.logprob)
|
||||
|
||||
def _shallow_copy_inputs(
|
||||
self, seq_group_metadata_list: List[SequenceGroupMetadata]
|
||||
) -> List[SequenceGroupMetadata]:
|
||||
"""Copy input data structures to remove side-effects when input data
|
||||
structures are shared with other modules.
|
||||
|
||||
Helpful when the vLLM scheduler runs in the same process as the worker.
|
||||
The alternative is deep-copying (or other form of deep copy); this has
|
||||
performance downsides.
|
||||
"""
|
||||
|
||||
# Shallow-copy the list of SequenceGroupMetadata. This allows us to
|
||||
# append tokens and change is_prompt without external side-effects.
|
||||
new_seq_group_metadata_list = []
|
||||
|
||||
for old_seq_group_metadata in seq_group_metadata_list:
|
||||
# We must shallow-copy seq_group_metadata as is_prompt could change.
|
||||
seq_group_metadata = copy.copy(old_seq_group_metadata)
|
||||
new_seq_group_metadata_list.append(seq_group_metadata)
|
||||
|
||||
# We must shallow-copy seq_data as we will append token ids
|
||||
new_seq_data = {}
|
||||
for seq_id, old_seq_data in seq_group_metadata.seq_data.items():
|
||||
new_seq_data[seq_id] = copy.copy(old_seq_data)
|
||||
new_seq_data[
|
||||
seq_id].output_token_ids = old_seq_data.output_token_ids[:]
|
||||
|
||||
seq_group_metadata.seq_data = new_seq_data
|
||||
|
||||
return new_seq_group_metadata_list
|
||||
|
||||
def _assert_enough_kv_space(
|
||||
self, seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
num_steps: int) -> None:
|
||||
"""Assert there are enough physical blocks per sequence to store the
|
||||
current KV plus additional KV from num_steps tokens.
|
||||
"""
|
||||
assert self.model_runner.block_size is not None
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
# Only one seq_id is guaranteed because there is no beam search.
|
||||
seq_id = list(seq_group_metadata.seq_data.keys())[0]
|
||||
seq = seq_group_metadata.seq_data[seq_id]
|
||||
|
||||
# After num_steps, the seq len will be the current seq len
|
||||
# plus one token per step.
|
||||
final_seq_len = seq.get_len() + num_steps
|
||||
|
||||
# We will have final_seq_len - 1 KV because vLLM saves KV for a
|
||||
# token in the iteration after the token was generated.
|
||||
required_num_kv_slots = final_seq_len - 1
|
||||
|
||||
# The allocated number of kv slots is the number of allocated blocks
|
||||
# times the number of slots of block.
|
||||
number_physical_blocks = len(
|
||||
seq_group_metadata.block_tables[seq_id])
|
||||
allocated_kv_slots = (number_physical_blocks *
|
||||
self.model_runner.block_size)
|
||||
|
||||
if required_num_kv_slots > allocated_kv_slots:
|
||||
request_id = seq_group_metadata.request_id
|
||||
raise ValueError(
|
||||
"The worker attempted to run "
|
||||
f"{num_steps} times but found insufficient KV space for "
|
||||
f"{request_id=} {seq_id=}. ({allocated_kv_slots=} "
|
||||
f"{required_num_kv_slots=}).")
|
||||
|
||||
def _raise_if_unsupported(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
) -> None:
|
||||
"""MultiStepWorker does not yet implement support for cache swap
|
||||
operations or beam search.
|
||||
"""
|
||||
if any([blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy]):
|
||||
raise NotImplementedError(
|
||||
"MultiStepWorker does not support cache operations")
|
||||
|
||||
if any(
|
||||
len(seq_group_metadata.seq_data.keys()) != 1
|
||||
for seq_group_metadata in seq_group_metadata_list):
|
||||
raise NotImplementedError(
|
||||
"MultiStepWorker does not support beam search.")
|
||||
|
||||
|
||||
class DraftModelTop1Proposer(SpeculativeProposer):
|
||||
"""Helper class which separates out sequences which would exceed the max
|
||||
model length when speculated upon.
|
||||
|
||||
This allows combinations of models such as JackFram/llama-68m draft with
|
||||
meta-llama/Llama2-13b-chat-hf, as llama-68m has max_position_embeddings of
|
||||
2048 while Llama2-13b has max_position_embeddings of 4096.
|
||||
|
||||
We treat the sequences which exceed the proposal draft model length as
|
||||
"non-spec sequences". Essentially they skip the draft model and go through
|
||||
normal decoding in the target model.
|
||||
|
||||
Currently, only proposal_lens of 0 and k are supported, where k is a global
|
||||
batch proposal length. In the future vLLM should support per-sequence
|
||||
proposal lengths.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
draft_worker: MultiStepWorker,
|
||||
device: str,
|
||||
max_model_len: int,
|
||||
vocab_size: int,
|
||||
):
|
||||
self._draft_worker = draft_worker
|
||||
self._device = device
|
||||
self._max_model_len = max_model_len
|
||||
self._vocab_size = vocab_size
|
||||
|
||||
def get_proposals(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
max_proposal_len: int,
|
||||
) -> SpeculativeProposals:
|
||||
"""Get speculative proposals given the input batch.
|
||||
|
||||
Sequences which would exceed the max model length are skipped during
|
||||
speculation.
|
||||
"""
|
||||
|
||||
# Split speculative- and non-speculative- sequences.
|
||||
proposal_lens, nonzero_proposal_len_seqs, nonzero_proposal_len_indices = self._split_by_max_model_len(
|
||||
seq_group_metadata_list, max_proposal_len)
|
||||
|
||||
if nonzero_proposal_len_seqs:
|
||||
# Speculate tokens using the draft worker for the speculative
|
||||
# sequences.
|
||||
maybe_sampler_output = self._draft_worker.execute_model_multi_step(
|
||||
seq_group_metadata_list=nonzero_proposal_len_seqs,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
num_steps=max_proposal_len,
|
||||
)
|
||||
else:
|
||||
# If no sequences can be speculated, set sampler output to None.
|
||||
maybe_sampler_output = None
|
||||
|
||||
# Combine speculative- and non-speculative sequences into the same
|
||||
# representation.
|
||||
proposal_tokens, proposal_probs, proposal_lens = self._merge_outputs(
|
||||
batch_size=len(seq_group_metadata_list),
|
||||
max_proposal_len=max_proposal_len,
|
||||
maybe_sampler_output=maybe_sampler_output,
|
||||
proposal_lens=proposal_lens,
|
||||
nonzero_proposal_len_indices=nonzero_proposal_len_indices,
|
||||
)
|
||||
|
||||
proposals = SpeculativeProposals(
|
||||
proposal_token_ids=proposal_tokens,
|
||||
proposal_probs=proposal_probs,
|
||||
proposal_lens=proposal_lens,
|
||||
)
|
||||
|
||||
return proposals
|
||||
|
||||
def _split_by_max_model_len(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
max_proposal_len: int,
|
||||
) -> Tuple[List[int], List[SequenceGroupMetadata], List[int]]:
|
||||
"""Determine which sequences would exceed the max model length.
|
||||
"""
|
||||
|
||||
proposal_lens: List[int] = []
|
||||
nonzero_proposal_len_seqs: List[SequenceGroupMetadata] = []
|
||||
nonzero_proposal_len_indices: List[int] = []
|
||||
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
||||
seq_data = next(iter(seq_group_metadata.seq_data.values()))
|
||||
seq_len = seq_data.get_len()
|
||||
|
||||
# Currently only proposal lens of 0 or the global batch proposal len
|
||||
# are supported.
|
||||
if seq_len + max_proposal_len < self._max_model_len:
|
||||
proposal_lens.append(max_proposal_len)
|
||||
nonzero_proposal_len_seqs.append(seq_group_metadata)
|
||||
nonzero_proposal_len_indices.append(i)
|
||||
else:
|
||||
proposal_lens.append(0)
|
||||
|
||||
return proposal_lens, nonzero_proposal_len_seqs, nonzero_proposal_len_indices
|
||||
|
||||
def _merge_outputs(
|
||||
self,
|
||||
batch_size: int,
|
||||
max_proposal_len: int,
|
||||
maybe_sampler_output: Optional[SamplerOutput],
|
||||
proposal_lens: List[int],
|
||||
nonzero_proposal_len_indices: List[int],
|
||||
) -> Tuple[torch.Tensor, torch.tensor, torch.Tensor]:
|
||||
"""After speculations are produced, merge the speculation results with
|
||||
the skipped sequences.
|
||||
"""
|
||||
if maybe_sampler_output is None:
|
||||
# If no speculative tokens, the sampler output will be None.
|
||||
# In this case we return empty tensors.
|
||||
proposal_tokens = torch.zeros(0,
|
||||
max_proposal_len,
|
||||
dtype=torch.long,
|
||||
device=self._device)
|
||||
proposal_probs = torch.zeros(0,
|
||||
max_proposal_len,
|
||||
self._vocab_size,
|
||||
dtype=torch.float32,
|
||||
device=self._device)
|
||||
proposal_lens = torch.zeros(len(proposal_lens),
|
||||
dtype=torch.long,
|
||||
device=self._device)
|
||||
return proposal_tokens, proposal_probs, proposal_lens
|
||||
|
||||
sampler_output = maybe_sampler_output
|
||||
|
||||
proposal_tokens, proposal_probs = sampler_output_to_torch(
|
||||
sampler_output)
|
||||
|
||||
# Now, reformat the output GPU tensors such that each sequence has
|
||||
# a proposal. the proposal can be empty, e.g. [-1, -1, -1]
|
||||
|
||||
entire_proposal_tokens = torch.full(size=(batch_size,
|
||||
*proposal_tokens.shape[1:]),
|
||||
fill_value=-1,
|
||||
dtype=torch.long,
|
||||
device=self._device)
|
||||
entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens
|
||||
entire_proposal_probs = torch.zeros(batch_size,
|
||||
*proposal_probs.shape[1:],
|
||||
dtype=torch.float32,
|
||||
device=self._device)
|
||||
entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs
|
||||
|
||||
proposal_tokens, proposal_probs = entire_proposal_tokens, entire_proposal_probs
|
||||
|
||||
proposal_lens = torch.zeros(batch_size,
|
||||
dtype=torch.long,
|
||||
device=self._device)
|
||||
proposal_lens[nonzero_proposal_len_indices] = max_proposal_len
|
||||
|
||||
return proposal_tokens, proposal_probs, proposal_lens
|
372
vllm/spec_decode/spec_decode_worker.py
Normal file
372
vllm/spec_decode/spec_decode_worker.py
Normal file
@ -0,0 +1,372 @@
|
||||
from typing import List, Tuple, Optional, Dict
|
||||
from functools import cached_property
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.spec_decode.metrics import AsyncMetricsCollector
|
||||
from vllm.sequence import (SamplerOutput, SequenceGroupMetadata,
|
||||
SequenceGroupOutput, SequenceOutput)
|
||||
from vllm.worker.worker import Worker
|
||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.spec_decode.util import nvtx_range, get_all_seq_ids, split_batch_by_proposal_len
|
||||
from vllm.spec_decode.interfaces import SpeculativeProposals, SpeculativeScores
|
||||
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
|
||||
from vllm.spec_decode.interfaces import SpeculativeScorer
|
||||
|
||||
|
||||
class SpecDecodeWorker:
|
||||
"""Worker which implements speculative decoding.
|
||||
|
||||
Speculative decoding reduces decoding per-token latency by using a proposal
|
||||
method, such as a small draft model, to speculate ahead of a larger LLM. The
|
||||
probabilities of the speculative tokens are then determined by the larger
|
||||
LLM, after which some verification routine determines which (if any) of the
|
||||
speculative tokens are accepted by the larger LLM.
|
||||
|
||||
See https://github.com/vllm-project/vllm/pull/2188 and
|
||||
https://github.com/vllm-project/vllm/pull/3103 for more info.
|
||||
|
||||
The current implementation has the following limitations:
|
||||
* Only draft-model proposal is implemented (contributions for more forms are
|
||||
welcome!).
|
||||
* Only top-1 proposal and scoring are implemented. Tree-attention is left as
|
||||
future work.
|
||||
* Only lossless rejection sampling is supported. Contributions adding lossy
|
||||
verification routines are welcome (e.g. Medusa's typical acceptance).
|
||||
* All sequences in a batch must have the same proposal length, or zero. This
|
||||
can be improved by having per-sequence speculation in the future.
|
||||
* The scoring forward pass is done without an MQA kernel, which is
|
||||
suboptimal especially as the batch size, proposal length, and sequence
|
||||
lengths grow. Contributions to add a MQA scoring are welcome once
|
||||
correctness tests pass.
|
||||
More info here https://docs.google.com/document/d/1T-JaS2T1NRfdP51qzqpyakoCXxSXTtORppiwaj5asxA/edit.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
proposer_worker: MultiStepWorker,
|
||||
scorer_worker: Worker,
|
||||
rejection_sampler: RejectionSampler,
|
||||
metrics_collector: Optional[AsyncMetricsCollector] = None,
|
||||
):
|
||||
"""
|
||||
Create a SpecDecodeWorker.
|
||||
|
||||
Args:
|
||||
proposer_worker: A worker that can produce speculative tokens for
|
||||
sequences.
|
||||
scorer_worker: A worker that produces probabilities of speculative
|
||||
tokens according to some base model. Typically a vanilla vLLM
|
||||
Worker.
|
||||
rejection_sampler: A Torch module used to perform modified rejection
|
||||
sampling for speculative decoding.
|
||||
metrics_collector: Helper class for collecting metrics; can be set
|
||||
for testing purposes.
|
||||
"""
|
||||
self.proposer_worker = proposer_worker
|
||||
self.scorer_worker = scorer_worker
|
||||
self.rejection_sampler = rejection_sampler
|
||||
|
||||
self._metrics = AsyncMetricsCollector(
|
||||
rejection_sampler
|
||||
) if metrics_collector is None else metrics_collector
|
||||
|
||||
self.probs_dtype = self.rejection_sampler.probs_dtype
|
||||
self.token_id_dtype = self.rejection_sampler.token_id_dtype
|
||||
|
||||
self.scorer: SpeculativeScorer = None
|
||||
|
||||
def init_model(self) -> None:
|
||||
"""Initialize both scorer and proposer models.
|
||||
"""
|
||||
# The scorer worker model is initialized first in case the proposer
|
||||
# model has a smaller TP degree than the target worker.
|
||||
self.scorer_worker.init_model()
|
||||
self.proposer_worker.init_model()
|
||||
|
||||
self._metrics.init_gpu_tensors(self.rank)
|
||||
self.rejection_sampler.init_gpu_tensors(self.rank)
|
||||
self.scorer = BatchExpansionTop1Scorer(
|
||||
scorer_worker=self.scorer_worker,
|
||||
device=self.device,
|
||||
vocab_size=self._vocab_size)
|
||||
|
||||
def profile_num_available_blocks(self, block_size: int,
|
||||
gpu_memory_utilization: float,
|
||||
cpu_swap_space: int,
|
||||
cache_dtype: str) -> Tuple[int, int]:
|
||||
"""Determine the number of cache blocks to use.
|
||||
|
||||
This is done by profiling the scorer model (which is typically the
|
||||
larger of the two). Then the total memory which would be used by the
|
||||
scorer cache is divided evenly between the proposer and scorer model KV,
|
||||
such that the number of blocks is equal in both KV caches.
|
||||
"""
|
||||
num_gpu_blocks, num_cpu_blocks = (
|
||||
self.scorer_worker.profile_num_available_blocks(
|
||||
block_size, gpu_memory_utilization, cpu_swap_space,
|
||||
cache_dtype))
|
||||
|
||||
scorer_cache_block_size_bytes = self.scorer_worker.get_cache_block_size_bytes(
|
||||
block_size, cache_dtype)
|
||||
proposer_cache_block_size_bytes = self.proposer_worker.get_cache_block_size_bytes(
|
||||
block_size, cache_dtype)
|
||||
|
||||
new_num_gpu_blocks = split_num_cache_blocks_evenly(
|
||||
scorer_cache_block_size_bytes, proposer_cache_block_size_bytes,
|
||||
num_gpu_blocks)
|
||||
return new_num_gpu_blocks, num_cpu_blocks
|
||||
|
||||
def init_cache_engine(self, cache_config: CacheConfig):
|
||||
"""Initialize the cache engine of the scorer and proposer workers.
|
||||
"""
|
||||
self.scorer_worker.init_cache_engine(cache_config)
|
||||
self.proposer_worker.init_cache_engine(cache_config)
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Optional[Dict[int, int]],
|
||||
blocks_to_swap_out: Optional[Dict[int, int]],
|
||||
blocks_to_copy: Optional[Dict[int, List[int]]],
|
||||
num_spec_tokens: int,
|
||||
) -> List[SamplerOutput]:
|
||||
"""Perform speculative decoding on the input batch.
|
||||
"""
|
||||
|
||||
assert seq_group_metadata_list is not None, (
|
||||
"speculative decoding "
|
||||
"requires non-None seq_group_metadata_list")
|
||||
|
||||
# If no spec tokens, call the proposer and scorer workers normally.
|
||||
# Used for prefill.
|
||||
if num_spec_tokens == 0 or len(seq_group_metadata_list) == 0:
|
||||
return self._run_no_spec(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
)
|
||||
|
||||
return self._run_speculative_decoding_step(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
k=num_spec_tokens,
|
||||
)
|
||||
|
||||
@nvtx_range("spec_decode_worker._run_no_spec")
|
||||
def _run_no_spec(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Optional[Dict[int, int]],
|
||||
blocks_to_swap_out: Optional[Dict[int, int]],
|
||||
blocks_to_copy: Optional[Dict[int, List[int]]],
|
||||
) -> List[SamplerOutput]:
|
||||
"""Run a prefill step, without any speculation. The input is sent to the
|
||||
proposer and scorer model so that the KV cache is consistent between the
|
||||
two.
|
||||
"""
|
||||
|
||||
self.proposer_worker.execute_model(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
return_python_output=False)
|
||||
|
||||
sampler_output = self.scorer_worker.execute_model(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
)
|
||||
|
||||
# Clear device tensors from sampler output. This reduces communication
|
||||
# overhead when the engine runs in a different process than the workers.
|
||||
sampler_output.probs = None
|
||||
sampler_output.sampled_tokens = None
|
||||
return [sampler_output]
|
||||
|
||||
@nvtx_range("spec_decode_worker._run_speculative_decoding_step")
|
||||
def _run_speculative_decoding_step(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Optional[Dict[int, int]],
|
||||
blocks_to_swap_out: Optional[Dict[int, int]],
|
||||
blocks_to_copy: Optional[Dict[int, List[int]]],
|
||||
k: int,
|
||||
) -> List[SamplerOutput]:
|
||||
"""Execute a single step of speculative decoding.
|
||||
|
||||
This invokes the proposer worker to get k speculative tokens for each
|
||||
sequence, then scores each speculative token using the scoring worker.
|
||||
|
||||
Returns a list of SamplerOutput, each containing a single token per
|
||||
sequence.
|
||||
"""
|
||||
|
||||
# Generate proposals using draft worker.
|
||||
proposals = self.proposer_worker.get_spec_proposals(
|
||||
seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out,
|
||||
blocks_to_copy, k)
|
||||
|
||||
proposal_scores = self.scorer.score_proposals(
|
||||
seq_group_metadata_list,
|
||||
blocks_to_swap_in,
|
||||
blocks_to_swap_out,
|
||||
blocks_to_copy,
|
||||
k,
|
||||
proposals,
|
||||
)
|
||||
|
||||
accepted_token_ids = self._verify_tokens(seq_group_metadata_list,
|
||||
proposal_scores, proposals, k)
|
||||
|
||||
return self._create_output_sampler_list(seq_group_metadata_list,
|
||||
accepted_token_ids, k)
|
||||
|
||||
@nvtx_range("spec_decode_worker._verify_tokens")
|
||||
def _verify_tokens(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
proposal_scores: SpeculativeScores,
|
||||
proposals: SpeculativeProposals,
|
||||
max_proposal_len: int,
|
||||
) -> torch.Tensor:
|
||||
"""Determine which speculative tokens are accepted using the
|
||||
probabilities of each token according to the proposer and scorer models.
|
||||
"""
|
||||
proposal_lens_list = proposals.proposal_lens.tolist()
|
||||
|
||||
# vLLM currently only supports proposal lens equal to zero or the batch
|
||||
# proposal len. This adds some complexity (splitting the batch into spec
|
||||
# and non spec sequences) and should be removed in the future. It can be
|
||||
# done by supporting per-sequence proposal lens.
|
||||
_, spec_indices = split_batch_by_proposal_len(
|
||||
seq_group_metadata_list,
|
||||
proposal_lens_list,
|
||||
select_proposal_len_zero=False)
|
||||
_, non_spec_indices = split_batch_by_proposal_len(
|
||||
seq_group_metadata_list,
|
||||
proposal_lens_list,
|
||||
select_proposal_len_zero=True)
|
||||
original_indices = spec_indices + non_spec_indices
|
||||
|
||||
proposal_probs = proposal_scores.probs[spec_indices, :-1]
|
||||
bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:]
|
||||
non_spec_token_ids = proposal_scores.token_ids[non_spec_indices]
|
||||
|
||||
accepted_token_ids = self.rejection_sampler(
|
||||
proposal_probs,
|
||||
bonus_token_ids,
|
||||
proposals.proposal_probs,
|
||||
proposals.proposal_token_ids,
|
||||
)
|
||||
|
||||
# Append output tokens from non-speculative sequences to
|
||||
# the accepted token ids tensor.
|
||||
non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len +
|
||||
1).clone()
|
||||
non_spec_token_ids[:, 1:] = -1
|
||||
accepted_token_ids = torch.cat(
|
||||
[accepted_token_ids, non_spec_token_ids])
|
||||
|
||||
# Rearrange so that results are in the order of the original seq group
|
||||
# metadata.
|
||||
accepted_token_ids[original_indices] = accepted_token_ids.clone()
|
||||
|
||||
return accepted_token_ids
|
||||
|
||||
def _create_output_sampler_list(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
accepted_token_ids: torch.Tensor, # shape: [batch_size, k+1]
|
||||
k: int,
|
||||
) -> List[SamplerOutput]:
|
||||
"""Given the accepted token ids, create a list of SamplerOutput.
|
||||
|
||||
The output is padded with -1 tokens such that each sequence has
|
||||
the same number of outputs.
|
||||
"""
|
||||
seq_ids = get_all_seq_ids(seq_group_metadata_list)
|
||||
|
||||
# shape: [k+1, batch_size]
|
||||
accepted_token_ids_by_step = accepted_token_ids.transpose(0,
|
||||
1).tolist()
|
||||
sampler_output_list = []
|
||||
for token_ids_by_step in accepted_token_ids_by_step:
|
||||
if all(token_id == -1 for token_id in token_ids_by_step):
|
||||
break
|
||||
|
||||
step_output_token_ids = []
|
||||
for token_id, seq_id in zip(token_ids_by_step, seq_ids):
|
||||
step_output_token_ids.append(
|
||||
SequenceGroupOutput(
|
||||
samples=[
|
||||
SequenceOutput(
|
||||
parent_seq_id=seq_id,
|
||||
output_token=token_id,
|
||||
# TODO Add verifier logprobs.
|
||||
logprobs={token_id: 0.0},
|
||||
)
|
||||
],
|
||||
prompt_logprobs=None,
|
||||
))
|
||||
sampler_output_list.append(
|
||||
SamplerOutput(outputs=step_output_token_ids))
|
||||
|
||||
maybe_rejsample_metrics = self._metrics.maybe_collect_rejsample_metrics(
|
||||
k)
|
||||
if maybe_rejsample_metrics is not None:
|
||||
sampler_output_list[
|
||||
0].spec_decode_worker_metrics = maybe_rejsample_metrics
|
||||
|
||||
return sampler_output_list
|
||||
|
||||
@cached_property
|
||||
def _vocab_size(self) -> int:
|
||||
"""Get the vocab size of the model and make sure it's consistent between
|
||||
draft and target workers.
|
||||
"""
|
||||
vocab_sizes = [
|
||||
worker.vocab_size
|
||||
for worker in [self.proposer_worker, self.scorer_worker]
|
||||
]
|
||||
assert all(vocab_sizes[0] == vocab_size for vocab_size in vocab_sizes)
|
||||
return vocab_sizes[0]
|
||||
|
||||
@property
|
||||
def rank(self):
|
||||
return self.scorer_worker.rank
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.scorer_worker.device
|
||||
|
||||
|
||||
def split_num_cache_blocks_evenly(scorer_cache_block_size_bytes: int,
|
||||
proposer_cache_block_size_bytes: int,
|
||||
total_num_gpu_blocks: int) -> int:
|
||||
"""Given total_num_gpu_blocks, the number of GPU blocks that could be
|
||||
allocate to the target model, this function calculates how many blocks
|
||||
should be given to the draft and target model.
|
||||
|
||||
Note that usually the block size, in bytes, of each model is different,
|
||||
as it's a function of number of KV/layer, number of heads, and hidden
|
||||
dimension size.
|
||||
|
||||
Since the target and draft models allocate the same number of blocks, we
|
||||
simply calculate the number of blocks where if allocated by both models,
|
||||
the total memory usage from KV cache is no larger than the number of
|
||||
blocks allocatable by the target model alone.
|
||||
"""
|
||||
new_num_gpu_blocks = int(
|
||||
total_num_gpu_blocks * scorer_cache_block_size_bytes /
|
||||
(proposer_cache_block_size_bytes + scorer_cache_block_size_bytes))
|
||||
|
||||
return new_num_gpu_blocks
|
99
vllm/spec_decode/util.py
Normal file
99
vllm/spec_decode/util.py
Normal file
@ -0,0 +1,99 @@
|
||||
import torch
|
||||
from typing import List, Tuple
|
||||
from vllm.sequence import SequenceGroupMetadata, SamplerOutput
|
||||
from contextlib import contextmanager
|
||||
from itertools import chain
|
||||
|
||||
SeqId = int
|
||||
|
||||
|
||||
def get_all_seq_ids(
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[SeqId]:
|
||||
"""Given a list of SequenceGroupMetadata, create a list of all
|
||||
sequence ids.
|
||||
"""
|
||||
return list(
|
||||
chain.from_iterable([
|
||||
seq_group_metadata.seq_data.keys()
|
||||
for seq_group_metadata in seq_group_metadata_list
|
||||
]))
|
||||
|
||||
|
||||
def split_batch_by_proposal_len(
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
proposal_lens: List[int], select_proposal_len_zero: bool
|
||||
) -> Tuple[List[SequenceGroupMetadata], List[int]]:
|
||||
"""Utility function that splits a batch based on whether the proposal len is
|
||||
zero or not. We should remove this once vLLM supports per-sequence proposal
|
||||
lens in a batch.
|
||||
"""
|
||||
|
||||
if select_proposal_len_zero:
|
||||
predicate = lambda proposal_len: proposal_len == 0
|
||||
else:
|
||||
predicate = lambda proposal_len: proposal_len != 0
|
||||
|
||||
indices = [
|
||||
i for i, (_, proposal_len
|
||||
) in enumerate(zip(seq_group_metadata_list, proposal_lens))
|
||||
if predicate(proposal_len)
|
||||
]
|
||||
seq_groups = [
|
||||
seq_group for seq_group, proposal_len in zip(
|
||||
seq_group_metadata_list, proposal_lens) if predicate(proposal_len)
|
||||
]
|
||||
|
||||
return seq_groups, indices
|
||||
|
||||
|
||||
def sampler_output_to_torch(
|
||||
sampler_output_list: List[SamplerOutput],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Utility function which converts a list of SamplerOutput to tensors.
|
||||
|
||||
Returns:
|
||||
sampled_token_ids: torch.Tensor
|
||||
shape: [batch_size, len(sampler_output_list)]
|
||||
|
||||
sampled_token_probs: torch.Tensor
|
||||
shape: [batch_size, len(sampler_output_list), vocab_size]
|
||||
"""
|
||||
|
||||
# shape: [batch_size, num_sampler_output, vocab_size]
|
||||
sampled_token_probs = torch.stack(
|
||||
[
|
||||
sampler_output.sampled_token_probs
|
||||
for sampler_output in sampler_output_list
|
||||
],
|
||||
dim=0,
|
||||
).transpose(0, 1)
|
||||
|
||||
# shape: [batch_size, num_sampler_output]
|
||||
sampled_token_ids = torch.stack(
|
||||
[
|
||||
sampler_output.sampled_token_ids.flatten()
|
||||
for sampler_output in sampler_output_list
|
||||
],
|
||||
dim=0,
|
||||
).transpose(0, 1)
|
||||
|
||||
return sampled_token_ids, sampled_token_probs
|
||||
|
||||
|
||||
@contextmanager
|
||||
def nvtx_range(msg, *args, **kwargs):
|
||||
"""
|
||||
Context manager / decorator that pushes an NVTX range at the beginning
|
||||
of its scope, and pops it at the end. If extra arguments are given,
|
||||
they are passed as arguments to msg.format().
|
||||
|
||||
If running with cuda graphs, you must enable nsys cuda graph profiling.
|
||||
|
||||
Arguments:
|
||||
msg (string): message to associate with the range
|
||||
"""
|
||||
torch.cuda.nvtx.range_push(msg.format(*args, **kwargs))
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
torch.cuda.nvtx.range_pop()
|
@ -97,8 +97,6 @@ class ModelRunner:
|
||||
f"Loading model weights took {self.model_memory_usage / float(2**30):.4f} GB"
|
||||
)
|
||||
|
||||
vocab_size = self.model.config.vocab_size
|
||||
|
||||
if self.lora_config:
|
||||
assert hasattr(
|
||||
self.model, "supported_lora_modules"
|
||||
@ -111,7 +109,7 @@ class ModelRunner:
|
||||
self.lora_manager = LRUCacheWorkerLoRAManager(
|
||||
self.scheduler_config.max_num_seqs,
|
||||
self.scheduler_config.max_num_batched_tokens +
|
||||
self.scheduler_config.max_paddings, vocab_size,
|
||||
self.scheduler_config.max_paddings, self.vocab_size,
|
||||
self.lora_config, self.device, self.model.embedding_modules,
|
||||
self.model.embedding_padding_modules)
|
||||
self.model = self.lora_manager.create_lora_manager(self.model)
|
||||
@ -607,8 +605,7 @@ class ModelRunner:
|
||||
@torch.inference_mode()
|
||||
def profile_run(self) -> None:
|
||||
# Enable top-k sampling to reflect the accurate memory usage.
|
||||
vocab_size = self.model_config.get_vocab_size()
|
||||
sampling_params = SamplingParams(top_p=0.99, top_k=vocab_size - 1)
|
||||
sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
|
||||
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
|
||||
max_num_seqs = self.scheduler_config.max_num_seqs
|
||||
|
||||
@ -774,6 +771,10 @@ class ModelRunner:
|
||||
self.graph_runners.clear()
|
||||
self.cupy_nccl_backend = None
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
return self.model_config.get_vocab_size()
|
||||
|
||||
|
||||
class CUDAGraphRunner:
|
||||
|
||||
|
@ -1,178 +0,0 @@
|
||||
from typing import List, Dict
|
||||
import copy
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.worker.worker import Worker
|
||||
|
||||
|
||||
class MultiStepWorker(Worker):
|
||||
"""The MultiStepWorker is equivalent to a Worker except that it allows
|
||||
multiple forward passes in a single call, assuming the scheduler has
|
||||
allocated enough space to store the additional KV. This reduces overhead
|
||||
by invoking the scheduler less.
|
||||
|
||||
The MultiStepWorker does not support cache swap operations, or beam search.
|
||||
Cache swap operations do not require large modifications. On the other hand,
|
||||
beam search requires memory allocations during sequence forks and thus
|
||||
requires more thought for MultiStepWorker support.
|
||||
"""
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model_multi_step(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
num_steps: int,
|
||||
) -> List[SamplerOutput]:
|
||||
"""Run the model forward pass num_steps times. Returns the list of
|
||||
sampler output, one per model forward pass.
|
||||
"""
|
||||
self._raise_if_unsupported(seq_group_metadata_list, blocks_to_swap_in,
|
||||
blocks_to_swap_out, blocks_to_copy)
|
||||
|
||||
# Shallow copy input data so modifications (such as appending tokens)
|
||||
# do not cause side-effects.
|
||||
copied_seq_group_metadata_list = self._shallow_copy_inputs(
|
||||
seq_group_metadata_list)
|
||||
|
||||
# Assert enough KV space for num_steps tokens per sequence.
|
||||
self._assert_enough_kv_space(seq_group_metadata_list, num_steps)
|
||||
|
||||
# Run model num_steps times.
|
||||
model_outputs = []
|
||||
for _ in range(num_steps):
|
||||
model_output = super().execute_model(
|
||||
seq_group_metadata_list=copied_seq_group_metadata_list,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
)
|
||||
|
||||
self._append_new_tokens(model_output,
|
||||
copied_seq_group_metadata_list)
|
||||
model_outputs.append(model_output)
|
||||
|
||||
return model_outputs
|
||||
|
||||
def _append_new_tokens(
|
||||
self, model_output: SamplerOutput,
|
||||
seq_group_metadata_list: SequenceGroupMetadata) -> None:
|
||||
"""Given model output from a single run, append the tokens to the
|
||||
sequences. This is normally done outside of the worker, but it is
|
||||
required if the worker is to perform multiple forward passes.
|
||||
"""
|
||||
for seq_group_metadata, sequence_group_outputs in zip(
|
||||
seq_group_metadata_list, model_output):
|
||||
seq_group_metadata.is_prompt = False
|
||||
|
||||
for seq_output in sequence_group_outputs.samples:
|
||||
# NOTE: Beam search is not supported, so we can assume that
|
||||
# parent_seq_id == seq_id.
|
||||
seq = seq_group_metadata.seq_data[seq_output.parent_seq_id]
|
||||
|
||||
token_id = seq_output.output_token
|
||||
token_logprob = seq_output.logprobs[token_id]
|
||||
|
||||
seq.append_token_id(token_id, token_logprob.logprob)
|
||||
|
||||
def _shallow_copy_inputs(
|
||||
self, seq_group_metadata_list: List[SequenceGroupMetadata]
|
||||
) -> List[SequenceGroupMetadata]:
|
||||
"""Copy input data structures to remove side-effects when input data
|
||||
structures are shared with other modules.
|
||||
|
||||
The multi-step worker must be able to append tokens to sequences after
|
||||
a forward pass. This necessitates modification of the data structures
|
||||
used by the worker. Since these data structures are shared with other
|
||||
parts of vLLM, like the scheduler, we must take care not to introduce
|
||||
unexpected side-effects.
|
||||
|
||||
When Ray is used to orchestrate worker processes (such as when the
|
||||
tensor-parallel degree is >1), this is not a problem because the input
|
||||
datastructures will be serialized and created anew in the worker
|
||||
process.
|
||||
|
||||
However, when Ray is not used to orchestrate the worker processes (such
|
||||
as when the tensor-parallel degree is 1), this is a problem. We avoid
|
||||
the problem by shallow-copying the input datastructures (specifically,
|
||||
the parts that will change in multiple steps).
|
||||
"""
|
||||
|
||||
# Shallow-copy the list of SequenceGroupMetadata. This allows us to
|
||||
# append tokens and change is_prompt without external side-effects.
|
||||
new_seq_group_metadata_list = []
|
||||
|
||||
for old_seq_group_metadata in seq_group_metadata_list:
|
||||
# We must shallow-copy seq_group_metadata as is_prompt could change.
|
||||
seq_group_metadata = copy.copy(old_seq_group_metadata)
|
||||
new_seq_group_metadata_list.append(seq_group_metadata)
|
||||
|
||||
# We must shallow-copy seq_data as we will append token ids
|
||||
new_seq_data = {}
|
||||
for seq_id, old_seq_data in seq_group_metadata.seq_data.items():
|
||||
new_seq_data[seq_id] = copy.copy(old_seq_data)
|
||||
new_seq_data[
|
||||
seq_id].output_token_ids = old_seq_data.output_token_ids[:]
|
||||
|
||||
seq_group_metadata.seq_data = new_seq_data
|
||||
|
||||
return new_seq_group_metadata_list
|
||||
|
||||
def _assert_enough_kv_space(
|
||||
self, seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
num_steps: int) -> None:
|
||||
"""Assert there are enough physical blocks per sequence to store the
|
||||
current KV plus additional KV from num_steps tokens.
|
||||
"""
|
||||
assert self.model_runner.block_size is not None
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
# Only one seq_id is guaranteed because there is no beam search.
|
||||
seq_id = list(seq_group_metadata.seq_data.keys())[0]
|
||||
seq = seq_group_metadata.seq_data[seq_id]
|
||||
|
||||
# After num_steps, the seq len will be the current seq len
|
||||
# plus one token per step.
|
||||
final_seq_len = seq.get_len() + num_steps
|
||||
|
||||
# We will have final_seq_len - 1 KV because vLLM saves KV for a
|
||||
# token in the iteration after the token was generated.
|
||||
required_num_kv_slots = final_seq_len - 1
|
||||
|
||||
# The allocated number of kv slots is the number of allocated blocks
|
||||
# times the number of slots of block.
|
||||
number_physical_blocks = len(
|
||||
seq_group_metadata.block_tables[seq_id])
|
||||
allocated_kv_slots = (number_physical_blocks *
|
||||
self.model_runner.block_size)
|
||||
|
||||
if required_num_kv_slots > allocated_kv_slots:
|
||||
request_id = seq_group_metadata.request_id
|
||||
raise ValueError(
|
||||
"The worker attempted to run "
|
||||
f"{num_steps} times but found insufficient KV space for "
|
||||
f"{request_id=} {seq_id=}. ({allocated_kv_slots=} "
|
||||
f"{required_num_kv_slots=}).")
|
||||
|
||||
def _raise_if_unsupported(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
) -> None:
|
||||
"""MultiStepWorker does not yet implement support for cache swap
|
||||
operations or beam search.
|
||||
"""
|
||||
if any([blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy]):
|
||||
raise NotImplementedError(
|
||||
"MultiStepWorker does not support cache operations")
|
||||
|
||||
if any(
|
||||
len(seq_group_metadata.seq_data.keys()) != 1
|
||||
for seq_group_metadata in seq_group_metadata_list):
|
||||
raise NotImplementedError(
|
||||
"MultiStepWorker does not support beam search.")
|
@ -130,8 +130,8 @@ class Worker:
|
||||
# GPU did not change their memory usage during the profiling.
|
||||
peak_memory = self.init_gpu_memory - free_gpu_memory
|
||||
|
||||
cache_block_size = CacheEngine.get_cache_block_size(
|
||||
block_size, cache_dtype, self.model_config, self.parallel_config)
|
||||
cache_block_size = self.get_cache_block_size_bytes(
|
||||
block_size, cache_dtype)
|
||||
num_gpu_blocks = int(
|
||||
(total_gpu_memory * gpu_memory_utilization - peak_memory) //
|
||||
cache_block_size)
|
||||
@ -232,6 +232,22 @@ class Worker:
|
||||
def list_loras(self) -> Set[int]:
|
||||
return self.model_runner.list_loras()
|
||||
|
||||
@property
|
||||
def max_model_len(self) -> int:
|
||||
return self.model_config.max_model_len
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
return self.model_runner.vocab_size
|
||||
|
||||
def get_cache_block_size_bytes(self, block_size: int,
|
||||
cache_dtype: str) -> int:
|
||||
"""Get the size of the KV cache block size in bytes.
|
||||
"""
|
||||
return CacheEngine.get_cache_block_size(block_size, cache_dtype,
|
||||
self.model_config,
|
||||
self.parallel_config)
|
||||
|
||||
|
||||
def init_distributed_environment(
|
||||
parallel_config: ParallelConfig,
|
||||
|
Loading…
x
Reference in New Issue
Block a user