From 8437bae6ef47a690d18c72f0da02c7e5abe83866 Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Fri, 8 Mar 2024 23:32:46 -0800 Subject: [PATCH] [Speculative decoding 3/9] Worker which speculates, scores, and applies rejection sampling (#3103) --- .buildkite/test-pipeline.yaml | 5 +- tests/{worker => }/spec_decode/__init__.py | 0 tests/spec_decode/test_batch_expansion.py | 95 +++ tests/spec_decode/test_metrics.py | 157 +++++ .../spec_decode/test_multi_step_worker.py | 162 ++++- tests/spec_decode/test_spec_decode_worker.py | 591 ++++++++++++++++++ tests/spec_decode/test_utils.py | 111 ++++ tests/{worker => }/spec_decode/utils.py | 115 +++- tests/test_sequence.py | 50 ++ .../layers/rejection_sampler.py | 10 +- vllm/model_executor/layers/sampler.py | 2 +- vllm/sequence.py | 55 +- vllm/spec_decode/batch_expansion.py | 351 +++++++++++ vllm/spec_decode/interfaces.py | 77 +++ vllm/spec_decode/metrics.py | 174 ++++++ vllm/spec_decode/multi_step_worker.py | 366 +++++++++++ vllm/spec_decode/spec_decode_worker.py | 372 +++++++++++ vllm/spec_decode/util.py | 99 +++ vllm/worker/model_runner.py | 11 +- vllm/worker/spec_decode/multi_step_worker.py | 178 ------ vllm/worker/worker.py | 20 +- 21 files changed, 2786 insertions(+), 215 deletions(-) rename tests/{worker => }/spec_decode/__init__.py (100%) create mode 100644 tests/spec_decode/test_batch_expansion.py create mode 100644 tests/spec_decode/test_metrics.py rename tests/{worker => }/spec_decode/test_multi_step_worker.py (61%) create mode 100644 tests/spec_decode/test_spec_decode_worker.py create mode 100644 tests/spec_decode/test_utils.py rename tests/{worker => }/spec_decode/utils.py (60%) create mode 100644 tests/test_sequence.py create mode 100644 vllm/spec_decode/batch_expansion.py create mode 100644 vllm/spec_decode/interfaces.py create mode 100644 vllm/spec_decode/metrics.py create mode 100644 vllm/spec_decode/multi_step_worker.py create mode 100644 vllm/spec_decode/spec_decode_worker.py create mode 100644 vllm/spec_decode/util.py delete mode 100644 vllm/worker/spec_decode/multi_step_worker.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 15f971b6..42a1eacb 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -28,7 +28,7 @@ steps: num_gpus: 2 # only support 1 or 2 for now. - label: Engine Test - command: pytest -v -s engine + command: pytest -v -s engine test_sequence.py - label: Entrypoints Test command: pytest -v -s entrypoints @@ -52,6 +52,9 @@ steps: - label: Worker Test command: pytest -v -s worker +- label: Speculative decoding tests + command: pytest -v -s spec_decode + - label: LoRA Test command: pytest -v -s lora --forked diff --git a/tests/worker/spec_decode/__init__.py b/tests/spec_decode/__init__.py similarity index 100% rename from tests/worker/spec_decode/__init__.py rename to tests/spec_decode/__init__.py diff --git a/tests/spec_decode/test_batch_expansion.py b/tests/spec_decode/test_batch_expansion.py new file mode 100644 index 00000000..fddc3995 --- /dev/null +++ b/tests/spec_decode/test_batch_expansion.py @@ -0,0 +1,95 @@ +import torch +import pytest + +from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer + +from .utils import mock_worker, create_seq_group_metadata_from_prompts + + +@pytest.mark.parametrize('num_target_seq_ids', [100]) +def test_create_target_seq_id_iterator(num_target_seq_ids: int): + """Verify all new sequence ids are greater than all input + seq ids. + """ + scorer = BatchExpansionTop1Scorer(mock_worker(), 'cuda:0', 32_000) + + all_seq_ids = [ + [1, 3, 5, 7], + list(range(100)) + [0], + [100], + ] + + for seq_ids in all_seq_ids: + max_seq_id = max(seq_ids) + iterator = scorer._create_target_seq_id_iterator(seq_ids) # pylint: disable=protected-access + for _ in range(num_target_seq_ids): + assert next(iterator) > max_seq_id + + +@pytest.mark.parametrize('k', [1, 2, 6]) +def test_get_token_ids_to_score(k: int): + """Verify correct tokens are selected for scoring. + """ + proposal_token_ids = torch.tensor( + list(range(k)), + dtype=torch.int64, + device='cuda', + ) + + expected_output = [ + [], + ] + for i in range(proposal_token_ids.shape[0]): + expected_output.append(proposal_token_ids[:i + 1].tolist()) + + scorer = BatchExpansionTop1Scorer(mock_worker(), 'cuda:0', 32_000) + actual_output = scorer._get_token_ids_to_score(proposal_token_ids) # pylint: disable=protected-access + + actual_output = [ + x.tolist() if isinstance(x, torch.Tensor) else x for x in actual_output + ] + + assert actual_output == expected_output + + +@pytest.mark.parametrize('k', [1, 2, 6]) +def test_create_single_target_seq_group_metadata(k: int): + """Verify correct creation of a batch-expanded seq group metadata. + """ + + prompt_tokens = [1, 2, 3] + prev_output_tokens = [4, 5, 6] + + token_ids = list(range(k)) + + num_tokens_processed = len(prompt_tokens) + len(prev_output_tokens) - 1 + + final_seq_len = len(prompt_tokens) + len(prev_output_tokens) + len( + token_ids) + + block_size = 32 + input_seq_group_metadata = create_seq_group_metadata_from_prompts( + [prompt_tokens], 2048 // block_size, block_size, [final_seq_len], + [prev_output_tokens], [num_tokens_processed])[0] + + input_seq_id = list(input_seq_group_metadata.seq_data.keys())[0] + target_seq_id = 100 + + scorer = BatchExpansionTop1Scorer(mock_worker(), 'cuda:0', 32_000) + output = scorer._create_single_target_seq_group_metadata( # pylint: disable=protected-access + input_seq_group_metadata, + input_seq_id, + target_seq_id, + token_ids, + ) + + assert output.request_id == input_seq_group_metadata.request_id + assert len(output.seq_data) == 1 + assert output.seq_data[target_seq_id].get_prompt_token_ids( + ) == prompt_tokens + assert output.seq_data[target_seq_id].get_output_token_ids( + ) == prev_output_tokens + token_ids + + assert len(output.block_tables) == 1 + assert output.block_tables[ + target_seq_id] == input_seq_group_metadata.block_tables[input_seq_id] diff --git a/tests/spec_decode/test_metrics.py b/tests/spec_decode/test_metrics.py new file mode 100644 index 00000000..941ea37a --- /dev/null +++ b/tests/spec_decode/test_metrics.py @@ -0,0 +1,157 @@ +import torch +import math +import pytest + +from unittest.mock import MagicMock + +from vllm.spec_decode.metrics import AsyncMetricsCollector + + +def test_initial_call_returns_none(): + """Expect first call to get metrics to return None. + """ + rej_sampler = MagicMock() + rej_sampler.num_accepted_tokens = torch.tensor(0, + dtype=torch.long, + device='cuda') + rej_sampler.num_emitted_tokens = torch.tensor(0, + dtype=torch.long, + device='cuda') + rej_sampler.num_draft_tokens = 0 + + collector = AsyncMetricsCollector(rej_sampler) + collector.init_gpu_tensors(rank=0) + maybe_metrics = collector.maybe_collect_rejsample_metrics(k=5) + assert maybe_metrics is None + + +def test_second_call_returns_metrics(): + """Expect second call to not return None. + """ + rej_sampler = MagicMock() + rej_sampler.num_accepted_tokens = torch.tensor(0, + dtype=torch.long, + device='cuda') + rej_sampler.num_emitted_tokens = torch.tensor(0, + dtype=torch.long, + device='cuda') + rej_sampler.num_draft_tokens = 0 + + collect_interval_s = 5.0 + timer = MagicMock() + timer.side_effect = [ + 0.0, collect_interval_s + 0.1, collect_interval_s + 0.2 + ] + + collector = AsyncMetricsCollector(rejection_sampler=rej_sampler, + timer=timer, + collect_interval_s=collect_interval_s) + collector.init_gpu_tensors(rank=0) + _ = collector.maybe_collect_rejsample_metrics(k=5) + metrics = collector.maybe_collect_rejsample_metrics(k=5) + assert metrics is not None + + +@pytest.mark.parametrize("rank", [1, 2, 3, 4]) +def test_nonzero_rank_noop(rank): + """Verify nonzero ranks don't collect metrics. + """ + rej_sampler = MagicMock() + rej_sampler.num_accepted_tokens = torch.tensor(0, + dtype=torch.long, + device='cuda') + rej_sampler.num_emitted_tokens = torch.tensor(0, + dtype=torch.long, + device='cuda') + rej_sampler.num_draft_tokens = 0 + + collector = AsyncMetricsCollector(rej_sampler) + collector.init_gpu_tensors(rank=rank) + _ = collector.maybe_collect_rejsample_metrics(k=5) + metrics = collector.maybe_collect_rejsample_metrics(k=5) + assert metrics is None + + +def test_noop_until_time(): + """Verify metrics aren't collected until enough time passes. + """ + rej_sampler = MagicMock() + rej_sampler.num_accepted_tokens = torch.tensor(0, + dtype=torch.long, + device='cuda') + rej_sampler.num_emitted_tokens = torch.tensor(0, + dtype=torch.long, + device='cuda') + rej_sampler.num_draft_tokens = 0 + + collect_interval_s = 5.0 + timer = MagicMock() + timer.side_effect = [ + 0.0, collect_interval_s - 0.1, collect_interval_s - 0.1, + collect_interval_s + 0.1, collect_interval_s + 0.1 + ] + + collector = AsyncMetricsCollector(rejection_sampler=rej_sampler, + timer=timer, + collect_interval_s=collect_interval_s) + collector.init_gpu_tensors(rank=0) + + _ = collector.maybe_collect_rejsample_metrics(k=5) + metrics = collector.maybe_collect_rejsample_metrics(k=5) + assert metrics is None + + _ = collector.maybe_collect_rejsample_metrics(k=5) + metrics = collector.maybe_collect_rejsample_metrics(k=5) + assert metrics is not None + + +@pytest.mark.parametrize("has_data", [True, False]) +def test_initial_metrics_has_correct_values(has_data: bool): + """Test correctness of metrics data. + """ + if has_data: + num_accepted_tokens = 103 + num_emitted_tokens = 104 + num_draft_tokens = 105 + else: + num_accepted_tokens = 0 + num_emitted_tokens = 0 + num_draft_tokens = 0 + k = 5 + + num_possible_tokens = AsyncMetricsCollector.get_max_num_accepted_tokens( + num_draft_tokens, k) + + rej_sampler = MagicMock() + rej_sampler.num_accepted_tokens = torch.tensor(num_accepted_tokens, + dtype=torch.long, + device='cuda') + rej_sampler.num_emitted_tokens = torch.tensor(num_emitted_tokens, + dtype=torch.long, + device='cuda') + rej_sampler.num_draft_tokens = num_draft_tokens + + collect_interval_s = 5.0 + timer = MagicMock() + timer.side_effect = [ + 0.0, collect_interval_s + 0.1, collect_interval_s + 0.2 + ] + + collector = AsyncMetricsCollector(rejection_sampler=rej_sampler, + timer=timer, + collect_interval_s=collect_interval_s) + collector.init_gpu_tensors(rank=0) + _ = collector.maybe_collect_rejsample_metrics(k) + metrics = collector.maybe_collect_rejsample_metrics(k) + + assert metrics.num_spec_tokens == k + assert metrics.accepted_tokens == num_accepted_tokens + assert metrics.draft_tokens == num_draft_tokens + assert metrics.emitted_tokens == num_emitted_tokens + + if has_data: + assert metrics.draft_acceptance_rate == num_accepted_tokens / num_draft_tokens + assert metrics.system_efficiency == num_emitted_tokens / num_possible_tokens + else: + assert math.isnan(metrics.draft_acceptance_rate) + assert math.isnan(metrics.system_efficiency) diff --git a/tests/worker/spec_decode/test_multi_step_worker.py b/tests/spec_decode/test_multi_step_worker.py similarity index 61% rename from tests/worker/spec_decode/test_multi_step_worker.py rename to tests/spec_decode/test_multi_step_worker.py index ea548029..88bb7c29 100644 --- a/tests/worker/spec_decode/test_multi_step_worker.py +++ b/tests/spec_decode/test_multi_step_worker.py @@ -3,14 +3,15 @@ import random import pytest from unittest.mock import MagicMock -from vllm.worker.spec_decode.multi_step_worker import MultiStepWorker +from vllm.spec_decode.multi_step_worker import MultiStepWorker, DraftModelTop1Proposer from vllm.worker.worker import Worker from vllm.model_executor.utils import set_random_seed +from vllm.sequence import SamplerOutput from .utils import (create_execute_model_data, create_worker, create_seq_group_metadata_from_prompts, zero_kv_cache, patch_execute_model_with_seeds, - assert_logprobs_dict_allclose) + assert_logprobs_dict_allclose, create_batch) @pytest.mark.parametrize('num_steps', list(range(1, 17))) @@ -259,3 +260,160 @@ def test_same_output_for_multi_step(): multi_step_output_logprobs, single_step_output_logprobs): assert_logprobs_dict_allclose(multi_step_logprobs, single_step_logprobs) + + +@torch.inference_mode() +def test_draft_proposals_full_speculation_len(): + """Verify DraftModelTop1Proposer correctly handles case where all sequences + can speculate. + """ + k = 10 + batch_size = 32 + vocab_size = 32_000 + device = 'cuda:0' + + draft_worker = MagicMock() + proposer = DraftModelTop1Proposer( + draft_worker=draft_worker, + device=device, + max_model_len=2048, + vocab_size=vocab_size, + ) + draft_worker.execute_model_multi_step.return_value = [ + SamplerOutput( + outputs=[], + sampled_token_probs=torch.rand(batch_size, + vocab_size, + device=device, + dtype=torch.float32), + sampled_token_ids=torch.randint(low=0, + high=vocab_size, + size=(batch_size, ), + device=device, + dtype=torch.long), + ) for _ in range(k) + ] + + execute_model_data, _, _ = create_batch(batch_size, k) + + proposals = proposer.get_proposals( + **execute_model_data.to_dict(), + max_proposal_len=k, + ) + + assert torch.is_tensor(proposals.proposal_token_ids) + assert torch.is_tensor(proposals.proposal_probs) + + assert proposals.proposal_token_ids.shape == torch.Size([batch_size, k]) + assert proposals.proposal_probs.shape[:-1] == torch.Size([batch_size, k]) + + assert proposals.proposal_lens.shape == torch.Size([batch_size]) + assert proposals.proposal_lens.tolist() == [k for _ in range(batch_size)] + + +@torch.inference_mode() +def test_draft_proposals_no_speculations(): + """Verify DraftModelTop1Proposer correctly handles case where no sequences + can speculate. + """ + k = 10 + batch_size = 32 + vocab_size = 32_000 + device = 'cuda:0' + prompt_len = 10 + + draft_worker = MagicMock() + proposer = DraftModelTop1Proposer( + draft_worker=draft_worker, + device=device, + max_model_len=prompt_len + k - 1, + vocab_size=vocab_size, + ) + + execute_model_data, _, _ = create_batch(batch_size, + k, + prompt_len=prompt_len) + + proposals = proposer.get_proposals( + **execute_model_data.to_dict(), + max_proposal_len=k, + ) + + assert torch.is_tensor(proposals.proposal_token_ids) + assert torch.is_tensor(proposals.proposal_probs) + + assert proposals.proposal_token_ids.shape == torch.Size([0, k]) + assert proposals.proposal_probs.shape[:-1] == torch.Size([0, k]) + + assert proposals.proposal_lens.shape == torch.Size([batch_size]) + assert proposals.proposal_lens.tolist() == [0 for _ in range(batch_size)] + + +@torch.inference_mode() +def test_draft_proposals_mixed_k(): + """Verify DraftModelTop1Proposer correctly handles case some sequences can + speculate and some can't. + """ + k = 10 + batch_size = 32 + vocab_size = 32_000 + device = 'cuda:0' + + small_prompt_len = 5 + long_prompt_len = 10 + prev_output_token_len = 20 + + expected_num_proposal_seqs = 6 + expected_num_no_proposal_seqs = batch_size - expected_num_proposal_seqs + + prompt_len = [ + small_prompt_len for _ in range(expected_num_proposal_seqs - 1) + ] + [long_prompt_len + for _ in range(expected_num_no_proposal_seqs)] + [small_prompt_len] + + draft_worker = MagicMock() + proposer = DraftModelTop1Proposer( + draft_worker=draft_worker, + device=device, + max_model_len=long_prompt_len + prev_output_token_len + k - 1, + vocab_size=vocab_size, + ) + + draft_worker.execute_model_multi_step.return_value = [ + SamplerOutput( + outputs=[], + sampled_token_probs=torch.rand(expected_num_proposal_seqs, + vocab_size, + device=device, + dtype=torch.float32), + sampled_token_ids=torch.randint( + low=0, + high=vocab_size, + size=(expected_num_proposal_seqs, ), + device=device, + dtype=torch.long), + ) for _ in range(k) + ] + + execute_model_data, _, _ = create_batch( + batch_size, + k, + prompt_len=prompt_len, + prev_output_token_len=prev_output_token_len, + ) + + proposals = proposer.get_proposals( + **execute_model_data.to_dict(), + max_proposal_len=k, + ) + + assert torch.is_tensor(proposals.proposal_token_ids) + assert torch.is_tensor(proposals.proposal_probs) + + assert proposals.proposal_token_ids.shape == torch.Size([batch_size, k]) + assert proposals.proposal_probs.shape[:-1] == torch.Size([batch_size, k]) + + assert proposals.proposal_lens.shape == torch.Size([batch_size]) + assert proposals.proposal_lens.tolist() == [ + k for _ in range(expected_num_proposal_seqs - 1) + ] + [0 for _ in range(expected_num_no_proposal_seqs)] + [k] diff --git a/tests/spec_decode/test_spec_decode_worker.py b/tests/spec_decode/test_spec_decode_worker.py new file mode 100644 index 00000000..e919711c --- /dev/null +++ b/tests/spec_decode/test_spec_decode_worker.py @@ -0,0 +1,591 @@ +import torch +import random +import pytest +from unittest.mock import MagicMock + +from vllm.spec_decode.multi_step_worker import MultiStepWorker +from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker, split_num_cache_blocks_evenly +from vllm.spec_decode.interfaces import SpeculativeProposals +from vllm.model_executor.utils import set_random_seed +from vllm.model_executor.layers.rejection_sampler import RejectionSampler +from .utils import mock_worker, create_batch, ExecuteModelData, create_sampler_output_list +from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics, AsyncMetricsCollector + + +@pytest.mark.parametrize('k', [1, 2, 6]) +@pytest.mark.parametrize('batch_size', [1, 2, 32]) +@torch.inference_mode() +def test_correctly_calls_draft_model(k: int, batch_size: int): + """Verify SpecDecodeWorker calls the draft worker with correct + inputs. Everything else is mocked out. + """ + draft_worker = mock_worker(cls=MultiStepWorker) + target_worker = mock_worker() + rejection_sampler = MagicMock(spec=RejectionSampler) + metrics_collector = MagicMock(spec=AsyncMetricsCollector) + worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, + metrics_collector) + + exception_secret = 'artifical stop' + draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret) + + execute_model_data, _, _ = create_batch(batch_size, k) + + with pytest.raises(ValueError, match=exception_secret): + worker.execute_model(**execute_model_data.to_dict(), num_spec_tokens=k) + + call_args_list = draft_worker.get_spec_proposals.call_args_list + assert len(call_args_list) == 1 + + for args, _ in call_args_list: + (seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out, + blocks_to_copy, actual_k) = args + actual_execute_model_data = ExecuteModelData(seq_group_metadata_list, + blocks_to_swap_in, + blocks_to_swap_out, + blocks_to_copy) + assert actual_execute_model_data == execute_model_data + assert actual_k == k + + +@pytest.mark.parametrize('k', [1, 2, 6]) +@pytest.mark.parametrize('batch_size', [1, 2, 32]) +@torch.inference_mode() +def test_correctly_calls_target_model(k: int, batch_size: int): + """Verify SpecDecodeWorker calls the target model with correct + inputs. Everything else is mocked out. + """ + draft_worker = mock_worker(cls=MultiStepWorker) + target_worker = mock_worker() + rejection_sampler = MagicMock(spec=RejectionSampler) + rejection_sampler.token_id_dtype = torch.int64 + metrics_collector = MagicMock(spec=AsyncMetricsCollector) + + draft_worker.device = 'cuda' + target_worker.device = 'cuda' + + set_random_seed(1) + + worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, + metrics_collector) + worker.init_model() + + vocab_size = 32_000 + + proposal_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, k), + dtype=torch.int64, + device='cuda') + proposal_probs = torch.rand(batch_size, + k, + vocab_size, + dtype=torch.float32, + device='cuda') + proposal_lens = torch.ones(batch_size, dtype=torch.int64, + device='cuda') * k + + execute_model_data, prompts, prev_output_tokens = create_batch( + batch_size, k) + + draft_worker.get_spec_proposals.return_value = SpeculativeProposals( + proposal_token_ids=proposal_token_ids, + proposal_probs=proposal_probs, + proposal_lens=proposal_lens) + + exception_secret = 'artifical stop' + target_worker.execute_model.side_effect = ValueError(exception_secret) + + with pytest.raises(ValueError, match=exception_secret): + worker.execute_model(**execute_model_data.to_dict(), num_spec_tokens=k) + + seen_contexts = [] + + call_args_list = target_worker.execute_model.call_args_list + assert len(call_args_list) == 1 + for args, kwargs in call_args_list: + target_execute_model_data = ExecuteModelData.from_dict(kwargs) + + assert len(target_execute_model_data.seq_group_metadata_list) == ( + k + 1) * batch_size + for seq_group_metadata in ( + target_execute_model_data.seq_group_metadata_list): + for seq_data in seq_group_metadata.seq_data.values(): + seen_contexts.append(seq_data.get_token_ids()) + + expected_seen_contexts = [] + + for prompt, prev_generated, draft_tokens in zip( + prompts, prev_output_tokens, proposal_token_ids.tolist()): + + for i in range(len(draft_tokens) + 1): + expected_seen_contexts.append(prompt + prev_generated + + draft_tokens[:i]) + + seen_contexts.sort() + expected_seen_contexts.sort() + assert expected_seen_contexts == seen_contexts + + +@pytest.mark.parametrize('k', [1, 2, 6]) +@pytest.mark.parametrize('batch_size', [1, 2, 32]) +@torch.inference_mode() +def test_correctly_calls_rejection_sampler(k: int, batch_size: int): + """Verify SpecDecodeWorker calls the rejection sampler with + correct inputs. Everything else is mocked out. + """ + vocab_size = 32_000 + + draft_worker = mock_worker(cls=MultiStepWorker, vocab_size=vocab_size) + target_worker = mock_worker(vocab_size=vocab_size) + rejection_sampler = MagicMock(spec=RejectionSampler) + rejection_sampler.token_id_dtype = torch.int64 + metrics_collector = MagicMock(spec=AsyncMetricsCollector) + draft_worker.device = 'cuda' + target_worker.device = 'cuda' + + set_random_seed(1) + + worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, + metrics_collector) + worker.init_model() + + proposal_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, k), + dtype=torch.int64, + device='cuda') + proposal_probs = torch.rand(batch_size, + k, + vocab_size, + dtype=torch.float32, + device='cuda') + + proposal_lens = torch.ones(batch_size, dtype=torch.int64, + device='cuda') * k + + execute_model_data, _, _ = create_batch(batch_size, k) + + draft_worker.get_spec_proposals.return_value = SpeculativeProposals( + proposal_token_ids=proposal_token_ids, + proposal_probs=proposal_probs, + proposal_lens=proposal_lens) + + target_token_ids = torch.randint(low=0, + high=vocab_size, + size=(1, batch_size * (k + 1)), + dtype=torch.int64, + device='cuda') + target_token_probs = torch.rand(1, + batch_size * (k + 1), + vocab_size, + dtype=torch.float32, + device='cuda') + target_output = create_sampler_output_list(target_token_ids, + target_token_probs) + + target_worker.execute_model.return_value = target_output[0] + + exception_secret = 'artifical stop' + rejection_sampler.side_effect = ValueError(exception_secret) + + with pytest.raises(ValueError, match=exception_secret): + worker.execute_model(**execute_model_data.to_dict(), num_spec_tokens=k) + + assert len(rejection_sampler.call_args_list) == 1 + args, _ = rejection_sampler.call_args_list[0] + (actual_proposal_scores, actual_bonus_token_ids, actual_proposal_probs, + actual_proposal_token_ids) = args + + assert torch.equal(actual_bonus_token_ids, + target_token_ids.reshape(batch_size, k + 1)[:, -1:]) + assert torch.equal( + actual_proposal_scores, + target_token_probs.reshape(batch_size, k + 1, -1)[:, :-1]) + assert torch.equal(actual_proposal_token_ids, proposal_token_ids) + assert torch.equal(actual_proposal_probs, proposal_probs) + + +@pytest.mark.parametrize('k', [1, 2, 6]) +@pytest.mark.parametrize('batch_size', [1, 2, 32]) +@torch.inference_mode() +def test_correctly_formats_output(k: int, batch_size: int): + """Verify SpecDecodeWorker formats sampler output correctly. + Everything else is mocked out. + """ + vocab_size = 32_000 + + draft_worker = mock_worker(cls=MultiStepWorker, vocab_size=vocab_size) + target_worker = mock_worker(vocab_size=vocab_size) + rejection_sampler = MagicMock(spec=RejectionSampler) + rejection_sampler.token_id_dtype = torch.int64 + metrics_collector = MagicMock(spec=AsyncMetricsCollector) + draft_worker.device = 'cuda' + target_worker.device = 'cuda' + + set_random_seed(1) + + worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, + metrics_collector) + worker.init_model() + + proposal_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, k), + dtype=torch.int64, + device='cuda') + proposal_probs = torch.rand(batch_size, + k, + vocab_size, + dtype=torch.float32, + device='cuda') + + proposal_lens = torch.ones(batch_size, dtype=torch.int64, + device='cuda') * k + + execute_model_data, _, _ = create_batch(batch_size, k) + + draft_worker.get_spec_proposals.return_value = SpeculativeProposals( + proposal_token_ids=proposal_token_ids, + proposal_probs=proposal_probs, + proposal_lens=proposal_lens) + + target_token_ids = torch.randint(low=0, + high=vocab_size, + size=(1, batch_size * (k + 1)), + dtype=torch.int64, + device='cuda') + target_token_probs = torch.rand(1, + batch_size * (k + 1), + vocab_size, + dtype=torch.float32, + device='cuda') + target_output = create_sampler_output_list(target_token_ids, + target_token_probs) + + target_worker.execute_model.return_value = target_output[0] + + rejection_sampler_output = torch.randint(low=0, + high=vocab_size, + size=(batch_size, k + 1), + dtype=torch.int64, + device='cuda') + for i in range(batch_size): + minimum_accepted_tokens = 1 + rejection_sampler_output[i][ + -random.randint(minimum_accepted_tokens, k + 1):] = -1 + + rejection_sampler.return_value = rejection_sampler_output + + output = worker.execute_model(**execute_model_data.to_dict(), + num_spec_tokens=k) + + expected_output = create_sampler_output_list( + rejection_sampler_output.transpose(0, 1), [None for _ in range(k + 1)]) + + seq_ids = [ + next(iter(seq_group_metadata.seq_data.keys())) + for seq_group_metadata in execute_model_data.seq_group_metadata_list + ] + actual_output_by_seq = {seq_id: [] for seq_id in seq_ids} + expected_output_by_seq = {seq_id: [] for seq_id in seq_ids} + + for step in output: + for seq_group in step: + for sample in seq_group.samples: + seq_id = sample.parent_seq_id + actual_output_by_seq[seq_id].append(sample) + + for step in expected_output: + for seq_group in step: + for sample in seq_group.samples: + seq_id = sample.parent_seq_id + expected_output_by_seq[seq_id].append(sample) + + all_seen_seq_ids = set( + list(actual_output_by_seq.keys()) + + list(expected_output_by_seq.keys())) + for seq_id in all_seen_seq_ids: + actual_by_step = actual_output_by_seq[seq_id] + expected_by_step = expected_output_by_seq[seq_id] + + for i in range(k + 1): + if i >= len(actual_by_step): + assert expected_by_step[i].output_token == -1 + continue + assert actual_by_step[i].output_token == expected_by_step[ + i].output_token + assert actual_by_step[i].logprobs == expected_by_step[i].logprobs + + +@pytest.mark.parametrize('k', [1, 2]) +@pytest.mark.parametrize('batch_size', [1]) +@pytest.mark.parametrize('returns_metrics', [True, False]) +@torch.inference_mode() +def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool): + """Verify SpecDecodeWorker collects metrics. + """ + vocab_size = 32_000 + + draft_worker = mock_worker(cls=MultiStepWorker, vocab_size=vocab_size) + target_worker = mock_worker(vocab_size=vocab_size) + rejection_sampler = MagicMock(spec=RejectionSampler) + rejection_sampler.token_id_dtype = torch.int64 + metrics_collector = MagicMock(spec=AsyncMetricsCollector) + draft_worker.device = 'cuda' + target_worker.device = 'cuda' + + set_random_seed(1) + + worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, + metrics_collector) + worker.init_model() + + proposal_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, k), + dtype=torch.int64, + device='cuda') + proposal_probs = torch.rand(batch_size, + k, + vocab_size, + dtype=torch.float32, + device='cuda') + + proposal_lens = torch.ones(batch_size, dtype=torch.int64, + device='cuda') * k + + execute_model_data, _, _ = create_batch(batch_size, k) + + draft_worker.get_spec_proposals.return_value = SpeculativeProposals( + proposal_token_ids=proposal_token_ids, + proposal_probs=proposal_probs, + proposal_lens=proposal_lens) + + target_token_ids = torch.randint(low=0, + high=vocab_size, + size=(1, batch_size * (k + 1)), + dtype=torch.int64, + device='cuda') + target_token_probs = torch.rand(1, + batch_size * (k + 1), + vocab_size, + dtype=torch.float32, + device='cuda') + target_output = create_sampler_output_list(target_token_ids, + target_token_probs) + + target_worker.execute_model.return_value = target_output[0] + + rejection_sampler_output = torch.randint(low=0, + high=vocab_size, + size=(batch_size, k + 1), + dtype=torch.int64, + device='cuda') + for i in range(batch_size): + minimum_accepted_tokens = 1 + rejection_sampler_output[i][ + -random.randint(minimum_accepted_tokens, k + 1):] = -1 + + rejection_sampler.return_value = rejection_sampler_output + + mock_rejsample_metrics = MagicMock( + spec=SpecDecodeWorkerMetrics) if returns_metrics else None + metrics_collector.maybe_collect_rejsample_metrics.return_value = mock_rejsample_metrics + + output = worker.execute_model(**execute_model_data.to_dict(), + num_spec_tokens=k) + assert output[0].spec_decode_worker_metrics == mock_rejsample_metrics + + call_args_list = metrics_collector.maybe_collect_rejsample_metrics.call_args_list + assert len(call_args_list) == 1 + args, kwargs = call_args_list[0] + assert args[0] == k or kwargs.get('k', -1) == k + + +@pytest.mark.parametrize('k', [0]) +@pytest.mark.parametrize('batch_size', [1, 2, 32]) +@torch.inference_mode() +def test_k_equals_zero(k: int, batch_size: int): + """Verify that the SpecDecodeWorker calls the draft and target workers + when k is zero. This happens during prefill. + """ + draft_worker = mock_worker(cls=MultiStepWorker) + target_worker = mock_worker() + rejection_sampler = MagicMock(spec=RejectionSampler) + rejection_sampler.token_id_dtype = torch.int64 + metrics_collector = MagicMock(spec=AsyncMetricsCollector) + + draft_worker.device = 'cuda' + target_worker.device = 'cuda' + + set_random_seed(1) + + worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, + metrics_collector) + + execute_model_data, prompts, prev_output_tokens = create_batch( + batch_size, k, prev_output_token_len=0) + + out = worker.execute_model(**execute_model_data.to_dict(), + num_spec_tokens=k) + + assert len(out) == 1, f"expected only one token output when {k=}" + assert out[0].probs is None, "expect gpu tensor references to be None" + assert out[ + 0].sampled_tokens is None, "expect gpu tensor references to be None" + + draft_worker.execute_model.assert_called_once_with( + **execute_model_data.to_dict(), return_python_output=False) + target_worker.execute_model.assert_called_once_with( + **execute_model_data.to_dict()) + + +@pytest.mark.parametrize('k', [0, 5]) +@pytest.mark.parametrize('batch_size', [0]) +@torch.inference_mode() +def test_empty_input_batch(k: int, batch_size: int): + """Verify that the SpecDecodeWorker calls the draft and target workers + when the input batch is empty. This can happen if the engine communicates + to the workers information without scheduling a batch. + """ + draft_worker = mock_worker(cls=MultiStepWorker) + target_worker = mock_worker() + rejection_sampler = MagicMock(spec=RejectionSampler) + rejection_sampler.token_id_dtype = torch.int64 + metrics_collector = MagicMock(spec=AsyncMetricsCollector) + + draft_worker.device = 'cuda' + target_worker.device = 'cuda' + + set_random_seed(1) + + worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, + metrics_collector) + + execute_model_data, prompts, prev_output_tokens = create_batch( + batch_size, k, prev_output_token_len=0) + + out = worker.execute_model(**execute_model_data.to_dict(), + num_spec_tokens=k) + + assert len(out) == 1, f"expected only one token output when {k=}" + assert out[0].probs is None, "expect gpu tensor references to be None" + assert out[ + 0].sampled_tokens is None, "expect gpu tensor references to be None" + + draft_worker.execute_model.assert_called_once_with( + **execute_model_data.to_dict(), return_python_output=False) + target_worker.execute_model.assert_called_once_with( + **execute_model_data.to_dict()) + + +@torch.inference_mode() +def test_init_model(): + """Verify SpecDecodeWorker invokes proposer/scorer worker init_model, as + well as other GPU initialization. + """ + draft_worker = mock_worker(cls=MultiStepWorker) + target_worker = mock_worker() + rejection_sampler = MagicMock(spec=RejectionSampler) + rejection_sampler.token_id_dtype = torch.int64 + metrics_collector = MagicMock(spec=AsyncMetricsCollector) + + worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, + metrics_collector) + + worker.init_model() + + draft_worker.init_model.assert_called_once() + + target_worker.init_model.assert_called_once() + + metrics_collector.init_gpu_tensors.assert_called_once() + rejection_sampler.init_gpu_tensors.assert_called_once() + + +@torch.inference_mode() +def test_init_cache_engine(): + """Verify SpecDecodeWorker invokes init_cache_engine on proposer/scorer + workers. + """ + draft_worker = mock_worker(cls=MultiStepWorker) + target_worker = mock_worker() + rejection_sampler = MagicMock(spec=RejectionSampler) + rejection_sampler.token_id_dtype = torch.int64 + metrics_collector = MagicMock(spec=AsyncMetricsCollector) + + worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, + metrics_collector) + + cache_config = MagicMock() + + worker.init_cache_engine(cache_config) + + draft_worker.init_cache_engine.assert_called_once_with(cache_config) + target_worker.init_cache_engine.assert_called_once_with(cache_config) + + +@pytest.mark.parametrize('available_gpu_blocks', [1, 1024]) +@pytest.mark.parametrize('available_cpu_blocks', [500]) +@pytest.mark.parametrize('target_cache_block_size_bytes', [2 * 2 * 4096]) +@pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096]) +@torch.inference_mode() +def test_profile_num_available_blocks(available_gpu_blocks: int, + available_cpu_blocks: int, + target_cache_block_size_bytes: int, + draft_kv_size_bytes: int): + """Verify SpecDecodeWorker correctly profiles num available GPU blocks. + Specifically, it should run profiling in the scorer worker, and then evenly + split the blocks between proposer and scorer worker. + """ + draft_worker = mock_worker(cls=MultiStepWorker) + target_worker = mock_worker() + rejection_sampler = MagicMock(spec=RejectionSampler) + rejection_sampler.token_id_dtype = torch.int64 + metrics_collector = MagicMock(spec=AsyncMetricsCollector) + + target_worker.profile_num_available_blocks.return_value = ( + available_gpu_blocks, available_cpu_blocks) + target_worker.get_cache_block_size_bytes.return_value = target_cache_block_size_bytes + draft_worker.get_cache_block_size_bytes.return_value = draft_kv_size_bytes + + worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, + metrics_collector) + + # These values do not directly impact the adjusted block size calculation, + # so they can be fixed. + gpu_memory_utilization = 0.9 + cpu_swap_space = 100 + block_size = 16 + + num_gpu_blocks, num_cpu_blocks = worker.profile_num_available_blocks( + block_size, gpu_memory_utilization, cpu_swap_space, cache_dtype="auto") + + target_worker.profile_num_available_blocks.assert_called_once_with( + block_size, gpu_memory_utilization, cpu_swap_space, "auto") + assert num_cpu_blocks == available_cpu_blocks + + assert num_gpu_blocks == split_num_cache_blocks_evenly( + target_cache_block_size_bytes, draft_kv_size_bytes, + available_gpu_blocks) + + +@pytest.mark.parametrize('available_gpu_blocks', + list(range(20)) + [1024, 1024**2]) +@pytest.mark.parametrize('target_cache_block_size_bytes', + [2 * 2 * 4096, 2 * 2 * 8192]) +@pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096]) +@torch.inference_mode() +def test_split_num_cache_blocks_evenly(available_gpu_blocks: int, + target_cache_block_size_bytes: int, + draft_kv_size_bytes: int): + """Verify split_num_cache_blocks_evenly does not exceed original memory + allocation in bytes. + """ + num_blocks = split_num_cache_blocks_evenly(target_cache_block_size_bytes, + draft_kv_size_bytes, + available_gpu_blocks) + assert (num_blocks * target_cache_block_size_bytes) + ( + num_blocks * draft_kv_size_bytes) <= (available_gpu_blocks * + target_cache_block_size_bytes) diff --git a/tests/spec_decode/test_utils.py b/tests/spec_decode/test_utils.py new file mode 100644 index 00000000..19833ddb --- /dev/null +++ b/tests/spec_decode/test_utils.py @@ -0,0 +1,111 @@ +from vllm.spec_decode.util import get_all_seq_ids +from vllm.sequence import SequenceGroupMetadata +from vllm.spec_decode.util import split_batch_by_proposal_len + +import pytest +from unittest.mock import MagicMock + + +def test_get_all_seq_ids(): + """Verify get_all_seq_ids extracts all seq ids. + """ + expected_seq_ids = list(range(10)) + list(range(100, 110)) + + seq_group_metadata_list = [ + SequenceGroupMetadata( + request_id=str(seq_id), + is_prompt=True, + seq_data={ + seq_id: MagicMock(), + }, + sampling_params=MagicMock(), + block_tables={ + seq_id: MagicMock(), + }, + lora_request=None, + ) for seq_id in expected_seq_ids + ] + + actual_seq_ids = get_all_seq_ids(seq_group_metadata_list) + assert actual_seq_ids == expected_seq_ids + + +@pytest.fixture +def fake_sequence_group_metadata(): + seq_ids = list(range(3)) + return [ + SequenceGroupMetadata( + request_id=str(i), + is_prompt=True, + seq_data={ + i: MagicMock(), + }, + sampling_params=MagicMock(), + block_tables={ + i: MagicMock(), + }, + lora_request=None, + ) for i in seq_ids + ] + + +def test_filter_zero_length_proposals(fake_sequence_group_metadata): + proposal_lens = [0, 1, 0] + filtered_groups, indices = split_batch_by_proposal_len( + fake_sequence_group_metadata, + proposal_lens, + select_proposal_len_zero=True) + + expected_groups = [ + fake_sequence_group_metadata[0], fake_sequence_group_metadata[2] + ] + expected_indices = [0, 2] + + assert filtered_groups == expected_groups + assert indices == expected_indices + + +def test_filter_non_zero_length_proposals(fake_sequence_group_metadata): + proposal_lens = [0, 1, 2] + filtered_groups, indices = split_batch_by_proposal_len( + fake_sequence_group_metadata, + proposal_lens, + select_proposal_len_zero=False) + + expected_groups = [ + fake_sequence_group_metadata[1], fake_sequence_group_metadata[2] + ] + expected_indices = [1, 2] + + assert filtered_groups == expected_groups + assert indices == expected_indices + + +def test_empty_inputs(): + filtered_groups, indices = split_batch_by_proposal_len( + [], [], select_proposal_len_zero=True) + + assert filtered_groups == [] + assert indices == [] + + +def test_all_zero_with_non_zero_filter(fake_sequence_group_metadata): + proposal_lens = [0, 0, 0] + filtered_groups, indices = split_batch_by_proposal_len( + fake_sequence_group_metadata, + proposal_lens, + select_proposal_len_zero=False) + + assert filtered_groups == [] + assert indices == [] + + +def test_all_non_zero_with_zero_filter(fake_sequence_group_metadata): + proposal_lens = [1, 1, 1] + filtered_groups, indices = split_batch_by_proposal_len( + fake_sequence_group_metadata, + proposal_lens, + select_proposal_len_zero=True) + + assert filtered_groups == [] + assert indices == [] diff --git a/tests/worker/spec_decode/utils.py b/tests/spec_decode/utils.py similarity index 60% rename from tests/worker/spec_decode/utils.py rename to tests/spec_decode/utils.py index fa8767cf..99709398 100644 --- a/tests/worker/spec_decode/utils.py +++ b/tests/spec_decode/utils.py @@ -1,13 +1,16 @@ import torch -from typing import List, Optional, Dict +from typing import List, Optional, Dict, Iterable, Union +from unittest.mock import MagicMock from vllm.worker.worker import Worker from vllm.utils import get_distributed_init_method, get_ip, get_open_port from vllm.engine.arg_utils import EngineArgs -from vllm.sequence import Logprob, SequenceGroupMetadata, SequenceData +from vllm.sequence import (Logprob, SequenceGroupMetadata, SequenceData, + SamplerOutput, SequenceGroupOutput, SequenceOutput) from vllm.sampling_params import SamplingParams from vllm.worker.cache_engine import CacheEngine from vllm.model_executor.utils import set_random_seed +from itertools import count from dataclasses import dataclass, fields @@ -24,6 +27,11 @@ class ExecuteModelData: return dict( (field.name, getattr(self, field.name)) for field in fields(self)) + @classmethod + def from_dict(cls, d): + cleaned = dict((field.name, d[field.name]) for field in fields(cls)) + return cls(**cleaned) + def round_up_to_next_block(seq_len: int, block_size: int) -> int: return (seq_len + block_size - 1) // block_size @@ -50,6 +58,21 @@ def create_execute_model_data( ) +def mock_worker(cls=None, + vocab_size: int = 30_000, + max_model_len: int = 2048, + rank: int = 0) -> MagicMock: + if cls is None: + cls = Worker + + worker = MagicMock(spec=cls) + worker.vocab_size = vocab_size + worker.max_model_len = max_model_len + worker.rank = rank + worker.device = 'cuda:0' + return worker + + def patch_execute_model_with_seeds(worker: Worker, rand_seeds: List[int]): seed_iter = iter(rand_seeds) original_execute_model = worker.execute_model @@ -117,25 +140,12 @@ def create_seq_group_metadata_from_prompts( block_size: int, final_seq_lens: List[int], continuations: Optional[List[List[int]]] = None, - num_tokens_processed: Optional[List[int]] = None, seq_ids: Optional[List[int]] = None, ) -> List[SequenceGroupMetadata]: if continuations is None: continuations = [[] for _ in prompts] - if num_tokens_processed is None: - # Default to 1 token missing from kv cache for generation sequences. - num_tokens_processed = [] - for continuation, prompt in zip(continuations, prompts): - # If prefill, then default to zero tokens processed. - if not continuation: - num_tokens_processed.append(0) - else: - # If generation, then default to all but one tokens processed. - num_tokens_processed.append( - len(continuation) + len(prompt) - 1) - if seq_ids is None: seq_ids = list(i for i, _ in enumerate(prompts)) @@ -155,13 +165,15 @@ def create_seq_group_metadata_from_prompts( is_prompt=len(cont_token_ids) == 0, seq_data={ i: - SequenceData(prompt_token_ids=prompt_token_ids[:] + - cont_token_ids[:]) + SequenceData( + prompt_token_ids=prompt_token_ids[:], + output_token_ids=cont_token_ids[:], + ), }, sampling_params=SamplingParams(temperature=0.0, ), block_tables={i: block_allocations[i][:]}, - ) for i, (prompt_token_ids, cont_token_ids, num_tokens_saved) in - enumerate(zip(prompts, continuations, num_tokens_processed)) + ) for i, (prompt_token_ids, + cont_token_ids) in enumerate(zip(prompts, continuations)) ] @@ -178,3 +190,68 @@ def assert_logprobs_dict_allclose( expected = torch.tensor( single_step_expected_logprobs[token_id].logprob) assert torch.allclose(actual, expected) + + +def create_sampler_output_list( + token_ids: torch.Tensor, + probs: Iterable[Optional[torch.Tensor]], + seq_ids: Optional[List[int]] = None) -> List[SamplerOutput]: + num_steps, batch_size = token_ids.shape + token_ids_by_step = token_ids.tolist() + + if seq_ids is None: + seq_ids = list(range(batch_size)) + + return [ + SamplerOutput(outputs=[ + SequenceGroupOutput( + samples=[ + SequenceOutput( + output_token=token_id, + parent_seq_id=seq_ids[seq_index], + logprobs={token_id: 0}, + ) + ], + prompt_logprobs=None, + ) for seq_index, token_id in enumerate(token_ids_by_step[step]) + ], + sampled_token_probs=probs[step], + sampled_token_ids=token_ids[step]) + for step in range(num_steps) + ] + + +def create_batch(batch_size, + k, + prompt_len: Union[int, List[int]] = 10, + prev_output_token_len: int = 10, + seq_ids: Optional[List[int]] = None, + num_gpu_blocks: Optional[int] = None, + block_size: Optional[int] = None): + if block_size is None: + block_size = 8 + + if num_gpu_blocks is None: + num_gpu_blocks = 2048 // block_size + + iterator = count() + + if isinstance(prompt_len, int): + prompt_lens = [prompt_len for _ in range(batch_size)] + else: + prompt_lens = prompt_len + + prompts = [[next(iterator) for _ in range(p_len)] for p_len in prompt_lens] + prev_output_tokens = [[ + next(iterator) for _ in range(prev_output_token_len) + ] for _ in range(batch_size)] + final_seq_lens = [ + len(prompt) + len(prev_output_token) + k + 1 + for prompt, prev_output_token in zip(prompts, prev_output_tokens) + ] + + execute_model_data = create_execute_model_data( + create_seq_group_metadata_from_prompts(prompts, num_gpu_blocks, + block_size, final_seq_lens, + prev_output_tokens, seq_ids), ) + return execute_model_data, prompts, prev_output_tokens diff --git a/tests/test_sequence.py b/tests/test_sequence.py new file mode 100644 index 00000000..e18df059 --- /dev/null +++ b/tests/test_sequence.py @@ -0,0 +1,50 @@ +import pytest + +from vllm.sequence import SequenceGroupOutput, SamplerOutput, SequenceOutput + + +@pytest.fixture +def sample_outputs(): + return [ + SequenceGroupOutput(samples=[ + SequenceOutput(parent_seq_id=0, output_token=i, logprobs={}) + ], + prompt_logprobs=None) for i in range(5) + ] + + +@pytest.fixture +def sampler_output(sample_outputs): + return SamplerOutput(outputs=sample_outputs) + + +def test_sampler_output_initialization(sampler_output, sample_outputs): + assert len(sampler_output) == len(sample_outputs) + assert sampler_output.sampled_token_probs is None + assert sampler_output.sampled_token_ids is None + assert sampler_output.spec_decode_worker_metrics is None + + +def test_sampler_output_getitem(sampler_output, sample_outputs): + assert sampler_output[2] == sample_outputs[2] + + +def test_sampler_output_setitem(sampler_output): + new_output = SequenceGroupOutput(samples=[ + SequenceOutput(parent_seq_id=0, output_token=99, logprobs={}) + ], + prompt_logprobs=None) + sampler_output[2] = new_output + assert sampler_output[2] == new_output + + +def test_sampler_output_len(sampler_output, sample_outputs): + assert len(sampler_output) == len(sample_outputs) + + +def test_sampler_output_eq(sample_outputs): + sampler_output1 = SamplerOutput(outputs=sample_outputs) + sampler_output2 = SamplerOutput(outputs=sample_outputs.copy()) + sampler_output3 = SamplerOutput(outputs=sample_outputs[:-1]) + assert sampler_output1 == sampler_output2 + assert sampler_output1 != sampler_output3 diff --git a/vllm/model_executor/layers/rejection_sampler.py b/vllm/model_executor/layers/rejection_sampler.py index 3e1cfc78..56434540 100644 --- a/vllm/model_executor/layers/rejection_sampler.py +++ b/vllm/model_executor/layers/rejection_sampler.py @@ -21,8 +21,6 @@ class RejectionSampler(nn.Module): nontrivial latency. """ super().__init__() - self.probs_dtype = torch.float32 - self.token_id_dtype = torch.int64 self._strict_mode = strict_mode # NOTE: A "bonus token" is accepted iff all proposal tokens are @@ -44,6 +42,14 @@ class RejectionSampler(nn.Module): dtype=torch.long, device=device) + @property + def probs_dtype(self): + return torch.float32 + + @property + def token_id_dtype(self): + return torch.int64 + def forward( self, target_probs: torch.Tensor, diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 320cb443..19e7f630 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -587,4 +587,4 @@ def _build_sampler_output( SequenceOutput(seq_ids[parent_id], next_token_id, logprobs)) sampler_output.append( SequenceGroupOutput(seq_outputs, group_prompt_logprobs)) - return sampler_output + return SamplerOutput(outputs=sampler_output) diff --git a/vllm/sequence.py b/vllm/sequence.py index fee96a87..37c10240 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -2,12 +2,16 @@ import copy import enum from dataclasses import dataclass -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Union, TYPE_CHECKING from vllm.block import LogicalTokenBlock from vllm.sampling_params import SamplingParams from vllm.lora.request import LoRARequest +if TYPE_CHECKING: + import torch + from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics + @dataclass class Logprob: @@ -81,6 +85,8 @@ class SequenceData: Args: prompt_token_ids: The token IDs of the prompt. + output_token_ids: The token IDs of the output. Set to an empty list if + None. Attributes: prompt_token_ids: The token IDs of the prompt. @@ -91,9 +97,13 @@ class SequenceData: def __init__( self, prompt_token_ids: List[int], + output_token_ids: Optional[List[int]] = None, ) -> None: + if output_token_ids is None: + output_token_ids = [] + self.prompt_token_ids = prompt_token_ids - self.output_token_ids: List[int] = [] + self.output_token_ids = output_token_ids self.cumulative_logprob = 0.0 def append_token_id(self, token_id: int, logprob: float) -> None: @@ -117,6 +127,12 @@ class SequenceData: return self.prompt_token_ids[-1] return self.output_token_ids[-1] + def get_prompt_token_ids(self) -> int: + return self.prompt_token_ids + + def get_output_token_ids(self) -> int: + return self.output_token_ids + def __repr__(self) -> str: return (f"SequenceData(" f"prompt_token_ids={self.prompt_token_ids}, " @@ -506,6 +522,35 @@ class SequenceGroupOutput: and self.prompt_logprobs == other.prompt_logprobs) -# For each sequence group, we generate a list of SequenceOutput object, -# each of which contains one possible candidate for the next token. -SamplerOutput = List[SequenceGroupOutput] +@dataclass +class SamplerOutput: + """For each sequence group, we generate a list of SequenceOutput object, + each of which contains one possible candidate for the next token. + + This datastructure implements methods so it can be used like a list, but + also has optional fields for device tensors. + """ + + outputs: List[SequenceGroupOutput] + + # On-device tensor containing probabilities of each token. + sampled_token_probs: Optional["torch.Tensor"] = None + + # On-device tensor containing the sampled token ids. + sampled_token_ids: Optional["torch.Tensor"] = None + + # Spec decode metrics populated by workers. + spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None + + def __getitem__(self, idx: int): + return self.outputs[idx] + + def __setitem__(self, idx: int, value): + self.outputs[idx] = value + + def __len__(self): + return len(self.outputs) + + def __eq__(self, other: object): + return isinstance(other, + self.__class__) and self.outputs == other.outputs diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py new file mode 100644 index 00000000..478c950f --- /dev/null +++ b/vllm/spec_decode/batch_expansion.py @@ -0,0 +1,351 @@ +from typing import Iterator, List, Tuple, Optional, Dict +from itertools import chain, count + +import torch + +from vllm.sequence import (SamplerOutput, SequenceGroupMetadata, SequenceData) +from vllm.worker.worker import Worker +from vllm.spec_decode.util import nvtx_range, sampler_output_to_torch, get_all_seq_ids, split_batch_by_proposal_len +from vllm.spec_decode.interfaces import SpeculativeScorer, SpeculativeProposals, SpeculativeScores + +SeqId = int +TargetSeqId = int +TokenId = int + + +class BatchExpansionTop1Scorer(SpeculativeScorer): + """Implements a speculative scorer that uses batch expansion to get + probabilities of speculative tokens according to the scoring model. + + Batch expansion converts a list of sequences and multiple query positions + to a new batch of sequences, each with a single query position. This allows + for MQA-like scoring in speculative decoding without requiring an MQA + kernel. + + It is strictly less efficient than MQA scoring. + + It only supports scoring the top1 proposal tokens of the proposer, instead + of topk/tree. + """ + + def __init__(self, scorer_worker: Worker, device: str, vocab_size: int): + self._scorer_worker = scorer_worker + self._device = device + self._vocab_size = vocab_size + + @nvtx_range("BatchExpansionTop1Scorer.score_proposals") + def score_proposals( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + blocks_to_swap_in: Optional[Dict[int, int]], + blocks_to_swap_out: Optional[Dict[int, int]], + blocks_to_copy: Optional[Dict[int, List[int]]], + k: int, + proposals: SpeculativeProposals, + ) -> SpeculativeScores: + """Score the proposed tokens via the scorer model. + + This converts each input sequence to a set of k+1 target sequences. The + target sequences have the unique continuations to be scored and a + unique sequence ID that is different from all input sequence ids. + + If a speculative sequence length would exceed the max model length, then + no speculation is produced for that sequence. + + Args: + seq_group_metadata_list: The input sequence group metadata. + blocks_to_swap_in: This is passed to the worker during scoring. + blocks_to_swap_out: This is passed to the worker during scoring. + blocks_to_copy: This is passed to the worker during scoring. + k: The fixed proposal length. + proposals: The speculative proposals to score. + Returns: + SpeculativeScores: The scores of each speculative token, along with + which sequences were ignored during scoring. + """ + + # TODO(cade) perform this on GPU to remove blocking call. + proposal_lens_list = proposals.proposal_lens.tolist() + proposal_token_ids_list = proposals.proposal_token_ids.tolist() + + spec_indices, non_spec_indices, target_seq_group_metadata_list, num_scoring_tokens = self._expand_batch( + seq_group_metadata_list=seq_group_metadata_list, + proposal_token_ids_list=proposal_token_ids_list, + proposal_lens_list=proposal_lens_list, + ) + + target_sampler_output = self._scorer_worker.execute_model( + seq_group_metadata_list=target_seq_group_metadata_list, + blocks_to_swap_in=blocks_to_swap_in, + blocks_to_swap_out=blocks_to_swap_out, + blocks_to_copy=blocks_to_copy, + return_python_output=False) + + all_tokens, all_probs = self._contract_batch( + original_bs=len(seq_group_metadata_list), + target_sampler_output=target_sampler_output, + proposals=proposals, + num_scoring_tokens=num_scoring_tokens, + non_spec_indices=non_spec_indices, + spec_indices=spec_indices, + k=k, + ) + + return SpeculativeScores( + probs=all_probs, + token_ids=all_tokens, + ) + + def _expand_batch( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + proposal_token_ids_list: List[TokenId], + proposal_lens_list: List[int], + ) -> Tuple[List[int], List[int], List[SequenceGroupMetadata], int]: + """Given the input sequences and potentially multiple corresponding + proposal tokens, create a new batch where each sequence has a single + query token. + """ + + # vLLM currently only supports proposal lens equal to zero or the batch + # proposal len. This adds some complexity (splitting the batch into spec + # and non spec sequences) and should be removed in the future. It can be + # done by supporting per-sequence proposal lens. + spec_seqs, spec_indices = split_batch_by_proposal_len( + seq_group_metadata_list, + proposal_lens_list, + select_proposal_len_zero=False) + non_spec_seqs, non_spec_indices = split_batch_by_proposal_len( + seq_group_metadata_list, + proposal_lens_list, + select_proposal_len_zero=True) + + target_seq_group_metadata_list = self._create_scoring_model_input( + spec_seqs, proposal_token_ids_list) + num_scoring_tokens = len(target_seq_group_metadata_list) + target_seq_group_metadata_list.extend(non_spec_seqs) + + return spec_indices, non_spec_indices, target_seq_group_metadata_list, num_scoring_tokens + + def _contract_batch(self, original_bs: int, + target_sampler_output: List[SamplerOutput], + proposals: SpeculativeProposals, + num_scoring_tokens: int, non_spec_indices: List[int], + spec_indices: List[int], + k: int) -> Tuple[torch.Tensor, torch.Tensor]: + """Contract the expanded batch back into its original size. + This maps the scores of speculative tokens back to their original + sequences. + """ + (target_token_ids, target_probs, non_spec_target_token_ids, + non_spec_target_probs) = self._split_scoring_output( + target_sampler_output, num_scoring_tokens) + + # Map distinct sequences used to score each token + # of shape [batch_size * k + 1] back to [batch_size, k + 1]. + batch_size, k = proposals.proposal_token_ids.shape + + target_token_ids = target_token_ids.squeeze().reshape( + batch_size, k + 1) + target_probs = target_probs.squeeze().reshape(batch_size, k + 1, + self._vocab_size) + + all_tokens = torch.full(size=(original_bs, k + 1), + fill_value=-1, + device=self._device, + dtype=torch.long) + all_probs = torch.zeros(original_bs, + k + 1, + self._vocab_size, + device=self._device, + dtype=torch.float32) + + if non_spec_indices: + all_tokens[non_spec_indices, 0] = non_spec_target_token_ids + all_probs[non_spec_indices, :1, :] = non_spec_target_probs + + if spec_indices: + all_tokens[spec_indices] = target_token_ids + all_probs[spec_indices] = target_probs + + return all_tokens, all_probs + + def _create_scoring_model_input( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k] + ) -> List[SequenceGroupMetadata]: + """Given the original input sequences and proposed tokens from the draft + model, create a list of target sequences that can be used for scoring. + """ + + if not seq_group_metadata_list: + return [] + + target_seq_ids_iter = self._create_target_seq_id_iterator( + get_all_seq_ids(seq_group_metadata_list)) + + target_seq_group_metadata = list( + chain.from_iterable( + self._create_target_seq_group_metadata( + seq_group_metadata, + proposal_token_ids, + i, + target_seq_ids_iter, + ) for i, seq_group_metadata in enumerate( + seq_group_metadata_list))) + + return target_seq_group_metadata + + def _create_target_seq_group_metadata( + self, + input_seq_group_metadata: SequenceGroupMetadata, + proposal_token_ids: List[TokenId], # shape: [batch_size, k] + batch_index: int, + target_seq_ids_iter: Iterator[TargetSeqId], + ) -> List[SequenceGroupMetadata]: + """Given an input sequence group metadata and a list of draft tokens, + create a list of target SequenceGroupMetadata, one for each + token id that needs to be scored. + + Naive speculative decoding requires K target model scores, one for each + draft model token. However one can add a bonus token such that if each + token is accepted, then a final token may be sampled from the model. + This function creates K+1 target SequenceGroupMetadata to take + advantage of the bonus token. + """ + assert not input_seq_group_metadata.is_prompt, ( + "Speculating on " + "prompts not yet supported") + assert len(input_seq_group_metadata.seq_data) == 1, ( + "Beam search " + "not supported in speculative decoding") + input_seq_id = next(iter(input_seq_group_metadata.seq_data.keys())) + + token_ids_to_score = self._get_token_ids_to_score( + proposal_token_ids[batch_index]) + + target_seq_group_metadata_list: List[SequenceGroupMetadata] = [] + for token_ids in token_ids_to_score: + target_seq_group_metadata_list.append( + self._create_single_target_seq_group_metadata( + input_seq_group_metadata, + input_seq_id, + next(target_seq_ids_iter), + token_ids, + )) + + return target_seq_group_metadata_list + + def _create_single_target_seq_group_metadata( + self, + seq_group_metadata: SequenceGroupMetadata, + seq_id: SeqId, + target_seq_id: TargetSeqId, + token_ids: List[TokenId], + ) -> SequenceGroupMetadata: + """Create a single target SequenceGroupMetadata. + + Args: + seq_group_metadata: The metadata for the input sequence. + seq_id: The input sequence ID. + target_seq_id: The corresponding target sequence ID. + token_ids: The list of token ids that are to be appended to the + input sequence. + """ + seq_data = seq_group_metadata.seq_data[seq_id] + prompt_token_ids = seq_data.get_prompt_token_ids() + new_output_token_ids = [*seq_data.get_output_token_ids(), *token_ids] + + return SequenceGroupMetadata( + request_id=seq_group_metadata.request_id, + is_prompt=seq_group_metadata.is_prompt, + seq_data={ + target_seq_id: + SequenceData( + prompt_token_ids=prompt_token_ids, + output_token_ids=new_output_token_ids, + ), + }, + sampling_params=seq_group_metadata.sampling_params, + block_tables={ + target_seq_id: seq_group_metadata.block_tables[seq_id], + }, + lora_request=None, + ) + + def _split_scoring_output( + self, sampler_output: SamplerOutput, num_scoring_tokens: int + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Split the target model output into speculative and non-speculative + output. + """ + + # vLLM currently only supports proposal lens equal to zero or the batch + # proposal len. This adds some complexity (splitting the batch into spec + # and non spec sequences) and should be removed in the future. It can be + # done by supporting per-sequence proposal lens. + # + # First samples are from speculative scoring, latter samples are non- + # speculative samples. + split_sizes = [ + num_scoring_tokens, + sampler_output.sampled_token_ids.numel() - num_scoring_tokens + ] + (spec_probs, non_spec_probs + ) = sampler_output.sampled_token_probs.split(split_sizes) + (spec_sampled_tokens, non_spec_sampled_tokens + ) = sampler_output.sampled_token_ids.flatten().split(split_sizes) + + # Convert scores to tensors. + sampler_output.sampled_token_probs = spec_probs + sampler_output.sampled_token_ids = spec_sampled_tokens + target_token_ids, target_probs = sampler_output_to_torch( + [sampler_output]) + + # Convert non-speculative output tokens to tensors. + sampler_output.sampled_token_probs = non_spec_probs + sampler_output.sampled_token_ids = non_spec_sampled_tokens + non_spec_target_token_ids, non_spec_target_probs = sampler_output_to_torch( + [sampler_output]) + + return target_token_ids, target_probs, non_spec_target_token_ids, non_spec_target_probs + + def _create_target_seq_id_iterator( + self, seq_ids: List[SeqId]) -> Iterator[TargetSeqId]: + """Create an iterator for creating target sequence ids. + Target sequence ids are distinct from sequence ids because we create a + distinct target sequence id for each proposal token to be scored. + + This implementation increments a counter starting at 1 + max of all + provided input sequence ids. + """ + return count(start=max(seq_ids) + 1) + + def _get_token_ids_to_score( + self, + full_spec_token_ids: List[TokenId] # shape: [k] + ) -> List[List[TokenId]]: + """Given an int tensor of proposal token ids, return a list of + token ids that should be scored. + + Returns k+1 output lists. The additional one is used for generating the + bonus token. + + Example: + Input: [0, 1, 2, 3] (k=4) + Output: (k+1 lists) + [] + [0] + [0, 1] + [0, 1, 2] + [0, 1, 2, 3] + """ + empty_token_ids = [] + + token_ids_to_score = [empty_token_ids] + token_ids_to_score.extend([ + full_spec_token_ids[:i + 1] + for i in range(len(full_spec_token_ids)) + ]) + return token_ids_to_score diff --git a/vllm/spec_decode/interfaces.py b/vllm/spec_decode/interfaces.py new file mode 100644 index 00000000..9e53ffb6 --- /dev/null +++ b/vllm/spec_decode/interfaces.py @@ -0,0 +1,77 @@ +from typing import List, Tuple, Optional, Dict +from dataclasses import dataclass +from abc import ABC, abstractmethod + +import torch + +from vllm.sequence import SequenceGroupMetadata + + +@dataclass +class SpeculativeProposals: + """Datastructure used to represent proposal tokens from some proposer. It + also tracks how many speculative tokens each sequence has. + """ + + # Speculative proposal tokens. + proposal_token_ids: torch.Tensor + + # Probabilities of the proposal tokens according to the proposer. + proposal_probs: torch.Tensor + + # The valid length of each proposal; can be zero. + proposal_lens: torch.Tensor + + def __repr__(self): + return (f"SpeculativeProposals(" + f"proposal_token_ids={self.proposal_token_ids.shape}, " + f"proposal_probs={self.proposal_probs.shape}, " + f"proposal_lens={self.proposal_lens.shape})") + + +@dataclass +class SpeculativeScores: + """Datastructure used to represent the scores of speculative tokens + according to the scoring model. + """ + + # Probabilities of the speculative tokens according to the scoring model. + probs: torch.Tensor + + # Token ids sampled from the scoring model. Used for speculative bonus + # tokens and also non-speculative normal decoding. + token_ids: torch.Tensor + + def __repr__(self): + return (f"SpeculativeScores(" + f"probs={self.probs.shape}, " + f"token_ids={self.token_ids.shape})") + + +class SpeculativeProposer(ABC): + + @abstractmethod + def get_proposals( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + blocks_to_swap_in: Dict[int, int], + blocks_to_swap_out: Dict[int, int], + blocks_to_copy: Dict[int, List[int]], + max_proposal_len: int, + ) -> SpeculativeProposals: + raise NotImplementedError + + +class SpeculativeScorer(ABC): + + @abstractmethod + def score_proposals( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + blocks_to_swap_in: Optional[Dict[int, int]], + blocks_to_swap_out: Optional[Dict[int, int]], + blocks_to_copy: Optional[Dict[int, List[int]]], + k: int, + proposals: SpeculativeProposals, + ) -> Tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError diff --git a/vllm/spec_decode/metrics.py b/vllm/spec_decode/metrics.py new file mode 100644 index 00000000..65a2a4a6 --- /dev/null +++ b/vllm/spec_decode/metrics.py @@ -0,0 +1,174 @@ +import torch +from dataclasses import dataclass +from vllm.model_executor.layers.rejection_sampler import RejectionSampler +from typing import Optional +from vllm.utils import in_wsl +import time +from typing import Callable + + +@dataclass +class SpecDecodeWorkerMetrics: + """Dataclass holding metrics emitted from the spec decode worker. + """ + + # The empirical acceptance rate of the proposal method on a per-token basis. + # This is useful for evaluating how well the proposal method aligns with the + # scoring method. + draft_acceptance_rate: float + + # The empirical efficiency, measured as the number of tokens emitted by the + # system divided by the number of tokens that could be emitted by the system + # if the proposal method were perfect. + system_efficiency: float + + # The number of speculative tokens produced by the proposal method. + draft_tokens: int + + # The number of tokens emitted by the entire system. + emitted_tokens: int + + # The number of tokens accepted by the scoring model and verification + # routine, e.g. Llama2-70B and lossless rejection sampling. + # + # NOTE: Any token accepted by the verification routine is considered + # accepted (regardless of if the speculative prefix is also accepted). The + # user will usually see less accepted tokens. This metric is helpful when + # evaluating alignment of the proposal method with the scoring model. + accepted_tokens: int + + # The number of speculative tokens per sequence. + num_spec_tokens: int + + +Timer = Callable[[], float] + + +class AsyncMetricsCollector: + """Class which copies rejection sampler metrics from the device to CPU on a + non-default Torch stream. + """ + + def __init__(self, + rejection_sampler: RejectionSampler, + timer: Optional[Timer] = None, + collect_interval_s: float = 5.0): + self._rejection_sampler = rejection_sampler + self._timer = time.time if timer is None else timer + + self._rank: Optional[int] = None + + # We don't have a device set yet. + self._copy_stream: Optional[torch.cuda.Stream] = None + + self._in_flight_copy: Optional[torch.cuda.Event] = None + + pin_memory = not in_wsl() + self._aggregate_num_accepted_tokens = torch.tensor( + 0, dtype=torch.long, device="cpu", pin_memory=pin_memory) + self._aggregate_num_emitted_tokens = torch.tensor( + 0, dtype=torch.long, device="cpu", pin_memory=pin_memory) + self._aggregate_num_draft_tokens = 0 + + self._rejsample_metrics_collect_interval_s = collect_interval_s + self._last_metrics_collect_time = self._timer() + + def init_gpu_tensors(self, rank: int) -> None: + self._rank = rank + self._copy_stream = torch.cuda.Stream() + + def maybe_collect_rejsample_metrics( + self, k: int) -> Optional[SpecDecodeWorkerMetrics]: + + # If a copy was initiated in the previous call, collect and return. + if self._in_flight_copy is not None: + ready_event = self._in_flight_copy + self._in_flight_copy = None + return self._collect_rejsample_metrics(k, ready_event) + + # Otherwise, check if we should start a new copy. + if self._should_collect_rejsample_metrics(self._timer()): + assert self._in_flight_copy is None + self._in_flight_copy = self._copy_rejsample_metrics_async() + + return None + + def _should_collect_rejsample_metrics(self, now: float) -> bool: + """Return whether or not this iteration should print rejection sampling + metrics. + """ + if self._rank != 0: + return False + + if (now - self._last_metrics_collect_time < + self._rejsample_metrics_collect_interval_s): + return False + return True + + def _copy_rejsample_metrics_async(self) -> torch.cuda.Event: + """Copy rejection sampling metrics (number of accepted tokens, etc) to + CPU asynchronously. + + Returns a CUDA event recording when the copy is complete. + """ + self._copy_stream.wait_stream(torch.cuda.current_stream()) + + with torch.cuda.stream(self._copy_stream): + self._aggregate_num_accepted_tokens.copy_( + self._rejection_sampler.num_accepted_tokens, non_blocking=True) + self._aggregate_num_emitted_tokens.copy_( + self._rejection_sampler.num_emitted_tokens, non_blocking=True) + # Number of draft tokens is calculated on CPU, so no copy is + # required. + self._aggregate_num_draft_tokens = ( + self._rejection_sampler.num_draft_tokens) + + aggregate_metrics_ready = torch.cuda.Event() + aggregate_metrics_ready.record(self._copy_stream) + + return aggregate_metrics_ready + + def _collect_rejsample_metrics( + self, k: int, + ready_event: torch.cuda.Event) -> SpecDecodeWorkerMetrics: + """Create metrics object from statistics copied asynchronously. + + Args: + k: int. The number of speculative tokens; used to determine system + efficiency. + ready_event: torch.cuda.Event. The CUDA event recording when the + async GPU->CPU copy is complete. + """ + + ready_event.synchronize() + accepted_tokens = self._aggregate_num_accepted_tokens.item() + emitted_tokens = self._aggregate_num_emitted_tokens.item() + draft_tokens = self._aggregate_num_draft_tokens + + num_possible_tokens = self.get_max_num_accepted_tokens(draft_tokens, k) + + if draft_tokens > 0: + draft_acceptance_rate = accepted_tokens / draft_tokens + else: + draft_acceptance_rate = float("nan") + + if num_possible_tokens > 0: + system_efficiency = emitted_tokens / num_possible_tokens + else: + system_efficiency = float("nan") + + return SpecDecodeWorkerMetrics( + num_spec_tokens=k, + draft_acceptance_rate=draft_acceptance_rate, + system_efficiency=system_efficiency, + accepted_tokens=accepted_tokens, + draft_tokens=draft_tokens, + emitted_tokens=emitted_tokens, + ) + + @staticmethod + def get_max_num_accepted_tokens(draft_tokens: int, k: int) -> int: + # Divide by k since batch size can be variable. + total_num_spec_seqs = draft_tokens / k + num_accepted_per_seq_if_all_accepted = k + 1 + return int(total_num_spec_seqs / num_accepted_per_seq_if_all_accepted) diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py new file mode 100644 index 00000000..f7be14d3 --- /dev/null +++ b/vllm/spec_decode/multi_step_worker.py @@ -0,0 +1,366 @@ +from typing import List, Dict, Optional, Tuple +import copy + +import torch + +from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.worker.worker import Worker +from vllm.spec_decode.interfaces import SpeculativeProposals, SpeculativeProposer +from vllm.spec_decode.util import sampler_output_to_torch + + +class MultiStepWorker(Worker): + """The MultiStepWorker is equivalent to a Worker except that it allows + multiple forward passes in a single call, assuming the scheduler has + allocated enough space to store the additional KV. This reduces overhead + by invoking the scheduler less. + + The MultiStepWorker does not support cache swap operations, or beam search. + Cache swap operations do not require large modifications. On the other hand, + beam search requires memory allocations during sequence forks and thus + requires more thought for MultiStepWorker support. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self._proposer: Optional[DraftModelTop1Proposer] = None + + def init_model(self): + super().init_model() + + self._proposer = DraftModelTop1Proposer( + self, + self.device, + self.max_model_len, + self.vocab_size, + ) + + @torch.inference_mode() + def execute_model_multi_step( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + blocks_to_swap_in: Dict[int, int], + blocks_to_swap_out: Dict[int, int], + blocks_to_copy: Dict[int, List[int]], + num_steps: int, + ) -> List[SamplerOutput]: + """Run the model forward pass num_steps times. Returns the list of + sampler output, one per model forward pass. + """ + self._raise_if_unsupported(seq_group_metadata_list, blocks_to_swap_in, + blocks_to_swap_out, blocks_to_copy) + + # Shallow copy input data so modifications (such as appending tokens) + # do not cause side-effects. + copied_seq_group_metadata_list = self._shallow_copy_inputs( + seq_group_metadata_list) + + # Assert enough KV space for num_steps tokens per sequence. + self._assert_enough_kv_space(seq_group_metadata_list, num_steps) + + # Run model num_steps times. + model_outputs = [] + for _ in range(num_steps): + model_output = super().execute_model( + seq_group_metadata_list=copied_seq_group_metadata_list, + blocks_to_swap_in=blocks_to_swap_in, + blocks_to_swap_out=blocks_to_swap_out, + blocks_to_copy=blocks_to_copy, + ) + + self._append_new_tokens(model_output, + copied_seq_group_metadata_list) + model_outputs.append(model_output) + + return model_outputs + + def get_spec_proposals( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + blocks_to_swap_in: Dict[int, int], + blocks_to_swap_out: Dict[int, int], + blocks_to_copy: Dict[int, List[int]], + max_proposal_len: int, + ) -> SpeculativeProposals: + """Produce speculations given an input batch of sequences. The number of + speculative tokens per sequence is determined by max_proposal_len. + """ + + return self._proposer.get_proposals( + seq_group_metadata_list, + blocks_to_swap_in, + blocks_to_swap_out, + blocks_to_copy, + max_proposal_len, + ) + + def _append_new_tokens( + self, model_output: SamplerOutput, + seq_group_metadata_list: SequenceGroupMetadata) -> None: + """Given model output from a single run, append the tokens to the + sequences. This is normally done outside of the worker, but it is + required if the worker is to perform multiple forward passes. + """ + for seq_group_metadata, sequence_group_outputs in zip( + seq_group_metadata_list, model_output): + seq_group_metadata.is_prompt = False + + for seq_output in sequence_group_outputs.samples: + # NOTE: Beam search is not supported, so we can assume that + # parent_seq_id == seq_id. + seq = seq_group_metadata.seq_data[seq_output.parent_seq_id] + + token_id = seq_output.output_token + token_logprob = seq_output.logprobs[token_id] + + seq.append_token_id(token_id, token_logprob.logprob) + + def _shallow_copy_inputs( + self, seq_group_metadata_list: List[SequenceGroupMetadata] + ) -> List[SequenceGroupMetadata]: + """Copy input data structures to remove side-effects when input data + structures are shared with other modules. + + Helpful when the vLLM scheduler runs in the same process as the worker. + The alternative is deep-copying (or other form of deep copy); this has + performance downsides. + """ + + # Shallow-copy the list of SequenceGroupMetadata. This allows us to + # append tokens and change is_prompt without external side-effects. + new_seq_group_metadata_list = [] + + for old_seq_group_metadata in seq_group_metadata_list: + # We must shallow-copy seq_group_metadata as is_prompt could change. + seq_group_metadata = copy.copy(old_seq_group_metadata) + new_seq_group_metadata_list.append(seq_group_metadata) + + # We must shallow-copy seq_data as we will append token ids + new_seq_data = {} + for seq_id, old_seq_data in seq_group_metadata.seq_data.items(): + new_seq_data[seq_id] = copy.copy(old_seq_data) + new_seq_data[ + seq_id].output_token_ids = old_seq_data.output_token_ids[:] + + seq_group_metadata.seq_data = new_seq_data + + return new_seq_group_metadata_list + + def _assert_enough_kv_space( + self, seq_group_metadata_list: List[SequenceGroupMetadata], + num_steps: int) -> None: + """Assert there are enough physical blocks per sequence to store the + current KV plus additional KV from num_steps tokens. + """ + assert self.model_runner.block_size is not None + for seq_group_metadata in seq_group_metadata_list: + # Only one seq_id is guaranteed because there is no beam search. + seq_id = list(seq_group_metadata.seq_data.keys())[0] + seq = seq_group_metadata.seq_data[seq_id] + + # After num_steps, the seq len will be the current seq len + # plus one token per step. + final_seq_len = seq.get_len() + num_steps + + # We will have final_seq_len - 1 KV because vLLM saves KV for a + # token in the iteration after the token was generated. + required_num_kv_slots = final_seq_len - 1 + + # The allocated number of kv slots is the number of allocated blocks + # times the number of slots of block. + number_physical_blocks = len( + seq_group_metadata.block_tables[seq_id]) + allocated_kv_slots = (number_physical_blocks * + self.model_runner.block_size) + + if required_num_kv_slots > allocated_kv_slots: + request_id = seq_group_metadata.request_id + raise ValueError( + "The worker attempted to run " + f"{num_steps} times but found insufficient KV space for " + f"{request_id=} {seq_id=}. ({allocated_kv_slots=} " + f"{required_num_kv_slots=}).") + + def _raise_if_unsupported( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + blocks_to_swap_in: Dict[int, int], + blocks_to_swap_out: Dict[int, int], + blocks_to_copy: Dict[int, List[int]], + ) -> None: + """MultiStepWorker does not yet implement support for cache swap + operations or beam search. + """ + if any([blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy]): + raise NotImplementedError( + "MultiStepWorker does not support cache operations") + + if any( + len(seq_group_metadata.seq_data.keys()) != 1 + for seq_group_metadata in seq_group_metadata_list): + raise NotImplementedError( + "MultiStepWorker does not support beam search.") + + +class DraftModelTop1Proposer(SpeculativeProposer): + """Helper class which separates out sequences which would exceed the max + model length when speculated upon. + + This allows combinations of models such as JackFram/llama-68m draft with + meta-llama/Llama2-13b-chat-hf, as llama-68m has max_position_embeddings of + 2048 while Llama2-13b has max_position_embeddings of 4096. + + We treat the sequences which exceed the proposal draft model length as + "non-spec sequences". Essentially they skip the draft model and go through + normal decoding in the target model. + + Currently, only proposal_lens of 0 and k are supported, where k is a global + batch proposal length. In the future vLLM should support per-sequence + proposal lengths. + """ + + def __init__( + self, + draft_worker: MultiStepWorker, + device: str, + max_model_len: int, + vocab_size: int, + ): + self._draft_worker = draft_worker + self._device = device + self._max_model_len = max_model_len + self._vocab_size = vocab_size + + def get_proposals( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + blocks_to_swap_in: Dict[int, int], + blocks_to_swap_out: Dict[int, int], + blocks_to_copy: Dict[int, List[int]], + max_proposal_len: int, + ) -> SpeculativeProposals: + """Get speculative proposals given the input batch. + + Sequences which would exceed the max model length are skipped during + speculation. + """ + + # Split speculative- and non-speculative- sequences. + proposal_lens, nonzero_proposal_len_seqs, nonzero_proposal_len_indices = self._split_by_max_model_len( + seq_group_metadata_list, max_proposal_len) + + if nonzero_proposal_len_seqs: + # Speculate tokens using the draft worker for the speculative + # sequences. + maybe_sampler_output = self._draft_worker.execute_model_multi_step( + seq_group_metadata_list=nonzero_proposal_len_seqs, + blocks_to_swap_in=blocks_to_swap_in, + blocks_to_swap_out=blocks_to_swap_out, + blocks_to_copy=blocks_to_copy, + num_steps=max_proposal_len, + ) + else: + # If no sequences can be speculated, set sampler output to None. + maybe_sampler_output = None + + # Combine speculative- and non-speculative sequences into the same + # representation. + proposal_tokens, proposal_probs, proposal_lens = self._merge_outputs( + batch_size=len(seq_group_metadata_list), + max_proposal_len=max_proposal_len, + maybe_sampler_output=maybe_sampler_output, + proposal_lens=proposal_lens, + nonzero_proposal_len_indices=nonzero_proposal_len_indices, + ) + + proposals = SpeculativeProposals( + proposal_token_ids=proposal_tokens, + proposal_probs=proposal_probs, + proposal_lens=proposal_lens, + ) + + return proposals + + def _split_by_max_model_len( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + max_proposal_len: int, + ) -> Tuple[List[int], List[SequenceGroupMetadata], List[int]]: + """Determine which sequences would exceed the max model length. + """ + + proposal_lens: List[int] = [] + nonzero_proposal_len_seqs: List[SequenceGroupMetadata] = [] + nonzero_proposal_len_indices: List[int] = [] + for i, seq_group_metadata in enumerate(seq_group_metadata_list): + seq_data = next(iter(seq_group_metadata.seq_data.values())) + seq_len = seq_data.get_len() + + # Currently only proposal lens of 0 or the global batch proposal len + # are supported. + if seq_len + max_proposal_len < self._max_model_len: + proposal_lens.append(max_proposal_len) + nonzero_proposal_len_seqs.append(seq_group_metadata) + nonzero_proposal_len_indices.append(i) + else: + proposal_lens.append(0) + + return proposal_lens, nonzero_proposal_len_seqs, nonzero_proposal_len_indices + + def _merge_outputs( + self, + batch_size: int, + max_proposal_len: int, + maybe_sampler_output: Optional[SamplerOutput], + proposal_lens: List[int], + nonzero_proposal_len_indices: List[int], + ) -> Tuple[torch.Tensor, torch.tensor, torch.Tensor]: + """After speculations are produced, merge the speculation results with + the skipped sequences. + """ + if maybe_sampler_output is None: + # If no speculative tokens, the sampler output will be None. + # In this case we return empty tensors. + proposal_tokens = torch.zeros(0, + max_proposal_len, + dtype=torch.long, + device=self._device) + proposal_probs = torch.zeros(0, + max_proposal_len, + self._vocab_size, + dtype=torch.float32, + device=self._device) + proposal_lens = torch.zeros(len(proposal_lens), + dtype=torch.long, + device=self._device) + return proposal_tokens, proposal_probs, proposal_lens + + sampler_output = maybe_sampler_output + + proposal_tokens, proposal_probs = sampler_output_to_torch( + sampler_output) + + # Now, reformat the output GPU tensors such that each sequence has + # a proposal. the proposal can be empty, e.g. [-1, -1, -1] + + entire_proposal_tokens = torch.full(size=(batch_size, + *proposal_tokens.shape[1:]), + fill_value=-1, + dtype=torch.long, + device=self._device) + entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens + entire_proposal_probs = torch.zeros(batch_size, + *proposal_probs.shape[1:], + dtype=torch.float32, + device=self._device) + entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs + + proposal_tokens, proposal_probs = entire_proposal_tokens, entire_proposal_probs + + proposal_lens = torch.zeros(batch_size, + dtype=torch.long, + device=self._device) + proposal_lens[nonzero_proposal_len_indices] = max_proposal_len + + return proposal_tokens, proposal_probs, proposal_lens diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py new file mode 100644 index 00000000..890e4792 --- /dev/null +++ b/vllm/spec_decode/spec_decode_worker.py @@ -0,0 +1,372 @@ +from typing import List, Tuple, Optional, Dict +from functools import cached_property + +import torch + +from vllm.spec_decode.metrics import AsyncMetricsCollector +from vllm.sequence import (SamplerOutput, SequenceGroupMetadata, + SequenceGroupOutput, SequenceOutput) +from vllm.worker.worker import Worker +from vllm.spec_decode.multi_step_worker import MultiStepWorker +from vllm.model_executor.layers.rejection_sampler import RejectionSampler +from vllm.config import CacheConfig +from vllm.spec_decode.util import nvtx_range, get_all_seq_ids, split_batch_by_proposal_len +from vllm.spec_decode.interfaces import SpeculativeProposals, SpeculativeScores +from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer +from vllm.spec_decode.interfaces import SpeculativeScorer + + +class SpecDecodeWorker: + """Worker which implements speculative decoding. + + Speculative decoding reduces decoding per-token latency by using a proposal + method, such as a small draft model, to speculate ahead of a larger LLM. The + probabilities of the speculative tokens are then determined by the larger + LLM, after which some verification routine determines which (if any) of the + speculative tokens are accepted by the larger LLM. + + See https://github.com/vllm-project/vllm/pull/2188 and + https://github.com/vllm-project/vllm/pull/3103 for more info. + + The current implementation has the following limitations: + * Only draft-model proposal is implemented (contributions for more forms are + welcome!). + * Only top-1 proposal and scoring are implemented. Tree-attention is left as + future work. + * Only lossless rejection sampling is supported. Contributions adding lossy + verification routines are welcome (e.g. Medusa's typical acceptance). + * All sequences in a batch must have the same proposal length, or zero. This + can be improved by having per-sequence speculation in the future. + * The scoring forward pass is done without an MQA kernel, which is + suboptimal especially as the batch size, proposal length, and sequence + lengths grow. Contributions to add a MQA scoring are welcome once + correctness tests pass. + More info here https://docs.google.com/document/d/1T-JaS2T1NRfdP51qzqpyakoCXxSXTtORppiwaj5asxA/edit. + """ + + def __init__( + self, + proposer_worker: MultiStepWorker, + scorer_worker: Worker, + rejection_sampler: RejectionSampler, + metrics_collector: Optional[AsyncMetricsCollector] = None, + ): + """ + Create a SpecDecodeWorker. + + Args: + proposer_worker: A worker that can produce speculative tokens for + sequences. + scorer_worker: A worker that produces probabilities of speculative + tokens according to some base model. Typically a vanilla vLLM + Worker. + rejection_sampler: A Torch module used to perform modified rejection + sampling for speculative decoding. + metrics_collector: Helper class for collecting metrics; can be set + for testing purposes. + """ + self.proposer_worker = proposer_worker + self.scorer_worker = scorer_worker + self.rejection_sampler = rejection_sampler + + self._metrics = AsyncMetricsCollector( + rejection_sampler + ) if metrics_collector is None else metrics_collector + + self.probs_dtype = self.rejection_sampler.probs_dtype + self.token_id_dtype = self.rejection_sampler.token_id_dtype + + self.scorer: SpeculativeScorer = None + + def init_model(self) -> None: + """Initialize both scorer and proposer models. + """ + # The scorer worker model is initialized first in case the proposer + # model has a smaller TP degree than the target worker. + self.scorer_worker.init_model() + self.proposer_worker.init_model() + + self._metrics.init_gpu_tensors(self.rank) + self.rejection_sampler.init_gpu_tensors(self.rank) + self.scorer = BatchExpansionTop1Scorer( + scorer_worker=self.scorer_worker, + device=self.device, + vocab_size=self._vocab_size) + + def profile_num_available_blocks(self, block_size: int, + gpu_memory_utilization: float, + cpu_swap_space: int, + cache_dtype: str) -> Tuple[int, int]: + """Determine the number of cache blocks to use. + + This is done by profiling the scorer model (which is typically the + larger of the two). Then the total memory which would be used by the + scorer cache is divided evenly between the proposer and scorer model KV, + such that the number of blocks is equal in both KV caches. + """ + num_gpu_blocks, num_cpu_blocks = ( + self.scorer_worker.profile_num_available_blocks( + block_size, gpu_memory_utilization, cpu_swap_space, + cache_dtype)) + + scorer_cache_block_size_bytes = self.scorer_worker.get_cache_block_size_bytes( + block_size, cache_dtype) + proposer_cache_block_size_bytes = self.proposer_worker.get_cache_block_size_bytes( + block_size, cache_dtype) + + new_num_gpu_blocks = split_num_cache_blocks_evenly( + scorer_cache_block_size_bytes, proposer_cache_block_size_bytes, + num_gpu_blocks) + return new_num_gpu_blocks, num_cpu_blocks + + def init_cache_engine(self, cache_config: CacheConfig): + """Initialize the cache engine of the scorer and proposer workers. + """ + self.scorer_worker.init_cache_engine(cache_config) + self.proposer_worker.init_cache_engine(cache_config) + + @torch.inference_mode() + def execute_model( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + blocks_to_swap_in: Optional[Dict[int, int]], + blocks_to_swap_out: Optional[Dict[int, int]], + blocks_to_copy: Optional[Dict[int, List[int]]], + num_spec_tokens: int, + ) -> List[SamplerOutput]: + """Perform speculative decoding on the input batch. + """ + + assert seq_group_metadata_list is not None, ( + "speculative decoding " + "requires non-None seq_group_metadata_list") + + # If no spec tokens, call the proposer and scorer workers normally. + # Used for prefill. + if num_spec_tokens == 0 or len(seq_group_metadata_list) == 0: + return self._run_no_spec( + seq_group_metadata_list=seq_group_metadata_list, + blocks_to_swap_in=blocks_to_swap_in, + blocks_to_swap_out=blocks_to_swap_out, + blocks_to_copy=blocks_to_copy, + ) + + return self._run_speculative_decoding_step( + seq_group_metadata_list=seq_group_metadata_list, + blocks_to_swap_in=blocks_to_swap_in, + blocks_to_swap_out=blocks_to_swap_out, + blocks_to_copy=blocks_to_copy, + k=num_spec_tokens, + ) + + @nvtx_range("spec_decode_worker._run_no_spec") + def _run_no_spec( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + blocks_to_swap_in: Optional[Dict[int, int]], + blocks_to_swap_out: Optional[Dict[int, int]], + blocks_to_copy: Optional[Dict[int, List[int]]], + ) -> List[SamplerOutput]: + """Run a prefill step, without any speculation. The input is sent to the + proposer and scorer model so that the KV cache is consistent between the + two. + """ + + self.proposer_worker.execute_model( + seq_group_metadata_list=seq_group_metadata_list, + blocks_to_swap_in=blocks_to_swap_in, + blocks_to_swap_out=blocks_to_swap_out, + blocks_to_copy=blocks_to_copy, + return_python_output=False) + + sampler_output = self.scorer_worker.execute_model( + seq_group_metadata_list=seq_group_metadata_list, + blocks_to_swap_in=blocks_to_swap_in, + blocks_to_swap_out=blocks_to_swap_out, + blocks_to_copy=blocks_to_copy, + ) + + # Clear device tensors from sampler output. This reduces communication + # overhead when the engine runs in a different process than the workers. + sampler_output.probs = None + sampler_output.sampled_tokens = None + return [sampler_output] + + @nvtx_range("spec_decode_worker._run_speculative_decoding_step") + def _run_speculative_decoding_step( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + blocks_to_swap_in: Optional[Dict[int, int]], + blocks_to_swap_out: Optional[Dict[int, int]], + blocks_to_copy: Optional[Dict[int, List[int]]], + k: int, + ) -> List[SamplerOutput]: + """Execute a single step of speculative decoding. + + This invokes the proposer worker to get k speculative tokens for each + sequence, then scores each speculative token using the scoring worker. + + Returns a list of SamplerOutput, each containing a single token per + sequence. + """ + + # Generate proposals using draft worker. + proposals = self.proposer_worker.get_spec_proposals( + seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out, + blocks_to_copy, k) + + proposal_scores = self.scorer.score_proposals( + seq_group_metadata_list, + blocks_to_swap_in, + blocks_to_swap_out, + blocks_to_copy, + k, + proposals, + ) + + accepted_token_ids = self._verify_tokens(seq_group_metadata_list, + proposal_scores, proposals, k) + + return self._create_output_sampler_list(seq_group_metadata_list, + accepted_token_ids, k) + + @nvtx_range("spec_decode_worker._verify_tokens") + def _verify_tokens( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + proposal_scores: SpeculativeScores, + proposals: SpeculativeProposals, + max_proposal_len: int, + ) -> torch.Tensor: + """Determine which speculative tokens are accepted using the + probabilities of each token according to the proposer and scorer models. + """ + proposal_lens_list = proposals.proposal_lens.tolist() + + # vLLM currently only supports proposal lens equal to zero or the batch + # proposal len. This adds some complexity (splitting the batch into spec + # and non spec sequences) and should be removed in the future. It can be + # done by supporting per-sequence proposal lens. + _, spec_indices = split_batch_by_proposal_len( + seq_group_metadata_list, + proposal_lens_list, + select_proposal_len_zero=False) + _, non_spec_indices = split_batch_by_proposal_len( + seq_group_metadata_list, + proposal_lens_list, + select_proposal_len_zero=True) + original_indices = spec_indices + non_spec_indices + + proposal_probs = proposal_scores.probs[spec_indices, :-1] + bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:] + non_spec_token_ids = proposal_scores.token_ids[non_spec_indices] + + accepted_token_ids = self.rejection_sampler( + proposal_probs, + bonus_token_ids, + proposals.proposal_probs, + proposals.proposal_token_ids, + ) + + # Append output tokens from non-speculative sequences to + # the accepted token ids tensor. + non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len + + 1).clone() + non_spec_token_ids[:, 1:] = -1 + accepted_token_ids = torch.cat( + [accepted_token_ids, non_spec_token_ids]) + + # Rearrange so that results are in the order of the original seq group + # metadata. + accepted_token_ids[original_indices] = accepted_token_ids.clone() + + return accepted_token_ids + + def _create_output_sampler_list( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + accepted_token_ids: torch.Tensor, # shape: [batch_size, k+1] + k: int, + ) -> List[SamplerOutput]: + """Given the accepted token ids, create a list of SamplerOutput. + + The output is padded with -1 tokens such that each sequence has + the same number of outputs. + """ + seq_ids = get_all_seq_ids(seq_group_metadata_list) + + # shape: [k+1, batch_size] + accepted_token_ids_by_step = accepted_token_ids.transpose(0, + 1).tolist() + sampler_output_list = [] + for token_ids_by_step in accepted_token_ids_by_step: + if all(token_id == -1 for token_id in token_ids_by_step): + break + + step_output_token_ids = [] + for token_id, seq_id in zip(token_ids_by_step, seq_ids): + step_output_token_ids.append( + SequenceGroupOutput( + samples=[ + SequenceOutput( + parent_seq_id=seq_id, + output_token=token_id, + # TODO Add verifier logprobs. + logprobs={token_id: 0.0}, + ) + ], + prompt_logprobs=None, + )) + sampler_output_list.append( + SamplerOutput(outputs=step_output_token_ids)) + + maybe_rejsample_metrics = self._metrics.maybe_collect_rejsample_metrics( + k) + if maybe_rejsample_metrics is not None: + sampler_output_list[ + 0].spec_decode_worker_metrics = maybe_rejsample_metrics + + return sampler_output_list + + @cached_property + def _vocab_size(self) -> int: + """Get the vocab size of the model and make sure it's consistent between + draft and target workers. + """ + vocab_sizes = [ + worker.vocab_size + for worker in [self.proposer_worker, self.scorer_worker] + ] + assert all(vocab_sizes[0] == vocab_size for vocab_size in vocab_sizes) + return vocab_sizes[0] + + @property + def rank(self): + return self.scorer_worker.rank + + @property + def device(self): + return self.scorer_worker.device + + +def split_num_cache_blocks_evenly(scorer_cache_block_size_bytes: int, + proposer_cache_block_size_bytes: int, + total_num_gpu_blocks: int) -> int: + """Given total_num_gpu_blocks, the number of GPU blocks that could be + allocate to the target model, this function calculates how many blocks + should be given to the draft and target model. + + Note that usually the block size, in bytes, of each model is different, + as it's a function of number of KV/layer, number of heads, and hidden + dimension size. + + Since the target and draft models allocate the same number of blocks, we + simply calculate the number of blocks where if allocated by both models, + the total memory usage from KV cache is no larger than the number of + blocks allocatable by the target model alone. + """ + new_num_gpu_blocks = int( + total_num_gpu_blocks * scorer_cache_block_size_bytes / + (proposer_cache_block_size_bytes + scorer_cache_block_size_bytes)) + + return new_num_gpu_blocks diff --git a/vllm/spec_decode/util.py b/vllm/spec_decode/util.py new file mode 100644 index 00000000..2c5f9545 --- /dev/null +++ b/vllm/spec_decode/util.py @@ -0,0 +1,99 @@ +import torch +from typing import List, Tuple +from vllm.sequence import SequenceGroupMetadata, SamplerOutput +from contextlib import contextmanager +from itertools import chain + +SeqId = int + + +def get_all_seq_ids( + seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[SeqId]: + """Given a list of SequenceGroupMetadata, create a list of all + sequence ids. + """ + return list( + chain.from_iterable([ + seq_group_metadata.seq_data.keys() + for seq_group_metadata in seq_group_metadata_list + ])) + + +def split_batch_by_proposal_len( + seq_group_metadata_list: List[SequenceGroupMetadata], + proposal_lens: List[int], select_proposal_len_zero: bool +) -> Tuple[List[SequenceGroupMetadata], List[int]]: + """Utility function that splits a batch based on whether the proposal len is + zero or not. We should remove this once vLLM supports per-sequence proposal + lens in a batch. + """ + + if select_proposal_len_zero: + predicate = lambda proposal_len: proposal_len == 0 + else: + predicate = lambda proposal_len: proposal_len != 0 + + indices = [ + i for i, (_, proposal_len + ) in enumerate(zip(seq_group_metadata_list, proposal_lens)) + if predicate(proposal_len) + ] + seq_groups = [ + seq_group for seq_group, proposal_len in zip( + seq_group_metadata_list, proposal_lens) if predicate(proposal_len) + ] + + return seq_groups, indices + + +def sampler_output_to_torch( + sampler_output_list: List[SamplerOutput], +) -> Tuple[torch.Tensor, torch.Tensor]: + """Utility function which converts a list of SamplerOutput to tensors. + + Returns: + sampled_token_ids: torch.Tensor + shape: [batch_size, len(sampler_output_list)] + + sampled_token_probs: torch.Tensor + shape: [batch_size, len(sampler_output_list), vocab_size] + """ + + # shape: [batch_size, num_sampler_output, vocab_size] + sampled_token_probs = torch.stack( + [ + sampler_output.sampled_token_probs + for sampler_output in sampler_output_list + ], + dim=0, + ).transpose(0, 1) + + # shape: [batch_size, num_sampler_output] + sampled_token_ids = torch.stack( + [ + sampler_output.sampled_token_ids.flatten() + for sampler_output in sampler_output_list + ], + dim=0, + ).transpose(0, 1) + + return sampled_token_ids, sampled_token_probs + + +@contextmanager +def nvtx_range(msg, *args, **kwargs): + """ + Context manager / decorator that pushes an NVTX range at the beginning + of its scope, and pops it at the end. If extra arguments are given, + they are passed as arguments to msg.format(). + + If running with cuda graphs, you must enable nsys cuda graph profiling. + + Arguments: + msg (string): message to associate with the range + """ + torch.cuda.nvtx.range_push(msg.format(*args, **kwargs)) + try: + yield + finally: + torch.cuda.nvtx.range_pop() diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 9023b0c5..0dd23090 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -97,8 +97,6 @@ class ModelRunner: f"Loading model weights took {self.model_memory_usage / float(2**30):.4f} GB" ) - vocab_size = self.model.config.vocab_size - if self.lora_config: assert hasattr( self.model, "supported_lora_modules" @@ -111,7 +109,7 @@ class ModelRunner: self.lora_manager = LRUCacheWorkerLoRAManager( self.scheduler_config.max_num_seqs, self.scheduler_config.max_num_batched_tokens + - self.scheduler_config.max_paddings, vocab_size, + self.scheduler_config.max_paddings, self.vocab_size, self.lora_config, self.device, self.model.embedding_modules, self.model.embedding_padding_modules) self.model = self.lora_manager.create_lora_manager(self.model) @@ -607,8 +605,7 @@ class ModelRunner: @torch.inference_mode() def profile_run(self) -> None: # Enable top-k sampling to reflect the accurate memory usage. - vocab_size = self.model_config.get_vocab_size() - sampling_params = SamplingParams(top_p=0.99, top_k=vocab_size - 1) + sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens max_num_seqs = self.scheduler_config.max_num_seqs @@ -774,6 +771,10 @@ class ModelRunner: self.graph_runners.clear() self.cupy_nccl_backend = None + @property + def vocab_size(self) -> int: + return self.model_config.get_vocab_size() + class CUDAGraphRunner: diff --git a/vllm/worker/spec_decode/multi_step_worker.py b/vllm/worker/spec_decode/multi_step_worker.py deleted file mode 100644 index ab3e2838..00000000 --- a/vllm/worker/spec_decode/multi_step_worker.py +++ /dev/null @@ -1,178 +0,0 @@ -from typing import List, Dict -import copy - -import torch - -from vllm.sequence import SamplerOutput, SequenceGroupMetadata -from vllm.worker.worker import Worker - - -class MultiStepWorker(Worker): - """The MultiStepWorker is equivalent to a Worker except that it allows - multiple forward passes in a single call, assuming the scheduler has - allocated enough space to store the additional KV. This reduces overhead - by invoking the scheduler less. - - The MultiStepWorker does not support cache swap operations, or beam search. - Cache swap operations do not require large modifications. On the other hand, - beam search requires memory allocations during sequence forks and thus - requires more thought for MultiStepWorker support. - """ - - @torch.inference_mode() - def execute_model_multi_step( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], - num_steps: int, - ) -> List[SamplerOutput]: - """Run the model forward pass num_steps times. Returns the list of - sampler output, one per model forward pass. - """ - self._raise_if_unsupported(seq_group_metadata_list, blocks_to_swap_in, - blocks_to_swap_out, blocks_to_copy) - - # Shallow copy input data so modifications (such as appending tokens) - # do not cause side-effects. - copied_seq_group_metadata_list = self._shallow_copy_inputs( - seq_group_metadata_list) - - # Assert enough KV space for num_steps tokens per sequence. - self._assert_enough_kv_space(seq_group_metadata_list, num_steps) - - # Run model num_steps times. - model_outputs = [] - for _ in range(num_steps): - model_output = super().execute_model( - seq_group_metadata_list=copied_seq_group_metadata_list, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, - ) - - self._append_new_tokens(model_output, - copied_seq_group_metadata_list) - model_outputs.append(model_output) - - return model_outputs - - def _append_new_tokens( - self, model_output: SamplerOutput, - seq_group_metadata_list: SequenceGroupMetadata) -> None: - """Given model output from a single run, append the tokens to the - sequences. This is normally done outside of the worker, but it is - required if the worker is to perform multiple forward passes. - """ - for seq_group_metadata, sequence_group_outputs in zip( - seq_group_metadata_list, model_output): - seq_group_metadata.is_prompt = False - - for seq_output in sequence_group_outputs.samples: - # NOTE: Beam search is not supported, so we can assume that - # parent_seq_id == seq_id. - seq = seq_group_metadata.seq_data[seq_output.parent_seq_id] - - token_id = seq_output.output_token - token_logprob = seq_output.logprobs[token_id] - - seq.append_token_id(token_id, token_logprob.logprob) - - def _shallow_copy_inputs( - self, seq_group_metadata_list: List[SequenceGroupMetadata] - ) -> List[SequenceGroupMetadata]: - """Copy input data structures to remove side-effects when input data - structures are shared with other modules. - - The multi-step worker must be able to append tokens to sequences after - a forward pass. This necessitates modification of the data structures - used by the worker. Since these data structures are shared with other - parts of vLLM, like the scheduler, we must take care not to introduce - unexpected side-effects. - - When Ray is used to orchestrate worker processes (such as when the - tensor-parallel degree is >1), this is not a problem because the input - datastructures will be serialized and created anew in the worker - process. - - However, when Ray is not used to orchestrate the worker processes (such - as when the tensor-parallel degree is 1), this is a problem. We avoid - the problem by shallow-copying the input datastructures (specifically, - the parts that will change in multiple steps). - """ - - # Shallow-copy the list of SequenceGroupMetadata. This allows us to - # append tokens and change is_prompt without external side-effects. - new_seq_group_metadata_list = [] - - for old_seq_group_metadata in seq_group_metadata_list: - # We must shallow-copy seq_group_metadata as is_prompt could change. - seq_group_metadata = copy.copy(old_seq_group_metadata) - new_seq_group_metadata_list.append(seq_group_metadata) - - # We must shallow-copy seq_data as we will append token ids - new_seq_data = {} - for seq_id, old_seq_data in seq_group_metadata.seq_data.items(): - new_seq_data[seq_id] = copy.copy(old_seq_data) - new_seq_data[ - seq_id].output_token_ids = old_seq_data.output_token_ids[:] - - seq_group_metadata.seq_data = new_seq_data - - return new_seq_group_metadata_list - - def _assert_enough_kv_space( - self, seq_group_metadata_list: List[SequenceGroupMetadata], - num_steps: int) -> None: - """Assert there are enough physical blocks per sequence to store the - current KV plus additional KV from num_steps tokens. - """ - assert self.model_runner.block_size is not None - for seq_group_metadata in seq_group_metadata_list: - # Only one seq_id is guaranteed because there is no beam search. - seq_id = list(seq_group_metadata.seq_data.keys())[0] - seq = seq_group_metadata.seq_data[seq_id] - - # After num_steps, the seq len will be the current seq len - # plus one token per step. - final_seq_len = seq.get_len() + num_steps - - # We will have final_seq_len - 1 KV because vLLM saves KV for a - # token in the iteration after the token was generated. - required_num_kv_slots = final_seq_len - 1 - - # The allocated number of kv slots is the number of allocated blocks - # times the number of slots of block. - number_physical_blocks = len( - seq_group_metadata.block_tables[seq_id]) - allocated_kv_slots = (number_physical_blocks * - self.model_runner.block_size) - - if required_num_kv_slots > allocated_kv_slots: - request_id = seq_group_metadata.request_id - raise ValueError( - "The worker attempted to run " - f"{num_steps} times but found insufficient KV space for " - f"{request_id=} {seq_id=}. ({allocated_kv_slots=} " - f"{required_num_kv_slots=}).") - - def _raise_if_unsupported( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], - ) -> None: - """MultiStepWorker does not yet implement support for cache swap - operations or beam search. - """ - if any([blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy]): - raise NotImplementedError( - "MultiStepWorker does not support cache operations") - - if any( - len(seq_group_metadata.seq_data.keys()) != 1 - for seq_group_metadata in seq_group_metadata_list): - raise NotImplementedError( - "MultiStepWorker does not support beam search.") diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 157e8c45..0dcd4018 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -130,8 +130,8 @@ class Worker: # GPU did not change their memory usage during the profiling. peak_memory = self.init_gpu_memory - free_gpu_memory - cache_block_size = CacheEngine.get_cache_block_size( - block_size, cache_dtype, self.model_config, self.parallel_config) + cache_block_size = self.get_cache_block_size_bytes( + block_size, cache_dtype) num_gpu_blocks = int( (total_gpu_memory * gpu_memory_utilization - peak_memory) // cache_block_size) @@ -232,6 +232,22 @@ class Worker: def list_loras(self) -> Set[int]: return self.model_runner.list_loras() + @property + def max_model_len(self) -> int: + return self.model_config.max_model_len + + @property + def vocab_size(self) -> int: + return self.model_runner.vocab_size + + def get_cache_block_size_bytes(self, block_size: int, + cache_dtype: str) -> int: + """Get the size of the KV cache block size in bytes. + """ + return CacheEngine.get_cache_block_size(block_size, cache_dtype, + self.model_config, + self.parallel_config) + def init_distributed_environment( parallel_config: ParallelConfig,