[Speculative decoding 3/9] Worker which speculates, scores, and applies rejection sampling (#3103)
This commit is contained in:
parent
f48c6791b7
commit
8437bae6ef
@ -28,7 +28,7 @@ steps:
|
|||||||
num_gpus: 2 # only support 1 or 2 for now.
|
num_gpus: 2 # only support 1 or 2 for now.
|
||||||
|
|
||||||
- label: Engine Test
|
- label: Engine Test
|
||||||
command: pytest -v -s engine
|
command: pytest -v -s engine test_sequence.py
|
||||||
|
|
||||||
- label: Entrypoints Test
|
- label: Entrypoints Test
|
||||||
command: pytest -v -s entrypoints
|
command: pytest -v -s entrypoints
|
||||||
@ -52,6 +52,9 @@ steps:
|
|||||||
- label: Worker Test
|
- label: Worker Test
|
||||||
command: pytest -v -s worker
|
command: pytest -v -s worker
|
||||||
|
|
||||||
|
- label: Speculative decoding tests
|
||||||
|
command: pytest -v -s spec_decode
|
||||||
|
|
||||||
- label: LoRA Test
|
- label: LoRA Test
|
||||||
command: pytest -v -s lora --forked
|
command: pytest -v -s lora --forked
|
||||||
|
|
||||||
|
95
tests/spec_decode/test_batch_expansion.py
Normal file
95
tests/spec_decode/test_batch_expansion.py
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
import torch
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
|
||||||
|
|
||||||
|
from .utils import mock_worker, create_seq_group_metadata_from_prompts
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('num_target_seq_ids', [100])
|
||||||
|
def test_create_target_seq_id_iterator(num_target_seq_ids: int):
|
||||||
|
"""Verify all new sequence ids are greater than all input
|
||||||
|
seq ids.
|
||||||
|
"""
|
||||||
|
scorer = BatchExpansionTop1Scorer(mock_worker(), 'cuda:0', 32_000)
|
||||||
|
|
||||||
|
all_seq_ids = [
|
||||||
|
[1, 3, 5, 7],
|
||||||
|
list(range(100)) + [0],
|
||||||
|
[100],
|
||||||
|
]
|
||||||
|
|
||||||
|
for seq_ids in all_seq_ids:
|
||||||
|
max_seq_id = max(seq_ids)
|
||||||
|
iterator = scorer._create_target_seq_id_iterator(seq_ids) # pylint: disable=protected-access
|
||||||
|
for _ in range(num_target_seq_ids):
|
||||||
|
assert next(iterator) > max_seq_id
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('k', [1, 2, 6])
|
||||||
|
def test_get_token_ids_to_score(k: int):
|
||||||
|
"""Verify correct tokens are selected for scoring.
|
||||||
|
"""
|
||||||
|
proposal_token_ids = torch.tensor(
|
||||||
|
list(range(k)),
|
||||||
|
dtype=torch.int64,
|
||||||
|
device='cuda',
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_output = [
|
||||||
|
[],
|
||||||
|
]
|
||||||
|
for i in range(proposal_token_ids.shape[0]):
|
||||||
|
expected_output.append(proposal_token_ids[:i + 1].tolist())
|
||||||
|
|
||||||
|
scorer = BatchExpansionTop1Scorer(mock_worker(), 'cuda:0', 32_000)
|
||||||
|
actual_output = scorer._get_token_ids_to_score(proposal_token_ids) # pylint: disable=protected-access
|
||||||
|
|
||||||
|
actual_output = [
|
||||||
|
x.tolist() if isinstance(x, torch.Tensor) else x for x in actual_output
|
||||||
|
]
|
||||||
|
|
||||||
|
assert actual_output == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('k', [1, 2, 6])
|
||||||
|
def test_create_single_target_seq_group_metadata(k: int):
|
||||||
|
"""Verify correct creation of a batch-expanded seq group metadata.
|
||||||
|
"""
|
||||||
|
|
||||||
|
prompt_tokens = [1, 2, 3]
|
||||||
|
prev_output_tokens = [4, 5, 6]
|
||||||
|
|
||||||
|
token_ids = list(range(k))
|
||||||
|
|
||||||
|
num_tokens_processed = len(prompt_tokens) + len(prev_output_tokens) - 1
|
||||||
|
|
||||||
|
final_seq_len = len(prompt_tokens) + len(prev_output_tokens) + len(
|
||||||
|
token_ids)
|
||||||
|
|
||||||
|
block_size = 32
|
||||||
|
input_seq_group_metadata = create_seq_group_metadata_from_prompts(
|
||||||
|
[prompt_tokens], 2048 // block_size, block_size, [final_seq_len],
|
||||||
|
[prev_output_tokens], [num_tokens_processed])[0]
|
||||||
|
|
||||||
|
input_seq_id = list(input_seq_group_metadata.seq_data.keys())[0]
|
||||||
|
target_seq_id = 100
|
||||||
|
|
||||||
|
scorer = BatchExpansionTop1Scorer(mock_worker(), 'cuda:0', 32_000)
|
||||||
|
output = scorer._create_single_target_seq_group_metadata( # pylint: disable=protected-access
|
||||||
|
input_seq_group_metadata,
|
||||||
|
input_seq_id,
|
||||||
|
target_seq_id,
|
||||||
|
token_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert output.request_id == input_seq_group_metadata.request_id
|
||||||
|
assert len(output.seq_data) == 1
|
||||||
|
assert output.seq_data[target_seq_id].get_prompt_token_ids(
|
||||||
|
) == prompt_tokens
|
||||||
|
assert output.seq_data[target_seq_id].get_output_token_ids(
|
||||||
|
) == prev_output_tokens + token_ids
|
||||||
|
|
||||||
|
assert len(output.block_tables) == 1
|
||||||
|
assert output.block_tables[
|
||||||
|
target_seq_id] == input_seq_group_metadata.block_tables[input_seq_id]
|
157
tests/spec_decode/test_metrics.py
Normal file
157
tests/spec_decode/test_metrics.py
Normal file
@ -0,0 +1,157 @@
|
|||||||
|
import torch
|
||||||
|
import math
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
from vllm.spec_decode.metrics import AsyncMetricsCollector
|
||||||
|
|
||||||
|
|
||||||
|
def test_initial_call_returns_none():
|
||||||
|
"""Expect first call to get metrics to return None.
|
||||||
|
"""
|
||||||
|
rej_sampler = MagicMock()
|
||||||
|
rej_sampler.num_accepted_tokens = torch.tensor(0,
|
||||||
|
dtype=torch.long,
|
||||||
|
device='cuda')
|
||||||
|
rej_sampler.num_emitted_tokens = torch.tensor(0,
|
||||||
|
dtype=torch.long,
|
||||||
|
device='cuda')
|
||||||
|
rej_sampler.num_draft_tokens = 0
|
||||||
|
|
||||||
|
collector = AsyncMetricsCollector(rej_sampler)
|
||||||
|
collector.init_gpu_tensors(rank=0)
|
||||||
|
maybe_metrics = collector.maybe_collect_rejsample_metrics(k=5)
|
||||||
|
assert maybe_metrics is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_second_call_returns_metrics():
|
||||||
|
"""Expect second call to not return None.
|
||||||
|
"""
|
||||||
|
rej_sampler = MagicMock()
|
||||||
|
rej_sampler.num_accepted_tokens = torch.tensor(0,
|
||||||
|
dtype=torch.long,
|
||||||
|
device='cuda')
|
||||||
|
rej_sampler.num_emitted_tokens = torch.tensor(0,
|
||||||
|
dtype=torch.long,
|
||||||
|
device='cuda')
|
||||||
|
rej_sampler.num_draft_tokens = 0
|
||||||
|
|
||||||
|
collect_interval_s = 5.0
|
||||||
|
timer = MagicMock()
|
||||||
|
timer.side_effect = [
|
||||||
|
0.0, collect_interval_s + 0.1, collect_interval_s + 0.2
|
||||||
|
]
|
||||||
|
|
||||||
|
collector = AsyncMetricsCollector(rejection_sampler=rej_sampler,
|
||||||
|
timer=timer,
|
||||||
|
collect_interval_s=collect_interval_s)
|
||||||
|
collector.init_gpu_tensors(rank=0)
|
||||||
|
_ = collector.maybe_collect_rejsample_metrics(k=5)
|
||||||
|
metrics = collector.maybe_collect_rejsample_metrics(k=5)
|
||||||
|
assert metrics is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("rank", [1, 2, 3, 4])
|
||||||
|
def test_nonzero_rank_noop(rank):
|
||||||
|
"""Verify nonzero ranks don't collect metrics.
|
||||||
|
"""
|
||||||
|
rej_sampler = MagicMock()
|
||||||
|
rej_sampler.num_accepted_tokens = torch.tensor(0,
|
||||||
|
dtype=torch.long,
|
||||||
|
device='cuda')
|
||||||
|
rej_sampler.num_emitted_tokens = torch.tensor(0,
|
||||||
|
dtype=torch.long,
|
||||||
|
device='cuda')
|
||||||
|
rej_sampler.num_draft_tokens = 0
|
||||||
|
|
||||||
|
collector = AsyncMetricsCollector(rej_sampler)
|
||||||
|
collector.init_gpu_tensors(rank=rank)
|
||||||
|
_ = collector.maybe_collect_rejsample_metrics(k=5)
|
||||||
|
metrics = collector.maybe_collect_rejsample_metrics(k=5)
|
||||||
|
assert metrics is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_noop_until_time():
|
||||||
|
"""Verify metrics aren't collected until enough time passes.
|
||||||
|
"""
|
||||||
|
rej_sampler = MagicMock()
|
||||||
|
rej_sampler.num_accepted_tokens = torch.tensor(0,
|
||||||
|
dtype=torch.long,
|
||||||
|
device='cuda')
|
||||||
|
rej_sampler.num_emitted_tokens = torch.tensor(0,
|
||||||
|
dtype=torch.long,
|
||||||
|
device='cuda')
|
||||||
|
rej_sampler.num_draft_tokens = 0
|
||||||
|
|
||||||
|
collect_interval_s = 5.0
|
||||||
|
timer = MagicMock()
|
||||||
|
timer.side_effect = [
|
||||||
|
0.0, collect_interval_s - 0.1, collect_interval_s - 0.1,
|
||||||
|
collect_interval_s + 0.1, collect_interval_s + 0.1
|
||||||
|
]
|
||||||
|
|
||||||
|
collector = AsyncMetricsCollector(rejection_sampler=rej_sampler,
|
||||||
|
timer=timer,
|
||||||
|
collect_interval_s=collect_interval_s)
|
||||||
|
collector.init_gpu_tensors(rank=0)
|
||||||
|
|
||||||
|
_ = collector.maybe_collect_rejsample_metrics(k=5)
|
||||||
|
metrics = collector.maybe_collect_rejsample_metrics(k=5)
|
||||||
|
assert metrics is None
|
||||||
|
|
||||||
|
_ = collector.maybe_collect_rejsample_metrics(k=5)
|
||||||
|
metrics = collector.maybe_collect_rejsample_metrics(k=5)
|
||||||
|
assert metrics is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("has_data", [True, False])
|
||||||
|
def test_initial_metrics_has_correct_values(has_data: bool):
|
||||||
|
"""Test correctness of metrics data.
|
||||||
|
"""
|
||||||
|
if has_data:
|
||||||
|
num_accepted_tokens = 103
|
||||||
|
num_emitted_tokens = 104
|
||||||
|
num_draft_tokens = 105
|
||||||
|
else:
|
||||||
|
num_accepted_tokens = 0
|
||||||
|
num_emitted_tokens = 0
|
||||||
|
num_draft_tokens = 0
|
||||||
|
k = 5
|
||||||
|
|
||||||
|
num_possible_tokens = AsyncMetricsCollector.get_max_num_accepted_tokens(
|
||||||
|
num_draft_tokens, k)
|
||||||
|
|
||||||
|
rej_sampler = MagicMock()
|
||||||
|
rej_sampler.num_accepted_tokens = torch.tensor(num_accepted_tokens,
|
||||||
|
dtype=torch.long,
|
||||||
|
device='cuda')
|
||||||
|
rej_sampler.num_emitted_tokens = torch.tensor(num_emitted_tokens,
|
||||||
|
dtype=torch.long,
|
||||||
|
device='cuda')
|
||||||
|
rej_sampler.num_draft_tokens = num_draft_tokens
|
||||||
|
|
||||||
|
collect_interval_s = 5.0
|
||||||
|
timer = MagicMock()
|
||||||
|
timer.side_effect = [
|
||||||
|
0.0, collect_interval_s + 0.1, collect_interval_s + 0.2
|
||||||
|
]
|
||||||
|
|
||||||
|
collector = AsyncMetricsCollector(rejection_sampler=rej_sampler,
|
||||||
|
timer=timer,
|
||||||
|
collect_interval_s=collect_interval_s)
|
||||||
|
collector.init_gpu_tensors(rank=0)
|
||||||
|
_ = collector.maybe_collect_rejsample_metrics(k)
|
||||||
|
metrics = collector.maybe_collect_rejsample_metrics(k)
|
||||||
|
|
||||||
|
assert metrics.num_spec_tokens == k
|
||||||
|
assert metrics.accepted_tokens == num_accepted_tokens
|
||||||
|
assert metrics.draft_tokens == num_draft_tokens
|
||||||
|
assert metrics.emitted_tokens == num_emitted_tokens
|
||||||
|
|
||||||
|
if has_data:
|
||||||
|
assert metrics.draft_acceptance_rate == num_accepted_tokens / num_draft_tokens
|
||||||
|
assert metrics.system_efficiency == num_emitted_tokens / num_possible_tokens
|
||||||
|
else:
|
||||||
|
assert math.isnan(metrics.draft_acceptance_rate)
|
||||||
|
assert math.isnan(metrics.system_efficiency)
|
@ -3,14 +3,15 @@ import random
|
|||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
from vllm.worker.spec_decode.multi_step_worker import MultiStepWorker
|
from vllm.spec_decode.multi_step_worker import MultiStepWorker, DraftModelTop1Proposer
|
||||||
from vllm.worker.worker import Worker
|
from vllm.worker.worker import Worker
|
||||||
from vllm.model_executor.utils import set_random_seed
|
from vllm.model_executor.utils import set_random_seed
|
||||||
|
from vllm.sequence import SamplerOutput
|
||||||
|
|
||||||
from .utils import (create_execute_model_data, create_worker,
|
from .utils import (create_execute_model_data, create_worker,
|
||||||
create_seq_group_metadata_from_prompts, zero_kv_cache,
|
create_seq_group_metadata_from_prompts, zero_kv_cache,
|
||||||
patch_execute_model_with_seeds,
|
patch_execute_model_with_seeds,
|
||||||
assert_logprobs_dict_allclose)
|
assert_logprobs_dict_allclose, create_batch)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('num_steps', list(range(1, 17)))
|
@pytest.mark.parametrize('num_steps', list(range(1, 17)))
|
||||||
@ -259,3 +260,160 @@ def test_same_output_for_multi_step():
|
|||||||
multi_step_output_logprobs, single_step_output_logprobs):
|
multi_step_output_logprobs, single_step_output_logprobs):
|
||||||
assert_logprobs_dict_allclose(multi_step_logprobs,
|
assert_logprobs_dict_allclose(multi_step_logprobs,
|
||||||
single_step_logprobs)
|
single_step_logprobs)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_draft_proposals_full_speculation_len():
|
||||||
|
"""Verify DraftModelTop1Proposer correctly handles case where all sequences
|
||||||
|
can speculate.
|
||||||
|
"""
|
||||||
|
k = 10
|
||||||
|
batch_size = 32
|
||||||
|
vocab_size = 32_000
|
||||||
|
device = 'cuda:0'
|
||||||
|
|
||||||
|
draft_worker = MagicMock()
|
||||||
|
proposer = DraftModelTop1Proposer(
|
||||||
|
draft_worker=draft_worker,
|
||||||
|
device=device,
|
||||||
|
max_model_len=2048,
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
)
|
||||||
|
draft_worker.execute_model_multi_step.return_value = [
|
||||||
|
SamplerOutput(
|
||||||
|
outputs=[],
|
||||||
|
sampled_token_probs=torch.rand(batch_size,
|
||||||
|
vocab_size,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.float32),
|
||||||
|
sampled_token_ids=torch.randint(low=0,
|
||||||
|
high=vocab_size,
|
||||||
|
size=(batch_size, ),
|
||||||
|
device=device,
|
||||||
|
dtype=torch.long),
|
||||||
|
) for _ in range(k)
|
||||||
|
]
|
||||||
|
|
||||||
|
execute_model_data, _, _ = create_batch(batch_size, k)
|
||||||
|
|
||||||
|
proposals = proposer.get_proposals(
|
||||||
|
**execute_model_data.to_dict(),
|
||||||
|
max_proposal_len=k,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||||
|
assert torch.is_tensor(proposals.proposal_probs)
|
||||||
|
|
||||||
|
assert proposals.proposal_token_ids.shape == torch.Size([batch_size, k])
|
||||||
|
assert proposals.proposal_probs.shape[:-1] == torch.Size([batch_size, k])
|
||||||
|
|
||||||
|
assert proposals.proposal_lens.shape == torch.Size([batch_size])
|
||||||
|
assert proposals.proposal_lens.tolist() == [k for _ in range(batch_size)]
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_draft_proposals_no_speculations():
|
||||||
|
"""Verify DraftModelTop1Proposer correctly handles case where no sequences
|
||||||
|
can speculate.
|
||||||
|
"""
|
||||||
|
k = 10
|
||||||
|
batch_size = 32
|
||||||
|
vocab_size = 32_000
|
||||||
|
device = 'cuda:0'
|
||||||
|
prompt_len = 10
|
||||||
|
|
||||||
|
draft_worker = MagicMock()
|
||||||
|
proposer = DraftModelTop1Proposer(
|
||||||
|
draft_worker=draft_worker,
|
||||||
|
device=device,
|
||||||
|
max_model_len=prompt_len + k - 1,
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
execute_model_data, _, _ = create_batch(batch_size,
|
||||||
|
k,
|
||||||
|
prompt_len=prompt_len)
|
||||||
|
|
||||||
|
proposals = proposer.get_proposals(
|
||||||
|
**execute_model_data.to_dict(),
|
||||||
|
max_proposal_len=k,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||||
|
assert torch.is_tensor(proposals.proposal_probs)
|
||||||
|
|
||||||
|
assert proposals.proposal_token_ids.shape == torch.Size([0, k])
|
||||||
|
assert proposals.proposal_probs.shape[:-1] == torch.Size([0, k])
|
||||||
|
|
||||||
|
assert proposals.proposal_lens.shape == torch.Size([batch_size])
|
||||||
|
assert proposals.proposal_lens.tolist() == [0 for _ in range(batch_size)]
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_draft_proposals_mixed_k():
|
||||||
|
"""Verify DraftModelTop1Proposer correctly handles case some sequences can
|
||||||
|
speculate and some can't.
|
||||||
|
"""
|
||||||
|
k = 10
|
||||||
|
batch_size = 32
|
||||||
|
vocab_size = 32_000
|
||||||
|
device = 'cuda:0'
|
||||||
|
|
||||||
|
small_prompt_len = 5
|
||||||
|
long_prompt_len = 10
|
||||||
|
prev_output_token_len = 20
|
||||||
|
|
||||||
|
expected_num_proposal_seqs = 6
|
||||||
|
expected_num_no_proposal_seqs = batch_size - expected_num_proposal_seqs
|
||||||
|
|
||||||
|
prompt_len = [
|
||||||
|
small_prompt_len for _ in range(expected_num_proposal_seqs - 1)
|
||||||
|
] + [long_prompt_len
|
||||||
|
for _ in range(expected_num_no_proposal_seqs)] + [small_prompt_len]
|
||||||
|
|
||||||
|
draft_worker = MagicMock()
|
||||||
|
proposer = DraftModelTop1Proposer(
|
||||||
|
draft_worker=draft_worker,
|
||||||
|
device=device,
|
||||||
|
max_model_len=long_prompt_len + prev_output_token_len + k - 1,
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
draft_worker.execute_model_multi_step.return_value = [
|
||||||
|
SamplerOutput(
|
||||||
|
outputs=[],
|
||||||
|
sampled_token_probs=torch.rand(expected_num_proposal_seqs,
|
||||||
|
vocab_size,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.float32),
|
||||||
|
sampled_token_ids=torch.randint(
|
||||||
|
low=0,
|
||||||
|
high=vocab_size,
|
||||||
|
size=(expected_num_proposal_seqs, ),
|
||||||
|
device=device,
|
||||||
|
dtype=torch.long),
|
||||||
|
) for _ in range(k)
|
||||||
|
]
|
||||||
|
|
||||||
|
execute_model_data, _, _ = create_batch(
|
||||||
|
batch_size,
|
||||||
|
k,
|
||||||
|
prompt_len=prompt_len,
|
||||||
|
prev_output_token_len=prev_output_token_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
proposals = proposer.get_proposals(
|
||||||
|
**execute_model_data.to_dict(),
|
||||||
|
max_proposal_len=k,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||||
|
assert torch.is_tensor(proposals.proposal_probs)
|
||||||
|
|
||||||
|
assert proposals.proposal_token_ids.shape == torch.Size([batch_size, k])
|
||||||
|
assert proposals.proposal_probs.shape[:-1] == torch.Size([batch_size, k])
|
||||||
|
|
||||||
|
assert proposals.proposal_lens.shape == torch.Size([batch_size])
|
||||||
|
assert proposals.proposal_lens.tolist() == [
|
||||||
|
k for _ in range(expected_num_proposal_seqs - 1)
|
||||||
|
] + [0 for _ in range(expected_num_no_proposal_seqs)] + [k]
|
591
tests/spec_decode/test_spec_decode_worker.py
Normal file
591
tests/spec_decode/test_spec_decode_worker.py
Normal file
@ -0,0 +1,591 @@
|
|||||||
|
import torch
|
||||||
|
import random
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||||
|
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker, split_num_cache_blocks_evenly
|
||||||
|
from vllm.spec_decode.interfaces import SpeculativeProposals
|
||||||
|
from vllm.model_executor.utils import set_random_seed
|
||||||
|
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||||
|
from .utils import mock_worker, create_batch, ExecuteModelData, create_sampler_output_list
|
||||||
|
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics, AsyncMetricsCollector
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('k', [1, 2, 6])
|
||||||
|
@pytest.mark.parametrize('batch_size', [1, 2, 32])
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_correctly_calls_draft_model(k: int, batch_size: int):
|
||||||
|
"""Verify SpecDecodeWorker calls the draft worker with correct
|
||||||
|
inputs. Everything else is mocked out.
|
||||||
|
"""
|
||||||
|
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||||
|
target_worker = mock_worker()
|
||||||
|
rejection_sampler = MagicMock(spec=RejectionSampler)
|
||||||
|
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||||
|
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
||||||
|
metrics_collector)
|
||||||
|
|
||||||
|
exception_secret = 'artifical stop'
|
||||||
|
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
|
||||||
|
|
||||||
|
execute_model_data, _, _ = create_batch(batch_size, k)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match=exception_secret):
|
||||||
|
worker.execute_model(**execute_model_data.to_dict(), num_spec_tokens=k)
|
||||||
|
|
||||||
|
call_args_list = draft_worker.get_spec_proposals.call_args_list
|
||||||
|
assert len(call_args_list) == 1
|
||||||
|
|
||||||
|
for args, _ in call_args_list:
|
||||||
|
(seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out,
|
||||||
|
blocks_to_copy, actual_k) = args
|
||||||
|
actual_execute_model_data = ExecuteModelData(seq_group_metadata_list,
|
||||||
|
blocks_to_swap_in,
|
||||||
|
blocks_to_swap_out,
|
||||||
|
blocks_to_copy)
|
||||||
|
assert actual_execute_model_data == execute_model_data
|
||||||
|
assert actual_k == k
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('k', [1, 2, 6])
|
||||||
|
@pytest.mark.parametrize('batch_size', [1, 2, 32])
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_correctly_calls_target_model(k: int, batch_size: int):
|
||||||
|
"""Verify SpecDecodeWorker calls the target model with correct
|
||||||
|
inputs. Everything else is mocked out.
|
||||||
|
"""
|
||||||
|
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||||
|
target_worker = mock_worker()
|
||||||
|
rejection_sampler = MagicMock(spec=RejectionSampler)
|
||||||
|
rejection_sampler.token_id_dtype = torch.int64
|
||||||
|
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||||
|
|
||||||
|
draft_worker.device = 'cuda'
|
||||||
|
target_worker.device = 'cuda'
|
||||||
|
|
||||||
|
set_random_seed(1)
|
||||||
|
|
||||||
|
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
||||||
|
metrics_collector)
|
||||||
|
worker.init_model()
|
||||||
|
|
||||||
|
vocab_size = 32_000
|
||||||
|
|
||||||
|
proposal_token_ids = torch.randint(low=0,
|
||||||
|
high=vocab_size,
|
||||||
|
size=(batch_size, k),
|
||||||
|
dtype=torch.int64,
|
||||||
|
device='cuda')
|
||||||
|
proposal_probs = torch.rand(batch_size,
|
||||||
|
k,
|
||||||
|
vocab_size,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device='cuda')
|
||||||
|
proposal_lens = torch.ones(batch_size, dtype=torch.int64,
|
||||||
|
device='cuda') * k
|
||||||
|
|
||||||
|
execute_model_data, prompts, prev_output_tokens = create_batch(
|
||||||
|
batch_size, k)
|
||||||
|
|
||||||
|
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
|
||||||
|
proposal_token_ids=proposal_token_ids,
|
||||||
|
proposal_probs=proposal_probs,
|
||||||
|
proposal_lens=proposal_lens)
|
||||||
|
|
||||||
|
exception_secret = 'artifical stop'
|
||||||
|
target_worker.execute_model.side_effect = ValueError(exception_secret)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match=exception_secret):
|
||||||
|
worker.execute_model(**execute_model_data.to_dict(), num_spec_tokens=k)
|
||||||
|
|
||||||
|
seen_contexts = []
|
||||||
|
|
||||||
|
call_args_list = target_worker.execute_model.call_args_list
|
||||||
|
assert len(call_args_list) == 1
|
||||||
|
for args, kwargs in call_args_list:
|
||||||
|
target_execute_model_data = ExecuteModelData.from_dict(kwargs)
|
||||||
|
|
||||||
|
assert len(target_execute_model_data.seq_group_metadata_list) == (
|
||||||
|
k + 1) * batch_size
|
||||||
|
for seq_group_metadata in (
|
||||||
|
target_execute_model_data.seq_group_metadata_list):
|
||||||
|
for seq_data in seq_group_metadata.seq_data.values():
|
||||||
|
seen_contexts.append(seq_data.get_token_ids())
|
||||||
|
|
||||||
|
expected_seen_contexts = []
|
||||||
|
|
||||||
|
for prompt, prev_generated, draft_tokens in zip(
|
||||||
|
prompts, prev_output_tokens, proposal_token_ids.tolist()):
|
||||||
|
|
||||||
|
for i in range(len(draft_tokens) + 1):
|
||||||
|
expected_seen_contexts.append(prompt + prev_generated +
|
||||||
|
draft_tokens[:i])
|
||||||
|
|
||||||
|
seen_contexts.sort()
|
||||||
|
expected_seen_contexts.sort()
|
||||||
|
assert expected_seen_contexts == seen_contexts
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('k', [1, 2, 6])
|
||||||
|
@pytest.mark.parametrize('batch_size', [1, 2, 32])
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
|
||||||
|
"""Verify SpecDecodeWorker calls the rejection sampler with
|
||||||
|
correct inputs. Everything else is mocked out.
|
||||||
|
"""
|
||||||
|
vocab_size = 32_000
|
||||||
|
|
||||||
|
draft_worker = mock_worker(cls=MultiStepWorker, vocab_size=vocab_size)
|
||||||
|
target_worker = mock_worker(vocab_size=vocab_size)
|
||||||
|
rejection_sampler = MagicMock(spec=RejectionSampler)
|
||||||
|
rejection_sampler.token_id_dtype = torch.int64
|
||||||
|
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||||
|
draft_worker.device = 'cuda'
|
||||||
|
target_worker.device = 'cuda'
|
||||||
|
|
||||||
|
set_random_seed(1)
|
||||||
|
|
||||||
|
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
||||||
|
metrics_collector)
|
||||||
|
worker.init_model()
|
||||||
|
|
||||||
|
proposal_token_ids = torch.randint(low=0,
|
||||||
|
high=vocab_size,
|
||||||
|
size=(batch_size, k),
|
||||||
|
dtype=torch.int64,
|
||||||
|
device='cuda')
|
||||||
|
proposal_probs = torch.rand(batch_size,
|
||||||
|
k,
|
||||||
|
vocab_size,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device='cuda')
|
||||||
|
|
||||||
|
proposal_lens = torch.ones(batch_size, dtype=torch.int64,
|
||||||
|
device='cuda') * k
|
||||||
|
|
||||||
|
execute_model_data, _, _ = create_batch(batch_size, k)
|
||||||
|
|
||||||
|
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
|
||||||
|
proposal_token_ids=proposal_token_ids,
|
||||||
|
proposal_probs=proposal_probs,
|
||||||
|
proposal_lens=proposal_lens)
|
||||||
|
|
||||||
|
target_token_ids = torch.randint(low=0,
|
||||||
|
high=vocab_size,
|
||||||
|
size=(1, batch_size * (k + 1)),
|
||||||
|
dtype=torch.int64,
|
||||||
|
device='cuda')
|
||||||
|
target_token_probs = torch.rand(1,
|
||||||
|
batch_size * (k + 1),
|
||||||
|
vocab_size,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device='cuda')
|
||||||
|
target_output = create_sampler_output_list(target_token_ids,
|
||||||
|
target_token_probs)
|
||||||
|
|
||||||
|
target_worker.execute_model.return_value = target_output[0]
|
||||||
|
|
||||||
|
exception_secret = 'artifical stop'
|
||||||
|
rejection_sampler.side_effect = ValueError(exception_secret)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match=exception_secret):
|
||||||
|
worker.execute_model(**execute_model_data.to_dict(), num_spec_tokens=k)
|
||||||
|
|
||||||
|
assert len(rejection_sampler.call_args_list) == 1
|
||||||
|
args, _ = rejection_sampler.call_args_list[0]
|
||||||
|
(actual_proposal_scores, actual_bonus_token_ids, actual_proposal_probs,
|
||||||
|
actual_proposal_token_ids) = args
|
||||||
|
|
||||||
|
assert torch.equal(actual_bonus_token_ids,
|
||||||
|
target_token_ids.reshape(batch_size, k + 1)[:, -1:])
|
||||||
|
assert torch.equal(
|
||||||
|
actual_proposal_scores,
|
||||||
|
target_token_probs.reshape(batch_size, k + 1, -1)[:, :-1])
|
||||||
|
assert torch.equal(actual_proposal_token_ids, proposal_token_ids)
|
||||||
|
assert torch.equal(actual_proposal_probs, proposal_probs)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('k', [1, 2, 6])
|
||||||
|
@pytest.mark.parametrize('batch_size', [1, 2, 32])
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_correctly_formats_output(k: int, batch_size: int):
|
||||||
|
"""Verify SpecDecodeWorker formats sampler output correctly.
|
||||||
|
Everything else is mocked out.
|
||||||
|
"""
|
||||||
|
vocab_size = 32_000
|
||||||
|
|
||||||
|
draft_worker = mock_worker(cls=MultiStepWorker, vocab_size=vocab_size)
|
||||||
|
target_worker = mock_worker(vocab_size=vocab_size)
|
||||||
|
rejection_sampler = MagicMock(spec=RejectionSampler)
|
||||||
|
rejection_sampler.token_id_dtype = torch.int64
|
||||||
|
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||||
|
draft_worker.device = 'cuda'
|
||||||
|
target_worker.device = 'cuda'
|
||||||
|
|
||||||
|
set_random_seed(1)
|
||||||
|
|
||||||
|
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
||||||
|
metrics_collector)
|
||||||
|
worker.init_model()
|
||||||
|
|
||||||
|
proposal_token_ids = torch.randint(low=0,
|
||||||
|
high=vocab_size,
|
||||||
|
size=(batch_size, k),
|
||||||
|
dtype=torch.int64,
|
||||||
|
device='cuda')
|
||||||
|
proposal_probs = torch.rand(batch_size,
|
||||||
|
k,
|
||||||
|
vocab_size,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device='cuda')
|
||||||
|
|
||||||
|
proposal_lens = torch.ones(batch_size, dtype=torch.int64,
|
||||||
|
device='cuda') * k
|
||||||
|
|
||||||
|
execute_model_data, _, _ = create_batch(batch_size, k)
|
||||||
|
|
||||||
|
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
|
||||||
|
proposal_token_ids=proposal_token_ids,
|
||||||
|
proposal_probs=proposal_probs,
|
||||||
|
proposal_lens=proposal_lens)
|
||||||
|
|
||||||
|
target_token_ids = torch.randint(low=0,
|
||||||
|
high=vocab_size,
|
||||||
|
size=(1, batch_size * (k + 1)),
|
||||||
|
dtype=torch.int64,
|
||||||
|
device='cuda')
|
||||||
|
target_token_probs = torch.rand(1,
|
||||||
|
batch_size * (k + 1),
|
||||||
|
vocab_size,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device='cuda')
|
||||||
|
target_output = create_sampler_output_list(target_token_ids,
|
||||||
|
target_token_probs)
|
||||||
|
|
||||||
|
target_worker.execute_model.return_value = target_output[0]
|
||||||
|
|
||||||
|
rejection_sampler_output = torch.randint(low=0,
|
||||||
|
high=vocab_size,
|
||||||
|
size=(batch_size, k + 1),
|
||||||
|
dtype=torch.int64,
|
||||||
|
device='cuda')
|
||||||
|
for i in range(batch_size):
|
||||||
|
minimum_accepted_tokens = 1
|
||||||
|
rejection_sampler_output[i][
|
||||||
|
-random.randint(minimum_accepted_tokens, k + 1):] = -1
|
||||||
|
|
||||||
|
rejection_sampler.return_value = rejection_sampler_output
|
||||||
|
|
||||||
|
output = worker.execute_model(**execute_model_data.to_dict(),
|
||||||
|
num_spec_tokens=k)
|
||||||
|
|
||||||
|
expected_output = create_sampler_output_list(
|
||||||
|
rejection_sampler_output.transpose(0, 1), [None for _ in range(k + 1)])
|
||||||
|
|
||||||
|
seq_ids = [
|
||||||
|
next(iter(seq_group_metadata.seq_data.keys()))
|
||||||
|
for seq_group_metadata in execute_model_data.seq_group_metadata_list
|
||||||
|
]
|
||||||
|
actual_output_by_seq = {seq_id: [] for seq_id in seq_ids}
|
||||||
|
expected_output_by_seq = {seq_id: [] for seq_id in seq_ids}
|
||||||
|
|
||||||
|
for step in output:
|
||||||
|
for seq_group in step:
|
||||||
|
for sample in seq_group.samples:
|
||||||
|
seq_id = sample.parent_seq_id
|
||||||
|
actual_output_by_seq[seq_id].append(sample)
|
||||||
|
|
||||||
|
for step in expected_output:
|
||||||
|
for seq_group in step:
|
||||||
|
for sample in seq_group.samples:
|
||||||
|
seq_id = sample.parent_seq_id
|
||||||
|
expected_output_by_seq[seq_id].append(sample)
|
||||||
|
|
||||||
|
all_seen_seq_ids = set(
|
||||||
|
list(actual_output_by_seq.keys()) +
|
||||||
|
list(expected_output_by_seq.keys()))
|
||||||
|
for seq_id in all_seen_seq_ids:
|
||||||
|
actual_by_step = actual_output_by_seq[seq_id]
|
||||||
|
expected_by_step = expected_output_by_seq[seq_id]
|
||||||
|
|
||||||
|
for i in range(k + 1):
|
||||||
|
if i >= len(actual_by_step):
|
||||||
|
assert expected_by_step[i].output_token == -1
|
||||||
|
continue
|
||||||
|
assert actual_by_step[i].output_token == expected_by_step[
|
||||||
|
i].output_token
|
||||||
|
assert actual_by_step[i].logprobs == expected_by_step[i].logprobs
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('k', [1, 2])
|
||||||
|
@pytest.mark.parametrize('batch_size', [1])
|
||||||
|
@pytest.mark.parametrize('returns_metrics', [True, False])
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
|
||||||
|
"""Verify SpecDecodeWorker collects metrics.
|
||||||
|
"""
|
||||||
|
vocab_size = 32_000
|
||||||
|
|
||||||
|
draft_worker = mock_worker(cls=MultiStepWorker, vocab_size=vocab_size)
|
||||||
|
target_worker = mock_worker(vocab_size=vocab_size)
|
||||||
|
rejection_sampler = MagicMock(spec=RejectionSampler)
|
||||||
|
rejection_sampler.token_id_dtype = torch.int64
|
||||||
|
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||||
|
draft_worker.device = 'cuda'
|
||||||
|
target_worker.device = 'cuda'
|
||||||
|
|
||||||
|
set_random_seed(1)
|
||||||
|
|
||||||
|
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
||||||
|
metrics_collector)
|
||||||
|
worker.init_model()
|
||||||
|
|
||||||
|
proposal_token_ids = torch.randint(low=0,
|
||||||
|
high=vocab_size,
|
||||||
|
size=(batch_size, k),
|
||||||
|
dtype=torch.int64,
|
||||||
|
device='cuda')
|
||||||
|
proposal_probs = torch.rand(batch_size,
|
||||||
|
k,
|
||||||
|
vocab_size,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device='cuda')
|
||||||
|
|
||||||
|
proposal_lens = torch.ones(batch_size, dtype=torch.int64,
|
||||||
|
device='cuda') * k
|
||||||
|
|
||||||
|
execute_model_data, _, _ = create_batch(batch_size, k)
|
||||||
|
|
||||||
|
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
|
||||||
|
proposal_token_ids=proposal_token_ids,
|
||||||
|
proposal_probs=proposal_probs,
|
||||||
|
proposal_lens=proposal_lens)
|
||||||
|
|
||||||
|
target_token_ids = torch.randint(low=0,
|
||||||
|
high=vocab_size,
|
||||||
|
size=(1, batch_size * (k + 1)),
|
||||||
|
dtype=torch.int64,
|
||||||
|
device='cuda')
|
||||||
|
target_token_probs = torch.rand(1,
|
||||||
|
batch_size * (k + 1),
|
||||||
|
vocab_size,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device='cuda')
|
||||||
|
target_output = create_sampler_output_list(target_token_ids,
|
||||||
|
target_token_probs)
|
||||||
|
|
||||||
|
target_worker.execute_model.return_value = target_output[0]
|
||||||
|
|
||||||
|
rejection_sampler_output = torch.randint(low=0,
|
||||||
|
high=vocab_size,
|
||||||
|
size=(batch_size, k + 1),
|
||||||
|
dtype=torch.int64,
|
||||||
|
device='cuda')
|
||||||
|
for i in range(batch_size):
|
||||||
|
minimum_accepted_tokens = 1
|
||||||
|
rejection_sampler_output[i][
|
||||||
|
-random.randint(minimum_accepted_tokens, k + 1):] = -1
|
||||||
|
|
||||||
|
rejection_sampler.return_value = rejection_sampler_output
|
||||||
|
|
||||||
|
mock_rejsample_metrics = MagicMock(
|
||||||
|
spec=SpecDecodeWorkerMetrics) if returns_metrics else None
|
||||||
|
metrics_collector.maybe_collect_rejsample_metrics.return_value = mock_rejsample_metrics
|
||||||
|
|
||||||
|
output = worker.execute_model(**execute_model_data.to_dict(),
|
||||||
|
num_spec_tokens=k)
|
||||||
|
assert output[0].spec_decode_worker_metrics == mock_rejsample_metrics
|
||||||
|
|
||||||
|
call_args_list = metrics_collector.maybe_collect_rejsample_metrics.call_args_list
|
||||||
|
assert len(call_args_list) == 1
|
||||||
|
args, kwargs = call_args_list[0]
|
||||||
|
assert args[0] == k or kwargs.get('k', -1) == k
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('k', [0])
|
||||||
|
@pytest.mark.parametrize('batch_size', [1, 2, 32])
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_k_equals_zero(k: int, batch_size: int):
|
||||||
|
"""Verify that the SpecDecodeWorker calls the draft and target workers
|
||||||
|
when k is zero. This happens during prefill.
|
||||||
|
"""
|
||||||
|
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||||
|
target_worker = mock_worker()
|
||||||
|
rejection_sampler = MagicMock(spec=RejectionSampler)
|
||||||
|
rejection_sampler.token_id_dtype = torch.int64
|
||||||
|
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||||
|
|
||||||
|
draft_worker.device = 'cuda'
|
||||||
|
target_worker.device = 'cuda'
|
||||||
|
|
||||||
|
set_random_seed(1)
|
||||||
|
|
||||||
|
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
||||||
|
metrics_collector)
|
||||||
|
|
||||||
|
execute_model_data, prompts, prev_output_tokens = create_batch(
|
||||||
|
batch_size, k, prev_output_token_len=0)
|
||||||
|
|
||||||
|
out = worker.execute_model(**execute_model_data.to_dict(),
|
||||||
|
num_spec_tokens=k)
|
||||||
|
|
||||||
|
assert len(out) == 1, f"expected only one token output when {k=}"
|
||||||
|
assert out[0].probs is None, "expect gpu tensor references to be None"
|
||||||
|
assert out[
|
||||||
|
0].sampled_tokens is None, "expect gpu tensor references to be None"
|
||||||
|
|
||||||
|
draft_worker.execute_model.assert_called_once_with(
|
||||||
|
**execute_model_data.to_dict(), return_python_output=False)
|
||||||
|
target_worker.execute_model.assert_called_once_with(
|
||||||
|
**execute_model_data.to_dict())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('k', [0, 5])
|
||||||
|
@pytest.mark.parametrize('batch_size', [0])
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_empty_input_batch(k: int, batch_size: int):
|
||||||
|
"""Verify that the SpecDecodeWorker calls the draft and target workers
|
||||||
|
when the input batch is empty. This can happen if the engine communicates
|
||||||
|
to the workers information without scheduling a batch.
|
||||||
|
"""
|
||||||
|
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||||
|
target_worker = mock_worker()
|
||||||
|
rejection_sampler = MagicMock(spec=RejectionSampler)
|
||||||
|
rejection_sampler.token_id_dtype = torch.int64
|
||||||
|
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||||
|
|
||||||
|
draft_worker.device = 'cuda'
|
||||||
|
target_worker.device = 'cuda'
|
||||||
|
|
||||||
|
set_random_seed(1)
|
||||||
|
|
||||||
|
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
||||||
|
metrics_collector)
|
||||||
|
|
||||||
|
execute_model_data, prompts, prev_output_tokens = create_batch(
|
||||||
|
batch_size, k, prev_output_token_len=0)
|
||||||
|
|
||||||
|
out = worker.execute_model(**execute_model_data.to_dict(),
|
||||||
|
num_spec_tokens=k)
|
||||||
|
|
||||||
|
assert len(out) == 1, f"expected only one token output when {k=}"
|
||||||
|
assert out[0].probs is None, "expect gpu tensor references to be None"
|
||||||
|
assert out[
|
||||||
|
0].sampled_tokens is None, "expect gpu tensor references to be None"
|
||||||
|
|
||||||
|
draft_worker.execute_model.assert_called_once_with(
|
||||||
|
**execute_model_data.to_dict(), return_python_output=False)
|
||||||
|
target_worker.execute_model.assert_called_once_with(
|
||||||
|
**execute_model_data.to_dict())
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_init_model():
|
||||||
|
"""Verify SpecDecodeWorker invokes proposer/scorer worker init_model, as
|
||||||
|
well as other GPU initialization.
|
||||||
|
"""
|
||||||
|
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||||
|
target_worker = mock_worker()
|
||||||
|
rejection_sampler = MagicMock(spec=RejectionSampler)
|
||||||
|
rejection_sampler.token_id_dtype = torch.int64
|
||||||
|
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||||
|
|
||||||
|
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
||||||
|
metrics_collector)
|
||||||
|
|
||||||
|
worker.init_model()
|
||||||
|
|
||||||
|
draft_worker.init_model.assert_called_once()
|
||||||
|
|
||||||
|
target_worker.init_model.assert_called_once()
|
||||||
|
|
||||||
|
metrics_collector.init_gpu_tensors.assert_called_once()
|
||||||
|
rejection_sampler.init_gpu_tensors.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_init_cache_engine():
|
||||||
|
"""Verify SpecDecodeWorker invokes init_cache_engine on proposer/scorer
|
||||||
|
workers.
|
||||||
|
"""
|
||||||
|
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||||
|
target_worker = mock_worker()
|
||||||
|
rejection_sampler = MagicMock(spec=RejectionSampler)
|
||||||
|
rejection_sampler.token_id_dtype = torch.int64
|
||||||
|
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||||
|
|
||||||
|
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
||||||
|
metrics_collector)
|
||||||
|
|
||||||
|
cache_config = MagicMock()
|
||||||
|
|
||||||
|
worker.init_cache_engine(cache_config)
|
||||||
|
|
||||||
|
draft_worker.init_cache_engine.assert_called_once_with(cache_config)
|
||||||
|
target_worker.init_cache_engine.assert_called_once_with(cache_config)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('available_gpu_blocks', [1, 1024])
|
||||||
|
@pytest.mark.parametrize('available_cpu_blocks', [500])
|
||||||
|
@pytest.mark.parametrize('target_cache_block_size_bytes', [2 * 2 * 4096])
|
||||||
|
@pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096])
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_profile_num_available_blocks(available_gpu_blocks: int,
|
||||||
|
available_cpu_blocks: int,
|
||||||
|
target_cache_block_size_bytes: int,
|
||||||
|
draft_kv_size_bytes: int):
|
||||||
|
"""Verify SpecDecodeWorker correctly profiles num available GPU blocks.
|
||||||
|
Specifically, it should run profiling in the scorer worker, and then evenly
|
||||||
|
split the blocks between proposer and scorer worker.
|
||||||
|
"""
|
||||||
|
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||||
|
target_worker = mock_worker()
|
||||||
|
rejection_sampler = MagicMock(spec=RejectionSampler)
|
||||||
|
rejection_sampler.token_id_dtype = torch.int64
|
||||||
|
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||||
|
|
||||||
|
target_worker.profile_num_available_blocks.return_value = (
|
||||||
|
available_gpu_blocks, available_cpu_blocks)
|
||||||
|
target_worker.get_cache_block_size_bytes.return_value = target_cache_block_size_bytes
|
||||||
|
draft_worker.get_cache_block_size_bytes.return_value = draft_kv_size_bytes
|
||||||
|
|
||||||
|
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
||||||
|
metrics_collector)
|
||||||
|
|
||||||
|
# These values do not directly impact the adjusted block size calculation,
|
||||||
|
# so they can be fixed.
|
||||||
|
gpu_memory_utilization = 0.9
|
||||||
|
cpu_swap_space = 100
|
||||||
|
block_size = 16
|
||||||
|
|
||||||
|
num_gpu_blocks, num_cpu_blocks = worker.profile_num_available_blocks(
|
||||||
|
block_size, gpu_memory_utilization, cpu_swap_space, cache_dtype="auto")
|
||||||
|
|
||||||
|
target_worker.profile_num_available_blocks.assert_called_once_with(
|
||||||
|
block_size, gpu_memory_utilization, cpu_swap_space, "auto")
|
||||||
|
assert num_cpu_blocks == available_cpu_blocks
|
||||||
|
|
||||||
|
assert num_gpu_blocks == split_num_cache_blocks_evenly(
|
||||||
|
target_cache_block_size_bytes, draft_kv_size_bytes,
|
||||||
|
available_gpu_blocks)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('available_gpu_blocks',
|
||||||
|
list(range(20)) + [1024, 1024**2])
|
||||||
|
@pytest.mark.parametrize('target_cache_block_size_bytes',
|
||||||
|
[2 * 2 * 4096, 2 * 2 * 8192])
|
||||||
|
@pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096])
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_split_num_cache_blocks_evenly(available_gpu_blocks: int,
|
||||||
|
target_cache_block_size_bytes: int,
|
||||||
|
draft_kv_size_bytes: int):
|
||||||
|
"""Verify split_num_cache_blocks_evenly does not exceed original memory
|
||||||
|
allocation in bytes.
|
||||||
|
"""
|
||||||
|
num_blocks = split_num_cache_blocks_evenly(target_cache_block_size_bytes,
|
||||||
|
draft_kv_size_bytes,
|
||||||
|
available_gpu_blocks)
|
||||||
|
assert (num_blocks * target_cache_block_size_bytes) + (
|
||||||
|
num_blocks * draft_kv_size_bytes) <= (available_gpu_blocks *
|
||||||
|
target_cache_block_size_bytes)
|
111
tests/spec_decode/test_utils.py
Normal file
111
tests/spec_decode/test_utils.py
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
from vllm.spec_decode.util import get_all_seq_ids
|
||||||
|
from vllm.sequence import SequenceGroupMetadata
|
||||||
|
from vllm.spec_decode.util import split_batch_by_proposal_len
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_all_seq_ids():
|
||||||
|
"""Verify get_all_seq_ids extracts all seq ids.
|
||||||
|
"""
|
||||||
|
expected_seq_ids = list(range(10)) + list(range(100, 110))
|
||||||
|
|
||||||
|
seq_group_metadata_list = [
|
||||||
|
SequenceGroupMetadata(
|
||||||
|
request_id=str(seq_id),
|
||||||
|
is_prompt=True,
|
||||||
|
seq_data={
|
||||||
|
seq_id: MagicMock(),
|
||||||
|
},
|
||||||
|
sampling_params=MagicMock(),
|
||||||
|
block_tables={
|
||||||
|
seq_id: MagicMock(),
|
||||||
|
},
|
||||||
|
lora_request=None,
|
||||||
|
) for seq_id in expected_seq_ids
|
||||||
|
]
|
||||||
|
|
||||||
|
actual_seq_ids = get_all_seq_ids(seq_group_metadata_list)
|
||||||
|
assert actual_seq_ids == expected_seq_ids
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def fake_sequence_group_metadata():
|
||||||
|
seq_ids = list(range(3))
|
||||||
|
return [
|
||||||
|
SequenceGroupMetadata(
|
||||||
|
request_id=str(i),
|
||||||
|
is_prompt=True,
|
||||||
|
seq_data={
|
||||||
|
i: MagicMock(),
|
||||||
|
},
|
||||||
|
sampling_params=MagicMock(),
|
||||||
|
block_tables={
|
||||||
|
i: MagicMock(),
|
||||||
|
},
|
||||||
|
lora_request=None,
|
||||||
|
) for i in seq_ids
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_filter_zero_length_proposals(fake_sequence_group_metadata):
|
||||||
|
proposal_lens = [0, 1, 0]
|
||||||
|
filtered_groups, indices = split_batch_by_proposal_len(
|
||||||
|
fake_sequence_group_metadata,
|
||||||
|
proposal_lens,
|
||||||
|
select_proposal_len_zero=True)
|
||||||
|
|
||||||
|
expected_groups = [
|
||||||
|
fake_sequence_group_metadata[0], fake_sequence_group_metadata[2]
|
||||||
|
]
|
||||||
|
expected_indices = [0, 2]
|
||||||
|
|
||||||
|
assert filtered_groups == expected_groups
|
||||||
|
assert indices == expected_indices
|
||||||
|
|
||||||
|
|
||||||
|
def test_filter_non_zero_length_proposals(fake_sequence_group_metadata):
|
||||||
|
proposal_lens = [0, 1, 2]
|
||||||
|
filtered_groups, indices = split_batch_by_proposal_len(
|
||||||
|
fake_sequence_group_metadata,
|
||||||
|
proposal_lens,
|
||||||
|
select_proposal_len_zero=False)
|
||||||
|
|
||||||
|
expected_groups = [
|
||||||
|
fake_sequence_group_metadata[1], fake_sequence_group_metadata[2]
|
||||||
|
]
|
||||||
|
expected_indices = [1, 2]
|
||||||
|
|
||||||
|
assert filtered_groups == expected_groups
|
||||||
|
assert indices == expected_indices
|
||||||
|
|
||||||
|
|
||||||
|
def test_empty_inputs():
|
||||||
|
filtered_groups, indices = split_batch_by_proposal_len(
|
||||||
|
[], [], select_proposal_len_zero=True)
|
||||||
|
|
||||||
|
assert filtered_groups == []
|
||||||
|
assert indices == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_zero_with_non_zero_filter(fake_sequence_group_metadata):
|
||||||
|
proposal_lens = [0, 0, 0]
|
||||||
|
filtered_groups, indices = split_batch_by_proposal_len(
|
||||||
|
fake_sequence_group_metadata,
|
||||||
|
proposal_lens,
|
||||||
|
select_proposal_len_zero=False)
|
||||||
|
|
||||||
|
assert filtered_groups == []
|
||||||
|
assert indices == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_non_zero_with_zero_filter(fake_sequence_group_metadata):
|
||||||
|
proposal_lens = [1, 1, 1]
|
||||||
|
filtered_groups, indices = split_batch_by_proposal_len(
|
||||||
|
fake_sequence_group_metadata,
|
||||||
|
proposal_lens,
|
||||||
|
select_proposal_len_zero=True)
|
||||||
|
|
||||||
|
assert filtered_groups == []
|
||||||
|
assert indices == []
|
@ -1,13 +1,16 @@
|
|||||||
import torch
|
import torch
|
||||||
from typing import List, Optional, Dict
|
from typing import List, Optional, Dict, Iterable, Union
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
from vllm.worker.worker import Worker
|
from vllm.worker.worker import Worker
|
||||||
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
|
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
|
||||||
from vllm.engine.arg_utils import EngineArgs
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
from vllm.sequence import Logprob, SequenceGroupMetadata, SequenceData
|
from vllm.sequence import (Logprob, SequenceGroupMetadata, SequenceData,
|
||||||
|
SamplerOutput, SequenceGroupOutput, SequenceOutput)
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.worker.cache_engine import CacheEngine
|
from vllm.worker.cache_engine import CacheEngine
|
||||||
from vllm.model_executor.utils import set_random_seed
|
from vllm.model_executor.utils import set_random_seed
|
||||||
|
from itertools import count
|
||||||
from dataclasses import dataclass, fields
|
from dataclasses import dataclass, fields
|
||||||
|
|
||||||
|
|
||||||
@ -24,6 +27,11 @@ class ExecuteModelData:
|
|||||||
return dict(
|
return dict(
|
||||||
(field.name, getattr(self, field.name)) for field in fields(self))
|
(field.name, getattr(self, field.name)) for field in fields(self))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, d):
|
||||||
|
cleaned = dict((field.name, d[field.name]) for field in fields(cls))
|
||||||
|
return cls(**cleaned)
|
||||||
|
|
||||||
|
|
||||||
def round_up_to_next_block(seq_len: int, block_size: int) -> int:
|
def round_up_to_next_block(seq_len: int, block_size: int) -> int:
|
||||||
return (seq_len + block_size - 1) // block_size
|
return (seq_len + block_size - 1) // block_size
|
||||||
@ -50,6 +58,21 @@ def create_execute_model_data(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def mock_worker(cls=None,
|
||||||
|
vocab_size: int = 30_000,
|
||||||
|
max_model_len: int = 2048,
|
||||||
|
rank: int = 0) -> MagicMock:
|
||||||
|
if cls is None:
|
||||||
|
cls = Worker
|
||||||
|
|
||||||
|
worker = MagicMock(spec=cls)
|
||||||
|
worker.vocab_size = vocab_size
|
||||||
|
worker.max_model_len = max_model_len
|
||||||
|
worker.rank = rank
|
||||||
|
worker.device = 'cuda:0'
|
||||||
|
return worker
|
||||||
|
|
||||||
|
|
||||||
def patch_execute_model_with_seeds(worker: Worker, rand_seeds: List[int]):
|
def patch_execute_model_with_seeds(worker: Worker, rand_seeds: List[int]):
|
||||||
seed_iter = iter(rand_seeds)
|
seed_iter = iter(rand_seeds)
|
||||||
original_execute_model = worker.execute_model
|
original_execute_model = worker.execute_model
|
||||||
@ -117,25 +140,12 @@ def create_seq_group_metadata_from_prompts(
|
|||||||
block_size: int,
|
block_size: int,
|
||||||
final_seq_lens: List[int],
|
final_seq_lens: List[int],
|
||||||
continuations: Optional[List[List[int]]] = None,
|
continuations: Optional[List[List[int]]] = None,
|
||||||
num_tokens_processed: Optional[List[int]] = None,
|
|
||||||
seq_ids: Optional[List[int]] = None,
|
seq_ids: Optional[List[int]] = None,
|
||||||
) -> List[SequenceGroupMetadata]:
|
) -> List[SequenceGroupMetadata]:
|
||||||
|
|
||||||
if continuations is None:
|
if continuations is None:
|
||||||
continuations = [[] for _ in prompts]
|
continuations = [[] for _ in prompts]
|
||||||
|
|
||||||
if num_tokens_processed is None:
|
|
||||||
# Default to 1 token missing from kv cache for generation sequences.
|
|
||||||
num_tokens_processed = []
|
|
||||||
for continuation, prompt in zip(continuations, prompts):
|
|
||||||
# If prefill, then default to zero tokens processed.
|
|
||||||
if not continuation:
|
|
||||||
num_tokens_processed.append(0)
|
|
||||||
else:
|
|
||||||
# If generation, then default to all but one tokens processed.
|
|
||||||
num_tokens_processed.append(
|
|
||||||
len(continuation) + len(prompt) - 1)
|
|
||||||
|
|
||||||
if seq_ids is None:
|
if seq_ids is None:
|
||||||
seq_ids = list(i for i, _ in enumerate(prompts))
|
seq_ids = list(i for i, _ in enumerate(prompts))
|
||||||
|
|
||||||
@ -155,13 +165,15 @@ def create_seq_group_metadata_from_prompts(
|
|||||||
is_prompt=len(cont_token_ids) == 0,
|
is_prompt=len(cont_token_ids) == 0,
|
||||||
seq_data={
|
seq_data={
|
||||||
i:
|
i:
|
||||||
SequenceData(prompt_token_ids=prompt_token_ids[:] +
|
SequenceData(
|
||||||
cont_token_ids[:])
|
prompt_token_ids=prompt_token_ids[:],
|
||||||
|
output_token_ids=cont_token_ids[:],
|
||||||
|
),
|
||||||
},
|
},
|
||||||
sampling_params=SamplingParams(temperature=0.0, ),
|
sampling_params=SamplingParams(temperature=0.0, ),
|
||||||
block_tables={i: block_allocations[i][:]},
|
block_tables={i: block_allocations[i][:]},
|
||||||
) for i, (prompt_token_ids, cont_token_ids, num_tokens_saved) in
|
) for i, (prompt_token_ids,
|
||||||
enumerate(zip(prompts, continuations, num_tokens_processed))
|
cont_token_ids) in enumerate(zip(prompts, continuations))
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -178,3 +190,68 @@ def assert_logprobs_dict_allclose(
|
|||||||
expected = torch.tensor(
|
expected = torch.tensor(
|
||||||
single_step_expected_logprobs[token_id].logprob)
|
single_step_expected_logprobs[token_id].logprob)
|
||||||
assert torch.allclose(actual, expected)
|
assert torch.allclose(actual, expected)
|
||||||
|
|
||||||
|
|
||||||
|
def create_sampler_output_list(
|
||||||
|
token_ids: torch.Tensor,
|
||||||
|
probs: Iterable[Optional[torch.Tensor]],
|
||||||
|
seq_ids: Optional[List[int]] = None) -> List[SamplerOutput]:
|
||||||
|
num_steps, batch_size = token_ids.shape
|
||||||
|
token_ids_by_step = token_ids.tolist()
|
||||||
|
|
||||||
|
if seq_ids is None:
|
||||||
|
seq_ids = list(range(batch_size))
|
||||||
|
|
||||||
|
return [
|
||||||
|
SamplerOutput(outputs=[
|
||||||
|
SequenceGroupOutput(
|
||||||
|
samples=[
|
||||||
|
SequenceOutput(
|
||||||
|
output_token=token_id,
|
||||||
|
parent_seq_id=seq_ids[seq_index],
|
||||||
|
logprobs={token_id: 0},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
prompt_logprobs=None,
|
||||||
|
) for seq_index, token_id in enumerate(token_ids_by_step[step])
|
||||||
|
],
|
||||||
|
sampled_token_probs=probs[step],
|
||||||
|
sampled_token_ids=token_ids[step])
|
||||||
|
for step in range(num_steps)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def create_batch(batch_size,
|
||||||
|
k,
|
||||||
|
prompt_len: Union[int, List[int]] = 10,
|
||||||
|
prev_output_token_len: int = 10,
|
||||||
|
seq_ids: Optional[List[int]] = None,
|
||||||
|
num_gpu_blocks: Optional[int] = None,
|
||||||
|
block_size: Optional[int] = None):
|
||||||
|
if block_size is None:
|
||||||
|
block_size = 8
|
||||||
|
|
||||||
|
if num_gpu_blocks is None:
|
||||||
|
num_gpu_blocks = 2048 // block_size
|
||||||
|
|
||||||
|
iterator = count()
|
||||||
|
|
||||||
|
if isinstance(prompt_len, int):
|
||||||
|
prompt_lens = [prompt_len for _ in range(batch_size)]
|
||||||
|
else:
|
||||||
|
prompt_lens = prompt_len
|
||||||
|
|
||||||
|
prompts = [[next(iterator) for _ in range(p_len)] for p_len in prompt_lens]
|
||||||
|
prev_output_tokens = [[
|
||||||
|
next(iterator) for _ in range(prev_output_token_len)
|
||||||
|
] for _ in range(batch_size)]
|
||||||
|
final_seq_lens = [
|
||||||
|
len(prompt) + len(prev_output_token) + k + 1
|
||||||
|
for prompt, prev_output_token in zip(prompts, prev_output_tokens)
|
||||||
|
]
|
||||||
|
|
||||||
|
execute_model_data = create_execute_model_data(
|
||||||
|
create_seq_group_metadata_from_prompts(prompts, num_gpu_blocks,
|
||||||
|
block_size, final_seq_lens,
|
||||||
|
prev_output_tokens, seq_ids), )
|
||||||
|
return execute_model_data, prompts, prev_output_tokens
|
50
tests/test_sequence.py
Normal file
50
tests/test_sequence.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.sequence import SequenceGroupOutput, SamplerOutput, SequenceOutput
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_outputs():
|
||||||
|
return [
|
||||||
|
SequenceGroupOutput(samples=[
|
||||||
|
SequenceOutput(parent_seq_id=0, output_token=i, logprobs={})
|
||||||
|
],
|
||||||
|
prompt_logprobs=None) for i in range(5)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sampler_output(sample_outputs):
|
||||||
|
return SamplerOutput(outputs=sample_outputs)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sampler_output_initialization(sampler_output, sample_outputs):
|
||||||
|
assert len(sampler_output) == len(sample_outputs)
|
||||||
|
assert sampler_output.sampled_token_probs is None
|
||||||
|
assert sampler_output.sampled_token_ids is None
|
||||||
|
assert sampler_output.spec_decode_worker_metrics is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_sampler_output_getitem(sampler_output, sample_outputs):
|
||||||
|
assert sampler_output[2] == sample_outputs[2]
|
||||||
|
|
||||||
|
|
||||||
|
def test_sampler_output_setitem(sampler_output):
|
||||||
|
new_output = SequenceGroupOutput(samples=[
|
||||||
|
SequenceOutput(parent_seq_id=0, output_token=99, logprobs={})
|
||||||
|
],
|
||||||
|
prompt_logprobs=None)
|
||||||
|
sampler_output[2] = new_output
|
||||||
|
assert sampler_output[2] == new_output
|
||||||
|
|
||||||
|
|
||||||
|
def test_sampler_output_len(sampler_output, sample_outputs):
|
||||||
|
assert len(sampler_output) == len(sample_outputs)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sampler_output_eq(sample_outputs):
|
||||||
|
sampler_output1 = SamplerOutput(outputs=sample_outputs)
|
||||||
|
sampler_output2 = SamplerOutput(outputs=sample_outputs.copy())
|
||||||
|
sampler_output3 = SamplerOutput(outputs=sample_outputs[:-1])
|
||||||
|
assert sampler_output1 == sampler_output2
|
||||||
|
assert sampler_output1 != sampler_output3
|
@ -21,8 +21,6 @@ class RejectionSampler(nn.Module):
|
|||||||
nontrivial latency.
|
nontrivial latency.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.probs_dtype = torch.float32
|
|
||||||
self.token_id_dtype = torch.int64
|
|
||||||
self._strict_mode = strict_mode
|
self._strict_mode = strict_mode
|
||||||
|
|
||||||
# NOTE: A "bonus token" is accepted iff all proposal tokens are
|
# NOTE: A "bonus token" is accepted iff all proposal tokens are
|
||||||
@ -44,6 +42,14 @@ class RejectionSampler(nn.Module):
|
|||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
device=device)
|
device=device)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def probs_dtype(self):
|
||||||
|
return torch.float32
|
||||||
|
|
||||||
|
@property
|
||||||
|
def token_id_dtype(self):
|
||||||
|
return torch.int64
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
target_probs: torch.Tensor,
|
target_probs: torch.Tensor,
|
||||||
|
@ -587,4 +587,4 @@ def _build_sampler_output(
|
|||||||
SequenceOutput(seq_ids[parent_id], next_token_id, logprobs))
|
SequenceOutput(seq_ids[parent_id], next_token_id, logprobs))
|
||||||
sampler_output.append(
|
sampler_output.append(
|
||||||
SequenceGroupOutput(seq_outputs, group_prompt_logprobs))
|
SequenceGroupOutput(seq_outputs, group_prompt_logprobs))
|
||||||
return sampler_output
|
return SamplerOutput(outputs=sampler_output)
|
||||||
|
@ -2,12 +2,16 @@
|
|||||||
import copy
|
import copy
|
||||||
import enum
|
import enum
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union, TYPE_CHECKING
|
||||||
|
|
||||||
from vllm.block import LogicalTokenBlock
|
from vllm.block import LogicalTokenBlock
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import torch
|
||||||
|
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Logprob:
|
class Logprob:
|
||||||
@ -81,6 +85,8 @@ class SequenceData:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
prompt_token_ids: The token IDs of the prompt.
|
prompt_token_ids: The token IDs of the prompt.
|
||||||
|
output_token_ids: The token IDs of the output. Set to an empty list if
|
||||||
|
None.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
prompt_token_ids: The token IDs of the prompt.
|
prompt_token_ids: The token IDs of the prompt.
|
||||||
@ -91,9 +97,13 @@ class SequenceData:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
prompt_token_ids: List[int],
|
prompt_token_ids: List[int],
|
||||||
|
output_token_ids: Optional[List[int]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
if output_token_ids is None:
|
||||||
|
output_token_ids = []
|
||||||
|
|
||||||
self.prompt_token_ids = prompt_token_ids
|
self.prompt_token_ids = prompt_token_ids
|
||||||
self.output_token_ids: List[int] = []
|
self.output_token_ids = output_token_ids
|
||||||
self.cumulative_logprob = 0.0
|
self.cumulative_logprob = 0.0
|
||||||
|
|
||||||
def append_token_id(self, token_id: int, logprob: float) -> None:
|
def append_token_id(self, token_id: int, logprob: float) -> None:
|
||||||
@ -117,6 +127,12 @@ class SequenceData:
|
|||||||
return self.prompt_token_ids[-1]
|
return self.prompt_token_ids[-1]
|
||||||
return self.output_token_ids[-1]
|
return self.output_token_ids[-1]
|
||||||
|
|
||||||
|
def get_prompt_token_ids(self) -> int:
|
||||||
|
return self.prompt_token_ids
|
||||||
|
|
||||||
|
def get_output_token_ids(self) -> int:
|
||||||
|
return self.output_token_ids
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (f"SequenceData("
|
return (f"SequenceData("
|
||||||
f"prompt_token_ids={self.prompt_token_ids}, "
|
f"prompt_token_ids={self.prompt_token_ids}, "
|
||||||
@ -506,6 +522,35 @@ class SequenceGroupOutput:
|
|||||||
and self.prompt_logprobs == other.prompt_logprobs)
|
and self.prompt_logprobs == other.prompt_logprobs)
|
||||||
|
|
||||||
|
|
||||||
# For each sequence group, we generate a list of SequenceOutput object,
|
@dataclass
|
||||||
# each of which contains one possible candidate for the next token.
|
class SamplerOutput:
|
||||||
SamplerOutput = List[SequenceGroupOutput]
|
"""For each sequence group, we generate a list of SequenceOutput object,
|
||||||
|
each of which contains one possible candidate for the next token.
|
||||||
|
|
||||||
|
This datastructure implements methods so it can be used like a list, but
|
||||||
|
also has optional fields for device tensors.
|
||||||
|
"""
|
||||||
|
|
||||||
|
outputs: List[SequenceGroupOutput]
|
||||||
|
|
||||||
|
# On-device tensor containing probabilities of each token.
|
||||||
|
sampled_token_probs: Optional["torch.Tensor"] = None
|
||||||
|
|
||||||
|
# On-device tensor containing the sampled token ids.
|
||||||
|
sampled_token_ids: Optional["torch.Tensor"] = None
|
||||||
|
|
||||||
|
# Spec decode metrics populated by workers.
|
||||||
|
spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int):
|
||||||
|
return self.outputs[idx]
|
||||||
|
|
||||||
|
def __setitem__(self, idx: int, value):
|
||||||
|
self.outputs[idx] = value
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.outputs)
|
||||||
|
|
||||||
|
def __eq__(self, other: object):
|
||||||
|
return isinstance(other,
|
||||||
|
self.__class__) and self.outputs == other.outputs
|
||||||
|
351
vllm/spec_decode/batch_expansion.py
Normal file
351
vllm/spec_decode/batch_expansion.py
Normal file
@ -0,0 +1,351 @@
|
|||||||
|
from typing import Iterator, List, Tuple, Optional, Dict
|
||||||
|
from itertools import chain, count
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.sequence import (SamplerOutput, SequenceGroupMetadata, SequenceData)
|
||||||
|
from vllm.worker.worker import Worker
|
||||||
|
from vllm.spec_decode.util import nvtx_range, sampler_output_to_torch, get_all_seq_ids, split_batch_by_proposal_len
|
||||||
|
from vllm.spec_decode.interfaces import SpeculativeScorer, SpeculativeProposals, SpeculativeScores
|
||||||
|
|
||||||
|
SeqId = int
|
||||||
|
TargetSeqId = int
|
||||||
|
TokenId = int
|
||||||
|
|
||||||
|
|
||||||
|
class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||||
|
"""Implements a speculative scorer that uses batch expansion to get
|
||||||
|
probabilities of speculative tokens according to the scoring model.
|
||||||
|
|
||||||
|
Batch expansion converts a list of sequences and multiple query positions
|
||||||
|
to a new batch of sequences, each with a single query position. This allows
|
||||||
|
for MQA-like scoring in speculative decoding without requiring an MQA
|
||||||
|
kernel.
|
||||||
|
|
||||||
|
It is strictly less efficient than MQA scoring.
|
||||||
|
|
||||||
|
It only supports scoring the top1 proposal tokens of the proposer, instead
|
||||||
|
of topk/tree.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, scorer_worker: Worker, device: str, vocab_size: int):
|
||||||
|
self._scorer_worker = scorer_worker
|
||||||
|
self._device = device
|
||||||
|
self._vocab_size = vocab_size
|
||||||
|
|
||||||
|
@nvtx_range("BatchExpansionTop1Scorer.score_proposals")
|
||||||
|
def score_proposals(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
blocks_to_swap_in: Optional[Dict[int, int]],
|
||||||
|
blocks_to_swap_out: Optional[Dict[int, int]],
|
||||||
|
blocks_to_copy: Optional[Dict[int, List[int]]],
|
||||||
|
k: int,
|
||||||
|
proposals: SpeculativeProposals,
|
||||||
|
) -> SpeculativeScores:
|
||||||
|
"""Score the proposed tokens via the scorer model.
|
||||||
|
|
||||||
|
This converts each input sequence to a set of k+1 target sequences. The
|
||||||
|
target sequences have the unique continuations to be scored and a
|
||||||
|
unique sequence ID that is different from all input sequence ids.
|
||||||
|
|
||||||
|
If a speculative sequence length would exceed the max model length, then
|
||||||
|
no speculation is produced for that sequence.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seq_group_metadata_list: The input sequence group metadata.
|
||||||
|
blocks_to_swap_in: This is passed to the worker during scoring.
|
||||||
|
blocks_to_swap_out: This is passed to the worker during scoring.
|
||||||
|
blocks_to_copy: This is passed to the worker during scoring.
|
||||||
|
k: The fixed proposal length.
|
||||||
|
proposals: The speculative proposals to score.
|
||||||
|
Returns:
|
||||||
|
SpeculativeScores: The scores of each speculative token, along with
|
||||||
|
which sequences were ignored during scoring.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# TODO(cade) perform this on GPU to remove blocking call.
|
||||||
|
proposal_lens_list = proposals.proposal_lens.tolist()
|
||||||
|
proposal_token_ids_list = proposals.proposal_token_ids.tolist()
|
||||||
|
|
||||||
|
spec_indices, non_spec_indices, target_seq_group_metadata_list, num_scoring_tokens = self._expand_batch(
|
||||||
|
seq_group_metadata_list=seq_group_metadata_list,
|
||||||
|
proposal_token_ids_list=proposal_token_ids_list,
|
||||||
|
proposal_lens_list=proposal_lens_list,
|
||||||
|
)
|
||||||
|
|
||||||
|
target_sampler_output = self._scorer_worker.execute_model(
|
||||||
|
seq_group_metadata_list=target_seq_group_metadata_list,
|
||||||
|
blocks_to_swap_in=blocks_to_swap_in,
|
||||||
|
blocks_to_swap_out=blocks_to_swap_out,
|
||||||
|
blocks_to_copy=blocks_to_copy,
|
||||||
|
return_python_output=False)
|
||||||
|
|
||||||
|
all_tokens, all_probs = self._contract_batch(
|
||||||
|
original_bs=len(seq_group_metadata_list),
|
||||||
|
target_sampler_output=target_sampler_output,
|
||||||
|
proposals=proposals,
|
||||||
|
num_scoring_tokens=num_scoring_tokens,
|
||||||
|
non_spec_indices=non_spec_indices,
|
||||||
|
spec_indices=spec_indices,
|
||||||
|
k=k,
|
||||||
|
)
|
||||||
|
|
||||||
|
return SpeculativeScores(
|
||||||
|
probs=all_probs,
|
||||||
|
token_ids=all_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _expand_batch(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
proposal_token_ids_list: List[TokenId],
|
||||||
|
proposal_lens_list: List[int],
|
||||||
|
) -> Tuple[List[int], List[int], List[SequenceGroupMetadata], int]:
|
||||||
|
"""Given the input sequences and potentially multiple corresponding
|
||||||
|
proposal tokens, create a new batch where each sequence has a single
|
||||||
|
query token.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# vLLM currently only supports proposal lens equal to zero or the batch
|
||||||
|
# proposal len. This adds some complexity (splitting the batch into spec
|
||||||
|
# and non spec sequences) and should be removed in the future. It can be
|
||||||
|
# done by supporting per-sequence proposal lens.
|
||||||
|
spec_seqs, spec_indices = split_batch_by_proposal_len(
|
||||||
|
seq_group_metadata_list,
|
||||||
|
proposal_lens_list,
|
||||||
|
select_proposal_len_zero=False)
|
||||||
|
non_spec_seqs, non_spec_indices = split_batch_by_proposal_len(
|
||||||
|
seq_group_metadata_list,
|
||||||
|
proposal_lens_list,
|
||||||
|
select_proposal_len_zero=True)
|
||||||
|
|
||||||
|
target_seq_group_metadata_list = self._create_scoring_model_input(
|
||||||
|
spec_seqs, proposal_token_ids_list)
|
||||||
|
num_scoring_tokens = len(target_seq_group_metadata_list)
|
||||||
|
target_seq_group_metadata_list.extend(non_spec_seqs)
|
||||||
|
|
||||||
|
return spec_indices, non_spec_indices, target_seq_group_metadata_list, num_scoring_tokens
|
||||||
|
|
||||||
|
def _contract_batch(self, original_bs: int,
|
||||||
|
target_sampler_output: List[SamplerOutput],
|
||||||
|
proposals: SpeculativeProposals,
|
||||||
|
num_scoring_tokens: int, non_spec_indices: List[int],
|
||||||
|
spec_indices: List[int],
|
||||||
|
k: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Contract the expanded batch back into its original size.
|
||||||
|
This maps the scores of speculative tokens back to their original
|
||||||
|
sequences.
|
||||||
|
"""
|
||||||
|
(target_token_ids, target_probs, non_spec_target_token_ids,
|
||||||
|
non_spec_target_probs) = self._split_scoring_output(
|
||||||
|
target_sampler_output, num_scoring_tokens)
|
||||||
|
|
||||||
|
# Map distinct sequences used to score each token
|
||||||
|
# of shape [batch_size * k + 1] back to [batch_size, k + 1].
|
||||||
|
batch_size, k = proposals.proposal_token_ids.shape
|
||||||
|
|
||||||
|
target_token_ids = target_token_ids.squeeze().reshape(
|
||||||
|
batch_size, k + 1)
|
||||||
|
target_probs = target_probs.squeeze().reshape(batch_size, k + 1,
|
||||||
|
self._vocab_size)
|
||||||
|
|
||||||
|
all_tokens = torch.full(size=(original_bs, k + 1),
|
||||||
|
fill_value=-1,
|
||||||
|
device=self._device,
|
||||||
|
dtype=torch.long)
|
||||||
|
all_probs = torch.zeros(original_bs,
|
||||||
|
k + 1,
|
||||||
|
self._vocab_size,
|
||||||
|
device=self._device,
|
||||||
|
dtype=torch.float32)
|
||||||
|
|
||||||
|
if non_spec_indices:
|
||||||
|
all_tokens[non_spec_indices, 0] = non_spec_target_token_ids
|
||||||
|
all_probs[non_spec_indices, :1, :] = non_spec_target_probs
|
||||||
|
|
||||||
|
if spec_indices:
|
||||||
|
all_tokens[spec_indices] = target_token_ids
|
||||||
|
all_probs[spec_indices] = target_probs
|
||||||
|
|
||||||
|
return all_tokens, all_probs
|
||||||
|
|
||||||
|
def _create_scoring_model_input(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k]
|
||||||
|
) -> List[SequenceGroupMetadata]:
|
||||||
|
"""Given the original input sequences and proposed tokens from the draft
|
||||||
|
model, create a list of target sequences that can be used for scoring.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not seq_group_metadata_list:
|
||||||
|
return []
|
||||||
|
|
||||||
|
target_seq_ids_iter = self._create_target_seq_id_iterator(
|
||||||
|
get_all_seq_ids(seq_group_metadata_list))
|
||||||
|
|
||||||
|
target_seq_group_metadata = list(
|
||||||
|
chain.from_iterable(
|
||||||
|
self._create_target_seq_group_metadata(
|
||||||
|
seq_group_metadata,
|
||||||
|
proposal_token_ids,
|
||||||
|
i,
|
||||||
|
target_seq_ids_iter,
|
||||||
|
) for i, seq_group_metadata in enumerate(
|
||||||
|
seq_group_metadata_list)))
|
||||||
|
|
||||||
|
return target_seq_group_metadata
|
||||||
|
|
||||||
|
def _create_target_seq_group_metadata(
|
||||||
|
self,
|
||||||
|
input_seq_group_metadata: SequenceGroupMetadata,
|
||||||
|
proposal_token_ids: List[TokenId], # shape: [batch_size, k]
|
||||||
|
batch_index: int,
|
||||||
|
target_seq_ids_iter: Iterator[TargetSeqId],
|
||||||
|
) -> List[SequenceGroupMetadata]:
|
||||||
|
"""Given an input sequence group metadata and a list of draft tokens,
|
||||||
|
create a list of target SequenceGroupMetadata, one for each
|
||||||
|
token id that needs to be scored.
|
||||||
|
|
||||||
|
Naive speculative decoding requires K target model scores, one for each
|
||||||
|
draft model token. However one can add a bonus token such that if each
|
||||||
|
token is accepted, then a final token may be sampled from the model.
|
||||||
|
This function creates K+1 target SequenceGroupMetadata to take
|
||||||
|
advantage of the bonus token.
|
||||||
|
"""
|
||||||
|
assert not input_seq_group_metadata.is_prompt, (
|
||||||
|
"Speculating on "
|
||||||
|
"prompts not yet supported")
|
||||||
|
assert len(input_seq_group_metadata.seq_data) == 1, (
|
||||||
|
"Beam search "
|
||||||
|
"not supported in speculative decoding")
|
||||||
|
input_seq_id = next(iter(input_seq_group_metadata.seq_data.keys()))
|
||||||
|
|
||||||
|
token_ids_to_score = self._get_token_ids_to_score(
|
||||||
|
proposal_token_ids[batch_index])
|
||||||
|
|
||||||
|
target_seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||||
|
for token_ids in token_ids_to_score:
|
||||||
|
target_seq_group_metadata_list.append(
|
||||||
|
self._create_single_target_seq_group_metadata(
|
||||||
|
input_seq_group_metadata,
|
||||||
|
input_seq_id,
|
||||||
|
next(target_seq_ids_iter),
|
||||||
|
token_ids,
|
||||||
|
))
|
||||||
|
|
||||||
|
return target_seq_group_metadata_list
|
||||||
|
|
||||||
|
def _create_single_target_seq_group_metadata(
|
||||||
|
self,
|
||||||
|
seq_group_metadata: SequenceGroupMetadata,
|
||||||
|
seq_id: SeqId,
|
||||||
|
target_seq_id: TargetSeqId,
|
||||||
|
token_ids: List[TokenId],
|
||||||
|
) -> SequenceGroupMetadata:
|
||||||
|
"""Create a single target SequenceGroupMetadata.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seq_group_metadata: The metadata for the input sequence.
|
||||||
|
seq_id: The input sequence ID.
|
||||||
|
target_seq_id: The corresponding target sequence ID.
|
||||||
|
token_ids: The list of token ids that are to be appended to the
|
||||||
|
input sequence.
|
||||||
|
"""
|
||||||
|
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||||
|
prompt_token_ids = seq_data.get_prompt_token_ids()
|
||||||
|
new_output_token_ids = [*seq_data.get_output_token_ids(), *token_ids]
|
||||||
|
|
||||||
|
return SequenceGroupMetadata(
|
||||||
|
request_id=seq_group_metadata.request_id,
|
||||||
|
is_prompt=seq_group_metadata.is_prompt,
|
||||||
|
seq_data={
|
||||||
|
target_seq_id:
|
||||||
|
SequenceData(
|
||||||
|
prompt_token_ids=prompt_token_ids,
|
||||||
|
output_token_ids=new_output_token_ids,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
sampling_params=seq_group_metadata.sampling_params,
|
||||||
|
block_tables={
|
||||||
|
target_seq_id: seq_group_metadata.block_tables[seq_id],
|
||||||
|
},
|
||||||
|
lora_request=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _split_scoring_output(
|
||||||
|
self, sampler_output: SamplerOutput, num_scoring_tokens: int
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""Split the target model output into speculative and non-speculative
|
||||||
|
output.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# vLLM currently only supports proposal lens equal to zero or the batch
|
||||||
|
# proposal len. This adds some complexity (splitting the batch into spec
|
||||||
|
# and non spec sequences) and should be removed in the future. It can be
|
||||||
|
# done by supporting per-sequence proposal lens.
|
||||||
|
#
|
||||||
|
# First samples are from speculative scoring, latter samples are non-
|
||||||
|
# speculative samples.
|
||||||
|
split_sizes = [
|
||||||
|
num_scoring_tokens,
|
||||||
|
sampler_output.sampled_token_ids.numel() - num_scoring_tokens
|
||||||
|
]
|
||||||
|
(spec_probs, non_spec_probs
|
||||||
|
) = sampler_output.sampled_token_probs.split(split_sizes)
|
||||||
|
(spec_sampled_tokens, non_spec_sampled_tokens
|
||||||
|
) = sampler_output.sampled_token_ids.flatten().split(split_sizes)
|
||||||
|
|
||||||
|
# Convert scores to tensors.
|
||||||
|
sampler_output.sampled_token_probs = spec_probs
|
||||||
|
sampler_output.sampled_token_ids = spec_sampled_tokens
|
||||||
|
target_token_ids, target_probs = sampler_output_to_torch(
|
||||||
|
[sampler_output])
|
||||||
|
|
||||||
|
# Convert non-speculative output tokens to tensors.
|
||||||
|
sampler_output.sampled_token_probs = non_spec_probs
|
||||||
|
sampler_output.sampled_token_ids = non_spec_sampled_tokens
|
||||||
|
non_spec_target_token_ids, non_spec_target_probs = sampler_output_to_torch(
|
||||||
|
[sampler_output])
|
||||||
|
|
||||||
|
return target_token_ids, target_probs, non_spec_target_token_ids, non_spec_target_probs
|
||||||
|
|
||||||
|
def _create_target_seq_id_iterator(
|
||||||
|
self, seq_ids: List[SeqId]) -> Iterator[TargetSeqId]:
|
||||||
|
"""Create an iterator for creating target sequence ids.
|
||||||
|
Target sequence ids are distinct from sequence ids because we create a
|
||||||
|
distinct target sequence id for each proposal token to be scored.
|
||||||
|
|
||||||
|
This implementation increments a counter starting at 1 + max of all
|
||||||
|
provided input sequence ids.
|
||||||
|
"""
|
||||||
|
return count(start=max(seq_ids) + 1)
|
||||||
|
|
||||||
|
def _get_token_ids_to_score(
|
||||||
|
self,
|
||||||
|
full_spec_token_ids: List[TokenId] # shape: [k]
|
||||||
|
) -> List[List[TokenId]]:
|
||||||
|
"""Given an int tensor of proposal token ids, return a list of
|
||||||
|
token ids that should be scored.
|
||||||
|
|
||||||
|
Returns k+1 output lists. The additional one is used for generating the
|
||||||
|
bonus token.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
Input: [0, 1, 2, 3] (k=4)
|
||||||
|
Output: (k+1 lists)
|
||||||
|
[]
|
||||||
|
[0]
|
||||||
|
[0, 1]
|
||||||
|
[0, 1, 2]
|
||||||
|
[0, 1, 2, 3]
|
||||||
|
"""
|
||||||
|
empty_token_ids = []
|
||||||
|
|
||||||
|
token_ids_to_score = [empty_token_ids]
|
||||||
|
token_ids_to_score.extend([
|
||||||
|
full_spec_token_ids[:i + 1]
|
||||||
|
for i in range(len(full_spec_token_ids))
|
||||||
|
])
|
||||||
|
return token_ids_to_score
|
77
vllm/spec_decode/interfaces.py
Normal file
77
vllm/spec_decode/interfaces.py
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
from typing import List, Tuple, Optional, Dict
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.sequence import SequenceGroupMetadata
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SpeculativeProposals:
|
||||||
|
"""Datastructure used to represent proposal tokens from some proposer. It
|
||||||
|
also tracks how many speculative tokens each sequence has.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Speculative proposal tokens.
|
||||||
|
proposal_token_ids: torch.Tensor
|
||||||
|
|
||||||
|
# Probabilities of the proposal tokens according to the proposer.
|
||||||
|
proposal_probs: torch.Tensor
|
||||||
|
|
||||||
|
# The valid length of each proposal; can be zero.
|
||||||
|
proposal_lens: torch.Tensor
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return (f"SpeculativeProposals("
|
||||||
|
f"proposal_token_ids={self.proposal_token_ids.shape}, "
|
||||||
|
f"proposal_probs={self.proposal_probs.shape}, "
|
||||||
|
f"proposal_lens={self.proposal_lens.shape})")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SpeculativeScores:
|
||||||
|
"""Datastructure used to represent the scores of speculative tokens
|
||||||
|
according to the scoring model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Probabilities of the speculative tokens according to the scoring model.
|
||||||
|
probs: torch.Tensor
|
||||||
|
|
||||||
|
# Token ids sampled from the scoring model. Used for speculative bonus
|
||||||
|
# tokens and also non-speculative normal decoding.
|
||||||
|
token_ids: torch.Tensor
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return (f"SpeculativeScores("
|
||||||
|
f"probs={self.probs.shape}, "
|
||||||
|
f"token_ids={self.token_ids.shape})")
|
||||||
|
|
||||||
|
|
||||||
|
class SpeculativeProposer(ABC):
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_proposals(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
blocks_to_swap_in: Dict[int, int],
|
||||||
|
blocks_to_swap_out: Dict[int, int],
|
||||||
|
blocks_to_copy: Dict[int, List[int]],
|
||||||
|
max_proposal_len: int,
|
||||||
|
) -> SpeculativeProposals:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class SpeculativeScorer(ABC):
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def score_proposals(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
blocks_to_swap_in: Optional[Dict[int, int]],
|
||||||
|
blocks_to_swap_out: Optional[Dict[int, int]],
|
||||||
|
blocks_to_copy: Optional[Dict[int, List[int]]],
|
||||||
|
k: int,
|
||||||
|
proposals: SpeculativeProposals,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
raise NotImplementedError
|
174
vllm/spec_decode/metrics.py
Normal file
174
vllm/spec_decode/metrics.py
Normal file
@ -0,0 +1,174 @@
|
|||||||
|
import torch
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||||
|
from typing import Optional
|
||||||
|
from vllm.utils import in_wsl
|
||||||
|
import time
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SpecDecodeWorkerMetrics:
|
||||||
|
"""Dataclass holding metrics emitted from the spec decode worker.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# The empirical acceptance rate of the proposal method on a per-token basis.
|
||||||
|
# This is useful for evaluating how well the proposal method aligns with the
|
||||||
|
# scoring method.
|
||||||
|
draft_acceptance_rate: float
|
||||||
|
|
||||||
|
# The empirical efficiency, measured as the number of tokens emitted by the
|
||||||
|
# system divided by the number of tokens that could be emitted by the system
|
||||||
|
# if the proposal method were perfect.
|
||||||
|
system_efficiency: float
|
||||||
|
|
||||||
|
# The number of speculative tokens produced by the proposal method.
|
||||||
|
draft_tokens: int
|
||||||
|
|
||||||
|
# The number of tokens emitted by the entire system.
|
||||||
|
emitted_tokens: int
|
||||||
|
|
||||||
|
# The number of tokens accepted by the scoring model and verification
|
||||||
|
# routine, e.g. Llama2-70B and lossless rejection sampling.
|
||||||
|
#
|
||||||
|
# NOTE: Any token accepted by the verification routine is considered
|
||||||
|
# accepted (regardless of if the speculative prefix is also accepted). The
|
||||||
|
# user will usually see less accepted tokens. This metric is helpful when
|
||||||
|
# evaluating alignment of the proposal method with the scoring model.
|
||||||
|
accepted_tokens: int
|
||||||
|
|
||||||
|
# The number of speculative tokens per sequence.
|
||||||
|
num_spec_tokens: int
|
||||||
|
|
||||||
|
|
||||||
|
Timer = Callable[[], float]
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncMetricsCollector:
|
||||||
|
"""Class which copies rejection sampler metrics from the device to CPU on a
|
||||||
|
non-default Torch stream.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
rejection_sampler: RejectionSampler,
|
||||||
|
timer: Optional[Timer] = None,
|
||||||
|
collect_interval_s: float = 5.0):
|
||||||
|
self._rejection_sampler = rejection_sampler
|
||||||
|
self._timer = time.time if timer is None else timer
|
||||||
|
|
||||||
|
self._rank: Optional[int] = None
|
||||||
|
|
||||||
|
# We don't have a device set yet.
|
||||||
|
self._copy_stream: Optional[torch.cuda.Stream] = None
|
||||||
|
|
||||||
|
self._in_flight_copy: Optional[torch.cuda.Event] = None
|
||||||
|
|
||||||
|
pin_memory = not in_wsl()
|
||||||
|
self._aggregate_num_accepted_tokens = torch.tensor(
|
||||||
|
0, dtype=torch.long, device="cpu", pin_memory=pin_memory)
|
||||||
|
self._aggregate_num_emitted_tokens = torch.tensor(
|
||||||
|
0, dtype=torch.long, device="cpu", pin_memory=pin_memory)
|
||||||
|
self._aggregate_num_draft_tokens = 0
|
||||||
|
|
||||||
|
self._rejsample_metrics_collect_interval_s = collect_interval_s
|
||||||
|
self._last_metrics_collect_time = self._timer()
|
||||||
|
|
||||||
|
def init_gpu_tensors(self, rank: int) -> None:
|
||||||
|
self._rank = rank
|
||||||
|
self._copy_stream = torch.cuda.Stream()
|
||||||
|
|
||||||
|
def maybe_collect_rejsample_metrics(
|
||||||
|
self, k: int) -> Optional[SpecDecodeWorkerMetrics]:
|
||||||
|
|
||||||
|
# If a copy was initiated in the previous call, collect and return.
|
||||||
|
if self._in_flight_copy is not None:
|
||||||
|
ready_event = self._in_flight_copy
|
||||||
|
self._in_flight_copy = None
|
||||||
|
return self._collect_rejsample_metrics(k, ready_event)
|
||||||
|
|
||||||
|
# Otherwise, check if we should start a new copy.
|
||||||
|
if self._should_collect_rejsample_metrics(self._timer()):
|
||||||
|
assert self._in_flight_copy is None
|
||||||
|
self._in_flight_copy = self._copy_rejsample_metrics_async()
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _should_collect_rejsample_metrics(self, now: float) -> bool:
|
||||||
|
"""Return whether or not this iteration should print rejection sampling
|
||||||
|
metrics.
|
||||||
|
"""
|
||||||
|
if self._rank != 0:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if (now - self._last_metrics_collect_time <
|
||||||
|
self._rejsample_metrics_collect_interval_s):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _copy_rejsample_metrics_async(self) -> torch.cuda.Event:
|
||||||
|
"""Copy rejection sampling metrics (number of accepted tokens, etc) to
|
||||||
|
CPU asynchronously.
|
||||||
|
|
||||||
|
Returns a CUDA event recording when the copy is complete.
|
||||||
|
"""
|
||||||
|
self._copy_stream.wait_stream(torch.cuda.current_stream())
|
||||||
|
|
||||||
|
with torch.cuda.stream(self._copy_stream):
|
||||||
|
self._aggregate_num_accepted_tokens.copy_(
|
||||||
|
self._rejection_sampler.num_accepted_tokens, non_blocking=True)
|
||||||
|
self._aggregate_num_emitted_tokens.copy_(
|
||||||
|
self._rejection_sampler.num_emitted_tokens, non_blocking=True)
|
||||||
|
# Number of draft tokens is calculated on CPU, so no copy is
|
||||||
|
# required.
|
||||||
|
self._aggregate_num_draft_tokens = (
|
||||||
|
self._rejection_sampler.num_draft_tokens)
|
||||||
|
|
||||||
|
aggregate_metrics_ready = torch.cuda.Event()
|
||||||
|
aggregate_metrics_ready.record(self._copy_stream)
|
||||||
|
|
||||||
|
return aggregate_metrics_ready
|
||||||
|
|
||||||
|
def _collect_rejsample_metrics(
|
||||||
|
self, k: int,
|
||||||
|
ready_event: torch.cuda.Event) -> SpecDecodeWorkerMetrics:
|
||||||
|
"""Create metrics object from statistics copied asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
k: int. The number of speculative tokens; used to determine system
|
||||||
|
efficiency.
|
||||||
|
ready_event: torch.cuda.Event. The CUDA event recording when the
|
||||||
|
async GPU->CPU copy is complete.
|
||||||
|
"""
|
||||||
|
|
||||||
|
ready_event.synchronize()
|
||||||
|
accepted_tokens = self._aggregate_num_accepted_tokens.item()
|
||||||
|
emitted_tokens = self._aggregate_num_emitted_tokens.item()
|
||||||
|
draft_tokens = self._aggregate_num_draft_tokens
|
||||||
|
|
||||||
|
num_possible_tokens = self.get_max_num_accepted_tokens(draft_tokens, k)
|
||||||
|
|
||||||
|
if draft_tokens > 0:
|
||||||
|
draft_acceptance_rate = accepted_tokens / draft_tokens
|
||||||
|
else:
|
||||||
|
draft_acceptance_rate = float("nan")
|
||||||
|
|
||||||
|
if num_possible_tokens > 0:
|
||||||
|
system_efficiency = emitted_tokens / num_possible_tokens
|
||||||
|
else:
|
||||||
|
system_efficiency = float("nan")
|
||||||
|
|
||||||
|
return SpecDecodeWorkerMetrics(
|
||||||
|
num_spec_tokens=k,
|
||||||
|
draft_acceptance_rate=draft_acceptance_rate,
|
||||||
|
system_efficiency=system_efficiency,
|
||||||
|
accepted_tokens=accepted_tokens,
|
||||||
|
draft_tokens=draft_tokens,
|
||||||
|
emitted_tokens=emitted_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_max_num_accepted_tokens(draft_tokens: int, k: int) -> int:
|
||||||
|
# Divide by k since batch size can be variable.
|
||||||
|
total_num_spec_seqs = draft_tokens / k
|
||||||
|
num_accepted_per_seq_if_all_accepted = k + 1
|
||||||
|
return int(total_num_spec_seqs / num_accepted_per_seq_if_all_accepted)
|
366
vllm/spec_decode/multi_step_worker.py
Normal file
366
vllm/spec_decode/multi_step_worker.py
Normal file
@ -0,0 +1,366 @@
|
|||||||
|
from typing import List, Dict, Optional, Tuple
|
||||||
|
import copy
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||||
|
from vllm.worker.worker import Worker
|
||||||
|
from vllm.spec_decode.interfaces import SpeculativeProposals, SpeculativeProposer
|
||||||
|
from vllm.spec_decode.util import sampler_output_to_torch
|
||||||
|
|
||||||
|
|
||||||
|
class MultiStepWorker(Worker):
|
||||||
|
"""The MultiStepWorker is equivalent to a Worker except that it allows
|
||||||
|
multiple forward passes in a single call, assuming the scheduler has
|
||||||
|
allocated enough space to store the additional KV. This reduces overhead
|
||||||
|
by invoking the scheduler less.
|
||||||
|
|
||||||
|
The MultiStepWorker does not support cache swap operations, or beam search.
|
||||||
|
Cache swap operations do not require large modifications. On the other hand,
|
||||||
|
beam search requires memory allocations during sequence forks and thus
|
||||||
|
requires more thought for MultiStepWorker support.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
self._proposer: Optional[DraftModelTop1Proposer] = None
|
||||||
|
|
||||||
|
def init_model(self):
|
||||||
|
super().init_model()
|
||||||
|
|
||||||
|
self._proposer = DraftModelTop1Proposer(
|
||||||
|
self,
|
||||||
|
self.device,
|
||||||
|
self.max_model_len,
|
||||||
|
self.vocab_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def execute_model_multi_step(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
blocks_to_swap_in: Dict[int, int],
|
||||||
|
blocks_to_swap_out: Dict[int, int],
|
||||||
|
blocks_to_copy: Dict[int, List[int]],
|
||||||
|
num_steps: int,
|
||||||
|
) -> List[SamplerOutput]:
|
||||||
|
"""Run the model forward pass num_steps times. Returns the list of
|
||||||
|
sampler output, one per model forward pass.
|
||||||
|
"""
|
||||||
|
self._raise_if_unsupported(seq_group_metadata_list, blocks_to_swap_in,
|
||||||
|
blocks_to_swap_out, blocks_to_copy)
|
||||||
|
|
||||||
|
# Shallow copy input data so modifications (such as appending tokens)
|
||||||
|
# do not cause side-effects.
|
||||||
|
copied_seq_group_metadata_list = self._shallow_copy_inputs(
|
||||||
|
seq_group_metadata_list)
|
||||||
|
|
||||||
|
# Assert enough KV space for num_steps tokens per sequence.
|
||||||
|
self._assert_enough_kv_space(seq_group_metadata_list, num_steps)
|
||||||
|
|
||||||
|
# Run model num_steps times.
|
||||||
|
model_outputs = []
|
||||||
|
for _ in range(num_steps):
|
||||||
|
model_output = super().execute_model(
|
||||||
|
seq_group_metadata_list=copied_seq_group_metadata_list,
|
||||||
|
blocks_to_swap_in=blocks_to_swap_in,
|
||||||
|
blocks_to_swap_out=blocks_to_swap_out,
|
||||||
|
blocks_to_copy=blocks_to_copy,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._append_new_tokens(model_output,
|
||||||
|
copied_seq_group_metadata_list)
|
||||||
|
model_outputs.append(model_output)
|
||||||
|
|
||||||
|
return model_outputs
|
||||||
|
|
||||||
|
def get_spec_proposals(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
blocks_to_swap_in: Dict[int, int],
|
||||||
|
blocks_to_swap_out: Dict[int, int],
|
||||||
|
blocks_to_copy: Dict[int, List[int]],
|
||||||
|
max_proposal_len: int,
|
||||||
|
) -> SpeculativeProposals:
|
||||||
|
"""Produce speculations given an input batch of sequences. The number of
|
||||||
|
speculative tokens per sequence is determined by max_proposal_len.
|
||||||
|
"""
|
||||||
|
|
||||||
|
return self._proposer.get_proposals(
|
||||||
|
seq_group_metadata_list,
|
||||||
|
blocks_to_swap_in,
|
||||||
|
blocks_to_swap_out,
|
||||||
|
blocks_to_copy,
|
||||||
|
max_proposal_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _append_new_tokens(
|
||||||
|
self, model_output: SamplerOutput,
|
||||||
|
seq_group_metadata_list: SequenceGroupMetadata) -> None:
|
||||||
|
"""Given model output from a single run, append the tokens to the
|
||||||
|
sequences. This is normally done outside of the worker, but it is
|
||||||
|
required if the worker is to perform multiple forward passes.
|
||||||
|
"""
|
||||||
|
for seq_group_metadata, sequence_group_outputs in zip(
|
||||||
|
seq_group_metadata_list, model_output):
|
||||||
|
seq_group_metadata.is_prompt = False
|
||||||
|
|
||||||
|
for seq_output in sequence_group_outputs.samples:
|
||||||
|
# NOTE: Beam search is not supported, so we can assume that
|
||||||
|
# parent_seq_id == seq_id.
|
||||||
|
seq = seq_group_metadata.seq_data[seq_output.parent_seq_id]
|
||||||
|
|
||||||
|
token_id = seq_output.output_token
|
||||||
|
token_logprob = seq_output.logprobs[token_id]
|
||||||
|
|
||||||
|
seq.append_token_id(token_id, token_logprob.logprob)
|
||||||
|
|
||||||
|
def _shallow_copy_inputs(
|
||||||
|
self, seq_group_metadata_list: List[SequenceGroupMetadata]
|
||||||
|
) -> List[SequenceGroupMetadata]:
|
||||||
|
"""Copy input data structures to remove side-effects when input data
|
||||||
|
structures are shared with other modules.
|
||||||
|
|
||||||
|
Helpful when the vLLM scheduler runs in the same process as the worker.
|
||||||
|
The alternative is deep-copying (or other form of deep copy); this has
|
||||||
|
performance downsides.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Shallow-copy the list of SequenceGroupMetadata. This allows us to
|
||||||
|
# append tokens and change is_prompt without external side-effects.
|
||||||
|
new_seq_group_metadata_list = []
|
||||||
|
|
||||||
|
for old_seq_group_metadata in seq_group_metadata_list:
|
||||||
|
# We must shallow-copy seq_group_metadata as is_prompt could change.
|
||||||
|
seq_group_metadata = copy.copy(old_seq_group_metadata)
|
||||||
|
new_seq_group_metadata_list.append(seq_group_metadata)
|
||||||
|
|
||||||
|
# We must shallow-copy seq_data as we will append token ids
|
||||||
|
new_seq_data = {}
|
||||||
|
for seq_id, old_seq_data in seq_group_metadata.seq_data.items():
|
||||||
|
new_seq_data[seq_id] = copy.copy(old_seq_data)
|
||||||
|
new_seq_data[
|
||||||
|
seq_id].output_token_ids = old_seq_data.output_token_ids[:]
|
||||||
|
|
||||||
|
seq_group_metadata.seq_data = new_seq_data
|
||||||
|
|
||||||
|
return new_seq_group_metadata_list
|
||||||
|
|
||||||
|
def _assert_enough_kv_space(
|
||||||
|
self, seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
num_steps: int) -> None:
|
||||||
|
"""Assert there are enough physical blocks per sequence to store the
|
||||||
|
current KV plus additional KV from num_steps tokens.
|
||||||
|
"""
|
||||||
|
assert self.model_runner.block_size is not None
|
||||||
|
for seq_group_metadata in seq_group_metadata_list:
|
||||||
|
# Only one seq_id is guaranteed because there is no beam search.
|
||||||
|
seq_id = list(seq_group_metadata.seq_data.keys())[0]
|
||||||
|
seq = seq_group_metadata.seq_data[seq_id]
|
||||||
|
|
||||||
|
# After num_steps, the seq len will be the current seq len
|
||||||
|
# plus one token per step.
|
||||||
|
final_seq_len = seq.get_len() + num_steps
|
||||||
|
|
||||||
|
# We will have final_seq_len - 1 KV because vLLM saves KV for a
|
||||||
|
# token in the iteration after the token was generated.
|
||||||
|
required_num_kv_slots = final_seq_len - 1
|
||||||
|
|
||||||
|
# The allocated number of kv slots is the number of allocated blocks
|
||||||
|
# times the number of slots of block.
|
||||||
|
number_physical_blocks = len(
|
||||||
|
seq_group_metadata.block_tables[seq_id])
|
||||||
|
allocated_kv_slots = (number_physical_blocks *
|
||||||
|
self.model_runner.block_size)
|
||||||
|
|
||||||
|
if required_num_kv_slots > allocated_kv_slots:
|
||||||
|
request_id = seq_group_metadata.request_id
|
||||||
|
raise ValueError(
|
||||||
|
"The worker attempted to run "
|
||||||
|
f"{num_steps} times but found insufficient KV space for "
|
||||||
|
f"{request_id=} {seq_id=}. ({allocated_kv_slots=} "
|
||||||
|
f"{required_num_kv_slots=}).")
|
||||||
|
|
||||||
|
def _raise_if_unsupported(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
blocks_to_swap_in: Dict[int, int],
|
||||||
|
blocks_to_swap_out: Dict[int, int],
|
||||||
|
blocks_to_copy: Dict[int, List[int]],
|
||||||
|
) -> None:
|
||||||
|
"""MultiStepWorker does not yet implement support for cache swap
|
||||||
|
operations or beam search.
|
||||||
|
"""
|
||||||
|
if any([blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy]):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"MultiStepWorker does not support cache operations")
|
||||||
|
|
||||||
|
if any(
|
||||||
|
len(seq_group_metadata.seq_data.keys()) != 1
|
||||||
|
for seq_group_metadata in seq_group_metadata_list):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"MultiStepWorker does not support beam search.")
|
||||||
|
|
||||||
|
|
||||||
|
class DraftModelTop1Proposer(SpeculativeProposer):
|
||||||
|
"""Helper class which separates out sequences which would exceed the max
|
||||||
|
model length when speculated upon.
|
||||||
|
|
||||||
|
This allows combinations of models such as JackFram/llama-68m draft with
|
||||||
|
meta-llama/Llama2-13b-chat-hf, as llama-68m has max_position_embeddings of
|
||||||
|
2048 while Llama2-13b has max_position_embeddings of 4096.
|
||||||
|
|
||||||
|
We treat the sequences which exceed the proposal draft model length as
|
||||||
|
"non-spec sequences". Essentially they skip the draft model and go through
|
||||||
|
normal decoding in the target model.
|
||||||
|
|
||||||
|
Currently, only proposal_lens of 0 and k are supported, where k is a global
|
||||||
|
batch proposal length. In the future vLLM should support per-sequence
|
||||||
|
proposal lengths.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
draft_worker: MultiStepWorker,
|
||||||
|
device: str,
|
||||||
|
max_model_len: int,
|
||||||
|
vocab_size: int,
|
||||||
|
):
|
||||||
|
self._draft_worker = draft_worker
|
||||||
|
self._device = device
|
||||||
|
self._max_model_len = max_model_len
|
||||||
|
self._vocab_size = vocab_size
|
||||||
|
|
||||||
|
def get_proposals(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
blocks_to_swap_in: Dict[int, int],
|
||||||
|
blocks_to_swap_out: Dict[int, int],
|
||||||
|
blocks_to_copy: Dict[int, List[int]],
|
||||||
|
max_proposal_len: int,
|
||||||
|
) -> SpeculativeProposals:
|
||||||
|
"""Get speculative proposals given the input batch.
|
||||||
|
|
||||||
|
Sequences which would exceed the max model length are skipped during
|
||||||
|
speculation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Split speculative- and non-speculative- sequences.
|
||||||
|
proposal_lens, nonzero_proposal_len_seqs, nonzero_proposal_len_indices = self._split_by_max_model_len(
|
||||||
|
seq_group_metadata_list, max_proposal_len)
|
||||||
|
|
||||||
|
if nonzero_proposal_len_seqs:
|
||||||
|
# Speculate tokens using the draft worker for the speculative
|
||||||
|
# sequences.
|
||||||
|
maybe_sampler_output = self._draft_worker.execute_model_multi_step(
|
||||||
|
seq_group_metadata_list=nonzero_proposal_len_seqs,
|
||||||
|
blocks_to_swap_in=blocks_to_swap_in,
|
||||||
|
blocks_to_swap_out=blocks_to_swap_out,
|
||||||
|
blocks_to_copy=blocks_to_copy,
|
||||||
|
num_steps=max_proposal_len,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# If no sequences can be speculated, set sampler output to None.
|
||||||
|
maybe_sampler_output = None
|
||||||
|
|
||||||
|
# Combine speculative- and non-speculative sequences into the same
|
||||||
|
# representation.
|
||||||
|
proposal_tokens, proposal_probs, proposal_lens = self._merge_outputs(
|
||||||
|
batch_size=len(seq_group_metadata_list),
|
||||||
|
max_proposal_len=max_proposal_len,
|
||||||
|
maybe_sampler_output=maybe_sampler_output,
|
||||||
|
proposal_lens=proposal_lens,
|
||||||
|
nonzero_proposal_len_indices=nonzero_proposal_len_indices,
|
||||||
|
)
|
||||||
|
|
||||||
|
proposals = SpeculativeProposals(
|
||||||
|
proposal_token_ids=proposal_tokens,
|
||||||
|
proposal_probs=proposal_probs,
|
||||||
|
proposal_lens=proposal_lens,
|
||||||
|
)
|
||||||
|
|
||||||
|
return proposals
|
||||||
|
|
||||||
|
def _split_by_max_model_len(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
max_proposal_len: int,
|
||||||
|
) -> Tuple[List[int], List[SequenceGroupMetadata], List[int]]:
|
||||||
|
"""Determine which sequences would exceed the max model length.
|
||||||
|
"""
|
||||||
|
|
||||||
|
proposal_lens: List[int] = []
|
||||||
|
nonzero_proposal_len_seqs: List[SequenceGroupMetadata] = []
|
||||||
|
nonzero_proposal_len_indices: List[int] = []
|
||||||
|
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
||||||
|
seq_data = next(iter(seq_group_metadata.seq_data.values()))
|
||||||
|
seq_len = seq_data.get_len()
|
||||||
|
|
||||||
|
# Currently only proposal lens of 0 or the global batch proposal len
|
||||||
|
# are supported.
|
||||||
|
if seq_len + max_proposal_len < self._max_model_len:
|
||||||
|
proposal_lens.append(max_proposal_len)
|
||||||
|
nonzero_proposal_len_seqs.append(seq_group_metadata)
|
||||||
|
nonzero_proposal_len_indices.append(i)
|
||||||
|
else:
|
||||||
|
proposal_lens.append(0)
|
||||||
|
|
||||||
|
return proposal_lens, nonzero_proposal_len_seqs, nonzero_proposal_len_indices
|
||||||
|
|
||||||
|
def _merge_outputs(
|
||||||
|
self,
|
||||||
|
batch_size: int,
|
||||||
|
max_proposal_len: int,
|
||||||
|
maybe_sampler_output: Optional[SamplerOutput],
|
||||||
|
proposal_lens: List[int],
|
||||||
|
nonzero_proposal_len_indices: List[int],
|
||||||
|
) -> Tuple[torch.Tensor, torch.tensor, torch.Tensor]:
|
||||||
|
"""After speculations are produced, merge the speculation results with
|
||||||
|
the skipped sequences.
|
||||||
|
"""
|
||||||
|
if maybe_sampler_output is None:
|
||||||
|
# If no speculative tokens, the sampler output will be None.
|
||||||
|
# In this case we return empty tensors.
|
||||||
|
proposal_tokens = torch.zeros(0,
|
||||||
|
max_proposal_len,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=self._device)
|
||||||
|
proposal_probs = torch.zeros(0,
|
||||||
|
max_proposal_len,
|
||||||
|
self._vocab_size,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=self._device)
|
||||||
|
proposal_lens = torch.zeros(len(proposal_lens),
|
||||||
|
dtype=torch.long,
|
||||||
|
device=self._device)
|
||||||
|
return proposal_tokens, proposal_probs, proposal_lens
|
||||||
|
|
||||||
|
sampler_output = maybe_sampler_output
|
||||||
|
|
||||||
|
proposal_tokens, proposal_probs = sampler_output_to_torch(
|
||||||
|
sampler_output)
|
||||||
|
|
||||||
|
# Now, reformat the output GPU tensors such that each sequence has
|
||||||
|
# a proposal. the proposal can be empty, e.g. [-1, -1, -1]
|
||||||
|
|
||||||
|
entire_proposal_tokens = torch.full(size=(batch_size,
|
||||||
|
*proposal_tokens.shape[1:]),
|
||||||
|
fill_value=-1,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=self._device)
|
||||||
|
entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens
|
||||||
|
entire_proposal_probs = torch.zeros(batch_size,
|
||||||
|
*proposal_probs.shape[1:],
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=self._device)
|
||||||
|
entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs
|
||||||
|
|
||||||
|
proposal_tokens, proposal_probs = entire_proposal_tokens, entire_proposal_probs
|
||||||
|
|
||||||
|
proposal_lens = torch.zeros(batch_size,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=self._device)
|
||||||
|
proposal_lens[nonzero_proposal_len_indices] = max_proposal_len
|
||||||
|
|
||||||
|
return proposal_tokens, proposal_probs, proposal_lens
|
372
vllm/spec_decode/spec_decode_worker.py
Normal file
372
vllm/spec_decode/spec_decode_worker.py
Normal file
@ -0,0 +1,372 @@
|
|||||||
|
from typing import List, Tuple, Optional, Dict
|
||||||
|
from functools import cached_property
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.spec_decode.metrics import AsyncMetricsCollector
|
||||||
|
from vllm.sequence import (SamplerOutput, SequenceGroupMetadata,
|
||||||
|
SequenceGroupOutput, SequenceOutput)
|
||||||
|
from vllm.worker.worker import Worker
|
||||||
|
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||||
|
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||||
|
from vllm.config import CacheConfig
|
||||||
|
from vllm.spec_decode.util import nvtx_range, get_all_seq_ids, split_batch_by_proposal_len
|
||||||
|
from vllm.spec_decode.interfaces import SpeculativeProposals, SpeculativeScores
|
||||||
|
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
|
||||||
|
from vllm.spec_decode.interfaces import SpeculativeScorer
|
||||||
|
|
||||||
|
|
||||||
|
class SpecDecodeWorker:
|
||||||
|
"""Worker which implements speculative decoding.
|
||||||
|
|
||||||
|
Speculative decoding reduces decoding per-token latency by using a proposal
|
||||||
|
method, such as a small draft model, to speculate ahead of a larger LLM. The
|
||||||
|
probabilities of the speculative tokens are then determined by the larger
|
||||||
|
LLM, after which some verification routine determines which (if any) of the
|
||||||
|
speculative tokens are accepted by the larger LLM.
|
||||||
|
|
||||||
|
See https://github.com/vllm-project/vllm/pull/2188 and
|
||||||
|
https://github.com/vllm-project/vllm/pull/3103 for more info.
|
||||||
|
|
||||||
|
The current implementation has the following limitations:
|
||||||
|
* Only draft-model proposal is implemented (contributions for more forms are
|
||||||
|
welcome!).
|
||||||
|
* Only top-1 proposal and scoring are implemented. Tree-attention is left as
|
||||||
|
future work.
|
||||||
|
* Only lossless rejection sampling is supported. Contributions adding lossy
|
||||||
|
verification routines are welcome (e.g. Medusa's typical acceptance).
|
||||||
|
* All sequences in a batch must have the same proposal length, or zero. This
|
||||||
|
can be improved by having per-sequence speculation in the future.
|
||||||
|
* The scoring forward pass is done without an MQA kernel, which is
|
||||||
|
suboptimal especially as the batch size, proposal length, and sequence
|
||||||
|
lengths grow. Contributions to add a MQA scoring are welcome once
|
||||||
|
correctness tests pass.
|
||||||
|
More info here https://docs.google.com/document/d/1T-JaS2T1NRfdP51qzqpyakoCXxSXTtORppiwaj5asxA/edit.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
proposer_worker: MultiStepWorker,
|
||||||
|
scorer_worker: Worker,
|
||||||
|
rejection_sampler: RejectionSampler,
|
||||||
|
metrics_collector: Optional[AsyncMetricsCollector] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create a SpecDecodeWorker.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
proposer_worker: A worker that can produce speculative tokens for
|
||||||
|
sequences.
|
||||||
|
scorer_worker: A worker that produces probabilities of speculative
|
||||||
|
tokens according to some base model. Typically a vanilla vLLM
|
||||||
|
Worker.
|
||||||
|
rejection_sampler: A Torch module used to perform modified rejection
|
||||||
|
sampling for speculative decoding.
|
||||||
|
metrics_collector: Helper class for collecting metrics; can be set
|
||||||
|
for testing purposes.
|
||||||
|
"""
|
||||||
|
self.proposer_worker = proposer_worker
|
||||||
|
self.scorer_worker = scorer_worker
|
||||||
|
self.rejection_sampler = rejection_sampler
|
||||||
|
|
||||||
|
self._metrics = AsyncMetricsCollector(
|
||||||
|
rejection_sampler
|
||||||
|
) if metrics_collector is None else metrics_collector
|
||||||
|
|
||||||
|
self.probs_dtype = self.rejection_sampler.probs_dtype
|
||||||
|
self.token_id_dtype = self.rejection_sampler.token_id_dtype
|
||||||
|
|
||||||
|
self.scorer: SpeculativeScorer = None
|
||||||
|
|
||||||
|
def init_model(self) -> None:
|
||||||
|
"""Initialize both scorer and proposer models.
|
||||||
|
"""
|
||||||
|
# The scorer worker model is initialized first in case the proposer
|
||||||
|
# model has a smaller TP degree than the target worker.
|
||||||
|
self.scorer_worker.init_model()
|
||||||
|
self.proposer_worker.init_model()
|
||||||
|
|
||||||
|
self._metrics.init_gpu_tensors(self.rank)
|
||||||
|
self.rejection_sampler.init_gpu_tensors(self.rank)
|
||||||
|
self.scorer = BatchExpansionTop1Scorer(
|
||||||
|
scorer_worker=self.scorer_worker,
|
||||||
|
device=self.device,
|
||||||
|
vocab_size=self._vocab_size)
|
||||||
|
|
||||||
|
def profile_num_available_blocks(self, block_size: int,
|
||||||
|
gpu_memory_utilization: float,
|
||||||
|
cpu_swap_space: int,
|
||||||
|
cache_dtype: str) -> Tuple[int, int]:
|
||||||
|
"""Determine the number of cache blocks to use.
|
||||||
|
|
||||||
|
This is done by profiling the scorer model (which is typically the
|
||||||
|
larger of the two). Then the total memory which would be used by the
|
||||||
|
scorer cache is divided evenly between the proposer and scorer model KV,
|
||||||
|
such that the number of blocks is equal in both KV caches.
|
||||||
|
"""
|
||||||
|
num_gpu_blocks, num_cpu_blocks = (
|
||||||
|
self.scorer_worker.profile_num_available_blocks(
|
||||||
|
block_size, gpu_memory_utilization, cpu_swap_space,
|
||||||
|
cache_dtype))
|
||||||
|
|
||||||
|
scorer_cache_block_size_bytes = self.scorer_worker.get_cache_block_size_bytes(
|
||||||
|
block_size, cache_dtype)
|
||||||
|
proposer_cache_block_size_bytes = self.proposer_worker.get_cache_block_size_bytes(
|
||||||
|
block_size, cache_dtype)
|
||||||
|
|
||||||
|
new_num_gpu_blocks = split_num_cache_blocks_evenly(
|
||||||
|
scorer_cache_block_size_bytes, proposer_cache_block_size_bytes,
|
||||||
|
num_gpu_blocks)
|
||||||
|
return new_num_gpu_blocks, num_cpu_blocks
|
||||||
|
|
||||||
|
def init_cache_engine(self, cache_config: CacheConfig):
|
||||||
|
"""Initialize the cache engine of the scorer and proposer workers.
|
||||||
|
"""
|
||||||
|
self.scorer_worker.init_cache_engine(cache_config)
|
||||||
|
self.proposer_worker.init_cache_engine(cache_config)
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def execute_model(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
blocks_to_swap_in: Optional[Dict[int, int]],
|
||||||
|
blocks_to_swap_out: Optional[Dict[int, int]],
|
||||||
|
blocks_to_copy: Optional[Dict[int, List[int]]],
|
||||||
|
num_spec_tokens: int,
|
||||||
|
) -> List[SamplerOutput]:
|
||||||
|
"""Perform speculative decoding on the input batch.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert seq_group_metadata_list is not None, (
|
||||||
|
"speculative decoding "
|
||||||
|
"requires non-None seq_group_metadata_list")
|
||||||
|
|
||||||
|
# If no spec tokens, call the proposer and scorer workers normally.
|
||||||
|
# Used for prefill.
|
||||||
|
if num_spec_tokens == 0 or len(seq_group_metadata_list) == 0:
|
||||||
|
return self._run_no_spec(
|
||||||
|
seq_group_metadata_list=seq_group_metadata_list,
|
||||||
|
blocks_to_swap_in=blocks_to_swap_in,
|
||||||
|
blocks_to_swap_out=blocks_to_swap_out,
|
||||||
|
blocks_to_copy=blocks_to_copy,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._run_speculative_decoding_step(
|
||||||
|
seq_group_metadata_list=seq_group_metadata_list,
|
||||||
|
blocks_to_swap_in=blocks_to_swap_in,
|
||||||
|
blocks_to_swap_out=blocks_to_swap_out,
|
||||||
|
blocks_to_copy=blocks_to_copy,
|
||||||
|
k=num_spec_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
@nvtx_range("spec_decode_worker._run_no_spec")
|
||||||
|
def _run_no_spec(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
blocks_to_swap_in: Optional[Dict[int, int]],
|
||||||
|
blocks_to_swap_out: Optional[Dict[int, int]],
|
||||||
|
blocks_to_copy: Optional[Dict[int, List[int]]],
|
||||||
|
) -> List[SamplerOutput]:
|
||||||
|
"""Run a prefill step, without any speculation. The input is sent to the
|
||||||
|
proposer and scorer model so that the KV cache is consistent between the
|
||||||
|
two.
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.proposer_worker.execute_model(
|
||||||
|
seq_group_metadata_list=seq_group_metadata_list,
|
||||||
|
blocks_to_swap_in=blocks_to_swap_in,
|
||||||
|
blocks_to_swap_out=blocks_to_swap_out,
|
||||||
|
blocks_to_copy=blocks_to_copy,
|
||||||
|
return_python_output=False)
|
||||||
|
|
||||||
|
sampler_output = self.scorer_worker.execute_model(
|
||||||
|
seq_group_metadata_list=seq_group_metadata_list,
|
||||||
|
blocks_to_swap_in=blocks_to_swap_in,
|
||||||
|
blocks_to_swap_out=blocks_to_swap_out,
|
||||||
|
blocks_to_copy=blocks_to_copy,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Clear device tensors from sampler output. This reduces communication
|
||||||
|
# overhead when the engine runs in a different process than the workers.
|
||||||
|
sampler_output.probs = None
|
||||||
|
sampler_output.sampled_tokens = None
|
||||||
|
return [sampler_output]
|
||||||
|
|
||||||
|
@nvtx_range("spec_decode_worker._run_speculative_decoding_step")
|
||||||
|
def _run_speculative_decoding_step(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
blocks_to_swap_in: Optional[Dict[int, int]],
|
||||||
|
blocks_to_swap_out: Optional[Dict[int, int]],
|
||||||
|
blocks_to_copy: Optional[Dict[int, List[int]]],
|
||||||
|
k: int,
|
||||||
|
) -> List[SamplerOutput]:
|
||||||
|
"""Execute a single step of speculative decoding.
|
||||||
|
|
||||||
|
This invokes the proposer worker to get k speculative tokens for each
|
||||||
|
sequence, then scores each speculative token using the scoring worker.
|
||||||
|
|
||||||
|
Returns a list of SamplerOutput, each containing a single token per
|
||||||
|
sequence.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Generate proposals using draft worker.
|
||||||
|
proposals = self.proposer_worker.get_spec_proposals(
|
||||||
|
seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out,
|
||||||
|
blocks_to_copy, k)
|
||||||
|
|
||||||
|
proposal_scores = self.scorer.score_proposals(
|
||||||
|
seq_group_metadata_list,
|
||||||
|
blocks_to_swap_in,
|
||||||
|
blocks_to_swap_out,
|
||||||
|
blocks_to_copy,
|
||||||
|
k,
|
||||||
|
proposals,
|
||||||
|
)
|
||||||
|
|
||||||
|
accepted_token_ids = self._verify_tokens(seq_group_metadata_list,
|
||||||
|
proposal_scores, proposals, k)
|
||||||
|
|
||||||
|
return self._create_output_sampler_list(seq_group_metadata_list,
|
||||||
|
accepted_token_ids, k)
|
||||||
|
|
||||||
|
@nvtx_range("spec_decode_worker._verify_tokens")
|
||||||
|
def _verify_tokens(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
proposal_scores: SpeculativeScores,
|
||||||
|
proposals: SpeculativeProposals,
|
||||||
|
max_proposal_len: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Determine which speculative tokens are accepted using the
|
||||||
|
probabilities of each token according to the proposer and scorer models.
|
||||||
|
"""
|
||||||
|
proposal_lens_list = proposals.proposal_lens.tolist()
|
||||||
|
|
||||||
|
# vLLM currently only supports proposal lens equal to zero or the batch
|
||||||
|
# proposal len. This adds some complexity (splitting the batch into spec
|
||||||
|
# and non spec sequences) and should be removed in the future. It can be
|
||||||
|
# done by supporting per-sequence proposal lens.
|
||||||
|
_, spec_indices = split_batch_by_proposal_len(
|
||||||
|
seq_group_metadata_list,
|
||||||
|
proposal_lens_list,
|
||||||
|
select_proposal_len_zero=False)
|
||||||
|
_, non_spec_indices = split_batch_by_proposal_len(
|
||||||
|
seq_group_metadata_list,
|
||||||
|
proposal_lens_list,
|
||||||
|
select_proposal_len_zero=True)
|
||||||
|
original_indices = spec_indices + non_spec_indices
|
||||||
|
|
||||||
|
proposal_probs = proposal_scores.probs[spec_indices, :-1]
|
||||||
|
bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:]
|
||||||
|
non_spec_token_ids = proposal_scores.token_ids[non_spec_indices]
|
||||||
|
|
||||||
|
accepted_token_ids = self.rejection_sampler(
|
||||||
|
proposal_probs,
|
||||||
|
bonus_token_ids,
|
||||||
|
proposals.proposal_probs,
|
||||||
|
proposals.proposal_token_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Append output tokens from non-speculative sequences to
|
||||||
|
# the accepted token ids tensor.
|
||||||
|
non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len +
|
||||||
|
1).clone()
|
||||||
|
non_spec_token_ids[:, 1:] = -1
|
||||||
|
accepted_token_ids = torch.cat(
|
||||||
|
[accepted_token_ids, non_spec_token_ids])
|
||||||
|
|
||||||
|
# Rearrange so that results are in the order of the original seq group
|
||||||
|
# metadata.
|
||||||
|
accepted_token_ids[original_indices] = accepted_token_ids.clone()
|
||||||
|
|
||||||
|
return accepted_token_ids
|
||||||
|
|
||||||
|
def _create_output_sampler_list(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
accepted_token_ids: torch.Tensor, # shape: [batch_size, k+1]
|
||||||
|
k: int,
|
||||||
|
) -> List[SamplerOutput]:
|
||||||
|
"""Given the accepted token ids, create a list of SamplerOutput.
|
||||||
|
|
||||||
|
The output is padded with -1 tokens such that each sequence has
|
||||||
|
the same number of outputs.
|
||||||
|
"""
|
||||||
|
seq_ids = get_all_seq_ids(seq_group_metadata_list)
|
||||||
|
|
||||||
|
# shape: [k+1, batch_size]
|
||||||
|
accepted_token_ids_by_step = accepted_token_ids.transpose(0,
|
||||||
|
1).tolist()
|
||||||
|
sampler_output_list = []
|
||||||
|
for token_ids_by_step in accepted_token_ids_by_step:
|
||||||
|
if all(token_id == -1 for token_id in token_ids_by_step):
|
||||||
|
break
|
||||||
|
|
||||||
|
step_output_token_ids = []
|
||||||
|
for token_id, seq_id in zip(token_ids_by_step, seq_ids):
|
||||||
|
step_output_token_ids.append(
|
||||||
|
SequenceGroupOutput(
|
||||||
|
samples=[
|
||||||
|
SequenceOutput(
|
||||||
|
parent_seq_id=seq_id,
|
||||||
|
output_token=token_id,
|
||||||
|
# TODO Add verifier logprobs.
|
||||||
|
logprobs={token_id: 0.0},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
prompt_logprobs=None,
|
||||||
|
))
|
||||||
|
sampler_output_list.append(
|
||||||
|
SamplerOutput(outputs=step_output_token_ids))
|
||||||
|
|
||||||
|
maybe_rejsample_metrics = self._metrics.maybe_collect_rejsample_metrics(
|
||||||
|
k)
|
||||||
|
if maybe_rejsample_metrics is not None:
|
||||||
|
sampler_output_list[
|
||||||
|
0].spec_decode_worker_metrics = maybe_rejsample_metrics
|
||||||
|
|
||||||
|
return sampler_output_list
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def _vocab_size(self) -> int:
|
||||||
|
"""Get the vocab size of the model and make sure it's consistent between
|
||||||
|
draft and target workers.
|
||||||
|
"""
|
||||||
|
vocab_sizes = [
|
||||||
|
worker.vocab_size
|
||||||
|
for worker in [self.proposer_worker, self.scorer_worker]
|
||||||
|
]
|
||||||
|
assert all(vocab_sizes[0] == vocab_size for vocab_size in vocab_sizes)
|
||||||
|
return vocab_sizes[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def rank(self):
|
||||||
|
return self.scorer_worker.rank
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
return self.scorer_worker.device
|
||||||
|
|
||||||
|
|
||||||
|
def split_num_cache_blocks_evenly(scorer_cache_block_size_bytes: int,
|
||||||
|
proposer_cache_block_size_bytes: int,
|
||||||
|
total_num_gpu_blocks: int) -> int:
|
||||||
|
"""Given total_num_gpu_blocks, the number of GPU blocks that could be
|
||||||
|
allocate to the target model, this function calculates how many blocks
|
||||||
|
should be given to the draft and target model.
|
||||||
|
|
||||||
|
Note that usually the block size, in bytes, of each model is different,
|
||||||
|
as it's a function of number of KV/layer, number of heads, and hidden
|
||||||
|
dimension size.
|
||||||
|
|
||||||
|
Since the target and draft models allocate the same number of blocks, we
|
||||||
|
simply calculate the number of blocks where if allocated by both models,
|
||||||
|
the total memory usage from KV cache is no larger than the number of
|
||||||
|
blocks allocatable by the target model alone.
|
||||||
|
"""
|
||||||
|
new_num_gpu_blocks = int(
|
||||||
|
total_num_gpu_blocks * scorer_cache_block_size_bytes /
|
||||||
|
(proposer_cache_block_size_bytes + scorer_cache_block_size_bytes))
|
||||||
|
|
||||||
|
return new_num_gpu_blocks
|
99
vllm/spec_decode/util.py
Normal file
99
vllm/spec_decode/util.py
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
import torch
|
||||||
|
from typing import List, Tuple
|
||||||
|
from vllm.sequence import SequenceGroupMetadata, SamplerOutput
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from itertools import chain
|
||||||
|
|
||||||
|
SeqId = int
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_seq_ids(
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[SeqId]:
|
||||||
|
"""Given a list of SequenceGroupMetadata, create a list of all
|
||||||
|
sequence ids.
|
||||||
|
"""
|
||||||
|
return list(
|
||||||
|
chain.from_iterable([
|
||||||
|
seq_group_metadata.seq_data.keys()
|
||||||
|
for seq_group_metadata in seq_group_metadata_list
|
||||||
|
]))
|
||||||
|
|
||||||
|
|
||||||
|
def split_batch_by_proposal_len(
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
proposal_lens: List[int], select_proposal_len_zero: bool
|
||||||
|
) -> Tuple[List[SequenceGroupMetadata], List[int]]:
|
||||||
|
"""Utility function that splits a batch based on whether the proposal len is
|
||||||
|
zero or not. We should remove this once vLLM supports per-sequence proposal
|
||||||
|
lens in a batch.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if select_proposal_len_zero:
|
||||||
|
predicate = lambda proposal_len: proposal_len == 0
|
||||||
|
else:
|
||||||
|
predicate = lambda proposal_len: proposal_len != 0
|
||||||
|
|
||||||
|
indices = [
|
||||||
|
i for i, (_, proposal_len
|
||||||
|
) in enumerate(zip(seq_group_metadata_list, proposal_lens))
|
||||||
|
if predicate(proposal_len)
|
||||||
|
]
|
||||||
|
seq_groups = [
|
||||||
|
seq_group for seq_group, proposal_len in zip(
|
||||||
|
seq_group_metadata_list, proposal_lens) if predicate(proposal_len)
|
||||||
|
]
|
||||||
|
|
||||||
|
return seq_groups, indices
|
||||||
|
|
||||||
|
|
||||||
|
def sampler_output_to_torch(
|
||||||
|
sampler_output_list: List[SamplerOutput],
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Utility function which converts a list of SamplerOutput to tensors.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
sampled_token_ids: torch.Tensor
|
||||||
|
shape: [batch_size, len(sampler_output_list)]
|
||||||
|
|
||||||
|
sampled_token_probs: torch.Tensor
|
||||||
|
shape: [batch_size, len(sampler_output_list), vocab_size]
|
||||||
|
"""
|
||||||
|
|
||||||
|
# shape: [batch_size, num_sampler_output, vocab_size]
|
||||||
|
sampled_token_probs = torch.stack(
|
||||||
|
[
|
||||||
|
sampler_output.sampled_token_probs
|
||||||
|
for sampler_output in sampler_output_list
|
||||||
|
],
|
||||||
|
dim=0,
|
||||||
|
).transpose(0, 1)
|
||||||
|
|
||||||
|
# shape: [batch_size, num_sampler_output]
|
||||||
|
sampled_token_ids = torch.stack(
|
||||||
|
[
|
||||||
|
sampler_output.sampled_token_ids.flatten()
|
||||||
|
for sampler_output in sampler_output_list
|
||||||
|
],
|
||||||
|
dim=0,
|
||||||
|
).transpose(0, 1)
|
||||||
|
|
||||||
|
return sampled_token_ids, sampled_token_probs
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def nvtx_range(msg, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Context manager / decorator that pushes an NVTX range at the beginning
|
||||||
|
of its scope, and pops it at the end. If extra arguments are given,
|
||||||
|
they are passed as arguments to msg.format().
|
||||||
|
|
||||||
|
If running with cuda graphs, you must enable nsys cuda graph profiling.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
msg (string): message to associate with the range
|
||||||
|
"""
|
||||||
|
torch.cuda.nvtx.range_push(msg.format(*args, **kwargs))
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
torch.cuda.nvtx.range_pop()
|
@ -97,8 +97,6 @@ class ModelRunner:
|
|||||||
f"Loading model weights took {self.model_memory_usage / float(2**30):.4f} GB"
|
f"Loading model weights took {self.model_memory_usage / float(2**30):.4f} GB"
|
||||||
)
|
)
|
||||||
|
|
||||||
vocab_size = self.model.config.vocab_size
|
|
||||||
|
|
||||||
if self.lora_config:
|
if self.lora_config:
|
||||||
assert hasattr(
|
assert hasattr(
|
||||||
self.model, "supported_lora_modules"
|
self.model, "supported_lora_modules"
|
||||||
@ -111,7 +109,7 @@ class ModelRunner:
|
|||||||
self.lora_manager = LRUCacheWorkerLoRAManager(
|
self.lora_manager = LRUCacheWorkerLoRAManager(
|
||||||
self.scheduler_config.max_num_seqs,
|
self.scheduler_config.max_num_seqs,
|
||||||
self.scheduler_config.max_num_batched_tokens +
|
self.scheduler_config.max_num_batched_tokens +
|
||||||
self.scheduler_config.max_paddings, vocab_size,
|
self.scheduler_config.max_paddings, self.vocab_size,
|
||||||
self.lora_config, self.device, self.model.embedding_modules,
|
self.lora_config, self.device, self.model.embedding_modules,
|
||||||
self.model.embedding_padding_modules)
|
self.model.embedding_padding_modules)
|
||||||
self.model = self.lora_manager.create_lora_manager(self.model)
|
self.model = self.lora_manager.create_lora_manager(self.model)
|
||||||
@ -607,8 +605,7 @@ class ModelRunner:
|
|||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def profile_run(self) -> None:
|
def profile_run(self) -> None:
|
||||||
# Enable top-k sampling to reflect the accurate memory usage.
|
# Enable top-k sampling to reflect the accurate memory usage.
|
||||||
vocab_size = self.model_config.get_vocab_size()
|
sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
|
||||||
sampling_params = SamplingParams(top_p=0.99, top_k=vocab_size - 1)
|
|
||||||
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
|
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
|
||||||
max_num_seqs = self.scheduler_config.max_num_seqs
|
max_num_seqs = self.scheduler_config.max_num_seqs
|
||||||
|
|
||||||
@ -774,6 +771,10 @@ class ModelRunner:
|
|||||||
self.graph_runners.clear()
|
self.graph_runners.clear()
|
||||||
self.cupy_nccl_backend = None
|
self.cupy_nccl_backend = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vocab_size(self) -> int:
|
||||||
|
return self.model_config.get_vocab_size()
|
||||||
|
|
||||||
|
|
||||||
class CUDAGraphRunner:
|
class CUDAGraphRunner:
|
||||||
|
|
||||||
|
@ -1,178 +0,0 @@
|
|||||||
from typing import List, Dict
|
|
||||||
import copy
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
|
||||||
from vllm.worker.worker import Worker
|
|
||||||
|
|
||||||
|
|
||||||
class MultiStepWorker(Worker):
|
|
||||||
"""The MultiStepWorker is equivalent to a Worker except that it allows
|
|
||||||
multiple forward passes in a single call, assuming the scheduler has
|
|
||||||
allocated enough space to store the additional KV. This reduces overhead
|
|
||||||
by invoking the scheduler less.
|
|
||||||
|
|
||||||
The MultiStepWorker does not support cache swap operations, or beam search.
|
|
||||||
Cache swap operations do not require large modifications. On the other hand,
|
|
||||||
beam search requires memory allocations during sequence forks and thus
|
|
||||||
requires more thought for MultiStepWorker support.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def execute_model_multi_step(
|
|
||||||
self,
|
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
|
||||||
blocks_to_swap_in: Dict[int, int],
|
|
||||||
blocks_to_swap_out: Dict[int, int],
|
|
||||||
blocks_to_copy: Dict[int, List[int]],
|
|
||||||
num_steps: int,
|
|
||||||
) -> List[SamplerOutput]:
|
|
||||||
"""Run the model forward pass num_steps times. Returns the list of
|
|
||||||
sampler output, one per model forward pass.
|
|
||||||
"""
|
|
||||||
self._raise_if_unsupported(seq_group_metadata_list, blocks_to_swap_in,
|
|
||||||
blocks_to_swap_out, blocks_to_copy)
|
|
||||||
|
|
||||||
# Shallow copy input data so modifications (such as appending tokens)
|
|
||||||
# do not cause side-effects.
|
|
||||||
copied_seq_group_metadata_list = self._shallow_copy_inputs(
|
|
||||||
seq_group_metadata_list)
|
|
||||||
|
|
||||||
# Assert enough KV space for num_steps tokens per sequence.
|
|
||||||
self._assert_enough_kv_space(seq_group_metadata_list, num_steps)
|
|
||||||
|
|
||||||
# Run model num_steps times.
|
|
||||||
model_outputs = []
|
|
||||||
for _ in range(num_steps):
|
|
||||||
model_output = super().execute_model(
|
|
||||||
seq_group_metadata_list=copied_seq_group_metadata_list,
|
|
||||||
blocks_to_swap_in=blocks_to_swap_in,
|
|
||||||
blocks_to_swap_out=blocks_to_swap_out,
|
|
||||||
blocks_to_copy=blocks_to_copy,
|
|
||||||
)
|
|
||||||
|
|
||||||
self._append_new_tokens(model_output,
|
|
||||||
copied_seq_group_metadata_list)
|
|
||||||
model_outputs.append(model_output)
|
|
||||||
|
|
||||||
return model_outputs
|
|
||||||
|
|
||||||
def _append_new_tokens(
|
|
||||||
self, model_output: SamplerOutput,
|
|
||||||
seq_group_metadata_list: SequenceGroupMetadata) -> None:
|
|
||||||
"""Given model output from a single run, append the tokens to the
|
|
||||||
sequences. This is normally done outside of the worker, but it is
|
|
||||||
required if the worker is to perform multiple forward passes.
|
|
||||||
"""
|
|
||||||
for seq_group_metadata, sequence_group_outputs in zip(
|
|
||||||
seq_group_metadata_list, model_output):
|
|
||||||
seq_group_metadata.is_prompt = False
|
|
||||||
|
|
||||||
for seq_output in sequence_group_outputs.samples:
|
|
||||||
# NOTE: Beam search is not supported, so we can assume that
|
|
||||||
# parent_seq_id == seq_id.
|
|
||||||
seq = seq_group_metadata.seq_data[seq_output.parent_seq_id]
|
|
||||||
|
|
||||||
token_id = seq_output.output_token
|
|
||||||
token_logprob = seq_output.logprobs[token_id]
|
|
||||||
|
|
||||||
seq.append_token_id(token_id, token_logprob.logprob)
|
|
||||||
|
|
||||||
def _shallow_copy_inputs(
|
|
||||||
self, seq_group_metadata_list: List[SequenceGroupMetadata]
|
|
||||||
) -> List[SequenceGroupMetadata]:
|
|
||||||
"""Copy input data structures to remove side-effects when input data
|
|
||||||
structures are shared with other modules.
|
|
||||||
|
|
||||||
The multi-step worker must be able to append tokens to sequences after
|
|
||||||
a forward pass. This necessitates modification of the data structures
|
|
||||||
used by the worker. Since these data structures are shared with other
|
|
||||||
parts of vLLM, like the scheduler, we must take care not to introduce
|
|
||||||
unexpected side-effects.
|
|
||||||
|
|
||||||
When Ray is used to orchestrate worker processes (such as when the
|
|
||||||
tensor-parallel degree is >1), this is not a problem because the input
|
|
||||||
datastructures will be serialized and created anew in the worker
|
|
||||||
process.
|
|
||||||
|
|
||||||
However, when Ray is not used to orchestrate the worker processes (such
|
|
||||||
as when the tensor-parallel degree is 1), this is a problem. We avoid
|
|
||||||
the problem by shallow-copying the input datastructures (specifically,
|
|
||||||
the parts that will change in multiple steps).
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Shallow-copy the list of SequenceGroupMetadata. This allows us to
|
|
||||||
# append tokens and change is_prompt without external side-effects.
|
|
||||||
new_seq_group_metadata_list = []
|
|
||||||
|
|
||||||
for old_seq_group_metadata in seq_group_metadata_list:
|
|
||||||
# We must shallow-copy seq_group_metadata as is_prompt could change.
|
|
||||||
seq_group_metadata = copy.copy(old_seq_group_metadata)
|
|
||||||
new_seq_group_metadata_list.append(seq_group_metadata)
|
|
||||||
|
|
||||||
# We must shallow-copy seq_data as we will append token ids
|
|
||||||
new_seq_data = {}
|
|
||||||
for seq_id, old_seq_data in seq_group_metadata.seq_data.items():
|
|
||||||
new_seq_data[seq_id] = copy.copy(old_seq_data)
|
|
||||||
new_seq_data[
|
|
||||||
seq_id].output_token_ids = old_seq_data.output_token_ids[:]
|
|
||||||
|
|
||||||
seq_group_metadata.seq_data = new_seq_data
|
|
||||||
|
|
||||||
return new_seq_group_metadata_list
|
|
||||||
|
|
||||||
def _assert_enough_kv_space(
|
|
||||||
self, seq_group_metadata_list: List[SequenceGroupMetadata],
|
|
||||||
num_steps: int) -> None:
|
|
||||||
"""Assert there are enough physical blocks per sequence to store the
|
|
||||||
current KV plus additional KV from num_steps tokens.
|
|
||||||
"""
|
|
||||||
assert self.model_runner.block_size is not None
|
|
||||||
for seq_group_metadata in seq_group_metadata_list:
|
|
||||||
# Only one seq_id is guaranteed because there is no beam search.
|
|
||||||
seq_id = list(seq_group_metadata.seq_data.keys())[0]
|
|
||||||
seq = seq_group_metadata.seq_data[seq_id]
|
|
||||||
|
|
||||||
# After num_steps, the seq len will be the current seq len
|
|
||||||
# plus one token per step.
|
|
||||||
final_seq_len = seq.get_len() + num_steps
|
|
||||||
|
|
||||||
# We will have final_seq_len - 1 KV because vLLM saves KV for a
|
|
||||||
# token in the iteration after the token was generated.
|
|
||||||
required_num_kv_slots = final_seq_len - 1
|
|
||||||
|
|
||||||
# The allocated number of kv slots is the number of allocated blocks
|
|
||||||
# times the number of slots of block.
|
|
||||||
number_physical_blocks = len(
|
|
||||||
seq_group_metadata.block_tables[seq_id])
|
|
||||||
allocated_kv_slots = (number_physical_blocks *
|
|
||||||
self.model_runner.block_size)
|
|
||||||
|
|
||||||
if required_num_kv_slots > allocated_kv_slots:
|
|
||||||
request_id = seq_group_metadata.request_id
|
|
||||||
raise ValueError(
|
|
||||||
"The worker attempted to run "
|
|
||||||
f"{num_steps} times but found insufficient KV space for "
|
|
||||||
f"{request_id=} {seq_id=}. ({allocated_kv_slots=} "
|
|
||||||
f"{required_num_kv_slots=}).")
|
|
||||||
|
|
||||||
def _raise_if_unsupported(
|
|
||||||
self,
|
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
|
||||||
blocks_to_swap_in: Dict[int, int],
|
|
||||||
blocks_to_swap_out: Dict[int, int],
|
|
||||||
blocks_to_copy: Dict[int, List[int]],
|
|
||||||
) -> None:
|
|
||||||
"""MultiStepWorker does not yet implement support for cache swap
|
|
||||||
operations or beam search.
|
|
||||||
"""
|
|
||||||
if any([blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy]):
|
|
||||||
raise NotImplementedError(
|
|
||||||
"MultiStepWorker does not support cache operations")
|
|
||||||
|
|
||||||
if any(
|
|
||||||
len(seq_group_metadata.seq_data.keys()) != 1
|
|
||||||
for seq_group_metadata in seq_group_metadata_list):
|
|
||||||
raise NotImplementedError(
|
|
||||||
"MultiStepWorker does not support beam search.")
|
|
@ -130,8 +130,8 @@ class Worker:
|
|||||||
# GPU did not change their memory usage during the profiling.
|
# GPU did not change their memory usage during the profiling.
|
||||||
peak_memory = self.init_gpu_memory - free_gpu_memory
|
peak_memory = self.init_gpu_memory - free_gpu_memory
|
||||||
|
|
||||||
cache_block_size = CacheEngine.get_cache_block_size(
|
cache_block_size = self.get_cache_block_size_bytes(
|
||||||
block_size, cache_dtype, self.model_config, self.parallel_config)
|
block_size, cache_dtype)
|
||||||
num_gpu_blocks = int(
|
num_gpu_blocks = int(
|
||||||
(total_gpu_memory * gpu_memory_utilization - peak_memory) //
|
(total_gpu_memory * gpu_memory_utilization - peak_memory) //
|
||||||
cache_block_size)
|
cache_block_size)
|
||||||
@ -232,6 +232,22 @@ class Worker:
|
|||||||
def list_loras(self) -> Set[int]:
|
def list_loras(self) -> Set[int]:
|
||||||
return self.model_runner.list_loras()
|
return self.model_runner.list_loras()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_model_len(self) -> int:
|
||||||
|
return self.model_config.max_model_len
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vocab_size(self) -> int:
|
||||||
|
return self.model_runner.vocab_size
|
||||||
|
|
||||||
|
def get_cache_block_size_bytes(self, block_size: int,
|
||||||
|
cache_dtype: str) -> int:
|
||||||
|
"""Get the size of the KV cache block size in bytes.
|
||||||
|
"""
|
||||||
|
return CacheEngine.get_cache_block_size(block_size, cache_dtype,
|
||||||
|
self.model_config,
|
||||||
|
self.parallel_config)
|
||||||
|
|
||||||
|
|
||||||
def init_distributed_environment(
|
def init_distributed_environment(
|
||||||
parallel_config: ParallelConfig,
|
parallel_config: ParallelConfig,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user