[Speculative Decoding] Enabling bonus token in speculative decoding for KV cache based models (#5765)
This commit is contained in:
parent
44cc76610d
commit
ae151d73be
@ -70,14 +70,17 @@ def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int,
|
|||||||
if queue_size < disable_by_batch_size:
|
if queue_size < disable_by_batch_size:
|
||||||
# Should raise exception when executing the mocked draft model.
|
# Should raise exception when executing the mocked draft model.
|
||||||
with pytest.raises(ValueError, match=exception_secret):
|
with pytest.raises(ValueError, match=exception_secret):
|
||||||
proposer.get_spec_proposals(execute_model_req=ExecuteModelRequest(
|
proposer.get_spec_proposals(
|
||||||
seq_group_metadata_list=seq_group_metadata_list,
|
execute_model_req=ExecuteModelRequest(
|
||||||
num_lookahead_slots=k), )
|
seq_group_metadata_list=seq_group_metadata_list,
|
||||||
|
num_lookahead_slots=k),
|
||||||
|
seq_ids_with_bonus_token_in_last_step=set())
|
||||||
else:
|
else:
|
||||||
# Should not execute the draft model because spec decode is disabled
|
# Should not execute the draft model because spec decode is disabled
|
||||||
# for all requests. Accordingly, the proposal length should be 0.
|
# for all requests. Accordingly, the proposal length should be 0.
|
||||||
proposals = proposer.get_spec_proposals(
|
proposals = proposer.get_spec_proposals(
|
||||||
execute_model_req=ExecuteModelRequest(
|
execute_model_req=ExecuteModelRequest(
|
||||||
seq_group_metadata_list=seq_group_metadata_list,
|
seq_group_metadata_list=seq_group_metadata_list,
|
||||||
num_lookahead_slots=k), )
|
num_lookahead_slots=k),
|
||||||
|
seq_ids_with_bonus_token_in_last_step=set())
|
||||||
assert proposals.proposal_lens.tolist() == [0] * batch_size
|
assert proposals.proposal_lens.tolist() == [0] * batch_size
|
||||||
|
@ -118,7 +118,8 @@ def test_same_output_for_single_step():
|
|||||||
actual_output, _ = multi_step_worker.sampler_output(
|
actual_output, _ = multi_step_worker.sampler_output(
|
||||||
execute_model_req=ExecuteModelRequest(
|
execute_model_req=ExecuteModelRequest(
|
||||||
seq_group_metadata_list=multi_step_seq_group),
|
seq_group_metadata_list=multi_step_seq_group),
|
||||||
sample_len=num_steps)
|
sample_len=num_steps,
|
||||||
|
seq_ids_with_bonus_token_in_last_step=set())
|
||||||
assert len(actual_output) == num_steps
|
assert len(actual_output) == num_steps
|
||||||
actual_output = actual_output[0]
|
actual_output = actual_output[0]
|
||||||
|
|
||||||
@ -210,7 +211,8 @@ def test_same_output_for_multi_step():
|
|||||||
multi_step_output, _ = multi_step_worker.sampler_output(
|
multi_step_output, _ = multi_step_worker.sampler_output(
|
||||||
execute_model_req=ExecuteModelRequest(
|
execute_model_req=ExecuteModelRequest(
|
||||||
seq_group_metadata_list=seq_group_metadata_list),
|
seq_group_metadata_list=seq_group_metadata_list),
|
||||||
sample_len=num_steps)
|
sample_len=num_steps,
|
||||||
|
seq_ids_with_bonus_token_in_last_step=set())
|
||||||
|
|
||||||
# Run single-step repeatedly.
|
# Run single-step repeatedly.
|
||||||
zero_kv_cache(worker.cache_engine)
|
zero_kv_cache(worker.cache_engine)
|
||||||
@ -277,6 +279,203 @@ def test_same_output_for_multi_step():
|
|||||||
single_step_logprobs)
|
single_step_logprobs)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_multi_step_with_batch_expansion_correct_output():
|
||||||
|
"""
|
||||||
|
In this test we verify that the MultiStepWorker is able to handle bonus
|
||||||
|
tokens correctly. The test verifies that if a sequence has a
|
||||||
|
bonus token then the MultiStepWorker is able to expand the batch by adding
|
||||||
|
new sequences corresponding to the sequences with bonus tokens. The
|
||||||
|
expanded batch is then used for predicting the next tokens.
|
||||||
|
"""
|
||||||
|
seed = 100
|
||||||
|
model_name = 'JackFram/llama-68m'
|
||||||
|
|
||||||
|
block_size = 16
|
||||||
|
num_gpu_blocks = 2048 // block_size
|
||||||
|
batch_size = 128
|
||||||
|
multi_step_worker = create_worker(
|
||||||
|
MultiStepWorker,
|
||||||
|
model_name,
|
||||||
|
block_size,
|
||||||
|
num_gpu_blocks,
|
||||||
|
seed,
|
||||||
|
model_runner_cls=TP1DraftModelRunner,
|
||||||
|
)
|
||||||
|
worker = create_worker(
|
||||||
|
Worker,
|
||||||
|
model_name,
|
||||||
|
block_size,
|
||||||
|
num_gpu_blocks,
|
||||||
|
seed,
|
||||||
|
)
|
||||||
|
random.seed(seed)
|
||||||
|
prompts = [[0] for _ in range(batch_size)]
|
||||||
|
num_steps = 2
|
||||||
|
final_prompt_lens = [(num_steps + 1) for prompt in prompts]
|
||||||
|
rand_seeds = list(random.randint(0, 100) for _ in range(num_steps))
|
||||||
|
multi_step_worker.execute_model = patch_execute_model_with_seeds(
|
||||||
|
multi_step_worker, rand_seeds)
|
||||||
|
worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds)
|
||||||
|
# Create the test continuations
|
||||||
|
continuations = [[random.randint(0, 1000)] for _ in prompts]
|
||||||
|
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||||
|
prompts,
|
||||||
|
num_gpu_blocks,
|
||||||
|
block_size,
|
||||||
|
continuations=continuations,
|
||||||
|
final_prompt_lens=final_prompt_lens)
|
||||||
|
|
||||||
|
# Run single-step twice to generate 2 tokens. This
|
||||||
|
# will simulate the bonus token case with the second token
|
||||||
|
# being the bonus token.
|
||||||
|
zero_kv_cache(worker.cache_engine)
|
||||||
|
single_step_output: List[SamplerOutput] = []
|
||||||
|
set_random_seed(seed)
|
||||||
|
for _ in range(num_steps):
|
||||||
|
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||||
|
prompts,
|
||||||
|
num_gpu_blocks,
|
||||||
|
block_size,
|
||||||
|
continuations=continuations,
|
||||||
|
final_prompt_lens=final_prompt_lens)
|
||||||
|
single_step_output.extend(
|
||||||
|
worker.execute_model(execute_model_req=ExecuteModelRequest(
|
||||||
|
seq_group_metadata_list=seq_group_metadata_list)))
|
||||||
|
# Append output tokens to new sequence data.
|
||||||
|
for i, seq_group_output in enumerate(single_step_output[-1]):
|
||||||
|
continuations[i].append(seq_group_output.samples[0].output_token)
|
||||||
|
|
||||||
|
# Create continuations for the MultiStepWorker. The continuations have
|
||||||
|
# 2 tokens in order to simulate the bonus token case.
|
||||||
|
multi_step_continuations = []
|
||||||
|
for continuation in continuations:
|
||||||
|
multi_step_continuations.append(continuation[:2])
|
||||||
|
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||||
|
prompts,
|
||||||
|
num_gpu_blocks,
|
||||||
|
block_size,
|
||||||
|
continuations=multi_step_continuations,
|
||||||
|
final_prompt_lens=final_prompt_lens)
|
||||||
|
|
||||||
|
# Run multi-step and verify that the third token prediction is accurate
|
||||||
|
# for all sequences.
|
||||||
|
zero_kv_cache(multi_step_worker.cache_engine)
|
||||||
|
all_seq_ids = {i for i in range(batch_size)}
|
||||||
|
multi_step_output, _ = multi_step_worker.sampler_output(
|
||||||
|
execute_model_req=ExecuteModelRequest(
|
||||||
|
seq_group_metadata_list=seq_group_metadata_list),
|
||||||
|
sample_len=1,
|
||||||
|
seq_ids_with_bonus_token_in_last_step=all_seq_ids)
|
||||||
|
for index, output in enumerate(multi_step_output[-1].outputs):
|
||||||
|
assert (continuations[index][-1] == output.samples[0].output_token)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_multi_step_with_batch_expansion_incorrect_output():
|
||||||
|
"""
|
||||||
|
Tests the MultiStepWorker's ability to handle batch expansion with bonus
|
||||||
|
tokens in a negative case scenario. This test provides the MultiStepWorker
|
||||||
|
with a batch containing sequences with bonus tokens but specifies the
|
||||||
|
sequence IDs with bonus tokens incorrectly. The test verifies that the
|
||||||
|
MultiStepWorker generates correct tokens for the sequences where the
|
||||||
|
sequence ID is specified correctly and incorrect tokens for those where
|
||||||
|
the sequence ID is specified incorrectly.
|
||||||
|
"""
|
||||||
|
seed = 100
|
||||||
|
model_name = 'JackFram/llama-68m'
|
||||||
|
|
||||||
|
block_size = 16
|
||||||
|
num_gpu_blocks = 2048 // block_size
|
||||||
|
batch_size = 128
|
||||||
|
multi_step_worker = create_worker(
|
||||||
|
MultiStepWorker,
|
||||||
|
model_name,
|
||||||
|
block_size,
|
||||||
|
num_gpu_blocks,
|
||||||
|
seed,
|
||||||
|
model_runner_cls=TP1DraftModelRunner,
|
||||||
|
)
|
||||||
|
worker = create_worker(
|
||||||
|
Worker,
|
||||||
|
model_name,
|
||||||
|
block_size,
|
||||||
|
num_gpu_blocks,
|
||||||
|
seed,
|
||||||
|
)
|
||||||
|
random.seed(seed)
|
||||||
|
prompts = [[0] for _ in range(batch_size)]
|
||||||
|
num_steps = 2
|
||||||
|
final_prompt_lens = [(num_steps + 1) for prompt in prompts]
|
||||||
|
rand_seeds = list(random.randint(0, 100) for _ in range(num_steps))
|
||||||
|
multi_step_worker.execute_model = patch_execute_model_with_seeds(
|
||||||
|
multi_step_worker, rand_seeds)
|
||||||
|
worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds)
|
||||||
|
# Create the test continuations
|
||||||
|
continuations = [[random.randint(0, 1000)] for _ in prompts]
|
||||||
|
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||||
|
prompts,
|
||||||
|
num_gpu_blocks,
|
||||||
|
block_size,
|
||||||
|
continuations=continuations,
|
||||||
|
final_prompt_lens=final_prompt_lens)
|
||||||
|
# Run single-step twice to generate 2 tokens. This
|
||||||
|
# will simulate the bonus token case with the second token
|
||||||
|
# being the bonus token.
|
||||||
|
zero_kv_cache(worker.cache_engine)
|
||||||
|
single_step_output: List[SamplerOutput] = []
|
||||||
|
set_random_seed(seed)
|
||||||
|
for _ in range(num_steps):
|
||||||
|
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||||
|
prompts,
|
||||||
|
num_gpu_blocks,
|
||||||
|
block_size,
|
||||||
|
continuations=continuations,
|
||||||
|
final_prompt_lens=final_prompt_lens)
|
||||||
|
single_step_output.extend(
|
||||||
|
worker.execute_model(execute_model_req=ExecuteModelRequest(
|
||||||
|
seq_group_metadata_list=seq_group_metadata_list)))
|
||||||
|
# Append output tokens to new sequence data.
|
||||||
|
for i, seq_group_output in enumerate(single_step_output[-1]):
|
||||||
|
continuations[i].append(seq_group_output.samples[0].output_token)
|
||||||
|
|
||||||
|
# Create continuations for the MultiStepWorker. The continuations have
|
||||||
|
# 2 tokens in order to simulate the bonus token case.
|
||||||
|
multi_step_continuations = []
|
||||||
|
for continuation in continuations:
|
||||||
|
multi_step_continuations.append(continuation[:2])
|
||||||
|
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||||
|
prompts,
|
||||||
|
num_gpu_blocks,
|
||||||
|
block_size,
|
||||||
|
continuations=multi_step_continuations,
|
||||||
|
final_prompt_lens=final_prompt_lens)
|
||||||
|
|
||||||
|
# Run multi-step. In this run INCORRECTLY specify that only the odd number
|
||||||
|
# sequences have bonus tokens. Verify that with this setting the third token
|
||||||
|
# prediction is accurate only for the odd numbered sequences. Also verify
|
||||||
|
# that the prediction might be wrong for some of the even numbered
|
||||||
|
# sequences.
|
||||||
|
zero_kv_cache(multi_step_worker.cache_engine)
|
||||||
|
set_random_seed(seed)
|
||||||
|
odd_seq_ids = {i for i in range(batch_size) if i % 2 != 0}
|
||||||
|
multi_step_output, _ = multi_step_worker.sampler_output(
|
||||||
|
execute_model_req=ExecuteModelRequest(
|
||||||
|
seq_group_metadata_list=seq_group_metadata_list),
|
||||||
|
sample_len=1,
|
||||||
|
seq_ids_with_bonus_token_in_last_step=odd_seq_ids)
|
||||||
|
num_mismatch = 0
|
||||||
|
for index, output in enumerate(multi_step_output[-1].outputs):
|
||||||
|
if (index % 2) != 0:
|
||||||
|
assert (continuations[index][-1] == output.samples[0].output_token)
|
||||||
|
elif (continuations[index][-1] != output.samples[0].output_token):
|
||||||
|
num_mismatch += 1
|
||||||
|
# The prediction is accurate for some of the sequences even without proper
|
||||||
|
# handling of the bonus tokens. Hence verify that the number of sequences
|
||||||
|
# for which there is a mismatch is > 0.
|
||||||
|
assert (num_mismatch > 0)
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_draft_proposals_full_speculation_len():
|
def test_draft_proposals_full_speculation_len():
|
||||||
"""Verify Top1Proposer correctly handles case where all sequences
|
"""Verify Top1Proposer correctly handles case where all sequences
|
||||||
@ -318,7 +517,8 @@ def test_draft_proposals_full_speculation_len():
|
|||||||
proposals = proposer.get_spec_proposals(
|
proposals = proposer.get_spec_proposals(
|
||||||
execute_model_req=ExecuteModelRequest(
|
execute_model_req=ExecuteModelRequest(
|
||||||
seq_group_metadata_list=seq_group_metadata_list,
|
seq_group_metadata_list=seq_group_metadata_list,
|
||||||
num_lookahead_slots=k), )
|
num_lookahead_slots=k),
|
||||||
|
seq_ids_with_bonus_token_in_last_step=set())
|
||||||
|
|
||||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||||
assert torch.is_tensor(proposals.proposal_probs)
|
assert torch.is_tensor(proposals.proposal_probs)
|
||||||
@ -356,7 +556,8 @@ def test_draft_proposals_no_speculations():
|
|||||||
proposals = proposer.get_spec_proposals(
|
proposals = proposer.get_spec_proposals(
|
||||||
execute_model_req=ExecuteModelRequest(
|
execute_model_req=ExecuteModelRequest(
|
||||||
seq_group_metadata_list=seq_group_metadata_list,
|
seq_group_metadata_list=seq_group_metadata_list,
|
||||||
num_lookahead_slots=k), )
|
num_lookahead_slots=k),
|
||||||
|
seq_ids_with_bonus_token_in_last_step=set())
|
||||||
|
|
||||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||||
assert torch.is_tensor(proposals.proposal_probs)
|
assert torch.is_tensor(proposals.proposal_probs)
|
||||||
@ -428,7 +629,8 @@ def test_draft_proposals_mixed_k():
|
|||||||
proposals = proposer.get_spec_proposals(
|
proposals = proposer.get_spec_proposals(
|
||||||
execute_model_req=ExecuteModelRequest(
|
execute_model_req=ExecuteModelRequest(
|
||||||
seq_group_metadata_list=seq_group_metadata_list,
|
seq_group_metadata_list=seq_group_metadata_list,
|
||||||
num_lookahead_slots=k), )
|
num_lookahead_slots=k),
|
||||||
|
seq_ids_with_bonus_token_in_last_step=set())
|
||||||
|
|
||||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||||
assert torch.is_tensor(proposals.proposal_probs)
|
assert torch.is_tensor(proposals.proposal_probs)
|
||||||
|
@ -53,7 +53,8 @@ def test_ngram_algo_correctness_for_single_no_match():
|
|||||||
proposals = proposer.get_spec_proposals(
|
proposals = proposer.get_spec_proposals(
|
||||||
execute_model_req=ExecuteModelRequest(
|
execute_model_req=ExecuteModelRequest(
|
||||||
seq_group_metadata_list=seq_group_metadata_list,
|
seq_group_metadata_list=seq_group_metadata_list,
|
||||||
num_lookahead_slots=proposal_len), )
|
num_lookahead_slots=proposal_len),
|
||||||
|
seq_ids_with_bonus_token_in_last_step=None)
|
||||||
|
|
||||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||||
assert torch.is_tensor(proposals.proposal_probs)
|
assert torch.is_tensor(proposals.proposal_probs)
|
||||||
@ -121,7 +122,8 @@ def test_ngram_algo_correctness_for_batches_not_match_all():
|
|||||||
proposals = proposer.get_spec_proposals(
|
proposals = proposer.get_spec_proposals(
|
||||||
execute_model_req=ExecuteModelRequest(
|
execute_model_req=ExecuteModelRequest(
|
||||||
seq_group_metadata_list=seq_group_metadata_list,
|
seq_group_metadata_list=seq_group_metadata_list,
|
||||||
num_lookahead_slots=proposal_len), )
|
num_lookahead_slots=proposal_len),
|
||||||
|
seq_ids_with_bonus_token_in_last_step=None)
|
||||||
|
|
||||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||||
assert torch.is_tensor(proposals.proposal_probs)
|
assert torch.is_tensor(proposals.proposal_probs)
|
||||||
@ -193,7 +195,8 @@ def test_ngram_algo_correctness_for_batches_match_all():
|
|||||||
proposals = proposer.get_spec_proposals(
|
proposals = proposer.get_spec_proposals(
|
||||||
execute_model_req=ExecuteModelRequest(
|
execute_model_req=ExecuteModelRequest(
|
||||||
seq_group_metadata_list=seq_group_metadata_list,
|
seq_group_metadata_list=seq_group_metadata_list,
|
||||||
num_lookahead_slots=proposal_len), )
|
num_lookahead_slots=proposal_len),
|
||||||
|
seq_ids_with_bonus_token_in_last_step=None)
|
||||||
|
|
||||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||||
assert torch.is_tensor(proposals.proposal_probs)
|
assert torch.is_tensor(proposals.proposal_probs)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import random
|
import random
|
||||||
|
from collections import defaultdict
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from typing import Dict, List
|
from typing import Dict, List, Set
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -377,8 +378,10 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool,
|
|||||||
|
|
||||||
set_random_seed(1)
|
set_random_seed(1)
|
||||||
|
|
||||||
worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler,
|
worker = SpecDecodeWorker(draft_worker,
|
||||||
metrics_collector)
|
target_worker,
|
||||||
|
spec_decode_sampler,
|
||||||
|
metrics_collector=metrics_collector)
|
||||||
worker.init_device()
|
worker.init_device()
|
||||||
|
|
||||||
proposal_token_ids = torch.randint(low=0,
|
proposal_token_ids = torch.randint(low=0,
|
||||||
@ -554,7 +557,6 @@ def test_init_device(acceptance_sampler_method: str):
|
|||||||
|
|
||||||
worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler,
|
worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler,
|
||||||
metrics_collector)
|
metrics_collector)
|
||||||
|
|
||||||
worker.init_device()
|
worker.init_device()
|
||||||
|
|
||||||
draft_worker.init_device.assert_called_once()
|
draft_worker.init_device.assert_called_once()
|
||||||
@ -645,3 +647,140 @@ def test_split_num_cache_blocks_evenly(available_gpu_blocks: int,
|
|||||||
assert (num_blocks * target_cache_block_size_bytes) + (
|
assert (num_blocks * target_cache_block_size_bytes) + (
|
||||||
num_blocks * draft_kv_size_bytes) <= (available_gpu_blocks *
|
num_blocks * draft_kv_size_bytes) <= (available_gpu_blocks *
|
||||||
target_cache_block_size_bytes)
|
target_cache_block_size_bytes)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_populate_seq_ids_with_bonus_tokens():
|
||||||
|
"""
|
||||||
|
Verify that a call to _create_output_sampler_list correctly updates
|
||||||
|
seq_with_bonus_token_in_last_step.
|
||||||
|
|
||||||
|
seq_with_bonus_token_in_last_step is an internal data structure in
|
||||||
|
SpecDecodeWorker that tracks the sequence IDs which are assigned bonus
|
||||||
|
tokens by the target model in their last forward pass. This state is
|
||||||
|
maintained only for models relying on the KV cache, such as those using
|
||||||
|
the MultiStepWorker.
|
||||||
|
"""
|
||||||
|
batch_size = 10
|
||||||
|
k = 5
|
||||||
|
vocab_size = 10000
|
||||||
|
num_sequences_with_bonus_tokens = 5
|
||||||
|
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
|
||||||
|
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||||
|
target_worker.execute_model.return_value = [MagicMock(spec=SamplerOutput)]
|
||||||
|
target_worker.device = 'cuda'
|
||||||
|
|
||||||
|
set_random_seed(1)
|
||||||
|
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||||
|
draft_worker.device = 'cuda'
|
||||||
|
# The sequence_ids attached to each sequence in the batch.
|
||||||
|
# The sequence at index i has seq_id assigned_seq_ids[i]
|
||||||
|
assigned_seq_ids = list(range(batch_size))
|
||||||
|
seq_group_metadata_list, _, _ = create_batch(batch_size,
|
||||||
|
k,
|
||||||
|
seq_ids=assigned_seq_ids,
|
||||||
|
prev_output_token_len=10)
|
||||||
|
target_token_logprobs = torch.rand(batch_size, (k + 1),
|
||||||
|
vocab_size,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device='cuda')
|
||||||
|
accepted_token_ids = torch.randint(low=0,
|
||||||
|
high=vocab_size,
|
||||||
|
size=(batch_size, (k + 1)),
|
||||||
|
dtype=torch.int64,
|
||||||
|
device='cuda')
|
||||||
|
expected_request_id_seq_ids_mapping: Dict[str, Set[int]] = defaultdict(set)
|
||||||
|
for seq_group_metadata in seq_group_metadata_list:
|
||||||
|
for seq_id in seq_group_metadata.seq_data:
|
||||||
|
expected_request_id_seq_ids_mapping[
|
||||||
|
seq_group_metadata.request_id].add(seq_id)
|
||||||
|
# Generate a random sample of sequence indexes with bonus tokens
|
||||||
|
seq_indexes_with_bonus_tokens = random.sample(
|
||||||
|
range(batch_size), num_sequences_with_bonus_tokens)
|
||||||
|
# Create a mask that is True for indices in seq_indexes_with_bonus_tokens
|
||||||
|
mask = torch.ones(batch_size, dtype=torch.bool, device='cuda')
|
||||||
|
mask[seq_indexes_with_bonus_tokens] = False
|
||||||
|
# Set the last token ID to -1 for all indices not in
|
||||||
|
# seq_indexes_with_bonus_tokens to indicate the lack of bonus token in
|
||||||
|
# those indices.
|
||||||
|
accepted_token_ids[mask, -1:] = -1
|
||||||
|
worker = SpecDecodeWorker(draft_worker,
|
||||||
|
target_worker,
|
||||||
|
mock_spec_decode_sampler("rejection_sampler"),
|
||||||
|
metrics_collector=metrics_collector)
|
||||||
|
# Initialize _seq_with_bonus_token_in_last_step with a set of sequence IDs.
|
||||||
|
# This set includes all sequence IDs in the batch as well as an additional
|
||||||
|
# `num_extra_sequence_ids` sequence IDs. Note that the sequence IDs are in
|
||||||
|
# the range [0, batch_size + num_extra_sequence_ids).
|
||||||
|
num_extra_sequence_ids = 10
|
||||||
|
worker._seq_with_bonus_token_in_last_step = set(
|
||||||
|
range(batch_size + num_extra_sequence_ids))
|
||||||
|
worker._create_output_sampler_list(
|
||||||
|
seq_group_metadata_list=seq_group_metadata_list,
|
||||||
|
accepted_token_ids=accepted_token_ids,
|
||||||
|
target_logprobs=target_token_logprobs,
|
||||||
|
k=k)
|
||||||
|
# Verify that _seq_with_bonus_token_in_last_step contains the following:
|
||||||
|
# 1. Sequence IDs that were already present in
|
||||||
|
# _seq_with_bonus_token_in_last_step but were not part of the current
|
||||||
|
# batch are retained.
|
||||||
|
# 2. Of the sequence IDs present in the current batch, only those with a
|
||||||
|
# bonus token are retained in _seq_with_bonus_token_in_last_step.
|
||||||
|
# Sequence IDs that are present in the current batch but do not have
|
||||||
|
# bonus tokens are removed from _seq_with_bonus_token_in_last_step.
|
||||||
|
expected_seq_ids_with_bonus_tokens = \
|
||||||
|
set([assigned_seq_ids[i] for i in seq_indexes_with_bonus_tokens])
|
||||||
|
additional_sequence_ids = \
|
||||||
|
set(range(batch_size, batch_size + num_extra_sequence_ids))
|
||||||
|
assert worker._seq_with_bonus_token_in_last_step == \
|
||||||
|
expected_seq_ids_with_bonus_tokens.union(additional_sequence_ids)
|
||||||
|
assert worker._request_id_seq_id_mapping == \
|
||||||
|
expected_request_id_seq_ids_mapping
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_handle_finished_requests():
|
||||||
|
"""
|
||||||
|
Test to verify that finished request IDs are appropriately processed to
|
||||||
|
update the internal state of the SpecDecodeWorker.
|
||||||
|
|
||||||
|
This test initializes the SpecDecodeWorker with mock data, marks certain
|
||||||
|
requests as finished, and ensures that the corresponding sequence IDs are
|
||||||
|
correctly removed from the internal mappings.
|
||||||
|
"""
|
||||||
|
batch_size = 32
|
||||||
|
k = 3
|
||||||
|
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||||
|
target_worker = mock_worker()
|
||||||
|
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||||
|
worker = SpecDecodeWorker(draft_worker, target_worker,
|
||||||
|
mock_spec_decode_sampler("rejection_sampler"),
|
||||||
|
metrics_collector)
|
||||||
|
# Initialize the request_id_seq_id_mapping mapping dict with a few fake
|
||||||
|
# request ids and corresponding sequence ids.
|
||||||
|
worker._request_id_seq_id_mapping = \
|
||||||
|
{'request-1': {1,2,3}, 'request-2': {4,5,6,7},
|
||||||
|
'request-3': {8,9}, 'request-4': {10,11}}
|
||||||
|
# Initialize seq_with_bonus_token_in_last_step with a few fake
|
||||||
|
# sequence ids.
|
||||||
|
worker._seq_with_bonus_token_in_last_step = {1, 4, 5, 8, 9, 10}
|
||||||
|
exception_secret = 'artificial stop'
|
||||||
|
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
|
||||||
|
|
||||||
|
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
|
||||||
|
# Mark requests with ids request-1 and request-3 as finished.
|
||||||
|
execute_model_req = ExecuteModelRequest(
|
||||||
|
seq_group_metadata_list=seq_group_metadata_list,
|
||||||
|
num_lookahead_slots=k,
|
||||||
|
finished_requests_ids=['request-1', 'request-3'])
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match=exception_secret):
|
||||||
|
worker.execute_model(execute_model_req=execute_model_req)
|
||||||
|
# Verify that request-1 and request-3 are removed from
|
||||||
|
# request_id_seq_id_mapping
|
||||||
|
assert worker._request_id_seq_id_mapping == \
|
||||||
|
{'request-2': {4,5,6,7}, 'request-4': {10,11}}
|
||||||
|
# Verify that all sequence ids corresponding to 'request-1'
|
||||||
|
# and 'request-3' are removed from seq_with_bonus_token_in_last_step.
|
||||||
|
assert worker._seq_with_bonus_token_in_last_step == \
|
||||||
|
{4,5,10}
|
||||||
|
@ -3,8 +3,9 @@ import copy
|
|||||||
import enum
|
import enum
|
||||||
import math
|
import math
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -916,6 +917,21 @@ def get_all_seq_ids(
|
|||||||
return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data]
|
return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data]
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_seq_ids_and_request_ids(
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata]
|
||||||
|
) -> Tuple[List[int], Dict[str, Set[int]]]:
|
||||||
|
"""Given a list of SequenceGroupMetadata, create a list of all
|
||||||
|
sequence ids.
|
||||||
|
"""
|
||||||
|
seq_ids: List[int] = []
|
||||||
|
request_id_seq_ids_mapping: Dict[str, Set[int]] = defaultdict(set)
|
||||||
|
for sg in seq_group_metadata_list:
|
||||||
|
for seq_id in sg.seq_data:
|
||||||
|
seq_ids.append(seq_id)
|
||||||
|
request_id_seq_ids_mapping[sg.request_id].add(seq_id)
|
||||||
|
return seq_ids, request_id_seq_ids_mapping
|
||||||
|
|
||||||
|
|
||||||
class HiddenStates:
|
class HiddenStates:
|
||||||
"""Hidden states corresponding to in-progress sequences.
|
"""Hidden states corresponding to in-progress sequences.
|
||||||
Used in speculative decoding to pass hidden states from
|
Used in speculative decoding to pass hidden states from
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
from typing import Optional, Set
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -62,6 +62,9 @@ class SpeculativeProposer(ABC):
|
|||||||
def get_spec_proposals(
|
def get_spec_proposals(
|
||||||
self,
|
self,
|
||||||
execute_model_req: ExecuteModelRequest,
|
execute_model_req: ExecuteModelRequest,
|
||||||
|
# If set, this contains all sequence IDs that were assigned
|
||||||
|
# bonus tokens in their last forward pass.
|
||||||
|
seq_ids_with_bonus_token_in_last_step: Set[int],
|
||||||
) -> SpeculativeProposals:
|
) -> SpeculativeProposals:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import weakref
|
import weakref
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Set, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -40,6 +40,8 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
|
|||||||
self,
|
self,
|
||||||
execute_model_req: ExecuteModelRequest,
|
execute_model_req: ExecuteModelRequest,
|
||||||
sample_len: int,
|
sample_len: int,
|
||||||
|
# Unused parameter.
|
||||||
|
seq_ids_with_bonus_token_in_last_step: Set[int],
|
||||||
) -> Tuple[List[SamplerOutput], bool]:
|
) -> Tuple[List[SamplerOutput], bool]:
|
||||||
"""Run the model forward pass to generate sample_len future tokens.
|
"""Run the model forward pass to generate sample_len future tokens.
|
||||||
Returns the list of sampler output, one per layer, along with indicator
|
Returns the list of sampler output, one per layer, along with indicator
|
||||||
@ -97,12 +99,14 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
|
|||||||
def get_spec_proposals(
|
def get_spec_proposals(
|
||||||
self,
|
self,
|
||||||
execute_model_req: ExecuteModelRequest,
|
execute_model_req: ExecuteModelRequest,
|
||||||
|
seq_ids_with_bonus_token_in_last_step: Set[int],
|
||||||
) -> SpeculativeProposals:
|
) -> SpeculativeProposals:
|
||||||
"""Produce speculations given an input batch of sequences. The number of
|
"""Produce speculations given an input batch of sequences. The number of
|
||||||
speculative tokens per sequence is determined by max_proposal_len.
|
speculative tokens per sequence is determined by max_proposal_len.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return self._proposer.get_spec_proposals(execute_model_req)
|
return self._proposer.get_spec_proposals(
|
||||||
|
execute_model_req, seq_ids_with_bonus_token_in_last_step)
|
||||||
|
|
||||||
def _raise_if_unsupported(
|
def _raise_if_unsupported(
|
||||||
self,
|
self,
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Set, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -20,6 +20,9 @@ class MLPSpeculatorWorker(NonLLMProposerWorkerBase, MultiStepWorker):
|
|||||||
self,
|
self,
|
||||||
execute_model_req: ExecuteModelRequest,
|
execute_model_req: ExecuteModelRequest,
|
||||||
sample_len: int,
|
sample_len: int,
|
||||||
|
# Unused parameter. MLPSpeculatorWorker does not use the KV Cache and
|
||||||
|
# therefore does not need this parameter.
|
||||||
|
seq_ids_with_bonus_token_in_last_step: Set[int],
|
||||||
) -> Tuple[List[SamplerOutput], bool]:
|
) -> Tuple[List[SamplerOutput], bool]:
|
||||||
"""Run the model forward pass to generate sample_len future tokens.
|
"""Run the model forward pass to generate sample_len future tokens.
|
||||||
Returns the list of sampler output, one per layer, along with indicator
|
Returns the list of sampler output, one per layer, along with indicator
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import copy
|
import copy
|
||||||
import weakref
|
import weakref
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Set, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -51,6 +51,7 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
|
|||||||
self,
|
self,
|
||||||
execute_model_req: ExecuteModelRequest,
|
execute_model_req: ExecuteModelRequest,
|
||||||
sample_len: int,
|
sample_len: int,
|
||||||
|
seq_ids_with_bonus_token_in_last_step: Set[int],
|
||||||
) -> Tuple[List[SamplerOutput], bool]:
|
) -> Tuple[List[SamplerOutput], bool]:
|
||||||
"""Run the model forward pass sample_len times. Returns the list of
|
"""Run the model forward pass sample_len times. Returns the list of
|
||||||
sampler output, one per model forward pass, along with indicator of
|
sampler output, one per model forward pass, along with indicator of
|
||||||
@ -60,44 +61,142 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
|
|||||||
For multi step worker, this indicator shall be True.
|
For multi step worker, this indicator shall be True.
|
||||||
"""
|
"""
|
||||||
self._raise_if_unsupported(execute_model_req)
|
self._raise_if_unsupported(execute_model_req)
|
||||||
|
# Expand the batch for sequences with a bonus token.
|
||||||
# Shallow copy input data so modifications (such as appending tokens)
|
# Perform a forward pass on the expanded batch and filter the
|
||||||
# do not cause side-effects.
|
# response to retain only the original sequences' responses.
|
||||||
copied_seq_group_metadata_list = self._shallow_copy_inputs(
|
expanded_request, indices_of_seq_with_bonus_tokens =\
|
||||||
execute_model_req.seq_group_metadata_list)
|
self._expand_execute_model_request(
|
||||||
copied_execute_model_req = execute_model_req.clone(
|
execute_model_req, seq_ids_with_bonus_token_in_last_step)
|
||||||
copied_seq_group_metadata_list)
|
|
||||||
|
|
||||||
# Run model sample_len times.
|
# Run model sample_len times.
|
||||||
model_outputs: List[SamplerOutput] = []
|
model_outputs: List[SamplerOutput] = []
|
||||||
if isinstance(self.model_runner, TP1DraftModelRunner):
|
if isinstance(self.model_runner, TP1DraftModelRunner):
|
||||||
copied_execute_model_req.num_steps = sample_len
|
expanded_request.num_steps = sample_len
|
||||||
model_outputs = self.execute_model(
|
model_outputs = self.execute_model(
|
||||||
execute_model_req=copied_execute_model_req)
|
execute_model_req=expanded_request)
|
||||||
else:
|
else:
|
||||||
# TODO: Remove this branch once DraftModelRunner supports TP>1.
|
# TODO: Remove this branch once DraftModelRunner supports TP>1.
|
||||||
for _ in range(sample_len):
|
for _ in range(sample_len):
|
||||||
model_output: List[SamplerOutput] = super().execute_model(
|
model_output: List[SamplerOutput] = super().execute_model(
|
||||||
execute_model_req=copied_execute_model_req)
|
execute_model_req=expanded_request)
|
||||||
assert (len(model_output) == 1
|
assert (len(model_output) == 1
|
||||||
), "composing multistep workers not supported"
|
), "composing multistep workers not supported"
|
||||||
model_output = model_output[0]
|
model_output = model_output[0]
|
||||||
|
|
||||||
self._append_new_tokens(model_output,
|
self._append_new_tokens(
|
||||||
copied_seq_group_metadata_list)
|
model_output, expanded_request.seq_group_metadata_list)
|
||||||
model_outputs.append(model_output)
|
model_outputs.append(model_output)
|
||||||
|
|
||||||
return model_outputs, True
|
filtered_model_outputs = self._filter_model_output(
|
||||||
|
model_outputs, indices_of_seq_with_bonus_tokens)
|
||||||
|
return filtered_model_outputs, True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _expand_execute_model_request(
|
||||||
|
execute_model_req: ExecuteModelRequest,
|
||||||
|
seq_with_bonus_token_in_last_step: set,
|
||||||
|
) -> Tuple[ExecuteModelRequest, List[int]]:
|
||||||
|
"""
|
||||||
|
Expands the execute model request based on sequences with bonus
|
||||||
|
tokens.
|
||||||
|
|
||||||
|
For each sequence with a bonus token, this method creates a new
|
||||||
|
sequence without the bonus token and adds it to the execute model
|
||||||
|
request. The original sequence groups are also retained. The indices
|
||||||
|
of the original sequence groups are returned for further processing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
execute_model_req (ExecuteModelRequest): The original execute
|
||||||
|
model request.
|
||||||
|
seq_with_bonus_token_in_last_step (set): Set of sequence IDs that
|
||||||
|
contain bonus tokens.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[ExecuteModelRequest, List[int]]: The updated execute model
|
||||||
|
request with expanded sequences and a list of indices corresponding
|
||||||
|
to the original sequence groups.
|
||||||
|
"""
|
||||||
|
updated_seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||||
|
updated_execute_model_req = execute_model_req.clone(
|
||||||
|
updated_seq_group_metadata_list)
|
||||||
|
indices_of_original_sequence_groups = []
|
||||||
|
for seq_group in execute_model_req.seq_group_metadata_list:
|
||||||
|
seq_group_has_bonus_tokens = False
|
||||||
|
for seq_id, _ in seq_group.seq_data.items():
|
||||||
|
# Identify sequences with bonus tokens in the sequence group.
|
||||||
|
if seq_id in seq_with_bonus_token_in_last_step:
|
||||||
|
seq_group_has_bonus_tokens = True
|
||||||
|
break
|
||||||
|
if seq_group_has_bonus_tokens:
|
||||||
|
#Create new sequences without the last bonus token. These new
|
||||||
|
# sequence have the same sequence id as the original sequence.
|
||||||
|
# We create a new sequence group and add them there.
|
||||||
|
updated_seq_group_without_bonus_token = \
|
||||||
|
MultiStepWorker._copy_seq_metadata_excluding_last_token(
|
||||||
|
seq_group, seq_with_bonus_token_in_last_step)
|
||||||
|
updated_seq_group_metadata_list.append(
|
||||||
|
updated_seq_group_without_bonus_token)
|
||||||
|
# Add the original sequence group.
|
||||||
|
updated_seq_group_metadata_list.append(
|
||||||
|
MultiStepWorker._shallow_copy_seq_group_metadata(seq_group))
|
||||||
|
# Record the index of the original sequence group.
|
||||||
|
indices_of_original_sequence_groups.append(
|
||||||
|
len(updated_seq_group_metadata_list) - 1)
|
||||||
|
|
||||||
|
updated_execute_model_req.seq_group_metadata_list =\
|
||||||
|
updated_seq_group_metadata_list
|
||||||
|
return updated_execute_model_req, indices_of_original_sequence_groups
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _filter_model_output(
|
||||||
|
expanded_batch_outputs: List[SamplerOutput],
|
||||||
|
output_indices_to_retain: List[int]) -> List[SamplerOutput]:
|
||||||
|
"""
|
||||||
|
Filters the model output to include only the specified sequence
|
||||||
|
outputs. This method contracts the expanded batch output from the
|
||||||
|
model to retain the outputs of only those sequences indicated by the
|
||||||
|
provided indices.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
expanded_batch_output (List[SamplerOutput]): The expanded output
|
||||||
|
batch from the model.
|
||||||
|
output_indices_to_retain (List[int]): Indices of the model outputs
|
||||||
|
to retain.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[SamplerOutput]: A list containing the filtered model
|
||||||
|
outputs for the specified indices.
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
SamplerOutput(
|
||||||
|
outputs=[
|
||||||
|
expanded_batch_output.outputs[i]
|
||||||
|
for i in output_indices_to_retain
|
||||||
|
],
|
||||||
|
sampled_token_probs=(
|
||||||
|
expanded_batch_output.
|
||||||
|
sampled_token_probs[output_indices_to_retain]
|
||||||
|
if expanded_batch_output.sampled_token_probs is not None
|
||||||
|
else None),
|
||||||
|
logprobs=(
|
||||||
|
expanded_batch_output.logprobs[output_indices_to_retain]
|
||||||
|
if expanded_batch_output.logprobs is not None else None),
|
||||||
|
sampled_token_ids=(expanded_batch_output.
|
||||||
|
sampled_token_ids[output_indices_to_retain]
|
||||||
|
if expanded_batch_output.sampled_token_ids
|
||||||
|
is not None else None))
|
||||||
|
for expanded_batch_output in expanded_batch_outputs
|
||||||
|
]
|
||||||
|
|
||||||
def get_spec_proposals(
|
def get_spec_proposals(
|
||||||
self,
|
self,
|
||||||
execute_model_req: ExecuteModelRequest,
|
execute_model_req: ExecuteModelRequest,
|
||||||
|
seq_ids_with_bonus_token_in_last_step: set,
|
||||||
) -> SpeculativeProposals:
|
) -> SpeculativeProposals:
|
||||||
"""Produce speculations given an input batch of sequences. The number of
|
"""Produce speculations given an input batch of sequences. The number of
|
||||||
speculative tokens per sequence is determined by max_proposal_len.
|
speculative tokens per sequence is determined by max_proposal_len.
|
||||||
"""
|
"""
|
||||||
|
return self._proposer.get_spec_proposals(
|
||||||
return self._proposer.get_spec_proposals(execute_model_req)
|
execute_model_req, seq_ids_with_bonus_token_in_last_step)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _append_new_tokens(
|
def _append_new_tokens(
|
||||||
@ -123,9 +222,8 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
|
|||||||
seq.update_num_computed_tokens(1)
|
seq.update_num_computed_tokens(1)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _shallow_copy_inputs(
|
def _shallow_copy_seq_group_metadata(
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata]
|
seq_group_metadata: SequenceGroupMetadata, ) -> SequenceGroupMetadata:
|
||||||
) -> List[SequenceGroupMetadata]:
|
|
||||||
"""Copy input data structures to remove side-effects when input data
|
"""Copy input data structures to remove side-effects when input data
|
||||||
structures are shared with other modules.
|
structures are shared with other modules.
|
||||||
|
|
||||||
@ -133,26 +231,62 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
|
|||||||
The alternative is deep-copying (or other form of deep copy); this has
|
The alternative is deep-copying (or other form of deep copy); this has
|
||||||
performance downsides.
|
performance downsides.
|
||||||
"""
|
"""
|
||||||
|
# Shallow-copy the SequenceGroupMetadata. This allows us to
|
||||||
# Shallow-copy the list of SequenceGroupMetadata. This allows us to
|
|
||||||
# append tokens and change is_prompt without external side-effects.
|
# append tokens and change is_prompt without external side-effects.
|
||||||
new_seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
# We must shallow-copy seq_group_metadata as is_prompt could change.
|
||||||
|
new_seq_group_metadata = copy.copy(seq_group_metadata)
|
||||||
|
|
||||||
for old_seq_group_metadata in seq_group_metadata_list:
|
# We must shallow-copy seq_data as we will append token ids
|
||||||
# We must shallow-copy seq_group_metadata as is_prompt could change.
|
new_seq_data: Dict[int, SequenceData] = {}
|
||||||
seq_group_metadata = copy.copy(old_seq_group_metadata)
|
for seq_id, old_seq_data in seq_group_metadata.seq_data.items():
|
||||||
new_seq_group_metadata_list.append(seq_group_metadata)
|
new_seq_data[seq_id] = copy.copy(old_seq_data)
|
||||||
|
new_seq_data[seq_id].output_token_ids =\
|
||||||
|
old_seq_data.output_token_ids[:]
|
||||||
|
|
||||||
# We must shallow-copy seq_data as we will append token ids
|
new_seq_group_metadata.seq_data = new_seq_data
|
||||||
new_seq_data: Dict[int, SequenceData] = {}
|
return new_seq_group_metadata
|
||||||
for seq_id, old_seq_data in seq_group_metadata.seq_data.items():
|
|
||||||
|
@staticmethod
|
||||||
|
def _copy_seq_metadata_excluding_last_token(
|
||||||
|
seq_group_metadata: SequenceGroupMetadata,
|
||||||
|
seq_ids_to_copy: Set[int],
|
||||||
|
) -> SequenceGroupMetadata:
|
||||||
|
"""
|
||||||
|
Creates a shallow copy of the given SequenceGroupMetadata, retaining
|
||||||
|
only the sequence IDs specified in seq_ids_to_copy. For each of these
|
||||||
|
sequence IDs, all output_token_ids except the last one are copied.
|
||||||
|
Sequence IDs not in seq_ids_to_copy are excluded from the copy.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
seq_group_metadata (SequenceGroupMetadata): The original sequence
|
||||||
|
group metadata.
|
||||||
|
seq_ids_to_copy (Set[int]): The set of sequence IDs to include in the
|
||||||
|
copy.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SequenceGroupMetadata: A shallow copy of the sequence group metadata
|
||||||
|
with the specified modifications.
|
||||||
|
"""
|
||||||
|
# Shallow-copy the SequenceGroupMetadata.
|
||||||
|
new_seq_group_metadata = copy.copy(seq_group_metadata)
|
||||||
|
# Shallow-copy seq_data and modify the output_token_ids.
|
||||||
|
new_seq_data: Dict[int, SequenceData] = {}
|
||||||
|
for seq_id, old_seq_data in seq_group_metadata.seq_data.items():
|
||||||
|
if (seq_id in seq_ids_to_copy):
|
||||||
new_seq_data[seq_id] = copy.copy(old_seq_data)
|
new_seq_data[seq_id] = copy.copy(old_seq_data)
|
||||||
new_seq_data[
|
# Copy all the output token ids except the last.
|
||||||
seq_id].output_token_ids = old_seq_data.output_token_ids[:]
|
# Also reduce num_computed_tokens by 1 since we are not
|
||||||
|
# including the last output token.
|
||||||
seq_group_metadata.seq_data = new_seq_data
|
# NOTE: num_computed_tokens is not directly used by the
|
||||||
|
# speculative decoding workers, as it is only relevant for
|
||||||
return new_seq_group_metadata_list
|
# chunked prefill, which is disabled for speculative decoding.
|
||||||
|
# However, to maintain consistency in num_computed_tokens,
|
||||||
|
# we update it here.
|
||||||
|
new_seq_data[seq_id].output_token_ids =\
|
||||||
|
old_seq_data.output_token_ids[:-1]
|
||||||
|
new_seq_data[seq_id].update_num_computed_tokens(-1)
|
||||||
|
new_seq_group_metadata.seq_data = new_seq_data
|
||||||
|
return new_seq_group_metadata
|
||||||
|
|
||||||
def _assert_enough_kv_space(
|
def _assert_enough_kv_space(
|
||||||
self, seq_group_metadata_list: List[SequenceGroupMetadata],
|
self, seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import weakref
|
import weakref
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Set, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -48,6 +48,9 @@ class NGramWorker(NonLLMProposerWorkerBase, LoraNotSupportedWorkerBase):
|
|||||||
self,
|
self,
|
||||||
execute_model_req: ExecuteModelRequest,
|
execute_model_req: ExecuteModelRequest,
|
||||||
sample_len: int,
|
sample_len: int,
|
||||||
|
# Unused parameter. NGramWorker does not use the KV Cache and
|
||||||
|
# therefore does not need this parameter.
|
||||||
|
seq_ids_with_bonus_token_in_last_step: Set[int],
|
||||||
) -> Tuple[Optional[List[Optional[SamplerOutput]]], bool]:
|
) -> Tuple[Optional[List[Optional[SamplerOutput]]], bool]:
|
||||||
"""NGram match algo to pick proposal candidate. Returns the list of
|
"""NGram match algo to pick proposal candidate. Returns the list of
|
||||||
sampler output, one per SequenceGroupMetadata.
|
sampler output, one per SequenceGroupMetadata.
|
||||||
@ -133,12 +136,15 @@ class NGramWorker(NonLLMProposerWorkerBase, LoraNotSupportedWorkerBase):
|
|||||||
def get_spec_proposals(
|
def get_spec_proposals(
|
||||||
self,
|
self,
|
||||||
execute_model_req: ExecuteModelRequest,
|
execute_model_req: ExecuteModelRequest,
|
||||||
|
# Unused parameter. NGramWorker does not use the KV Cache and
|
||||||
|
# therefore does not need this parameter.
|
||||||
|
seq_ids_with_bonus_token_in_last_step: Set[int],
|
||||||
) -> SpeculativeProposals:
|
) -> SpeculativeProposals:
|
||||||
"""Produce speculations given an input batch of sequences. The number of
|
"""Produce speculations given an input batch of sequences. The number of
|
||||||
speculative tokens per sequence is determined by max_proposal_len.
|
speculative tokens per sequence is determined by max_proposal_len.
|
||||||
"""
|
"""
|
||||||
|
return self._proposer.get_spec_proposals(
|
||||||
return self._proposer.get_spec_proposals(execute_model_req)
|
execute_model_req, seq_ids_with_bonus_token_in_last_step)
|
||||||
|
|
||||||
def _raise_if_unsupported(
|
def _raise_if_unsupported(
|
||||||
self,
|
self,
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Set, Tuple
|
||||||
|
|
||||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||||
from vllm.spec_decode.interfaces import SpeculativeProposer
|
from vllm.spec_decode.interfaces import SpeculativeProposer
|
||||||
@ -14,6 +14,13 @@ class ProposerWorkerBase(LoraNotSupportedWorkerBase, SpeculativeProposer):
|
|||||||
self,
|
self,
|
||||||
execute_model_req: ExecuteModelRequest,
|
execute_model_req: ExecuteModelRequest,
|
||||||
sample_len: int,
|
sample_len: int,
|
||||||
|
# A set containing all sequence IDs that were assigned bonus tokens
|
||||||
|
# in their last forward pass. This set is used to backfill the KV cache
|
||||||
|
# with the key-value pairs of the penultimate token in the sequences.
|
||||||
|
# This parameter is only used by the MultiStepWorker, which relies on
|
||||||
|
# the KV cache for token generation. It is not used by workers that
|
||||||
|
# do not utilize the KV cache.
|
||||||
|
seq_ids_with_bonus_token_in_last_step: Set[int]
|
||||||
) -> Tuple[Optional[List[SamplerOutput]], bool]:
|
) -> Tuple[Optional[List[SamplerOutput]], bool]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Set, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -110,13 +110,17 @@ class SmallerTpProposerWorker(ProposerWorkerBase):
|
|||||||
self,
|
self,
|
||||||
execute_model_req: ExecuteModelRequest,
|
execute_model_req: ExecuteModelRequest,
|
||||||
sample_len: int,
|
sample_len: int,
|
||||||
|
seq_ids_with_bonus_token_in_last_step: Set[int],
|
||||||
) -> Tuple[List[SamplerOutput], bool]:
|
) -> Tuple[List[SamplerOutput], bool]:
|
||||||
# Do not check _is_dummy, as it's always called by get_spec_proposals
|
# Do not check _is_dummy, as it's always called by get_spec_proposals
|
||||||
return self._worker.sampler_output(execute_model_req, sample_len)
|
return self._worker.sampler_output(
|
||||||
|
execute_model_req, sample_len,
|
||||||
|
seq_ids_with_bonus_token_in_last_step)
|
||||||
|
|
||||||
def get_spec_proposals(
|
def get_spec_proposals(
|
||||||
self,
|
self,
|
||||||
execute_model_req: ExecuteModelRequest,
|
execute_model_req: ExecuteModelRequest,
|
||||||
|
seq_ids_with_bonus_token_in_last_step: Set[int],
|
||||||
) -> SpeculativeProposals:
|
) -> SpeculativeProposals:
|
||||||
"""Produce speculations given an input batch of sequences. The number of
|
"""Produce speculations given an input batch of sequences. The number of
|
||||||
speculative tokens per sequence is determined by max_proposal_len.
|
speculative tokens per sequence is determined by max_proposal_len.
|
||||||
@ -125,7 +129,8 @@ class SmallerTpProposerWorker(ProposerWorkerBase):
|
|||||||
return SpeculativeProposals(None, None, None)
|
return SpeculativeProposals(None, None, None)
|
||||||
|
|
||||||
with self._patch_tensor_parallel_group():
|
with self._patch_tensor_parallel_group():
|
||||||
return self._worker.get_spec_proposals(execute_model_req)
|
return self._worker.get_spec_proposals(
|
||||||
|
execute_model_req, seq_ids_with_bonus_token_in_last_step)
|
||||||
|
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
|
from collections import defaultdict
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -13,7 +14,7 @@ from vllm.model_executor.layers.typical_acceptance_sampler import (
|
|||||||
TypicalAcceptanceSampler)
|
TypicalAcceptanceSampler)
|
||||||
from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest,
|
from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest,
|
||||||
HiddenStates, SamplerOutput, SequenceGroupMetadata,
|
HiddenStates, SamplerOutput, SequenceGroupMetadata,
|
||||||
get_all_seq_ids)
|
get_all_seq_ids_and_request_ids)
|
||||||
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
|
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
|
||||||
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
|
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
|
||||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||||
@ -112,11 +113,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
|
draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
|
||||||
ngram_prompt_lookup_min = (
|
ngram_prompt_lookup_min = (
|
||||||
draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
|
draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
|
||||||
|
|
||||||
disable_bonus_tokens = True
|
|
||||||
|
|
||||||
if ngram_prompt_lookup_max > 0:
|
if ngram_prompt_lookup_max > 0:
|
||||||
disable_bonus_tokens = False
|
|
||||||
proposer_worker = NGramWorker(**draft_worker_kwargs)
|
proposer_worker = NGramWorker(**draft_worker_kwargs)
|
||||||
proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
|
proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
|
||||||
ngram_prompt_lookup_max)
|
ngram_prompt_lookup_max)
|
||||||
@ -128,11 +125,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
|
|
||||||
if draft_worker_kwargs[
|
if draft_worker_kwargs[
|
||||||
"model_config"].hf_config.model_type == "mlp_speculator":
|
"model_config"].hf_config.model_type == "mlp_speculator":
|
||||||
disable_bonus_tokens = False
|
|
||||||
proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs)
|
proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs)
|
||||||
elif draft_worker_kwargs[
|
elif draft_worker_kwargs[
|
||||||
"model_config"].hf_config.model_type == "medusa":
|
"model_config"].hf_config.model_type == "medusa":
|
||||||
disable_bonus_tokens = False
|
|
||||||
proposer_worker = MedusaWorker(**draft_worker_kwargs)
|
proposer_worker = MedusaWorker(**draft_worker_kwargs)
|
||||||
else:
|
else:
|
||||||
if draft_tp == 1:
|
if draft_tp == 1:
|
||||||
@ -149,10 +144,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
spec_decode_sampler: SpecDecodeBaseSampler = None
|
spec_decode_sampler: SpecDecodeBaseSampler = None
|
||||||
if draft_token_acceptance_method == "rejection_sampler":
|
if draft_token_acceptance_method == "rejection_sampler":
|
||||||
spec_decode_sampler = RejectionSampler(
|
spec_decode_sampler = RejectionSampler(
|
||||||
disable_bonus_tokens=disable_bonus_tokens, )
|
disable_bonus_tokens=False, )
|
||||||
elif draft_token_acceptance_method == "typical_acceptance_sampler":
|
elif draft_token_acceptance_method == "typical_acceptance_sampler":
|
||||||
spec_decode_sampler = TypicalAcceptanceSampler(
|
spec_decode_sampler = TypicalAcceptanceSampler(
|
||||||
disable_bonus_tokens=disable_bonus_tokens,
|
disable_bonus_tokens=False,
|
||||||
posterior_threshold=\
|
posterior_threshold=\
|
||||||
typical_acceptance_sampler_posterior_threshold,
|
typical_acceptance_sampler_posterior_threshold,
|
||||||
posterior_alpha=typical_acceptance_sampler_posterior_alpha,
|
posterior_alpha=typical_acceptance_sampler_posterior_alpha,
|
||||||
@ -200,6 +195,15 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
self._metrics = AsyncMetricsCollector(
|
self._metrics = AsyncMetricsCollector(
|
||||||
self.spec_decode_sampler
|
self.spec_decode_sampler
|
||||||
) if metrics_collector is None else metrics_collector
|
) if metrics_collector is None else metrics_collector
|
||||||
|
# Tracks the sequence IDs that received a bonus token ID in
|
||||||
|
# their last forward pass. Needed only if KV cache is being
|
||||||
|
# used for token generation such as in the case of MultiStepWorker.
|
||||||
|
self._seq_with_bonus_token_in_last_step: Set[int] = set()
|
||||||
|
# Tracks the currently active request ids and the sequence IDs
|
||||||
|
# corresponding to them
|
||||||
|
self._request_id_seq_id_mapping: Dict[str, Set[int]] = defaultdict(set)
|
||||||
|
# Tracks if the proposer worker uses the KV cache or not.
|
||||||
|
|
||||||
self.probs_dtype = self.spec_decode_sampler.probs_dtype
|
self.probs_dtype = self.spec_decode_sampler.probs_dtype
|
||||||
self.token_id_dtype = self.spec_decode_sampler.token_id_dtype
|
self.token_id_dtype = self.spec_decode_sampler.token_id_dtype
|
||||||
# Lazy initiazliation.
|
# Lazy initiazliation.
|
||||||
@ -307,6 +311,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
broadcast_tensor_dict({}, src=0)
|
broadcast_tensor_dict({}, src=0)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
self._track_finished_requests(execute_model_req)
|
||||||
disable_all_speculation = self._should_disable_all_speculation(
|
disable_all_speculation = self._should_disable_all_speculation(
|
||||||
execute_model_req)
|
execute_model_req)
|
||||||
num_lookahead_slots = execute_model_req.num_lookahead_slots
|
num_lookahead_slots = execute_model_req.num_lookahead_slots
|
||||||
@ -453,7 +458,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
self.previous_hidden_states = None
|
self.previous_hidden_states = None
|
||||||
|
|
||||||
# Generate proposals using draft worker.
|
# Generate proposals using draft worker.
|
||||||
proposals = self.proposer_worker.get_spec_proposals(execute_model_req)
|
proposals = self.proposer_worker.get_spec_proposals(
|
||||||
|
execute_model_req, self._seq_with_bonus_token_in_last_step)
|
||||||
|
|
||||||
proposal_scores = self.scorer.score_proposals(
|
proposal_scores = self.scorer.score_proposals(
|
||||||
execute_model_req,
|
execute_model_req,
|
||||||
@ -585,7 +591,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
|
|
||||||
# Get the sequence ids and num_logprobs (sampling parameter) in the
|
# Get the sequence ids and num_logprobs (sampling parameter) in the
|
||||||
# batch.
|
# batch.
|
||||||
seq_ids = get_all_seq_ids(seq_group_metadata_list)
|
seq_ids, request_ids_seq_ids_mapping = get_all_seq_ids_and_request_ids(
|
||||||
|
seq_group_metadata_list)
|
||||||
|
|
||||||
num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list)
|
num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list)
|
||||||
|
|
||||||
# Serialize all tensors to CPU Python lists.
|
# Serialize all tensors to CPU Python lists.
|
||||||
@ -608,7 +616,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
for sequence_index in range(batch_size):
|
for sequence_index in range(batch_size):
|
||||||
# Each sequence may have a different num_logprobs; retrieve it.
|
# Each sequence may have a different num_logprobs; retrieve it.
|
||||||
num_logprobs = num_logprobs_per_seq[sequence_index]
|
num_logprobs = num_logprobs_per_seq[sequence_index]
|
||||||
|
|
||||||
step_output_token_ids.append(
|
step_output_token_ids.append(
|
||||||
create_sequence_group_output(
|
create_sequence_group_output(
|
||||||
token_id=accepted_token_ids_by_step[step_index]
|
token_id=accepted_token_ids_by_step[step_index]
|
||||||
@ -623,18 +630,48 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
topk_logprobs=topk_logprobs_by_step[step_index]
|
topk_logprobs=topk_logprobs_by_step[step_index]
|
||||||
[sequence_index][:num_logprobs],
|
[sequence_index][:num_logprobs],
|
||||||
))
|
))
|
||||||
|
|
||||||
sampler_output_list.append(
|
sampler_output_list.append(
|
||||||
SamplerOutput(outputs=step_output_token_ids))
|
SamplerOutput(outputs=step_output_token_ids))
|
||||||
|
|
||||||
|
# Populate the data structures needed to keep track of sequences with
|
||||||
|
# bonus tokens.
|
||||||
|
self._track_sequences_with_bonus_tokens(seq_ids,
|
||||||
|
request_ids_seq_ids_mapping,
|
||||||
|
accepted_token_ids_by_step)
|
||||||
maybe_rejsample_metrics = (
|
maybe_rejsample_metrics = (
|
||||||
self._metrics.maybe_collect_rejsample_metrics(k))
|
self._metrics.maybe_collect_rejsample_metrics(k))
|
||||||
if maybe_rejsample_metrics is not None:
|
if maybe_rejsample_metrics is not None:
|
||||||
sampler_output_list[
|
sampler_output_list[
|
||||||
0].spec_decode_worker_metrics = maybe_rejsample_metrics
|
0].spec_decode_worker_metrics = maybe_rejsample_metrics
|
||||||
|
|
||||||
return sampler_output_list
|
return sampler_output_list
|
||||||
|
|
||||||
|
def _track_finished_requests(self, execute_model_req: ExecuteModelRequest):
|
||||||
|
"""
|
||||||
|
Removes the finished requests and their associated sequence ids from
|
||||||
|
internal book keeping data structures.
|
||||||
|
"""
|
||||||
|
for finished_request in execute_model_req.finished_requests_ids:
|
||||||
|
for seq_id in self._request_id_seq_id_mapping[finished_request]:
|
||||||
|
self._seq_with_bonus_token_in_last_step.discard(seq_id)
|
||||||
|
del self._request_id_seq_id_mapping[finished_request]
|
||||||
|
|
||||||
|
def _track_sequences_with_bonus_tokens(
|
||||||
|
self, seq_ids: List[int],
|
||||||
|
request_ids_seq_ids_mapping: Dict[str, Set[int]],
|
||||||
|
accepted_token_ids_by_step: List[List[int]]):
|
||||||
|
"""
|
||||||
|
Updates the internal data structures which keep track of sequences
|
||||||
|
which have been assigned bonus tokens in their last forward pass.
|
||||||
|
"""
|
||||||
|
for seq_index, seq_id in enumerate(seq_ids):
|
||||||
|
last_token_id = accepted_token_ids_by_step[-1][seq_index]
|
||||||
|
if last_token_id == -1:
|
||||||
|
self._seq_with_bonus_token_in_last_step.discard(seq_id)
|
||||||
|
else:
|
||||||
|
self._seq_with_bonus_token_in_last_step.add(seq_id)
|
||||||
|
for request_id, sequences in request_ids_seq_ids_mapping.items():
|
||||||
|
self._request_id_seq_id_mapping[request_id].update(sequences)
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def _vocab_size(self) -> int:
|
def _vocab_size(self) -> int:
|
||||||
"""Get the vocab size of the model and make sure it's consistent between
|
"""Get the vocab size of the model and make sure it's consistent between
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Set, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -42,6 +42,7 @@ class Top1Proposer(SpeculativeProposer):
|
|||||||
def get_spec_proposals(
|
def get_spec_proposals(
|
||||||
self,
|
self,
|
||||||
execute_model_req: ExecuteModelRequest,
|
execute_model_req: ExecuteModelRequest,
|
||||||
|
seq_ids_with_bonus_token_in_last_step: Set[int],
|
||||||
) -> SpeculativeProposals:
|
) -> SpeculativeProposals:
|
||||||
"""Get speculative proposals given the input batch.
|
"""Get speculative proposals given the input batch.
|
||||||
|
|
||||||
@ -76,6 +77,8 @@ class Top1Proposer(SpeculativeProposer):
|
|||||||
maybe_sampler_output, transposed = self._worker.sampler_output(
|
maybe_sampler_output, transposed = self._worker.sampler_output(
|
||||||
execute_model_req=nonzero_execute_model_req,
|
execute_model_req=nonzero_execute_model_req,
|
||||||
sample_len=proposal_len,
|
sample_len=proposal_len,
|
||||||
|
seq_ids_with_bonus_token_in_last_step=\
|
||||||
|
seq_ids_with_bonus_token_in_last_step,
|
||||||
)
|
)
|
||||||
(
|
(
|
||||||
proposal_lens,
|
proposal_lens,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user