[Speculative Decoding] Add ProposerWorkerBase abstract class (#5252)

This commit is contained in:
Nick Hill 2024-06-05 14:53:05 -07:00 committed by GitHub
parent f270a39537
commit faf71bcd4b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 91 additions and 60 deletions

View File

@ -68,13 +68,13 @@ def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int):
if queue_size < disable_by_batch_size: if queue_size < disable_by_batch_size:
# Should raise exception when executing the mocked draft model. # Should raise exception when executing the mocked draft model.
with pytest.raises(ValueError, match=exception_secret): with pytest.raises(ValueError, match=exception_secret):
proposer.get_proposals(execute_model_req=ExecuteModelRequest( proposer.get_spec_proposals(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k), ) num_lookahead_slots=k), )
else: else:
# Should not execute the draft model because spec decode is disabled # Should not execute the draft model because spec decode is disabled
# for all requests. Accordingly, the proposal length should be 0. # for all requests. Accordingly, the proposal length should be 0.
proposals = proposer.get_proposals( proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest( execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k), ) num_lookahead_slots=k), )

View File

@ -307,7 +307,8 @@ def test_draft_proposals_full_speculation_len():
seq_group_metadata_list, _, _ = create_batch(batch_size, k) seq_group_metadata_list, _, _ = create_batch(batch_size, k)
proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest( proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k), ) num_lookahead_slots=k), )
@ -344,7 +345,8 @@ def test_draft_proposals_no_speculations():
k, k,
prompt_len=prompt_len) prompt_len=prompt_len)
proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest( proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k), ) num_lookahead_slots=k), )
@ -415,7 +417,8 @@ def test_draft_proposals_mixed_k():
prev_output_token_len=prev_output_token_len, prev_output_token_len=prev_output_token_len,
) )
proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest( proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k), ) num_lookahead_slots=k), )

View File

@ -50,7 +50,8 @@ def test_ngram_algo_correctness_for_single_no_match():
block_size, block_size,
final_prompt_lens=final_prompt_lens) final_prompt_lens=final_prompt_lens)
proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest( proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=proposal_len), ) num_lookahead_slots=proposal_len), )
@ -117,7 +118,8 @@ def test_ngram_algo_correctness_for_batches_not_match_all():
block_size, block_size,
final_prompt_lens=final_prompt_lens) final_prompt_lens=final_prompt_lens)
proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest( proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=proposal_len), ) num_lookahead_slots=proposal_len), )
@ -188,7 +190,8 @@ def test_ngram_algo_correctness_for_batches_match_all():
block_size, block_size,
final_prompt_lens=final_prompt_lens) final_prompt_lens=final_prompt_lens)
proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest( proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=proposal_len), ) num_lookahead_slots=proposal_len), )

View File

@ -55,7 +55,7 @@ class SpeculativeScores:
class SpeculativeProposer(ABC): class SpeculativeProposer(ABC):
@abstractmethod @abstractmethod
def get_proposals( def get_spec_proposals(
self, self,
execute_model_req: ExecuteModelRequest, execute_model_req: ExecuteModelRequest,
) -> SpeculativeProposals: ) -> SpeculativeProposals:

View File

@ -7,11 +7,12 @@ import torch
from vllm.sequence import (ExecuteModelRequest, SamplerOutput, from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
SequenceGroupMetadata) SequenceGroupMetadata)
from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
from vllm.spec_decode.top1_proposer import Top1Proposer from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker import Worker from vllm.worker.worker import Worker
class MultiStepWorker(Worker): class MultiStepWorker(Worker, ProposerWorkerBase):
"""The MultiStepWorker is equivalent to a Worker except that it allows """The MultiStepWorker is equivalent to a Worker except that it allows
multiple forward passes in a single call, assuming the scheduler has multiple forward passes in a single call, assuming the scheduler has
allocated enough space to store the additional KV. This reduces overhead allocated enough space to store the additional KV. This reduces overhead
@ -33,7 +34,7 @@ class MultiStepWorker(Worker):
super().init_device() super().init_device()
self._proposer = Top1Proposer( self._proposer = Top1Proposer(
weakref.proxy(self), weakref.proxy(self), # type: ignore[arg-type]
self.device, self.device,
self.vocab_size, self.vocab_size,
max_proposal_len=self.max_model_len, max_proposal_len=self.max_model_len,
@ -92,11 +93,12 @@ class MultiStepWorker(Worker):
speculative tokens per sequence is determined by max_proposal_len. speculative tokens per sequence is determined by max_proposal_len.
""" """
return self._proposer.get_proposals(execute_model_req) return self._proposer.get_spec_proposals(execute_model_req)
@staticmethod
def _append_new_tokens( def _append_new_tokens(
self, model_output: SamplerOutput, model_output: List[SamplerOutput],
seq_group_metadata_list: SequenceGroupMetadata) -> None: seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
"""Given model output from a single run, append the tokens to the """Given model output from a single run, append the tokens to the
sequences. This is normally done outside of the worker, but it is sequences. This is normally done outside of the worker, but it is
required if the worker is to perform multiple forward passes. required if the worker is to perform multiple forward passes.
@ -116,8 +118,9 @@ class MultiStepWorker(Worker):
seq.append_token_id(token_id, token_logprob.logprob) seq.append_token_id(token_id, token_logprob.logprob)
seq.update_num_computed_tokens(1) seq.update_num_computed_tokens(1)
@staticmethod
def _shallow_copy_inputs( def _shallow_copy_inputs(
self, seq_group_metadata_list: List[SequenceGroupMetadata] seq_group_metadata_list: List[SequenceGroupMetadata]
) -> List[SequenceGroupMetadata]: ) -> List[SequenceGroupMetadata]:
"""Copy input data structures to remove side-effects when input data """Copy input data structures to remove side-effects when input data
structures are shared with other modules. structures are shared with other modules.

View File

@ -5,15 +5,16 @@ import torch
from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
from vllm.spec_decode.top1_proposer import Top1Proposer from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker_base import LoraNotSupportedWorkerBase from vllm.worker.worker_base import LoraNotSupportedWorkerBase
class NGramWorker(LoraNotSupportedWorkerBase): class NGramWorker(NonLLMProposerWorkerBase, LoraNotSupportedWorkerBase):
"""NGramWorker provides a light drafter without need for model. """NGramWorker provides a light drafter without need for model.
Current NGramWorker only implement prompt lookup decoding, Current NGramWorker only implement prompt lookup decoding,
and in future we may also do RAG type drafter and other scenerios and in future we may also do RAG type drafter and other scenarios
which don't rely on LLM model to give proposals. which don't rely on LLM model to give proposals.
""" """
@ -38,34 +39,11 @@ class NGramWorker(LoraNotSupportedWorkerBase):
# Current only support Top1Proposer # Current only support Top1Proposer
self._proposer = Top1Proposer( self._proposer = Top1Proposer(
weakref.proxy(self), weakref.proxy(self), # type: ignore[arg-type]
device=self.device, device=self.device,
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
) )
def set_include_gpu_probs_tensor(self):
# NGram don't need gpu sampler
pass
def execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None) -> None:
"""NGram doesn't depend on model execution, just pass this function"""
pass
def determine_num_available_blocks(self) -> None:
"""NGram doesn't depend on model execution, no need to check blocks"""
pass
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""As there is no cache need to handle, just pass this function"""
pass
def get_cache_block_size_bytes(self):
"""Return the size of a cache block in bytes."""
return 0
def sampler_output( def sampler_output(
self, self,
execute_model_req: ExecuteModelRequest, execute_model_req: ExecuteModelRequest,
@ -97,7 +75,6 @@ class NGramWorker(LoraNotSupportedWorkerBase):
-1, -1,
): ):
ngram_tensor = input_ids[-ngram_size:] ngram_tensor = input_ids[-ngram_size:]
proposal_start_idx = None
if ngram_size == 1: if ngram_size == 1:
# Do not match itself and do not use unfold and all # Do not match itself and do not use unfold and all
matches = (input_ids[:-1] == ngram_tensor) matches = (input_ids[:-1] == ngram_tensor)
@ -161,7 +138,7 @@ class NGramWorker(LoraNotSupportedWorkerBase):
speculative tokens per sequence is determined by max_proposal_len. speculative tokens per sequence is determined by max_proposal_len.
""" """
return self._proposer.get_proposals(execute_model_req) return self._proposer.get_spec_proposals(execute_model_req)
def _raise_if_unsupported( def _raise_if_unsupported(
self, self,

View File

@ -0,0 +1,44 @@
from abc import ABC, abstractmethod
from typing import List, Optional, Tuple
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.spec_decode.interfaces import SpeculativeProposer
from vllm.worker.worker_base import WorkerBase
class ProposerWorkerBase(WorkerBase, SpeculativeProposer):
"""Interface for proposer workers"""
@abstractmethod
def sampler_output(
self,
execute_model_req: ExecuteModelRequest,
sample_len: int,
) -> Tuple[Optional[List[SamplerOutput]], bool]:
raise NotImplementedError
def set_include_gpu_probs_tensor(self):
"""Implementation optional"""
pass
class NonLLMProposerWorkerBase(ProposerWorkerBase, ABC):
"""Proposer worker which does not use a model with kvcache"""
def execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
"""get_spec_proposals is used to get the proposals"""
return []
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""This is never called on the proposer, only the target model"""
raise NotImplementedError
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
pass
def get_cache_block_size_bytes(self) -> int:
return 0

View File

@ -14,6 +14,7 @@ from vllm.spec_decode.interfaces import (SpeculativeProposals,
from vllm.spec_decode.metrics import AsyncMetricsCollector from vllm.spec_decode.metrics import AsyncMetricsCollector
from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.ngram_worker import NGramWorker from vllm.spec_decode.ngram_worker import NGramWorker
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
from vllm.spec_decode.util import (create_sequence_group_output, from vllm.spec_decode.util import (create_sequence_group_output,
get_all_num_logprobs, get_all_seq_ids, get_all_num_logprobs, get_all_seq_ids,
get_sampled_token_logprobs, nvtx_range, get_sampled_token_logprobs, nvtx_range,
@ -117,7 +118,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
def __init__( def __init__(
self, self,
proposer_worker: WorkerBase, proposer_worker: ProposerWorkerBase,
scorer_worker: WorkerBase, scorer_worker: WorkerBase,
rejection_sampler: RejectionSampler, rejection_sampler: RejectionSampler,
metrics_collector: Optional[AsyncMetricsCollector] = None, metrics_collector: Optional[AsyncMetricsCollector] = None,
@ -260,7 +261,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# This is required as if the number of draft model runs changes # This is required as if the number of draft model runs changes
# dynamically, the non-driver workers won't know unless we perform a # dynamically, the non-driver workers won't know unless we perform a
# communication to inform then. # communication to inform them.
broadcast_dict = dict( broadcast_dict = dict(
num_lookahead_slots=num_lookahead_slots, num_lookahead_slots=num_lookahead_slots,
disable_all_speculation=disable_all_speculation, disable_all_speculation=disable_all_speculation,

View File

@ -6,8 +6,8 @@ from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
SequenceGroupMetadata) SequenceGroupMetadata)
from vllm.spec_decode.interfaces import (SpeculativeProposals, from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeProposer) SpeculativeProposer)
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
from vllm.spec_decode.util import sampler_output_to_torch from vllm.spec_decode.util import sampler_output_to_torch
from vllm.worker.worker_base import WorkerBase
class Top1Proposer(SpeculativeProposer): class Top1Proposer(SpeculativeProposer):
@ -29,7 +29,7 @@ class Top1Proposer(SpeculativeProposer):
def __init__( def __init__(
self, self,
worker: WorkerBase, worker: ProposerWorkerBase,
device: str, device: str,
vocab_size: int, vocab_size: int,
max_proposal_len: Optional[int] = None, max_proposal_len: Optional[int] = None,
@ -39,7 +39,7 @@ class Top1Proposer(SpeculativeProposer):
self.max_proposal_len = max_proposal_len self.max_proposal_len = max_proposal_len
self._vocab_size = vocab_size self._vocab_size = vocab_size
def get_proposals( def get_spec_proposals(
self, self,
execute_model_req: ExecuteModelRequest, execute_model_req: ExecuteModelRequest,
) -> SpeculativeProposals: ) -> SpeculativeProposals: