[Speculative Decoding] Enabling bonus token in speculative decoding for KV cache based models (#5765)
This commit is contained in:
parent
44cc76610d
commit
ae151d73be
@ -70,14 +70,17 @@ def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int,
|
||||
if queue_size < disable_by_batch_size:
|
||||
# Should raise exception when executing the mocked draft model.
|
||||
with pytest.raises(ValueError, match=exception_secret):
|
||||
proposer.get_spec_proposals(execute_model_req=ExecuteModelRequest(
|
||||
proposer.get_spec_proposals(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k), )
|
||||
num_lookahead_slots=k),
|
||||
seq_ids_with_bonus_token_in_last_step=set())
|
||||
else:
|
||||
# Should not execute the draft model because spec decode is disabled
|
||||
# for all requests. Accordingly, the proposal length should be 0.
|
||||
proposals = proposer.get_spec_proposals(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k), )
|
||||
num_lookahead_slots=k),
|
||||
seq_ids_with_bonus_token_in_last_step=set())
|
||||
assert proposals.proposal_lens.tolist() == [0] * batch_size
|
||||
|
@ -118,7 +118,8 @@ def test_same_output_for_single_step():
|
||||
actual_output, _ = multi_step_worker.sampler_output(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=multi_step_seq_group),
|
||||
sample_len=num_steps)
|
||||
sample_len=num_steps,
|
||||
seq_ids_with_bonus_token_in_last_step=set())
|
||||
assert len(actual_output) == num_steps
|
||||
actual_output = actual_output[0]
|
||||
|
||||
@ -210,7 +211,8 @@ def test_same_output_for_multi_step():
|
||||
multi_step_output, _ = multi_step_worker.sampler_output(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list),
|
||||
sample_len=num_steps)
|
||||
sample_len=num_steps,
|
||||
seq_ids_with_bonus_token_in_last_step=set())
|
||||
|
||||
# Run single-step repeatedly.
|
||||
zero_kv_cache(worker.cache_engine)
|
||||
@ -277,6 +279,203 @@ def test_same_output_for_multi_step():
|
||||
single_step_logprobs)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_multi_step_with_batch_expansion_correct_output():
|
||||
"""
|
||||
In this test we verify that the MultiStepWorker is able to handle bonus
|
||||
tokens correctly. The test verifies that if a sequence has a
|
||||
bonus token then the MultiStepWorker is able to expand the batch by adding
|
||||
new sequences corresponding to the sequences with bonus tokens. The
|
||||
expanded batch is then used for predicting the next tokens.
|
||||
"""
|
||||
seed = 100
|
||||
model_name = 'JackFram/llama-68m'
|
||||
|
||||
block_size = 16
|
||||
num_gpu_blocks = 2048 // block_size
|
||||
batch_size = 128
|
||||
multi_step_worker = create_worker(
|
||||
MultiStepWorker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
model_runner_cls=TP1DraftModelRunner,
|
||||
)
|
||||
worker = create_worker(
|
||||
Worker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
)
|
||||
random.seed(seed)
|
||||
prompts = [[0] for _ in range(batch_size)]
|
||||
num_steps = 2
|
||||
final_prompt_lens = [(num_steps + 1) for prompt in prompts]
|
||||
rand_seeds = list(random.randint(0, 100) for _ in range(num_steps))
|
||||
multi_step_worker.execute_model = patch_execute_model_with_seeds(
|
||||
multi_step_worker, rand_seeds)
|
||||
worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds)
|
||||
# Create the test continuations
|
||||
continuations = [[random.randint(0, 1000)] for _ in prompts]
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=continuations,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
# Run single-step twice to generate 2 tokens. This
|
||||
# will simulate the bonus token case with the second token
|
||||
# being the bonus token.
|
||||
zero_kv_cache(worker.cache_engine)
|
||||
single_step_output: List[SamplerOutput] = []
|
||||
set_random_seed(seed)
|
||||
for _ in range(num_steps):
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=continuations,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
single_step_output.extend(
|
||||
worker.execute_model(execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list)))
|
||||
# Append output tokens to new sequence data.
|
||||
for i, seq_group_output in enumerate(single_step_output[-1]):
|
||||
continuations[i].append(seq_group_output.samples[0].output_token)
|
||||
|
||||
# Create continuations for the MultiStepWorker. The continuations have
|
||||
# 2 tokens in order to simulate the bonus token case.
|
||||
multi_step_continuations = []
|
||||
for continuation in continuations:
|
||||
multi_step_continuations.append(continuation[:2])
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=multi_step_continuations,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
# Run multi-step and verify that the third token prediction is accurate
|
||||
# for all sequences.
|
||||
zero_kv_cache(multi_step_worker.cache_engine)
|
||||
all_seq_ids = {i for i in range(batch_size)}
|
||||
multi_step_output, _ = multi_step_worker.sampler_output(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list),
|
||||
sample_len=1,
|
||||
seq_ids_with_bonus_token_in_last_step=all_seq_ids)
|
||||
for index, output in enumerate(multi_step_output[-1].outputs):
|
||||
assert (continuations[index][-1] == output.samples[0].output_token)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_multi_step_with_batch_expansion_incorrect_output():
|
||||
"""
|
||||
Tests the MultiStepWorker's ability to handle batch expansion with bonus
|
||||
tokens in a negative case scenario. This test provides the MultiStepWorker
|
||||
with a batch containing sequences with bonus tokens but specifies the
|
||||
sequence IDs with bonus tokens incorrectly. The test verifies that the
|
||||
MultiStepWorker generates correct tokens for the sequences where the
|
||||
sequence ID is specified correctly and incorrect tokens for those where
|
||||
the sequence ID is specified incorrectly.
|
||||
"""
|
||||
seed = 100
|
||||
model_name = 'JackFram/llama-68m'
|
||||
|
||||
block_size = 16
|
||||
num_gpu_blocks = 2048 // block_size
|
||||
batch_size = 128
|
||||
multi_step_worker = create_worker(
|
||||
MultiStepWorker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
model_runner_cls=TP1DraftModelRunner,
|
||||
)
|
||||
worker = create_worker(
|
||||
Worker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
)
|
||||
random.seed(seed)
|
||||
prompts = [[0] for _ in range(batch_size)]
|
||||
num_steps = 2
|
||||
final_prompt_lens = [(num_steps + 1) for prompt in prompts]
|
||||
rand_seeds = list(random.randint(0, 100) for _ in range(num_steps))
|
||||
multi_step_worker.execute_model = patch_execute_model_with_seeds(
|
||||
multi_step_worker, rand_seeds)
|
||||
worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds)
|
||||
# Create the test continuations
|
||||
continuations = [[random.randint(0, 1000)] for _ in prompts]
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=continuations,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
# Run single-step twice to generate 2 tokens. This
|
||||
# will simulate the bonus token case with the second token
|
||||
# being the bonus token.
|
||||
zero_kv_cache(worker.cache_engine)
|
||||
single_step_output: List[SamplerOutput] = []
|
||||
set_random_seed(seed)
|
||||
for _ in range(num_steps):
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=continuations,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
single_step_output.extend(
|
||||
worker.execute_model(execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list)))
|
||||
# Append output tokens to new sequence data.
|
||||
for i, seq_group_output in enumerate(single_step_output[-1]):
|
||||
continuations[i].append(seq_group_output.samples[0].output_token)
|
||||
|
||||
# Create continuations for the MultiStepWorker. The continuations have
|
||||
# 2 tokens in order to simulate the bonus token case.
|
||||
multi_step_continuations = []
|
||||
for continuation in continuations:
|
||||
multi_step_continuations.append(continuation[:2])
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=multi_step_continuations,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
# Run multi-step. In this run INCORRECTLY specify that only the odd number
|
||||
# sequences have bonus tokens. Verify that with this setting the third token
|
||||
# prediction is accurate only for the odd numbered sequences. Also verify
|
||||
# that the prediction might be wrong for some of the even numbered
|
||||
# sequences.
|
||||
zero_kv_cache(multi_step_worker.cache_engine)
|
||||
set_random_seed(seed)
|
||||
odd_seq_ids = {i for i in range(batch_size) if i % 2 != 0}
|
||||
multi_step_output, _ = multi_step_worker.sampler_output(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list),
|
||||
sample_len=1,
|
||||
seq_ids_with_bonus_token_in_last_step=odd_seq_ids)
|
||||
num_mismatch = 0
|
||||
for index, output in enumerate(multi_step_output[-1].outputs):
|
||||
if (index % 2) != 0:
|
||||
assert (continuations[index][-1] == output.samples[0].output_token)
|
||||
elif (continuations[index][-1] != output.samples[0].output_token):
|
||||
num_mismatch += 1
|
||||
# The prediction is accurate for some of the sequences even without proper
|
||||
# handling of the bonus tokens. Hence verify that the number of sequences
|
||||
# for which there is a mismatch is > 0.
|
||||
assert (num_mismatch > 0)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_draft_proposals_full_speculation_len():
|
||||
"""Verify Top1Proposer correctly handles case where all sequences
|
||||
@ -318,7 +517,8 @@ def test_draft_proposals_full_speculation_len():
|
||||
proposals = proposer.get_spec_proposals(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k), )
|
||||
num_lookahead_slots=k),
|
||||
seq_ids_with_bonus_token_in_last_step=set())
|
||||
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
assert torch.is_tensor(proposals.proposal_probs)
|
||||
@ -356,7 +556,8 @@ def test_draft_proposals_no_speculations():
|
||||
proposals = proposer.get_spec_proposals(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k), )
|
||||
num_lookahead_slots=k),
|
||||
seq_ids_with_bonus_token_in_last_step=set())
|
||||
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
assert torch.is_tensor(proposals.proposal_probs)
|
||||
@ -428,7 +629,8 @@ def test_draft_proposals_mixed_k():
|
||||
proposals = proposer.get_spec_proposals(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k), )
|
||||
num_lookahead_slots=k),
|
||||
seq_ids_with_bonus_token_in_last_step=set())
|
||||
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
assert torch.is_tensor(proposals.proposal_probs)
|
||||
|
@ -53,7 +53,8 @@ def test_ngram_algo_correctness_for_single_no_match():
|
||||
proposals = proposer.get_spec_proposals(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=proposal_len), )
|
||||
num_lookahead_slots=proposal_len),
|
||||
seq_ids_with_bonus_token_in_last_step=None)
|
||||
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
assert torch.is_tensor(proposals.proposal_probs)
|
||||
@ -121,7 +122,8 @@ def test_ngram_algo_correctness_for_batches_not_match_all():
|
||||
proposals = proposer.get_spec_proposals(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=proposal_len), )
|
||||
num_lookahead_slots=proposal_len),
|
||||
seq_ids_with_bonus_token_in_last_step=None)
|
||||
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
assert torch.is_tensor(proposals.proposal_probs)
|
||||
@ -193,7 +195,8 @@ def test_ngram_algo_correctness_for_batches_match_all():
|
||||
proposals = proposer.get_spec_proposals(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=proposal_len), )
|
||||
num_lookahead_slots=proposal_len),
|
||||
seq_ids_with_bonus_token_in_last_step=None)
|
||||
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
assert torch.is_tensor(proposals.proposal_probs)
|
||||
|
@ -1,6 +1,7 @@
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from types import SimpleNamespace
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Set
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
@ -377,8 +378,10 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool,
|
||||
|
||||
set_random_seed(1)
|
||||
|
||||
worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler,
|
||||
metrics_collector)
|
||||
worker = SpecDecodeWorker(draft_worker,
|
||||
target_worker,
|
||||
spec_decode_sampler,
|
||||
metrics_collector=metrics_collector)
|
||||
worker.init_device()
|
||||
|
||||
proposal_token_ids = torch.randint(low=0,
|
||||
@ -554,7 +557,6 @@ def test_init_device(acceptance_sampler_method: str):
|
||||
|
||||
worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler,
|
||||
metrics_collector)
|
||||
|
||||
worker.init_device()
|
||||
|
||||
draft_worker.init_device.assert_called_once()
|
||||
@ -645,3 +647,140 @@ def test_split_num_cache_blocks_evenly(available_gpu_blocks: int,
|
||||
assert (num_blocks * target_cache_block_size_bytes) + (
|
||||
num_blocks * draft_kv_size_bytes) <= (available_gpu_blocks *
|
||||
target_cache_block_size_bytes)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_populate_seq_ids_with_bonus_tokens():
|
||||
"""
|
||||
Verify that a call to _create_output_sampler_list correctly updates
|
||||
seq_with_bonus_token_in_last_step.
|
||||
|
||||
seq_with_bonus_token_in_last_step is an internal data structure in
|
||||
SpecDecodeWorker that tracks the sequence IDs which are assigned bonus
|
||||
tokens by the target model in their last forward pass. This state is
|
||||
maintained only for models relying on the KV cache, such as those using
|
||||
the MultiStepWorker.
|
||||
"""
|
||||
batch_size = 10
|
||||
k = 5
|
||||
vocab_size = 10000
|
||||
num_sequences_with_bonus_tokens = 5
|
||||
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
target_worker.execute_model.return_value = [MagicMock(spec=SamplerOutput)]
|
||||
target_worker.device = 'cuda'
|
||||
|
||||
set_random_seed(1)
|
||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||
draft_worker.device = 'cuda'
|
||||
# The sequence_ids attached to each sequence in the batch.
|
||||
# The sequence at index i has seq_id assigned_seq_ids[i]
|
||||
assigned_seq_ids = list(range(batch_size))
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size,
|
||||
k,
|
||||
seq_ids=assigned_seq_ids,
|
||||
prev_output_token_len=10)
|
||||
target_token_logprobs = torch.rand(batch_size, (k + 1),
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='cuda')
|
||||
accepted_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, (k + 1)),
|
||||
dtype=torch.int64,
|
||||
device='cuda')
|
||||
expected_request_id_seq_ids_mapping: Dict[str, Set[int]] = defaultdict(set)
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
for seq_id in seq_group_metadata.seq_data:
|
||||
expected_request_id_seq_ids_mapping[
|
||||
seq_group_metadata.request_id].add(seq_id)
|
||||
# Generate a random sample of sequence indexes with bonus tokens
|
||||
seq_indexes_with_bonus_tokens = random.sample(
|
||||
range(batch_size), num_sequences_with_bonus_tokens)
|
||||
# Create a mask that is True for indices in seq_indexes_with_bonus_tokens
|
||||
mask = torch.ones(batch_size, dtype=torch.bool, device='cuda')
|
||||
mask[seq_indexes_with_bonus_tokens] = False
|
||||
# Set the last token ID to -1 for all indices not in
|
||||
# seq_indexes_with_bonus_tokens to indicate the lack of bonus token in
|
||||
# those indices.
|
||||
accepted_token_ids[mask, -1:] = -1
|
||||
worker = SpecDecodeWorker(draft_worker,
|
||||
target_worker,
|
||||
mock_spec_decode_sampler("rejection_sampler"),
|
||||
metrics_collector=metrics_collector)
|
||||
# Initialize _seq_with_bonus_token_in_last_step with a set of sequence IDs.
|
||||
# This set includes all sequence IDs in the batch as well as an additional
|
||||
# `num_extra_sequence_ids` sequence IDs. Note that the sequence IDs are in
|
||||
# the range [0, batch_size + num_extra_sequence_ids).
|
||||
num_extra_sequence_ids = 10
|
||||
worker._seq_with_bonus_token_in_last_step = set(
|
||||
range(batch_size + num_extra_sequence_ids))
|
||||
worker._create_output_sampler_list(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
accepted_token_ids=accepted_token_ids,
|
||||
target_logprobs=target_token_logprobs,
|
||||
k=k)
|
||||
# Verify that _seq_with_bonus_token_in_last_step contains the following:
|
||||
# 1. Sequence IDs that were already present in
|
||||
# _seq_with_bonus_token_in_last_step but were not part of the current
|
||||
# batch are retained.
|
||||
# 2. Of the sequence IDs present in the current batch, only those with a
|
||||
# bonus token are retained in _seq_with_bonus_token_in_last_step.
|
||||
# Sequence IDs that are present in the current batch but do not have
|
||||
# bonus tokens are removed from _seq_with_bonus_token_in_last_step.
|
||||
expected_seq_ids_with_bonus_tokens = \
|
||||
set([assigned_seq_ids[i] for i in seq_indexes_with_bonus_tokens])
|
||||
additional_sequence_ids = \
|
||||
set(range(batch_size, batch_size + num_extra_sequence_ids))
|
||||
assert worker._seq_with_bonus_token_in_last_step == \
|
||||
expected_seq_ids_with_bonus_tokens.union(additional_sequence_ids)
|
||||
assert worker._request_id_seq_id_mapping == \
|
||||
expected_request_id_seq_ids_mapping
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_handle_finished_requests():
|
||||
"""
|
||||
Test to verify that finished request IDs are appropriately processed to
|
||||
update the internal state of the SpecDecodeWorker.
|
||||
|
||||
This test initializes the SpecDecodeWorker with mock data, marks certain
|
||||
requests as finished, and ensures that the corresponding sequence IDs are
|
||||
correctly removed from the internal mappings.
|
||||
"""
|
||||
batch_size = 32
|
||||
k = 3
|
||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||
target_worker = mock_worker()
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
worker = SpecDecodeWorker(draft_worker, target_worker,
|
||||
mock_spec_decode_sampler("rejection_sampler"),
|
||||
metrics_collector)
|
||||
# Initialize the request_id_seq_id_mapping mapping dict with a few fake
|
||||
# request ids and corresponding sequence ids.
|
||||
worker._request_id_seq_id_mapping = \
|
||||
{'request-1': {1,2,3}, 'request-2': {4,5,6,7},
|
||||
'request-3': {8,9}, 'request-4': {10,11}}
|
||||
# Initialize seq_with_bonus_token_in_last_step with a few fake
|
||||
# sequence ids.
|
||||
worker._seq_with_bonus_token_in_last_step = {1, 4, 5, 8, 9, 10}
|
||||
exception_secret = 'artificial stop'
|
||||
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
|
||||
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
|
||||
# Mark requests with ids request-1 and request-3 as finished.
|
||||
execute_model_req = ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k,
|
||||
finished_requests_ids=['request-1', 'request-3'])
|
||||
|
||||
with pytest.raises(ValueError, match=exception_secret):
|
||||
worker.execute_model(execute_model_req=execute_model_req)
|
||||
# Verify that request-1 and request-3 are removed from
|
||||
# request_id_seq_id_mapping
|
||||
assert worker._request_id_seq_id_mapping == \
|
||||
{'request-2': {4,5,6,7}, 'request-4': {10,11}}
|
||||
# Verify that all sequence ids corresponding to 'request-1'
|
||||
# and 'request-3' are removed from seq_with_bonus_token_in_last_step.
|
||||
assert worker._seq_with_bonus_token_in_last_step == \
|
||||
{4,5,10}
|
||||
|
@ -3,8 +3,9 @@ import copy
|
||||
import enum
|
||||
import math
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -916,6 +917,21 @@ def get_all_seq_ids(
|
||||
return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data]
|
||||
|
||||
|
||||
def get_all_seq_ids_and_request_ids(
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata]
|
||||
) -> Tuple[List[int], Dict[str, Set[int]]]:
|
||||
"""Given a list of SequenceGroupMetadata, create a list of all
|
||||
sequence ids.
|
||||
"""
|
||||
seq_ids: List[int] = []
|
||||
request_id_seq_ids_mapping: Dict[str, Set[int]] = defaultdict(set)
|
||||
for sg in seq_group_metadata_list:
|
||||
for seq_id in sg.seq_data:
|
||||
seq_ids.append(seq_id)
|
||||
request_id_seq_ids_mapping[sg.request_id].add(seq_id)
|
||||
return seq_ids, request_id_seq_ids_mapping
|
||||
|
||||
|
||||
class HiddenStates:
|
||||
"""Hidden states corresponding to in-progress sequences.
|
||||
Used in speculative decoding to pass hidden states from
|
||||
|
@ -1,6 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from typing import Optional, Set
|
||||
|
||||
import torch
|
||||
|
||||
@ -62,6 +62,9 @@ class SpeculativeProposer(ABC):
|
||||
def get_spec_proposals(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
# If set, this contains all sequence IDs that were assigned
|
||||
# bonus tokens in their last forward pass.
|
||||
seq_ids_with_bonus_token_in_last_step: Set[int],
|
||||
) -> SpeculativeProposals:
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
import weakref
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@ -40,6 +40,8 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
sample_len: int,
|
||||
# Unused parameter.
|
||||
seq_ids_with_bonus_token_in_last_step: Set[int],
|
||||
) -> Tuple[List[SamplerOutput], bool]:
|
||||
"""Run the model forward pass to generate sample_len future tokens.
|
||||
Returns the list of sampler output, one per layer, along with indicator
|
||||
@ -97,12 +99,14 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
|
||||
def get_spec_proposals(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
seq_ids_with_bonus_token_in_last_step: Set[int],
|
||||
) -> SpeculativeProposals:
|
||||
"""Produce speculations given an input batch of sequences. The number of
|
||||
speculative tokens per sequence is determined by max_proposal_len.
|
||||
"""
|
||||
|
||||
return self._proposer.get_spec_proposals(execute_model_req)
|
||||
return self._proposer.get_spec_proposals(
|
||||
execute_model_req, seq_ids_with_bonus_token_in_last_step)
|
||||
|
||||
def _raise_if_unsupported(
|
||||
self,
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@ -20,6 +20,9 @@ class MLPSpeculatorWorker(NonLLMProposerWorkerBase, MultiStepWorker):
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
sample_len: int,
|
||||
# Unused parameter. MLPSpeculatorWorker does not use the KV Cache and
|
||||
# therefore does not need this parameter.
|
||||
seq_ids_with_bonus_token_in_last_step: Set[int],
|
||||
) -> Tuple[List[SamplerOutput], bool]:
|
||||
"""Run the model forward pass to generate sample_len future tokens.
|
||||
Returns the list of sampler output, one per layer, along with indicator
|
||||
|
@ -1,6 +1,6 @@
|
||||
import copy
|
||||
import weakref
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Dict, List, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@ -51,6 +51,7 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
sample_len: int,
|
||||
seq_ids_with_bonus_token_in_last_step: Set[int],
|
||||
) -> Tuple[List[SamplerOutput], bool]:
|
||||
"""Run the model forward pass sample_len times. Returns the list of
|
||||
sampler output, one per model forward pass, along with indicator of
|
||||
@ -60,44 +61,142 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
|
||||
For multi step worker, this indicator shall be True.
|
||||
"""
|
||||
self._raise_if_unsupported(execute_model_req)
|
||||
|
||||
# Shallow copy input data so modifications (such as appending tokens)
|
||||
# do not cause side-effects.
|
||||
copied_seq_group_metadata_list = self._shallow_copy_inputs(
|
||||
execute_model_req.seq_group_metadata_list)
|
||||
copied_execute_model_req = execute_model_req.clone(
|
||||
copied_seq_group_metadata_list)
|
||||
|
||||
# Expand the batch for sequences with a bonus token.
|
||||
# Perform a forward pass on the expanded batch and filter the
|
||||
# response to retain only the original sequences' responses.
|
||||
expanded_request, indices_of_seq_with_bonus_tokens =\
|
||||
self._expand_execute_model_request(
|
||||
execute_model_req, seq_ids_with_bonus_token_in_last_step)
|
||||
# Run model sample_len times.
|
||||
model_outputs: List[SamplerOutput] = []
|
||||
if isinstance(self.model_runner, TP1DraftModelRunner):
|
||||
copied_execute_model_req.num_steps = sample_len
|
||||
expanded_request.num_steps = sample_len
|
||||
model_outputs = self.execute_model(
|
||||
execute_model_req=copied_execute_model_req)
|
||||
execute_model_req=expanded_request)
|
||||
else:
|
||||
# TODO: Remove this branch once DraftModelRunner supports TP>1.
|
||||
for _ in range(sample_len):
|
||||
model_output: List[SamplerOutput] = super().execute_model(
|
||||
execute_model_req=copied_execute_model_req)
|
||||
execute_model_req=expanded_request)
|
||||
assert (len(model_output) == 1
|
||||
), "composing multistep workers not supported"
|
||||
model_output = model_output[0]
|
||||
|
||||
self._append_new_tokens(model_output,
|
||||
copied_seq_group_metadata_list)
|
||||
self._append_new_tokens(
|
||||
model_output, expanded_request.seq_group_metadata_list)
|
||||
model_outputs.append(model_output)
|
||||
|
||||
return model_outputs, True
|
||||
filtered_model_outputs = self._filter_model_output(
|
||||
model_outputs, indices_of_seq_with_bonus_tokens)
|
||||
return filtered_model_outputs, True
|
||||
|
||||
@staticmethod
|
||||
def _expand_execute_model_request(
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
seq_with_bonus_token_in_last_step: set,
|
||||
) -> Tuple[ExecuteModelRequest, List[int]]:
|
||||
"""
|
||||
Expands the execute model request based on sequences with bonus
|
||||
tokens.
|
||||
|
||||
For each sequence with a bonus token, this method creates a new
|
||||
sequence without the bonus token and adds it to the execute model
|
||||
request. The original sequence groups are also retained. The indices
|
||||
of the original sequence groups are returned for further processing.
|
||||
|
||||
Args:
|
||||
execute_model_req (ExecuteModelRequest): The original execute
|
||||
model request.
|
||||
seq_with_bonus_token_in_last_step (set): Set of sequence IDs that
|
||||
contain bonus tokens.
|
||||
|
||||
Returns:
|
||||
Tuple[ExecuteModelRequest, List[int]]: The updated execute model
|
||||
request with expanded sequences and a list of indices corresponding
|
||||
to the original sequence groups.
|
||||
"""
|
||||
updated_seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||
updated_execute_model_req = execute_model_req.clone(
|
||||
updated_seq_group_metadata_list)
|
||||
indices_of_original_sequence_groups = []
|
||||
for seq_group in execute_model_req.seq_group_metadata_list:
|
||||
seq_group_has_bonus_tokens = False
|
||||
for seq_id, _ in seq_group.seq_data.items():
|
||||
# Identify sequences with bonus tokens in the sequence group.
|
||||
if seq_id in seq_with_bonus_token_in_last_step:
|
||||
seq_group_has_bonus_tokens = True
|
||||
break
|
||||
if seq_group_has_bonus_tokens:
|
||||
#Create new sequences without the last bonus token. These new
|
||||
# sequence have the same sequence id as the original sequence.
|
||||
# We create a new sequence group and add them there.
|
||||
updated_seq_group_without_bonus_token = \
|
||||
MultiStepWorker._copy_seq_metadata_excluding_last_token(
|
||||
seq_group, seq_with_bonus_token_in_last_step)
|
||||
updated_seq_group_metadata_list.append(
|
||||
updated_seq_group_without_bonus_token)
|
||||
# Add the original sequence group.
|
||||
updated_seq_group_metadata_list.append(
|
||||
MultiStepWorker._shallow_copy_seq_group_metadata(seq_group))
|
||||
# Record the index of the original sequence group.
|
||||
indices_of_original_sequence_groups.append(
|
||||
len(updated_seq_group_metadata_list) - 1)
|
||||
|
||||
updated_execute_model_req.seq_group_metadata_list =\
|
||||
updated_seq_group_metadata_list
|
||||
return updated_execute_model_req, indices_of_original_sequence_groups
|
||||
|
||||
@staticmethod
|
||||
def _filter_model_output(
|
||||
expanded_batch_outputs: List[SamplerOutput],
|
||||
output_indices_to_retain: List[int]) -> List[SamplerOutput]:
|
||||
"""
|
||||
Filters the model output to include only the specified sequence
|
||||
outputs. This method contracts the expanded batch output from the
|
||||
model to retain the outputs of only those sequences indicated by the
|
||||
provided indices.
|
||||
|
||||
Args:
|
||||
expanded_batch_output (List[SamplerOutput]): The expanded output
|
||||
batch from the model.
|
||||
output_indices_to_retain (List[int]): Indices of the model outputs
|
||||
to retain.
|
||||
|
||||
Returns:
|
||||
List[SamplerOutput]: A list containing the filtered model
|
||||
outputs for the specified indices.
|
||||
"""
|
||||
return [
|
||||
SamplerOutput(
|
||||
outputs=[
|
||||
expanded_batch_output.outputs[i]
|
||||
for i in output_indices_to_retain
|
||||
],
|
||||
sampled_token_probs=(
|
||||
expanded_batch_output.
|
||||
sampled_token_probs[output_indices_to_retain]
|
||||
if expanded_batch_output.sampled_token_probs is not None
|
||||
else None),
|
||||
logprobs=(
|
||||
expanded_batch_output.logprobs[output_indices_to_retain]
|
||||
if expanded_batch_output.logprobs is not None else None),
|
||||
sampled_token_ids=(expanded_batch_output.
|
||||
sampled_token_ids[output_indices_to_retain]
|
||||
if expanded_batch_output.sampled_token_ids
|
||||
is not None else None))
|
||||
for expanded_batch_output in expanded_batch_outputs
|
||||
]
|
||||
|
||||
def get_spec_proposals(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
seq_ids_with_bonus_token_in_last_step: set,
|
||||
) -> SpeculativeProposals:
|
||||
"""Produce speculations given an input batch of sequences. The number of
|
||||
speculative tokens per sequence is determined by max_proposal_len.
|
||||
"""
|
||||
|
||||
return self._proposer.get_spec_proposals(execute_model_req)
|
||||
return self._proposer.get_spec_proposals(
|
||||
execute_model_req, seq_ids_with_bonus_token_in_last_step)
|
||||
|
||||
@staticmethod
|
||||
def _append_new_tokens(
|
||||
@ -123,9 +222,8 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
|
||||
seq.update_num_computed_tokens(1)
|
||||
|
||||
@staticmethod
|
||||
def _shallow_copy_inputs(
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata]
|
||||
) -> List[SequenceGroupMetadata]:
|
||||
def _shallow_copy_seq_group_metadata(
|
||||
seq_group_metadata: SequenceGroupMetadata, ) -> SequenceGroupMetadata:
|
||||
"""Copy input data structures to remove side-effects when input data
|
||||
structures are shared with other modules.
|
||||
|
||||
@ -133,26 +231,62 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
|
||||
The alternative is deep-copying (or other form of deep copy); this has
|
||||
performance downsides.
|
||||
"""
|
||||
|
||||
# Shallow-copy the list of SequenceGroupMetadata. This allows us to
|
||||
# Shallow-copy the SequenceGroupMetadata. This allows us to
|
||||
# append tokens and change is_prompt without external side-effects.
|
||||
new_seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||
|
||||
for old_seq_group_metadata in seq_group_metadata_list:
|
||||
# We must shallow-copy seq_group_metadata as is_prompt could change.
|
||||
seq_group_metadata = copy.copy(old_seq_group_metadata)
|
||||
new_seq_group_metadata_list.append(seq_group_metadata)
|
||||
new_seq_group_metadata = copy.copy(seq_group_metadata)
|
||||
|
||||
# We must shallow-copy seq_data as we will append token ids
|
||||
new_seq_data: Dict[int, SequenceData] = {}
|
||||
for seq_id, old_seq_data in seq_group_metadata.seq_data.items():
|
||||
new_seq_data[seq_id] = copy.copy(old_seq_data)
|
||||
new_seq_data[
|
||||
seq_id].output_token_ids = old_seq_data.output_token_ids[:]
|
||||
new_seq_data[seq_id].output_token_ids =\
|
||||
old_seq_data.output_token_ids[:]
|
||||
|
||||
seq_group_metadata.seq_data = new_seq_data
|
||||
new_seq_group_metadata.seq_data = new_seq_data
|
||||
return new_seq_group_metadata
|
||||
|
||||
return new_seq_group_metadata_list
|
||||
@staticmethod
|
||||
def _copy_seq_metadata_excluding_last_token(
|
||||
seq_group_metadata: SequenceGroupMetadata,
|
||||
seq_ids_to_copy: Set[int],
|
||||
) -> SequenceGroupMetadata:
|
||||
"""
|
||||
Creates a shallow copy of the given SequenceGroupMetadata, retaining
|
||||
only the sequence IDs specified in seq_ids_to_copy. For each of these
|
||||
sequence IDs, all output_token_ids except the last one are copied.
|
||||
Sequence IDs not in seq_ids_to_copy are excluded from the copy.
|
||||
|
||||
Parameters:
|
||||
seq_group_metadata (SequenceGroupMetadata): The original sequence
|
||||
group metadata.
|
||||
seq_ids_to_copy (Set[int]): The set of sequence IDs to include in the
|
||||
copy.
|
||||
|
||||
Returns:
|
||||
SequenceGroupMetadata: A shallow copy of the sequence group metadata
|
||||
with the specified modifications.
|
||||
"""
|
||||
# Shallow-copy the SequenceGroupMetadata.
|
||||
new_seq_group_metadata = copy.copy(seq_group_metadata)
|
||||
# Shallow-copy seq_data and modify the output_token_ids.
|
||||
new_seq_data: Dict[int, SequenceData] = {}
|
||||
for seq_id, old_seq_data in seq_group_metadata.seq_data.items():
|
||||
if (seq_id in seq_ids_to_copy):
|
||||
new_seq_data[seq_id] = copy.copy(old_seq_data)
|
||||
# Copy all the output token ids except the last.
|
||||
# Also reduce num_computed_tokens by 1 since we are not
|
||||
# including the last output token.
|
||||
# NOTE: num_computed_tokens is not directly used by the
|
||||
# speculative decoding workers, as it is only relevant for
|
||||
# chunked prefill, which is disabled for speculative decoding.
|
||||
# However, to maintain consistency in num_computed_tokens,
|
||||
# we update it here.
|
||||
new_seq_data[seq_id].output_token_ids =\
|
||||
old_seq_data.output_token_ids[:-1]
|
||||
new_seq_data[seq_id].update_num_computed_tokens(-1)
|
||||
new_seq_group_metadata.seq_data = new_seq_data
|
||||
return new_seq_group_metadata
|
||||
|
||||
def _assert_enough_kv_space(
|
||||
self, seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
|
@ -1,5 +1,5 @@
|
||||
import weakref
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@ -48,6 +48,9 @@ class NGramWorker(NonLLMProposerWorkerBase, LoraNotSupportedWorkerBase):
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
sample_len: int,
|
||||
# Unused parameter. NGramWorker does not use the KV Cache and
|
||||
# therefore does not need this parameter.
|
||||
seq_ids_with_bonus_token_in_last_step: Set[int],
|
||||
) -> Tuple[Optional[List[Optional[SamplerOutput]]], bool]:
|
||||
"""NGram match algo to pick proposal candidate. Returns the list of
|
||||
sampler output, one per SequenceGroupMetadata.
|
||||
@ -133,12 +136,15 @@ class NGramWorker(NonLLMProposerWorkerBase, LoraNotSupportedWorkerBase):
|
||||
def get_spec_proposals(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
# Unused parameter. NGramWorker does not use the KV Cache and
|
||||
# therefore does not need this parameter.
|
||||
seq_ids_with_bonus_token_in_last_step: Set[int],
|
||||
) -> SpeculativeProposals:
|
||||
"""Produce speculations given an input batch of sequences. The number of
|
||||
speculative tokens per sequence is determined by max_proposal_len.
|
||||
"""
|
||||
|
||||
return self._proposer.get_spec_proposals(execute_model_req)
|
||||
return self._proposer.get_spec_proposals(
|
||||
execute_model_req, seq_ids_with_bonus_token_in_last_step)
|
||||
|
||||
def _raise_if_unsupported(
|
||||
self,
|
||||
|
@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||
from vllm.spec_decode.interfaces import SpeculativeProposer
|
||||
@ -14,6 +14,13 @@ class ProposerWorkerBase(LoraNotSupportedWorkerBase, SpeculativeProposer):
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
sample_len: int,
|
||||
# A set containing all sequence IDs that were assigned bonus tokens
|
||||
# in their last forward pass. This set is used to backfill the KV cache
|
||||
# with the key-value pairs of the penultimate token in the sequences.
|
||||
# This parameter is only used by the MultiStepWorker, which relies on
|
||||
# the KV cache for token generation. It is not used by workers that
|
||||
# do not utilize the KV cache.
|
||||
seq_ids_with_bonus_token_in_last_step: Set[int]
|
||||
) -> Tuple[Optional[List[SamplerOutput]], bool]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@ -110,13 +110,17 @@ class SmallerTpProposerWorker(ProposerWorkerBase):
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
sample_len: int,
|
||||
seq_ids_with_bonus_token_in_last_step: Set[int],
|
||||
) -> Tuple[List[SamplerOutput], bool]:
|
||||
# Do not check _is_dummy, as it's always called by get_spec_proposals
|
||||
return self._worker.sampler_output(execute_model_req, sample_len)
|
||||
return self._worker.sampler_output(
|
||||
execute_model_req, sample_len,
|
||||
seq_ids_with_bonus_token_in_last_step)
|
||||
|
||||
def get_spec_proposals(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
seq_ids_with_bonus_token_in_last_step: Set[int],
|
||||
) -> SpeculativeProposals:
|
||||
"""Produce speculations given an input batch of sequences. The number of
|
||||
speculative tokens per sequence is determined by max_proposal_len.
|
||||
@ -125,7 +129,8 @@ class SmallerTpProposerWorker(ProposerWorkerBase):
|
||||
return SpeculativeProposals(None, None, None)
|
||||
|
||||
with self._patch_tensor_parallel_group():
|
||||
return self._worker.get_spec_proposals(execute_model_req)
|
||||
return self._worker.get_spec_proposals(
|
||||
execute_model_req, seq_ids_with_bonus_token_in_last_step)
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
|
@ -1,5 +1,6 @@
|
||||
from collections import defaultdict
|
||||
from functools import cached_property
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@ -13,7 +14,7 @@ from vllm.model_executor.layers.typical_acceptance_sampler import (
|
||||
TypicalAcceptanceSampler)
|
||||
from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest,
|
||||
HiddenStates, SamplerOutput, SequenceGroupMetadata,
|
||||
get_all_seq_ids)
|
||||
get_all_seq_ids_and_request_ids)
|
||||
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
|
||||
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
@ -112,11 +113,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
|
||||
ngram_prompt_lookup_min = (
|
||||
draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
|
||||
|
||||
disable_bonus_tokens = True
|
||||
|
||||
if ngram_prompt_lookup_max > 0:
|
||||
disable_bonus_tokens = False
|
||||
proposer_worker = NGramWorker(**draft_worker_kwargs)
|
||||
proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
|
||||
ngram_prompt_lookup_max)
|
||||
@ -128,11 +125,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
|
||||
if draft_worker_kwargs[
|
||||
"model_config"].hf_config.model_type == "mlp_speculator":
|
||||
disable_bonus_tokens = False
|
||||
proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs)
|
||||
elif draft_worker_kwargs[
|
||||
"model_config"].hf_config.model_type == "medusa":
|
||||
disable_bonus_tokens = False
|
||||
proposer_worker = MedusaWorker(**draft_worker_kwargs)
|
||||
else:
|
||||
if draft_tp == 1:
|
||||
@ -149,10 +144,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
spec_decode_sampler: SpecDecodeBaseSampler = None
|
||||
if draft_token_acceptance_method == "rejection_sampler":
|
||||
spec_decode_sampler = RejectionSampler(
|
||||
disable_bonus_tokens=disable_bonus_tokens, )
|
||||
disable_bonus_tokens=False, )
|
||||
elif draft_token_acceptance_method == "typical_acceptance_sampler":
|
||||
spec_decode_sampler = TypicalAcceptanceSampler(
|
||||
disable_bonus_tokens=disable_bonus_tokens,
|
||||
disable_bonus_tokens=False,
|
||||
posterior_threshold=\
|
||||
typical_acceptance_sampler_posterior_threshold,
|
||||
posterior_alpha=typical_acceptance_sampler_posterior_alpha,
|
||||
@ -200,6 +195,15 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
self._metrics = AsyncMetricsCollector(
|
||||
self.spec_decode_sampler
|
||||
) if metrics_collector is None else metrics_collector
|
||||
# Tracks the sequence IDs that received a bonus token ID in
|
||||
# their last forward pass. Needed only if KV cache is being
|
||||
# used for token generation such as in the case of MultiStepWorker.
|
||||
self._seq_with_bonus_token_in_last_step: Set[int] = set()
|
||||
# Tracks the currently active request ids and the sequence IDs
|
||||
# corresponding to them
|
||||
self._request_id_seq_id_mapping: Dict[str, Set[int]] = defaultdict(set)
|
||||
# Tracks if the proposer worker uses the KV cache or not.
|
||||
|
||||
self.probs_dtype = self.spec_decode_sampler.probs_dtype
|
||||
self.token_id_dtype = self.spec_decode_sampler.token_id_dtype
|
||||
# Lazy initiazliation.
|
||||
@ -307,6 +311,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
broadcast_tensor_dict({}, src=0)
|
||||
return []
|
||||
|
||||
self._track_finished_requests(execute_model_req)
|
||||
disable_all_speculation = self._should_disable_all_speculation(
|
||||
execute_model_req)
|
||||
num_lookahead_slots = execute_model_req.num_lookahead_slots
|
||||
@ -453,7 +458,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
self.previous_hidden_states = None
|
||||
|
||||
# Generate proposals using draft worker.
|
||||
proposals = self.proposer_worker.get_spec_proposals(execute_model_req)
|
||||
proposals = self.proposer_worker.get_spec_proposals(
|
||||
execute_model_req, self._seq_with_bonus_token_in_last_step)
|
||||
|
||||
proposal_scores = self.scorer.score_proposals(
|
||||
execute_model_req,
|
||||
@ -585,7 +591,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
|
||||
# Get the sequence ids and num_logprobs (sampling parameter) in the
|
||||
# batch.
|
||||
seq_ids = get_all_seq_ids(seq_group_metadata_list)
|
||||
seq_ids, request_ids_seq_ids_mapping = get_all_seq_ids_and_request_ids(
|
||||
seq_group_metadata_list)
|
||||
|
||||
num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list)
|
||||
|
||||
# Serialize all tensors to CPU Python lists.
|
||||
@ -608,7 +616,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
for sequence_index in range(batch_size):
|
||||
# Each sequence may have a different num_logprobs; retrieve it.
|
||||
num_logprobs = num_logprobs_per_seq[sequence_index]
|
||||
|
||||
step_output_token_ids.append(
|
||||
create_sequence_group_output(
|
||||
token_id=accepted_token_ids_by_step[step_index]
|
||||
@ -623,18 +630,48 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
topk_logprobs=topk_logprobs_by_step[step_index]
|
||||
[sequence_index][:num_logprobs],
|
||||
))
|
||||
|
||||
sampler_output_list.append(
|
||||
SamplerOutput(outputs=step_output_token_ids))
|
||||
|
||||
# Populate the data structures needed to keep track of sequences with
|
||||
# bonus tokens.
|
||||
self._track_sequences_with_bonus_tokens(seq_ids,
|
||||
request_ids_seq_ids_mapping,
|
||||
accepted_token_ids_by_step)
|
||||
maybe_rejsample_metrics = (
|
||||
self._metrics.maybe_collect_rejsample_metrics(k))
|
||||
if maybe_rejsample_metrics is not None:
|
||||
sampler_output_list[
|
||||
0].spec_decode_worker_metrics = maybe_rejsample_metrics
|
||||
|
||||
return sampler_output_list
|
||||
|
||||
def _track_finished_requests(self, execute_model_req: ExecuteModelRequest):
|
||||
"""
|
||||
Removes the finished requests and their associated sequence ids from
|
||||
internal book keeping data structures.
|
||||
"""
|
||||
for finished_request in execute_model_req.finished_requests_ids:
|
||||
for seq_id in self._request_id_seq_id_mapping[finished_request]:
|
||||
self._seq_with_bonus_token_in_last_step.discard(seq_id)
|
||||
del self._request_id_seq_id_mapping[finished_request]
|
||||
|
||||
def _track_sequences_with_bonus_tokens(
|
||||
self, seq_ids: List[int],
|
||||
request_ids_seq_ids_mapping: Dict[str, Set[int]],
|
||||
accepted_token_ids_by_step: List[List[int]]):
|
||||
"""
|
||||
Updates the internal data structures which keep track of sequences
|
||||
which have been assigned bonus tokens in their last forward pass.
|
||||
"""
|
||||
for seq_index, seq_id in enumerate(seq_ids):
|
||||
last_token_id = accepted_token_ids_by_step[-1][seq_index]
|
||||
if last_token_id == -1:
|
||||
self._seq_with_bonus_token_in_last_step.discard(seq_id)
|
||||
else:
|
||||
self._seq_with_bonus_token_in_last_step.add(seq_id)
|
||||
for request_id, sequences in request_ids_seq_ids_mapping.items():
|
||||
self._request_id_seq_id_mapping[request_id].update(sequences)
|
||||
|
||||
@cached_property
|
||||
def _vocab_size(self) -> int:
|
||||
"""Get the vocab size of the model and make sure it's consistent between
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@ -42,6 +42,7 @@ class Top1Proposer(SpeculativeProposer):
|
||||
def get_spec_proposals(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
seq_ids_with_bonus_token_in_last_step: Set[int],
|
||||
) -> SpeculativeProposals:
|
||||
"""Get speculative proposals given the input batch.
|
||||
|
||||
@ -76,6 +77,8 @@ class Top1Proposer(SpeculativeProposer):
|
||||
maybe_sampler_output, transposed = self._worker.sampler_output(
|
||||
execute_model_req=nonzero_execute_model_req,
|
||||
sample_len=proposal_len,
|
||||
seq_ids_with_bonus_token_in_last_step=\
|
||||
seq_ids_with_bonus_token_in_last_step,
|
||||
)
|
||||
(
|
||||
proposal_lens,
|
||||
|
Loading…
x
Reference in New Issue
Block a user