[Speculative Decoding] Enabling bonus token in speculative decoding for KV cache based models (#5765)

This commit is contained in:
sroy745 2024-07-10 16:02:47 -07:00 committed by GitHub
parent 44cc76610d
commit ae151d73be
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 645 additions and 80 deletions

View File

@ -70,14 +70,17 @@ def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int,
if queue_size < disable_by_batch_size: if queue_size < disable_by_batch_size:
# Should raise exception when executing the mocked draft model. # Should raise exception when executing the mocked draft model.
with pytest.raises(ValueError, match=exception_secret): with pytest.raises(ValueError, match=exception_secret):
proposer.get_spec_proposals(execute_model_req=ExecuteModelRequest( proposer.get_spec_proposals(
seq_group_metadata_list=seq_group_metadata_list, execute_model_req=ExecuteModelRequest(
num_lookahead_slots=k), ) seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k),
seq_ids_with_bonus_token_in_last_step=set())
else: else:
# Should not execute the draft model because spec decode is disabled # Should not execute the draft model because spec decode is disabled
# for all requests. Accordingly, the proposal length should be 0. # for all requests. Accordingly, the proposal length should be 0.
proposals = proposer.get_spec_proposals( proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest( execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k), ) num_lookahead_slots=k),
seq_ids_with_bonus_token_in_last_step=set())
assert proposals.proposal_lens.tolist() == [0] * batch_size assert proposals.proposal_lens.tolist() == [0] * batch_size

View File

@ -118,7 +118,8 @@ def test_same_output_for_single_step():
actual_output, _ = multi_step_worker.sampler_output( actual_output, _ = multi_step_worker.sampler_output(
execute_model_req=ExecuteModelRequest( execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=multi_step_seq_group), seq_group_metadata_list=multi_step_seq_group),
sample_len=num_steps) sample_len=num_steps,
seq_ids_with_bonus_token_in_last_step=set())
assert len(actual_output) == num_steps assert len(actual_output) == num_steps
actual_output = actual_output[0] actual_output = actual_output[0]
@ -210,7 +211,8 @@ def test_same_output_for_multi_step():
multi_step_output, _ = multi_step_worker.sampler_output( multi_step_output, _ = multi_step_worker.sampler_output(
execute_model_req=ExecuteModelRequest( execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list), seq_group_metadata_list=seq_group_metadata_list),
sample_len=num_steps) sample_len=num_steps,
seq_ids_with_bonus_token_in_last_step=set())
# Run single-step repeatedly. # Run single-step repeatedly.
zero_kv_cache(worker.cache_engine) zero_kv_cache(worker.cache_engine)
@ -277,6 +279,203 @@ def test_same_output_for_multi_step():
single_step_logprobs) single_step_logprobs)
@torch.inference_mode()
def test_multi_step_with_batch_expansion_correct_output():
"""
In this test we verify that the MultiStepWorker is able to handle bonus
tokens correctly. The test verifies that if a sequence has a
bonus token then the MultiStepWorker is able to expand the batch by adding
new sequences corresponding to the sequences with bonus tokens. The
expanded batch is then used for predicting the next tokens.
"""
seed = 100
model_name = 'JackFram/llama-68m'
block_size = 16
num_gpu_blocks = 2048 // block_size
batch_size = 128
multi_step_worker = create_worker(
MultiStepWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
model_runner_cls=TP1DraftModelRunner,
)
worker = create_worker(
Worker,
model_name,
block_size,
num_gpu_blocks,
seed,
)
random.seed(seed)
prompts = [[0] for _ in range(batch_size)]
num_steps = 2
final_prompt_lens = [(num_steps + 1) for prompt in prompts]
rand_seeds = list(random.randint(0, 100) for _ in range(num_steps))
multi_step_worker.execute_model = patch_execute_model_with_seeds(
multi_step_worker, rand_seeds)
worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds)
# Create the test continuations
continuations = [[random.randint(0, 1000)] for _ in prompts]
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=continuations,
final_prompt_lens=final_prompt_lens)
# Run single-step twice to generate 2 tokens. This
# will simulate the bonus token case with the second token
# being the bonus token.
zero_kv_cache(worker.cache_engine)
single_step_output: List[SamplerOutput] = []
set_random_seed(seed)
for _ in range(num_steps):
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=continuations,
final_prompt_lens=final_prompt_lens)
single_step_output.extend(
worker.execute_model(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list)))
# Append output tokens to new sequence data.
for i, seq_group_output in enumerate(single_step_output[-1]):
continuations[i].append(seq_group_output.samples[0].output_token)
# Create continuations for the MultiStepWorker. The continuations have
# 2 tokens in order to simulate the bonus token case.
multi_step_continuations = []
for continuation in continuations:
multi_step_continuations.append(continuation[:2])
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=multi_step_continuations,
final_prompt_lens=final_prompt_lens)
# Run multi-step and verify that the third token prediction is accurate
# for all sequences.
zero_kv_cache(multi_step_worker.cache_engine)
all_seq_ids = {i for i in range(batch_size)}
multi_step_output, _ = multi_step_worker.sampler_output(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list),
sample_len=1,
seq_ids_with_bonus_token_in_last_step=all_seq_ids)
for index, output in enumerate(multi_step_output[-1].outputs):
assert (continuations[index][-1] == output.samples[0].output_token)
@torch.inference_mode()
def test_multi_step_with_batch_expansion_incorrect_output():
"""
Tests the MultiStepWorker's ability to handle batch expansion with bonus
tokens in a negative case scenario. This test provides the MultiStepWorker
with a batch containing sequences with bonus tokens but specifies the
sequence IDs with bonus tokens incorrectly. The test verifies that the
MultiStepWorker generates correct tokens for the sequences where the
sequence ID is specified correctly and incorrect tokens for those where
the sequence ID is specified incorrectly.
"""
seed = 100
model_name = 'JackFram/llama-68m'
block_size = 16
num_gpu_blocks = 2048 // block_size
batch_size = 128
multi_step_worker = create_worker(
MultiStepWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
model_runner_cls=TP1DraftModelRunner,
)
worker = create_worker(
Worker,
model_name,
block_size,
num_gpu_blocks,
seed,
)
random.seed(seed)
prompts = [[0] for _ in range(batch_size)]
num_steps = 2
final_prompt_lens = [(num_steps + 1) for prompt in prompts]
rand_seeds = list(random.randint(0, 100) for _ in range(num_steps))
multi_step_worker.execute_model = patch_execute_model_with_seeds(
multi_step_worker, rand_seeds)
worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds)
# Create the test continuations
continuations = [[random.randint(0, 1000)] for _ in prompts]
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=continuations,
final_prompt_lens=final_prompt_lens)
# Run single-step twice to generate 2 tokens. This
# will simulate the bonus token case with the second token
# being the bonus token.
zero_kv_cache(worker.cache_engine)
single_step_output: List[SamplerOutput] = []
set_random_seed(seed)
for _ in range(num_steps):
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=continuations,
final_prompt_lens=final_prompt_lens)
single_step_output.extend(
worker.execute_model(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list)))
# Append output tokens to new sequence data.
for i, seq_group_output in enumerate(single_step_output[-1]):
continuations[i].append(seq_group_output.samples[0].output_token)
# Create continuations for the MultiStepWorker. The continuations have
# 2 tokens in order to simulate the bonus token case.
multi_step_continuations = []
for continuation in continuations:
multi_step_continuations.append(continuation[:2])
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=multi_step_continuations,
final_prompt_lens=final_prompt_lens)
# Run multi-step. In this run INCORRECTLY specify that only the odd number
# sequences have bonus tokens. Verify that with this setting the third token
# prediction is accurate only for the odd numbered sequences. Also verify
# that the prediction might be wrong for some of the even numbered
# sequences.
zero_kv_cache(multi_step_worker.cache_engine)
set_random_seed(seed)
odd_seq_ids = {i for i in range(batch_size) if i % 2 != 0}
multi_step_output, _ = multi_step_worker.sampler_output(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list),
sample_len=1,
seq_ids_with_bonus_token_in_last_step=odd_seq_ids)
num_mismatch = 0
for index, output in enumerate(multi_step_output[-1].outputs):
if (index % 2) != 0:
assert (continuations[index][-1] == output.samples[0].output_token)
elif (continuations[index][-1] != output.samples[0].output_token):
num_mismatch += 1
# The prediction is accurate for some of the sequences even without proper
# handling of the bonus tokens. Hence verify that the number of sequences
# for which there is a mismatch is > 0.
assert (num_mismatch > 0)
@torch.inference_mode() @torch.inference_mode()
def test_draft_proposals_full_speculation_len(): def test_draft_proposals_full_speculation_len():
"""Verify Top1Proposer correctly handles case where all sequences """Verify Top1Proposer correctly handles case where all sequences
@ -318,7 +517,8 @@ def test_draft_proposals_full_speculation_len():
proposals = proposer.get_spec_proposals( proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest( execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k), ) num_lookahead_slots=k),
seq_ids_with_bonus_token_in_last_step=set())
assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs) assert torch.is_tensor(proposals.proposal_probs)
@ -356,7 +556,8 @@ def test_draft_proposals_no_speculations():
proposals = proposer.get_spec_proposals( proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest( execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k), ) num_lookahead_slots=k),
seq_ids_with_bonus_token_in_last_step=set())
assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs) assert torch.is_tensor(proposals.proposal_probs)
@ -428,7 +629,8 @@ def test_draft_proposals_mixed_k():
proposals = proposer.get_spec_proposals( proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest( execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k), ) num_lookahead_slots=k),
seq_ids_with_bonus_token_in_last_step=set())
assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs) assert torch.is_tensor(proposals.proposal_probs)

View File

@ -53,7 +53,8 @@ def test_ngram_algo_correctness_for_single_no_match():
proposals = proposer.get_spec_proposals( proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest( execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=proposal_len), ) num_lookahead_slots=proposal_len),
seq_ids_with_bonus_token_in_last_step=None)
assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs) assert torch.is_tensor(proposals.proposal_probs)
@ -121,7 +122,8 @@ def test_ngram_algo_correctness_for_batches_not_match_all():
proposals = proposer.get_spec_proposals( proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest( execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=proposal_len), ) num_lookahead_slots=proposal_len),
seq_ids_with_bonus_token_in_last_step=None)
assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs) assert torch.is_tensor(proposals.proposal_probs)
@ -193,7 +195,8 @@ def test_ngram_algo_correctness_for_batches_match_all():
proposals = proposer.get_spec_proposals( proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest( execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=proposal_len), ) num_lookahead_slots=proposal_len),
seq_ids_with_bonus_token_in_last_step=None)
assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs) assert torch.is_tensor(proposals.proposal_probs)

View File

@ -1,6 +1,7 @@
import random import random
from collections import defaultdict
from types import SimpleNamespace from types import SimpleNamespace
from typing import Dict, List from typing import Dict, List, Set
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest import pytest
@ -377,8 +378,10 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool,
set_random_seed(1) set_random_seed(1)
worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler, worker = SpecDecodeWorker(draft_worker,
metrics_collector) target_worker,
spec_decode_sampler,
metrics_collector=metrics_collector)
worker.init_device() worker.init_device()
proposal_token_ids = torch.randint(low=0, proposal_token_ids = torch.randint(low=0,
@ -554,7 +557,6 @@ def test_init_device(acceptance_sampler_method: str):
worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler, worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler,
metrics_collector) metrics_collector)
worker.init_device() worker.init_device()
draft_worker.init_device.assert_called_once() draft_worker.init_device.assert_called_once()
@ -645,3 +647,140 @@ def test_split_num_cache_blocks_evenly(available_gpu_blocks: int,
assert (num_blocks * target_cache_block_size_bytes) + ( assert (num_blocks * target_cache_block_size_bytes) + (
num_blocks * draft_kv_size_bytes) <= (available_gpu_blocks * num_blocks * draft_kv_size_bytes) <= (available_gpu_blocks *
target_cache_block_size_bytes) target_cache_block_size_bytes)
@torch.inference_mode()
def test_populate_seq_ids_with_bonus_tokens():
"""
Verify that a call to _create_output_sampler_list correctly updates
seq_with_bonus_token_in_last_step.
seq_with_bonus_token_in_last_step is an internal data structure in
SpecDecodeWorker that tracks the sequence IDs which are assigned bonus
tokens by the target model in their last forward pass. This state is
maintained only for models relying on the KV cache, such as those using
the MultiStepWorker.
"""
batch_size = 10
k = 5
vocab_size = 10000
num_sequences_with_bonus_tokens = 5
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
target_worker.execute_model.return_value = [MagicMock(spec=SamplerOutput)]
target_worker.device = 'cuda'
set_random_seed(1)
draft_worker = mock_worker(cls=MultiStepWorker)
draft_worker.device = 'cuda'
# The sequence_ids attached to each sequence in the batch.
# The sequence at index i has seq_id assigned_seq_ids[i]
assigned_seq_ids = list(range(batch_size))
seq_group_metadata_list, _, _ = create_batch(batch_size,
k,
seq_ids=assigned_seq_ids,
prev_output_token_len=10)
target_token_logprobs = torch.rand(batch_size, (k + 1),
vocab_size,
dtype=torch.float32,
device='cuda')
accepted_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, (k + 1)),
dtype=torch.int64,
device='cuda')
expected_request_id_seq_ids_mapping: Dict[str, Set[int]] = defaultdict(set)
for seq_group_metadata in seq_group_metadata_list:
for seq_id in seq_group_metadata.seq_data:
expected_request_id_seq_ids_mapping[
seq_group_metadata.request_id].add(seq_id)
# Generate a random sample of sequence indexes with bonus tokens
seq_indexes_with_bonus_tokens = random.sample(
range(batch_size), num_sequences_with_bonus_tokens)
# Create a mask that is True for indices in seq_indexes_with_bonus_tokens
mask = torch.ones(batch_size, dtype=torch.bool, device='cuda')
mask[seq_indexes_with_bonus_tokens] = False
# Set the last token ID to -1 for all indices not in
# seq_indexes_with_bonus_tokens to indicate the lack of bonus token in
# those indices.
accepted_token_ids[mask, -1:] = -1
worker = SpecDecodeWorker(draft_worker,
target_worker,
mock_spec_decode_sampler("rejection_sampler"),
metrics_collector=metrics_collector)
# Initialize _seq_with_bonus_token_in_last_step with a set of sequence IDs.
# This set includes all sequence IDs in the batch as well as an additional
# `num_extra_sequence_ids` sequence IDs. Note that the sequence IDs are in
# the range [0, batch_size + num_extra_sequence_ids).
num_extra_sequence_ids = 10
worker._seq_with_bonus_token_in_last_step = set(
range(batch_size + num_extra_sequence_ids))
worker._create_output_sampler_list(
seq_group_metadata_list=seq_group_metadata_list,
accepted_token_ids=accepted_token_ids,
target_logprobs=target_token_logprobs,
k=k)
# Verify that _seq_with_bonus_token_in_last_step contains the following:
# 1. Sequence IDs that were already present in
# _seq_with_bonus_token_in_last_step but were not part of the current
# batch are retained.
# 2. Of the sequence IDs present in the current batch, only those with a
# bonus token are retained in _seq_with_bonus_token_in_last_step.
# Sequence IDs that are present in the current batch but do not have
# bonus tokens are removed from _seq_with_bonus_token_in_last_step.
expected_seq_ids_with_bonus_tokens = \
set([assigned_seq_ids[i] for i in seq_indexes_with_bonus_tokens])
additional_sequence_ids = \
set(range(batch_size, batch_size + num_extra_sequence_ids))
assert worker._seq_with_bonus_token_in_last_step == \
expected_seq_ids_with_bonus_tokens.union(additional_sequence_ids)
assert worker._request_id_seq_id_mapping == \
expected_request_id_seq_ids_mapping
@torch.inference_mode()
def test_handle_finished_requests():
"""
Test to verify that finished request IDs are appropriately processed to
update the internal state of the SpecDecodeWorker.
This test initializes the SpecDecodeWorker with mock data, marks certain
requests as finished, and ensures that the corresponding sequence IDs are
correctly removed from the internal mappings.
"""
batch_size = 32
k = 3
draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker()
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
worker = SpecDecodeWorker(draft_worker, target_worker,
mock_spec_decode_sampler("rejection_sampler"),
metrics_collector)
# Initialize the request_id_seq_id_mapping mapping dict with a few fake
# request ids and corresponding sequence ids.
worker._request_id_seq_id_mapping = \
{'request-1': {1,2,3}, 'request-2': {4,5,6,7},
'request-3': {8,9}, 'request-4': {10,11}}
# Initialize seq_with_bonus_token_in_last_step with a few fake
# sequence ids.
worker._seq_with_bonus_token_in_last_step = {1, 4, 5, 8, 9, 10}
exception_secret = 'artificial stop'
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
# Mark requests with ids request-1 and request-3 as finished.
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k,
finished_requests_ids=['request-1', 'request-3'])
with pytest.raises(ValueError, match=exception_secret):
worker.execute_model(execute_model_req=execute_model_req)
# Verify that request-1 and request-3 are removed from
# request_id_seq_id_mapping
assert worker._request_id_seq_id_mapping == \
{'request-2': {4,5,6,7}, 'request-4': {10,11}}
# Verify that all sequence ids corresponding to 'request-1'
# and 'request-3' are removed from seq_with_bonus_token_in_last_step.
assert worker._seq_with_bonus_token_in_last_step == \
{4,5,10}

View File

@ -3,8 +3,9 @@ import copy
import enum import enum
import math import math
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import defaultdict
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union
import torch import torch
@ -916,6 +917,21 @@ def get_all_seq_ids(
return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data] return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data]
def get_all_seq_ids_and_request_ids(
seq_group_metadata_list: List[SequenceGroupMetadata]
) -> Tuple[List[int], Dict[str, Set[int]]]:
"""Given a list of SequenceGroupMetadata, create a list of all
sequence ids.
"""
seq_ids: List[int] = []
request_id_seq_ids_mapping: Dict[str, Set[int]] = defaultdict(set)
for sg in seq_group_metadata_list:
for seq_id in sg.seq_data:
seq_ids.append(seq_id)
request_id_seq_ids_mapping[sg.request_id].add(seq_id)
return seq_ids, request_id_seq_ids_mapping
class HiddenStates: class HiddenStates:
"""Hidden states corresponding to in-progress sequences. """Hidden states corresponding to in-progress sequences.
Used in speculative decoding to pass hidden states from Used in speculative decoding to pass hidden states from

View File

@ -1,6 +1,6 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Optional, Set
import torch import torch
@ -62,6 +62,9 @@ class SpeculativeProposer(ABC):
def get_spec_proposals( def get_spec_proposals(
self, self,
execute_model_req: ExecuteModelRequest, execute_model_req: ExecuteModelRequest,
# If set, this contains all sequence IDs that were assigned
# bonus tokens in their last forward pass.
seq_ids_with_bonus_token_in_last_step: Set[int],
) -> SpeculativeProposals: ) -> SpeculativeProposals:
raise NotImplementedError raise NotImplementedError

View File

@ -1,5 +1,5 @@
import weakref import weakref
from typing import List, Optional, Tuple from typing import List, Optional, Set, Tuple
import torch import torch
@ -40,6 +40,8 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
self, self,
execute_model_req: ExecuteModelRequest, execute_model_req: ExecuteModelRequest,
sample_len: int, sample_len: int,
# Unused parameter.
seq_ids_with_bonus_token_in_last_step: Set[int],
) -> Tuple[List[SamplerOutput], bool]: ) -> Tuple[List[SamplerOutput], bool]:
"""Run the model forward pass to generate sample_len future tokens. """Run the model forward pass to generate sample_len future tokens.
Returns the list of sampler output, one per layer, along with indicator Returns the list of sampler output, one per layer, along with indicator
@ -97,12 +99,14 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
def get_spec_proposals( def get_spec_proposals(
self, self,
execute_model_req: ExecuteModelRequest, execute_model_req: ExecuteModelRequest,
seq_ids_with_bonus_token_in_last_step: Set[int],
) -> SpeculativeProposals: ) -> SpeculativeProposals:
"""Produce speculations given an input batch of sequences. The number of """Produce speculations given an input batch of sequences. The number of
speculative tokens per sequence is determined by max_proposal_len. speculative tokens per sequence is determined by max_proposal_len.
""" """
return self._proposer.get_spec_proposals(execute_model_req) return self._proposer.get_spec_proposals(
execute_model_req, seq_ids_with_bonus_token_in_last_step)
def _raise_if_unsupported( def _raise_if_unsupported(
self, self,

View File

@ -1,4 +1,4 @@
from typing import List, Optional, Tuple from typing import List, Optional, Set, Tuple
import torch import torch
@ -20,6 +20,9 @@ class MLPSpeculatorWorker(NonLLMProposerWorkerBase, MultiStepWorker):
self, self,
execute_model_req: ExecuteModelRequest, execute_model_req: ExecuteModelRequest,
sample_len: int, sample_len: int,
# Unused parameter. MLPSpeculatorWorker does not use the KV Cache and
# therefore does not need this parameter.
seq_ids_with_bonus_token_in_last_step: Set[int],
) -> Tuple[List[SamplerOutput], bool]: ) -> Tuple[List[SamplerOutput], bool]:
"""Run the model forward pass to generate sample_len future tokens. """Run the model forward pass to generate sample_len future tokens.
Returns the list of sampler output, one per layer, along with indicator Returns the list of sampler output, one per layer, along with indicator

View File

@ -1,6 +1,6 @@
import copy import copy
import weakref import weakref
from typing import Dict, List, Tuple from typing import Dict, List, Set, Tuple
import torch import torch
@ -51,6 +51,7 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
self, self,
execute_model_req: ExecuteModelRequest, execute_model_req: ExecuteModelRequest,
sample_len: int, sample_len: int,
seq_ids_with_bonus_token_in_last_step: Set[int],
) -> Tuple[List[SamplerOutput], bool]: ) -> Tuple[List[SamplerOutput], bool]:
"""Run the model forward pass sample_len times. Returns the list of """Run the model forward pass sample_len times. Returns the list of
sampler output, one per model forward pass, along with indicator of sampler output, one per model forward pass, along with indicator of
@ -60,44 +61,142 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
For multi step worker, this indicator shall be True. For multi step worker, this indicator shall be True.
""" """
self._raise_if_unsupported(execute_model_req) self._raise_if_unsupported(execute_model_req)
# Expand the batch for sequences with a bonus token.
# Shallow copy input data so modifications (such as appending tokens) # Perform a forward pass on the expanded batch and filter the
# do not cause side-effects. # response to retain only the original sequences' responses.
copied_seq_group_metadata_list = self._shallow_copy_inputs( expanded_request, indices_of_seq_with_bonus_tokens =\
execute_model_req.seq_group_metadata_list) self._expand_execute_model_request(
copied_execute_model_req = execute_model_req.clone( execute_model_req, seq_ids_with_bonus_token_in_last_step)
copied_seq_group_metadata_list)
# Run model sample_len times. # Run model sample_len times.
model_outputs: List[SamplerOutput] = [] model_outputs: List[SamplerOutput] = []
if isinstance(self.model_runner, TP1DraftModelRunner): if isinstance(self.model_runner, TP1DraftModelRunner):
copied_execute_model_req.num_steps = sample_len expanded_request.num_steps = sample_len
model_outputs = self.execute_model( model_outputs = self.execute_model(
execute_model_req=copied_execute_model_req) execute_model_req=expanded_request)
else: else:
# TODO: Remove this branch once DraftModelRunner supports TP>1. # TODO: Remove this branch once DraftModelRunner supports TP>1.
for _ in range(sample_len): for _ in range(sample_len):
model_output: List[SamplerOutput] = super().execute_model( model_output: List[SamplerOutput] = super().execute_model(
execute_model_req=copied_execute_model_req) execute_model_req=expanded_request)
assert (len(model_output) == 1 assert (len(model_output) == 1
), "composing multistep workers not supported" ), "composing multistep workers not supported"
model_output = model_output[0] model_output = model_output[0]
self._append_new_tokens(model_output, self._append_new_tokens(
copied_seq_group_metadata_list) model_output, expanded_request.seq_group_metadata_list)
model_outputs.append(model_output) model_outputs.append(model_output)
return model_outputs, True filtered_model_outputs = self._filter_model_output(
model_outputs, indices_of_seq_with_bonus_tokens)
return filtered_model_outputs, True
@staticmethod
def _expand_execute_model_request(
execute_model_req: ExecuteModelRequest,
seq_with_bonus_token_in_last_step: set,
) -> Tuple[ExecuteModelRequest, List[int]]:
"""
Expands the execute model request based on sequences with bonus
tokens.
For each sequence with a bonus token, this method creates a new
sequence without the bonus token and adds it to the execute model
request. The original sequence groups are also retained. The indices
of the original sequence groups are returned for further processing.
Args:
execute_model_req (ExecuteModelRequest): The original execute
model request.
seq_with_bonus_token_in_last_step (set): Set of sequence IDs that
contain bonus tokens.
Returns:
Tuple[ExecuteModelRequest, List[int]]: The updated execute model
request with expanded sequences and a list of indices corresponding
to the original sequence groups.
"""
updated_seq_group_metadata_list: List[SequenceGroupMetadata] = []
updated_execute_model_req = execute_model_req.clone(
updated_seq_group_metadata_list)
indices_of_original_sequence_groups = []
for seq_group in execute_model_req.seq_group_metadata_list:
seq_group_has_bonus_tokens = False
for seq_id, _ in seq_group.seq_data.items():
# Identify sequences with bonus tokens in the sequence group.
if seq_id in seq_with_bonus_token_in_last_step:
seq_group_has_bonus_tokens = True
break
if seq_group_has_bonus_tokens:
#Create new sequences without the last bonus token. These new
# sequence have the same sequence id as the original sequence.
# We create a new sequence group and add them there.
updated_seq_group_without_bonus_token = \
MultiStepWorker._copy_seq_metadata_excluding_last_token(
seq_group, seq_with_bonus_token_in_last_step)
updated_seq_group_metadata_list.append(
updated_seq_group_without_bonus_token)
# Add the original sequence group.
updated_seq_group_metadata_list.append(
MultiStepWorker._shallow_copy_seq_group_metadata(seq_group))
# Record the index of the original sequence group.
indices_of_original_sequence_groups.append(
len(updated_seq_group_metadata_list) - 1)
updated_execute_model_req.seq_group_metadata_list =\
updated_seq_group_metadata_list
return updated_execute_model_req, indices_of_original_sequence_groups
@staticmethod
def _filter_model_output(
expanded_batch_outputs: List[SamplerOutput],
output_indices_to_retain: List[int]) -> List[SamplerOutput]:
"""
Filters the model output to include only the specified sequence
outputs. This method contracts the expanded batch output from the
model to retain the outputs of only those sequences indicated by the
provided indices.
Args:
expanded_batch_output (List[SamplerOutput]): The expanded output
batch from the model.
output_indices_to_retain (List[int]): Indices of the model outputs
to retain.
Returns:
List[SamplerOutput]: A list containing the filtered model
outputs for the specified indices.
"""
return [
SamplerOutput(
outputs=[
expanded_batch_output.outputs[i]
for i in output_indices_to_retain
],
sampled_token_probs=(
expanded_batch_output.
sampled_token_probs[output_indices_to_retain]
if expanded_batch_output.sampled_token_probs is not None
else None),
logprobs=(
expanded_batch_output.logprobs[output_indices_to_retain]
if expanded_batch_output.logprobs is not None else None),
sampled_token_ids=(expanded_batch_output.
sampled_token_ids[output_indices_to_retain]
if expanded_batch_output.sampled_token_ids
is not None else None))
for expanded_batch_output in expanded_batch_outputs
]
def get_spec_proposals( def get_spec_proposals(
self, self,
execute_model_req: ExecuteModelRequest, execute_model_req: ExecuteModelRequest,
seq_ids_with_bonus_token_in_last_step: set,
) -> SpeculativeProposals: ) -> SpeculativeProposals:
"""Produce speculations given an input batch of sequences. The number of """Produce speculations given an input batch of sequences. The number of
speculative tokens per sequence is determined by max_proposal_len. speculative tokens per sequence is determined by max_proposal_len.
""" """
return self._proposer.get_spec_proposals(
return self._proposer.get_spec_proposals(execute_model_req) execute_model_req, seq_ids_with_bonus_token_in_last_step)
@staticmethod @staticmethod
def _append_new_tokens( def _append_new_tokens(
@ -123,9 +222,8 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
seq.update_num_computed_tokens(1) seq.update_num_computed_tokens(1)
@staticmethod @staticmethod
def _shallow_copy_inputs( def _shallow_copy_seq_group_metadata(
seq_group_metadata_list: List[SequenceGroupMetadata] seq_group_metadata: SequenceGroupMetadata, ) -> SequenceGroupMetadata:
) -> List[SequenceGroupMetadata]:
"""Copy input data structures to remove side-effects when input data """Copy input data structures to remove side-effects when input data
structures are shared with other modules. structures are shared with other modules.
@ -133,26 +231,62 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
The alternative is deep-copying (or other form of deep copy); this has The alternative is deep-copying (or other form of deep copy); this has
performance downsides. performance downsides.
""" """
# Shallow-copy the SequenceGroupMetadata. This allows us to
# Shallow-copy the list of SequenceGroupMetadata. This allows us to
# append tokens and change is_prompt without external side-effects. # append tokens and change is_prompt without external side-effects.
new_seq_group_metadata_list: List[SequenceGroupMetadata] = [] # We must shallow-copy seq_group_metadata as is_prompt could change.
new_seq_group_metadata = copy.copy(seq_group_metadata)
for old_seq_group_metadata in seq_group_metadata_list: # We must shallow-copy seq_data as we will append token ids
# We must shallow-copy seq_group_metadata as is_prompt could change. new_seq_data: Dict[int, SequenceData] = {}
seq_group_metadata = copy.copy(old_seq_group_metadata) for seq_id, old_seq_data in seq_group_metadata.seq_data.items():
new_seq_group_metadata_list.append(seq_group_metadata) new_seq_data[seq_id] = copy.copy(old_seq_data)
new_seq_data[seq_id].output_token_ids =\
old_seq_data.output_token_ids[:]
# We must shallow-copy seq_data as we will append token ids new_seq_group_metadata.seq_data = new_seq_data
new_seq_data: Dict[int, SequenceData] = {} return new_seq_group_metadata
for seq_id, old_seq_data in seq_group_metadata.seq_data.items():
@staticmethod
def _copy_seq_metadata_excluding_last_token(
seq_group_metadata: SequenceGroupMetadata,
seq_ids_to_copy: Set[int],
) -> SequenceGroupMetadata:
"""
Creates a shallow copy of the given SequenceGroupMetadata, retaining
only the sequence IDs specified in seq_ids_to_copy. For each of these
sequence IDs, all output_token_ids except the last one are copied.
Sequence IDs not in seq_ids_to_copy are excluded from the copy.
Parameters:
seq_group_metadata (SequenceGroupMetadata): The original sequence
group metadata.
seq_ids_to_copy (Set[int]): The set of sequence IDs to include in the
copy.
Returns:
SequenceGroupMetadata: A shallow copy of the sequence group metadata
with the specified modifications.
"""
# Shallow-copy the SequenceGroupMetadata.
new_seq_group_metadata = copy.copy(seq_group_metadata)
# Shallow-copy seq_data and modify the output_token_ids.
new_seq_data: Dict[int, SequenceData] = {}
for seq_id, old_seq_data in seq_group_metadata.seq_data.items():
if (seq_id in seq_ids_to_copy):
new_seq_data[seq_id] = copy.copy(old_seq_data) new_seq_data[seq_id] = copy.copy(old_seq_data)
new_seq_data[ # Copy all the output token ids except the last.
seq_id].output_token_ids = old_seq_data.output_token_ids[:] # Also reduce num_computed_tokens by 1 since we are not
# including the last output token.
seq_group_metadata.seq_data = new_seq_data # NOTE: num_computed_tokens is not directly used by the
# speculative decoding workers, as it is only relevant for
return new_seq_group_metadata_list # chunked prefill, which is disabled for speculative decoding.
# However, to maintain consistency in num_computed_tokens,
# we update it here.
new_seq_data[seq_id].output_token_ids =\
old_seq_data.output_token_ids[:-1]
new_seq_data[seq_id].update_num_computed_tokens(-1)
new_seq_group_metadata.seq_data = new_seq_data
return new_seq_group_metadata
def _assert_enough_kv_space( def _assert_enough_kv_space(
self, seq_group_metadata_list: List[SequenceGroupMetadata], self, seq_group_metadata_list: List[SequenceGroupMetadata],

View File

@ -1,5 +1,5 @@
import weakref import weakref
from typing import List, Optional, Tuple from typing import List, Optional, Set, Tuple
import torch import torch
@ -48,6 +48,9 @@ class NGramWorker(NonLLMProposerWorkerBase, LoraNotSupportedWorkerBase):
self, self,
execute_model_req: ExecuteModelRequest, execute_model_req: ExecuteModelRequest,
sample_len: int, sample_len: int,
# Unused parameter. NGramWorker does not use the KV Cache and
# therefore does not need this parameter.
seq_ids_with_bonus_token_in_last_step: Set[int],
) -> Tuple[Optional[List[Optional[SamplerOutput]]], bool]: ) -> Tuple[Optional[List[Optional[SamplerOutput]]], bool]:
"""NGram match algo to pick proposal candidate. Returns the list of """NGram match algo to pick proposal candidate. Returns the list of
sampler output, one per SequenceGroupMetadata. sampler output, one per SequenceGroupMetadata.
@ -133,12 +136,15 @@ class NGramWorker(NonLLMProposerWorkerBase, LoraNotSupportedWorkerBase):
def get_spec_proposals( def get_spec_proposals(
self, self,
execute_model_req: ExecuteModelRequest, execute_model_req: ExecuteModelRequest,
# Unused parameter. NGramWorker does not use the KV Cache and
# therefore does not need this parameter.
seq_ids_with_bonus_token_in_last_step: Set[int],
) -> SpeculativeProposals: ) -> SpeculativeProposals:
"""Produce speculations given an input batch of sequences. The number of """Produce speculations given an input batch of sequences. The number of
speculative tokens per sequence is determined by max_proposal_len. speculative tokens per sequence is determined by max_proposal_len.
""" """
return self._proposer.get_spec_proposals(
return self._proposer.get_spec_proposals(execute_model_req) execute_model_req, seq_ids_with_bonus_token_in_last_step)
def _raise_if_unsupported( def _raise_if_unsupported(
self, self,

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Optional, Tuple from typing import List, Optional, Set, Tuple
from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.spec_decode.interfaces import SpeculativeProposer from vllm.spec_decode.interfaces import SpeculativeProposer
@ -14,6 +14,13 @@ class ProposerWorkerBase(LoraNotSupportedWorkerBase, SpeculativeProposer):
self, self,
execute_model_req: ExecuteModelRequest, execute_model_req: ExecuteModelRequest,
sample_len: int, sample_len: int,
# A set containing all sequence IDs that were assigned bonus tokens
# in their last forward pass. This set is used to backfill the KV cache
# with the key-value pairs of the penultimate token in the sequences.
# This parameter is only used by the MultiStepWorker, which relies on
# the KV cache for token generation. It is not used by workers that
# do not utilize the KV cache.
seq_ids_with_bonus_token_in_last_step: Set[int]
) -> Tuple[Optional[List[SamplerOutput]], bool]: ) -> Tuple[Optional[List[SamplerOutput]], bool]:
raise NotImplementedError raise NotImplementedError

View File

@ -1,4 +1,4 @@
from typing import List, Optional, Tuple from typing import List, Optional, Set, Tuple
import torch import torch
@ -110,13 +110,17 @@ class SmallerTpProposerWorker(ProposerWorkerBase):
self, self,
execute_model_req: ExecuteModelRequest, execute_model_req: ExecuteModelRequest,
sample_len: int, sample_len: int,
seq_ids_with_bonus_token_in_last_step: Set[int],
) -> Tuple[List[SamplerOutput], bool]: ) -> Tuple[List[SamplerOutput], bool]:
# Do not check _is_dummy, as it's always called by get_spec_proposals # Do not check _is_dummy, as it's always called by get_spec_proposals
return self._worker.sampler_output(execute_model_req, sample_len) return self._worker.sampler_output(
execute_model_req, sample_len,
seq_ids_with_bonus_token_in_last_step)
def get_spec_proposals( def get_spec_proposals(
self, self,
execute_model_req: ExecuteModelRequest, execute_model_req: ExecuteModelRequest,
seq_ids_with_bonus_token_in_last_step: Set[int],
) -> SpeculativeProposals: ) -> SpeculativeProposals:
"""Produce speculations given an input batch of sequences. The number of """Produce speculations given an input batch of sequences. The number of
speculative tokens per sequence is determined by max_proposal_len. speculative tokens per sequence is determined by max_proposal_len.
@ -125,7 +129,8 @@ class SmallerTpProposerWorker(ProposerWorkerBase):
return SpeculativeProposals(None, None, None) return SpeculativeProposals(None, None, None)
with self._patch_tensor_parallel_group(): with self._patch_tensor_parallel_group():
return self._worker.get_spec_proposals(execute_model_req) return self._worker.get_spec_proposals(
execute_model_req, seq_ids_with_bonus_token_in_last_step)
def execute_model( def execute_model(
self, self,

View File

@ -1,5 +1,6 @@
from collections import defaultdict
from functools import cached_property from functools import cached_property
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Set, Tuple
import torch import torch
@ -13,7 +14,7 @@ from vllm.model_executor.layers.typical_acceptance_sampler import (
TypicalAcceptanceSampler) TypicalAcceptanceSampler)
from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest, from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest,
HiddenStates, SamplerOutput, SequenceGroupMetadata, HiddenStates, SamplerOutput, SequenceGroupMetadata,
get_all_seq_ids) get_all_seq_ids_and_request_ids)
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
from vllm.spec_decode.interfaces import (SpeculativeProposals, from vllm.spec_decode.interfaces import (SpeculativeProposals,
@ -112,11 +113,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
draft_worker_kwargs.pop("ngram_prompt_lookup_max")) draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
ngram_prompt_lookup_min = ( ngram_prompt_lookup_min = (
draft_worker_kwargs.pop("ngram_prompt_lookup_min")) draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
disable_bonus_tokens = True
if ngram_prompt_lookup_max > 0: if ngram_prompt_lookup_max > 0:
disable_bonus_tokens = False
proposer_worker = NGramWorker(**draft_worker_kwargs) proposer_worker = NGramWorker(**draft_worker_kwargs)
proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min, proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
ngram_prompt_lookup_max) ngram_prompt_lookup_max)
@ -128,11 +125,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
if draft_worker_kwargs[ if draft_worker_kwargs[
"model_config"].hf_config.model_type == "mlp_speculator": "model_config"].hf_config.model_type == "mlp_speculator":
disable_bonus_tokens = False
proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs) proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs)
elif draft_worker_kwargs[ elif draft_worker_kwargs[
"model_config"].hf_config.model_type == "medusa": "model_config"].hf_config.model_type == "medusa":
disable_bonus_tokens = False
proposer_worker = MedusaWorker(**draft_worker_kwargs) proposer_worker = MedusaWorker(**draft_worker_kwargs)
else: else:
if draft_tp == 1: if draft_tp == 1:
@ -149,10 +144,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
spec_decode_sampler: SpecDecodeBaseSampler = None spec_decode_sampler: SpecDecodeBaseSampler = None
if draft_token_acceptance_method == "rejection_sampler": if draft_token_acceptance_method == "rejection_sampler":
spec_decode_sampler = RejectionSampler( spec_decode_sampler = RejectionSampler(
disable_bonus_tokens=disable_bonus_tokens, ) disable_bonus_tokens=False, )
elif draft_token_acceptance_method == "typical_acceptance_sampler": elif draft_token_acceptance_method == "typical_acceptance_sampler":
spec_decode_sampler = TypicalAcceptanceSampler( spec_decode_sampler = TypicalAcceptanceSampler(
disable_bonus_tokens=disable_bonus_tokens, disable_bonus_tokens=False,
posterior_threshold=\ posterior_threshold=\
typical_acceptance_sampler_posterior_threshold, typical_acceptance_sampler_posterior_threshold,
posterior_alpha=typical_acceptance_sampler_posterior_alpha, posterior_alpha=typical_acceptance_sampler_posterior_alpha,
@ -200,6 +195,15 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self._metrics = AsyncMetricsCollector( self._metrics = AsyncMetricsCollector(
self.spec_decode_sampler self.spec_decode_sampler
) if metrics_collector is None else metrics_collector ) if metrics_collector is None else metrics_collector
# Tracks the sequence IDs that received a bonus token ID in
# their last forward pass. Needed only if KV cache is being
# used for token generation such as in the case of MultiStepWorker.
self._seq_with_bonus_token_in_last_step: Set[int] = set()
# Tracks the currently active request ids and the sequence IDs
# corresponding to them
self._request_id_seq_id_mapping: Dict[str, Set[int]] = defaultdict(set)
# Tracks if the proposer worker uses the KV cache or not.
self.probs_dtype = self.spec_decode_sampler.probs_dtype self.probs_dtype = self.spec_decode_sampler.probs_dtype
self.token_id_dtype = self.spec_decode_sampler.token_id_dtype self.token_id_dtype = self.spec_decode_sampler.token_id_dtype
# Lazy initiazliation. # Lazy initiazliation.
@ -307,6 +311,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
broadcast_tensor_dict({}, src=0) broadcast_tensor_dict({}, src=0)
return [] return []
self._track_finished_requests(execute_model_req)
disable_all_speculation = self._should_disable_all_speculation( disable_all_speculation = self._should_disable_all_speculation(
execute_model_req) execute_model_req)
num_lookahead_slots = execute_model_req.num_lookahead_slots num_lookahead_slots = execute_model_req.num_lookahead_slots
@ -453,7 +458,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self.previous_hidden_states = None self.previous_hidden_states = None
# Generate proposals using draft worker. # Generate proposals using draft worker.
proposals = self.proposer_worker.get_spec_proposals(execute_model_req) proposals = self.proposer_worker.get_spec_proposals(
execute_model_req, self._seq_with_bonus_token_in_last_step)
proposal_scores = self.scorer.score_proposals( proposal_scores = self.scorer.score_proposals(
execute_model_req, execute_model_req,
@ -585,7 +591,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# Get the sequence ids and num_logprobs (sampling parameter) in the # Get the sequence ids and num_logprobs (sampling parameter) in the
# batch. # batch.
seq_ids = get_all_seq_ids(seq_group_metadata_list) seq_ids, request_ids_seq_ids_mapping = get_all_seq_ids_and_request_ids(
seq_group_metadata_list)
num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list) num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list)
# Serialize all tensors to CPU Python lists. # Serialize all tensors to CPU Python lists.
@ -608,7 +616,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
for sequence_index in range(batch_size): for sequence_index in range(batch_size):
# Each sequence may have a different num_logprobs; retrieve it. # Each sequence may have a different num_logprobs; retrieve it.
num_logprobs = num_logprobs_per_seq[sequence_index] num_logprobs = num_logprobs_per_seq[sequence_index]
step_output_token_ids.append( step_output_token_ids.append(
create_sequence_group_output( create_sequence_group_output(
token_id=accepted_token_ids_by_step[step_index] token_id=accepted_token_ids_by_step[step_index]
@ -623,18 +630,48 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
topk_logprobs=topk_logprobs_by_step[step_index] topk_logprobs=topk_logprobs_by_step[step_index]
[sequence_index][:num_logprobs], [sequence_index][:num_logprobs],
)) ))
sampler_output_list.append( sampler_output_list.append(
SamplerOutput(outputs=step_output_token_ids)) SamplerOutput(outputs=step_output_token_ids))
# Populate the data structures needed to keep track of sequences with
# bonus tokens.
self._track_sequences_with_bonus_tokens(seq_ids,
request_ids_seq_ids_mapping,
accepted_token_ids_by_step)
maybe_rejsample_metrics = ( maybe_rejsample_metrics = (
self._metrics.maybe_collect_rejsample_metrics(k)) self._metrics.maybe_collect_rejsample_metrics(k))
if maybe_rejsample_metrics is not None: if maybe_rejsample_metrics is not None:
sampler_output_list[ sampler_output_list[
0].spec_decode_worker_metrics = maybe_rejsample_metrics 0].spec_decode_worker_metrics = maybe_rejsample_metrics
return sampler_output_list return sampler_output_list
def _track_finished_requests(self, execute_model_req: ExecuteModelRequest):
"""
Removes the finished requests and their associated sequence ids from
internal book keeping data structures.
"""
for finished_request in execute_model_req.finished_requests_ids:
for seq_id in self._request_id_seq_id_mapping[finished_request]:
self._seq_with_bonus_token_in_last_step.discard(seq_id)
del self._request_id_seq_id_mapping[finished_request]
def _track_sequences_with_bonus_tokens(
self, seq_ids: List[int],
request_ids_seq_ids_mapping: Dict[str, Set[int]],
accepted_token_ids_by_step: List[List[int]]):
"""
Updates the internal data structures which keep track of sequences
which have been assigned bonus tokens in their last forward pass.
"""
for seq_index, seq_id in enumerate(seq_ids):
last_token_id = accepted_token_ids_by_step[-1][seq_index]
if last_token_id == -1:
self._seq_with_bonus_token_in_last_step.discard(seq_id)
else:
self._seq_with_bonus_token_in_last_step.add(seq_id)
for request_id, sequences in request_ids_seq_ids_mapping.items():
self._request_id_seq_id_mapping[request_id].update(sequences)
@cached_property @cached_property
def _vocab_size(self) -> int: def _vocab_size(self) -> int:
"""Get the vocab size of the model and make sure it's consistent between """Get the vocab size of the model and make sure it's consistent between

View File

@ -1,4 +1,4 @@
from typing import List, Optional, Tuple from typing import List, Optional, Set, Tuple
import torch import torch
@ -42,6 +42,7 @@ class Top1Proposer(SpeculativeProposer):
def get_spec_proposals( def get_spec_proposals(
self, self,
execute_model_req: ExecuteModelRequest, execute_model_req: ExecuteModelRequest,
seq_ids_with_bonus_token_in_last_step: Set[int],
) -> SpeculativeProposals: ) -> SpeculativeProposals:
"""Get speculative proposals given the input batch. """Get speculative proposals given the input batch.
@ -76,6 +77,8 @@ class Top1Proposer(SpeculativeProposer):
maybe_sampler_output, transposed = self._worker.sampler_output( maybe_sampler_output, transposed = self._worker.sampler_output(
execute_model_req=nonzero_execute_model_req, execute_model_req=nonzero_execute_model_req,
sample_len=proposal_len, sample_len=proposal_len,
seq_ids_with_bonus_token_in_last_step=\
seq_ids_with_bonus_token_in_last_step,
) )
( (
proposal_lens, proposal_lens,