[Speculative Decoding] Add ProposerWorkerBase
abstract class (#5252)
This commit is contained in:
parent
f270a39537
commit
faf71bcd4b
@ -68,13 +68,13 @@ def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int):
|
||||
if queue_size < disable_by_batch_size:
|
||||
# Should raise exception when executing the mocked draft model.
|
||||
with pytest.raises(ValueError, match=exception_secret):
|
||||
proposer.get_proposals(execute_model_req=ExecuteModelRequest(
|
||||
proposer.get_spec_proposals(execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k), )
|
||||
else:
|
||||
# Should not execute the draft model because spec decode is disabled
|
||||
# for all requests. Accordingly, the proposal length should be 0.
|
||||
proposals = proposer.get_proposals(
|
||||
proposals = proposer.get_spec_proposals(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k), )
|
||||
|
@ -307,7 +307,8 @@ def test_draft_proposals_full_speculation_len():
|
||||
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
|
||||
|
||||
proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
|
||||
proposals = proposer.get_spec_proposals(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k), )
|
||||
|
||||
@ -344,7 +345,8 @@ def test_draft_proposals_no_speculations():
|
||||
k,
|
||||
prompt_len=prompt_len)
|
||||
|
||||
proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
|
||||
proposals = proposer.get_spec_proposals(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k), )
|
||||
|
||||
@ -415,7 +417,8 @@ def test_draft_proposals_mixed_k():
|
||||
prev_output_token_len=prev_output_token_len,
|
||||
)
|
||||
|
||||
proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
|
||||
proposals = proposer.get_spec_proposals(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k), )
|
||||
|
||||
|
@ -50,7 +50,8 @@ def test_ngram_algo_correctness_for_single_no_match():
|
||||
block_size,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
|
||||
proposals = proposer.get_spec_proposals(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=proposal_len), )
|
||||
|
||||
@ -117,7 +118,8 @@ def test_ngram_algo_correctness_for_batches_not_match_all():
|
||||
block_size,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
|
||||
proposals = proposer.get_spec_proposals(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=proposal_len), )
|
||||
|
||||
@ -188,7 +190,8 @@ def test_ngram_algo_correctness_for_batches_match_all():
|
||||
block_size,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
|
||||
proposals = proposer.get_spec_proposals(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=proposal_len), )
|
||||
|
||||
|
@ -55,7 +55,7 @@ class SpeculativeScores:
|
||||
class SpeculativeProposer(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def get_proposals(
|
||||
def get_spec_proposals(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
) -> SpeculativeProposals:
|
||||
|
@ -7,11 +7,12 @@ import torch
|
||||
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.spec_decode.interfaces import SpeculativeProposals
|
||||
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
|
||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||
from vllm.worker.worker import Worker
|
||||
|
||||
|
||||
class MultiStepWorker(Worker):
|
||||
class MultiStepWorker(Worker, ProposerWorkerBase):
|
||||
"""The MultiStepWorker is equivalent to a Worker except that it allows
|
||||
multiple forward passes in a single call, assuming the scheduler has
|
||||
allocated enough space to store the additional KV. This reduces overhead
|
||||
@ -33,7 +34,7 @@ class MultiStepWorker(Worker):
|
||||
super().init_device()
|
||||
|
||||
self._proposer = Top1Proposer(
|
||||
weakref.proxy(self),
|
||||
weakref.proxy(self), # type: ignore[arg-type]
|
||||
self.device,
|
||||
self.vocab_size,
|
||||
max_proposal_len=self.max_model_len,
|
||||
@ -92,11 +93,12 @@ class MultiStepWorker(Worker):
|
||||
speculative tokens per sequence is determined by max_proposal_len.
|
||||
"""
|
||||
|
||||
return self._proposer.get_proposals(execute_model_req)
|
||||
return self._proposer.get_spec_proposals(execute_model_req)
|
||||
|
||||
@staticmethod
|
||||
def _append_new_tokens(
|
||||
self, model_output: SamplerOutput,
|
||||
seq_group_metadata_list: SequenceGroupMetadata) -> None:
|
||||
model_output: List[SamplerOutput],
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
|
||||
"""Given model output from a single run, append the tokens to the
|
||||
sequences. This is normally done outside of the worker, but it is
|
||||
required if the worker is to perform multiple forward passes.
|
||||
@ -116,8 +118,9 @@ class MultiStepWorker(Worker):
|
||||
seq.append_token_id(token_id, token_logprob.logprob)
|
||||
seq.update_num_computed_tokens(1)
|
||||
|
||||
@staticmethod
|
||||
def _shallow_copy_inputs(
|
||||
self, seq_group_metadata_list: List[SequenceGroupMetadata]
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata]
|
||||
) -> List[SequenceGroupMetadata]:
|
||||
"""Copy input data structures to remove side-effects when input data
|
||||
structures are shared with other modules.
|
||||
|
@ -5,15 +5,16 @@ import torch
|
||||
|
||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||
from vllm.spec_decode.interfaces import SpeculativeProposals
|
||||
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
|
||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
|
||||
|
||||
|
||||
class NGramWorker(LoraNotSupportedWorkerBase):
|
||||
class NGramWorker(NonLLMProposerWorkerBase, LoraNotSupportedWorkerBase):
|
||||
"""NGramWorker provides a light drafter without need for model.
|
||||
|
||||
Current NGramWorker only implement prompt lookup decoding,
|
||||
and in future we may also do RAG type drafter and other scenerios
|
||||
and in future we may also do RAG type drafter and other scenarios
|
||||
which don't rely on LLM model to give proposals.
|
||||
"""
|
||||
|
||||
@ -38,34 +39,11 @@ class NGramWorker(LoraNotSupportedWorkerBase):
|
||||
|
||||
# Current only support Top1Proposer
|
||||
self._proposer = Top1Proposer(
|
||||
weakref.proxy(self),
|
||||
weakref.proxy(self), # type: ignore[arg-type]
|
||||
device=self.device,
|
||||
vocab_size=self.vocab_size,
|
||||
)
|
||||
|
||||
def set_include_gpu_probs_tensor(self):
|
||||
# NGram don't need gpu sampler
|
||||
pass
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None) -> None:
|
||||
"""NGram doesn't depend on model execution, just pass this function"""
|
||||
pass
|
||||
|
||||
def determine_num_available_blocks(self) -> None:
|
||||
"""NGram doesn't depend on model execution, no need to check blocks"""
|
||||
pass
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int,
|
||||
num_cpu_blocks: int) -> None:
|
||||
"""As there is no cache need to handle, just pass this function"""
|
||||
pass
|
||||
|
||||
def get_cache_block_size_bytes(self):
|
||||
"""Return the size of a cache block in bytes."""
|
||||
return 0
|
||||
|
||||
def sampler_output(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
@ -97,7 +75,6 @@ class NGramWorker(LoraNotSupportedWorkerBase):
|
||||
-1,
|
||||
):
|
||||
ngram_tensor = input_ids[-ngram_size:]
|
||||
proposal_start_idx = None
|
||||
if ngram_size == 1:
|
||||
# Do not match itself and do not use unfold and all
|
||||
matches = (input_ids[:-1] == ngram_tensor)
|
||||
@ -161,7 +138,7 @@ class NGramWorker(LoraNotSupportedWorkerBase):
|
||||
speculative tokens per sequence is determined by max_proposal_len.
|
||||
"""
|
||||
|
||||
return self._proposer.get_proposals(execute_model_req)
|
||||
return self._proposer.get_spec_proposals(execute_model_req)
|
||||
|
||||
def _raise_if_unsupported(
|
||||
self,
|
||||
|
44
vllm/spec_decode/proposer_worker_base.py
Normal file
44
vllm/spec_decode/proposer_worker_base.py
Normal file
@ -0,0 +1,44 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||
from vllm.spec_decode.interfaces import SpeculativeProposer
|
||||
from vllm.worker.worker_base import WorkerBase
|
||||
|
||||
|
||||
class ProposerWorkerBase(WorkerBase, SpeculativeProposer):
|
||||
"""Interface for proposer workers"""
|
||||
|
||||
@abstractmethod
|
||||
def sampler_output(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
sample_len: int,
|
||||
) -> Tuple[Optional[List[SamplerOutput]], bool]:
|
||||
raise NotImplementedError
|
||||
|
||||
def set_include_gpu_probs_tensor(self):
|
||||
"""Implementation optional"""
|
||||
pass
|
||||
|
||||
|
||||
class NonLLMProposerWorkerBase(ProposerWorkerBase, ABC):
|
||||
"""Proposer worker which does not use a model with kvcache"""
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||
) -> List[SamplerOutput]:
|
||||
"""get_spec_proposals is used to get the proposals"""
|
||||
return []
|
||||
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
"""This is never called on the proposer, only the target model"""
|
||||
raise NotImplementedError
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int,
|
||||
num_cpu_blocks: int) -> None:
|
||||
pass
|
||||
|
||||
def get_cache_block_size_bytes(self) -> int:
|
||||
return 0
|
@ -14,6 +14,7 @@ from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
from vllm.spec_decode.metrics import AsyncMetricsCollector
|
||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
from vllm.spec_decode.ngram_worker import NGramWorker
|
||||
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
|
||||
from vllm.spec_decode.util import (create_sequence_group_output,
|
||||
get_all_num_logprobs, get_all_seq_ids,
|
||||
get_sampled_token_logprobs, nvtx_range,
|
||||
@ -117,7 +118,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
proposer_worker: WorkerBase,
|
||||
proposer_worker: ProposerWorkerBase,
|
||||
scorer_worker: WorkerBase,
|
||||
rejection_sampler: RejectionSampler,
|
||||
metrics_collector: Optional[AsyncMetricsCollector] = None,
|
||||
@ -260,7 +261,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
|
||||
# This is required as if the number of draft model runs changes
|
||||
# dynamically, the non-driver workers won't know unless we perform a
|
||||
# communication to inform then.
|
||||
# communication to inform them.
|
||||
broadcast_dict = dict(
|
||||
num_lookahead_slots=num_lookahead_slots,
|
||||
disable_all_speculation=disable_all_speculation,
|
||||
|
@ -6,8 +6,8 @@ from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeProposer)
|
||||
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
|
||||
from vllm.spec_decode.util import sampler_output_to_torch
|
||||
from vllm.worker.worker_base import WorkerBase
|
||||
|
||||
|
||||
class Top1Proposer(SpeculativeProposer):
|
||||
@ -29,7 +29,7 @@ class Top1Proposer(SpeculativeProposer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
worker: WorkerBase,
|
||||
worker: ProposerWorkerBase,
|
||||
device: str,
|
||||
vocab_size: int,
|
||||
max_proposal_len: Optional[int] = None,
|
||||
@ -39,7 +39,7 @@ class Top1Proposer(SpeculativeProposer):
|
||||
self.max_proposal_len = max_proposal_len
|
||||
self._vocab_size = vocab_size
|
||||
|
||||
def get_proposals(
|
||||
def get_spec_proposals(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
) -> SpeculativeProposals:
|
||||
|
Loading…
x
Reference in New Issue
Block a user