[Speculative Decoding] Add ProposerWorkerBase
abstract class (#5252)
This commit is contained in:
parent
f270a39537
commit
faf71bcd4b
@ -68,13 +68,13 @@ def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int):
|
|||||||
if queue_size < disable_by_batch_size:
|
if queue_size < disable_by_batch_size:
|
||||||
# Should raise exception when executing the mocked draft model.
|
# Should raise exception when executing the mocked draft model.
|
||||||
with pytest.raises(ValueError, match=exception_secret):
|
with pytest.raises(ValueError, match=exception_secret):
|
||||||
proposer.get_proposals(execute_model_req=ExecuteModelRequest(
|
proposer.get_spec_proposals(execute_model_req=ExecuteModelRequest(
|
||||||
seq_group_metadata_list=seq_group_metadata_list,
|
seq_group_metadata_list=seq_group_metadata_list,
|
||||||
num_lookahead_slots=k), )
|
num_lookahead_slots=k), )
|
||||||
else:
|
else:
|
||||||
# Should not execute the draft model because spec decode is disabled
|
# Should not execute the draft model because spec decode is disabled
|
||||||
# for all requests. Accordingly, the proposal length should be 0.
|
# for all requests. Accordingly, the proposal length should be 0.
|
||||||
proposals = proposer.get_proposals(
|
proposals = proposer.get_spec_proposals(
|
||||||
execute_model_req=ExecuteModelRequest(
|
execute_model_req=ExecuteModelRequest(
|
||||||
seq_group_metadata_list=seq_group_metadata_list,
|
seq_group_metadata_list=seq_group_metadata_list,
|
||||||
num_lookahead_slots=k), )
|
num_lookahead_slots=k), )
|
||||||
|
@ -307,7 +307,8 @@ def test_draft_proposals_full_speculation_len():
|
|||||||
|
|
||||||
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
|
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
|
||||||
|
|
||||||
proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
|
proposals = proposer.get_spec_proposals(
|
||||||
|
execute_model_req=ExecuteModelRequest(
|
||||||
seq_group_metadata_list=seq_group_metadata_list,
|
seq_group_metadata_list=seq_group_metadata_list,
|
||||||
num_lookahead_slots=k), )
|
num_lookahead_slots=k), )
|
||||||
|
|
||||||
@ -344,7 +345,8 @@ def test_draft_proposals_no_speculations():
|
|||||||
k,
|
k,
|
||||||
prompt_len=prompt_len)
|
prompt_len=prompt_len)
|
||||||
|
|
||||||
proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
|
proposals = proposer.get_spec_proposals(
|
||||||
|
execute_model_req=ExecuteModelRequest(
|
||||||
seq_group_metadata_list=seq_group_metadata_list,
|
seq_group_metadata_list=seq_group_metadata_list,
|
||||||
num_lookahead_slots=k), )
|
num_lookahead_slots=k), )
|
||||||
|
|
||||||
@ -415,7 +417,8 @@ def test_draft_proposals_mixed_k():
|
|||||||
prev_output_token_len=prev_output_token_len,
|
prev_output_token_len=prev_output_token_len,
|
||||||
)
|
)
|
||||||
|
|
||||||
proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
|
proposals = proposer.get_spec_proposals(
|
||||||
|
execute_model_req=ExecuteModelRequest(
|
||||||
seq_group_metadata_list=seq_group_metadata_list,
|
seq_group_metadata_list=seq_group_metadata_list,
|
||||||
num_lookahead_slots=k), )
|
num_lookahead_slots=k), )
|
||||||
|
|
||||||
|
@ -50,7 +50,8 @@ def test_ngram_algo_correctness_for_single_no_match():
|
|||||||
block_size,
|
block_size,
|
||||||
final_prompt_lens=final_prompt_lens)
|
final_prompt_lens=final_prompt_lens)
|
||||||
|
|
||||||
proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
|
proposals = proposer.get_spec_proposals(
|
||||||
|
execute_model_req=ExecuteModelRequest(
|
||||||
seq_group_metadata_list=seq_group_metadata_list,
|
seq_group_metadata_list=seq_group_metadata_list,
|
||||||
num_lookahead_slots=proposal_len), )
|
num_lookahead_slots=proposal_len), )
|
||||||
|
|
||||||
@ -117,7 +118,8 @@ def test_ngram_algo_correctness_for_batches_not_match_all():
|
|||||||
block_size,
|
block_size,
|
||||||
final_prompt_lens=final_prompt_lens)
|
final_prompt_lens=final_prompt_lens)
|
||||||
|
|
||||||
proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
|
proposals = proposer.get_spec_proposals(
|
||||||
|
execute_model_req=ExecuteModelRequest(
|
||||||
seq_group_metadata_list=seq_group_metadata_list,
|
seq_group_metadata_list=seq_group_metadata_list,
|
||||||
num_lookahead_slots=proposal_len), )
|
num_lookahead_slots=proposal_len), )
|
||||||
|
|
||||||
@ -188,7 +190,8 @@ def test_ngram_algo_correctness_for_batches_match_all():
|
|||||||
block_size,
|
block_size,
|
||||||
final_prompt_lens=final_prompt_lens)
|
final_prompt_lens=final_prompt_lens)
|
||||||
|
|
||||||
proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
|
proposals = proposer.get_spec_proposals(
|
||||||
|
execute_model_req=ExecuteModelRequest(
|
||||||
seq_group_metadata_list=seq_group_metadata_list,
|
seq_group_metadata_list=seq_group_metadata_list,
|
||||||
num_lookahead_slots=proposal_len), )
|
num_lookahead_slots=proposal_len), )
|
||||||
|
|
||||||
|
@ -55,7 +55,7 @@ class SpeculativeScores:
|
|||||||
class SpeculativeProposer(ABC):
|
class SpeculativeProposer(ABC):
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_proposals(
|
def get_spec_proposals(
|
||||||
self,
|
self,
|
||||||
execute_model_req: ExecuteModelRequest,
|
execute_model_req: ExecuteModelRequest,
|
||||||
) -> SpeculativeProposals:
|
) -> SpeculativeProposals:
|
||||||
|
@ -7,11 +7,12 @@ import torch
|
|||||||
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
|
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
|
||||||
SequenceGroupMetadata)
|
SequenceGroupMetadata)
|
||||||
from vllm.spec_decode.interfaces import SpeculativeProposals
|
from vllm.spec_decode.interfaces import SpeculativeProposals
|
||||||
|
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
|
||||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||||
from vllm.worker.worker import Worker
|
from vllm.worker.worker import Worker
|
||||||
|
|
||||||
|
|
||||||
class MultiStepWorker(Worker):
|
class MultiStepWorker(Worker, ProposerWorkerBase):
|
||||||
"""The MultiStepWorker is equivalent to a Worker except that it allows
|
"""The MultiStepWorker is equivalent to a Worker except that it allows
|
||||||
multiple forward passes in a single call, assuming the scheduler has
|
multiple forward passes in a single call, assuming the scheduler has
|
||||||
allocated enough space to store the additional KV. This reduces overhead
|
allocated enough space to store the additional KV. This reduces overhead
|
||||||
@ -33,7 +34,7 @@ class MultiStepWorker(Worker):
|
|||||||
super().init_device()
|
super().init_device()
|
||||||
|
|
||||||
self._proposer = Top1Proposer(
|
self._proposer = Top1Proposer(
|
||||||
weakref.proxy(self),
|
weakref.proxy(self), # type: ignore[arg-type]
|
||||||
self.device,
|
self.device,
|
||||||
self.vocab_size,
|
self.vocab_size,
|
||||||
max_proposal_len=self.max_model_len,
|
max_proposal_len=self.max_model_len,
|
||||||
@ -92,11 +93,12 @@ class MultiStepWorker(Worker):
|
|||||||
speculative tokens per sequence is determined by max_proposal_len.
|
speculative tokens per sequence is determined by max_proposal_len.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return self._proposer.get_proposals(execute_model_req)
|
return self._proposer.get_spec_proposals(execute_model_req)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def _append_new_tokens(
|
def _append_new_tokens(
|
||||||
self, model_output: SamplerOutput,
|
model_output: List[SamplerOutput],
|
||||||
seq_group_metadata_list: SequenceGroupMetadata) -> None:
|
seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
|
||||||
"""Given model output from a single run, append the tokens to the
|
"""Given model output from a single run, append the tokens to the
|
||||||
sequences. This is normally done outside of the worker, but it is
|
sequences. This is normally done outside of the worker, but it is
|
||||||
required if the worker is to perform multiple forward passes.
|
required if the worker is to perform multiple forward passes.
|
||||||
@ -116,8 +118,9 @@ class MultiStepWorker(Worker):
|
|||||||
seq.append_token_id(token_id, token_logprob.logprob)
|
seq.append_token_id(token_id, token_logprob.logprob)
|
||||||
seq.update_num_computed_tokens(1)
|
seq.update_num_computed_tokens(1)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def _shallow_copy_inputs(
|
def _shallow_copy_inputs(
|
||||||
self, seq_group_metadata_list: List[SequenceGroupMetadata]
|
seq_group_metadata_list: List[SequenceGroupMetadata]
|
||||||
) -> List[SequenceGroupMetadata]:
|
) -> List[SequenceGroupMetadata]:
|
||||||
"""Copy input data structures to remove side-effects when input data
|
"""Copy input data structures to remove side-effects when input data
|
||||||
structures are shared with other modules.
|
structures are shared with other modules.
|
||||||
|
@ -5,15 +5,16 @@ import torch
|
|||||||
|
|
||||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||||
from vllm.spec_decode.interfaces import SpeculativeProposals
|
from vllm.spec_decode.interfaces import SpeculativeProposals
|
||||||
|
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
|
||||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||||
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
|
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
|
||||||
|
|
||||||
|
|
||||||
class NGramWorker(LoraNotSupportedWorkerBase):
|
class NGramWorker(NonLLMProposerWorkerBase, LoraNotSupportedWorkerBase):
|
||||||
"""NGramWorker provides a light drafter without need for model.
|
"""NGramWorker provides a light drafter without need for model.
|
||||||
|
|
||||||
Current NGramWorker only implement prompt lookup decoding,
|
Current NGramWorker only implement prompt lookup decoding,
|
||||||
and in future we may also do RAG type drafter and other scenerios
|
and in future we may also do RAG type drafter and other scenarios
|
||||||
which don't rely on LLM model to give proposals.
|
which don't rely on LLM model to give proposals.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -38,34 +39,11 @@ class NGramWorker(LoraNotSupportedWorkerBase):
|
|||||||
|
|
||||||
# Current only support Top1Proposer
|
# Current only support Top1Proposer
|
||||||
self._proposer = Top1Proposer(
|
self._proposer = Top1Proposer(
|
||||||
weakref.proxy(self),
|
weakref.proxy(self), # type: ignore[arg-type]
|
||||||
device=self.device,
|
device=self.device,
|
||||||
vocab_size=self.vocab_size,
|
vocab_size=self.vocab_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
def set_include_gpu_probs_tensor(self):
|
|
||||||
# NGram don't need gpu sampler
|
|
||||||
pass
|
|
||||||
|
|
||||||
def execute_model(
|
|
||||||
self,
|
|
||||||
execute_model_req: Optional[ExecuteModelRequest] = None) -> None:
|
|
||||||
"""NGram doesn't depend on model execution, just pass this function"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def determine_num_available_blocks(self) -> None:
|
|
||||||
"""NGram doesn't depend on model execution, no need to check blocks"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def initialize_cache(self, num_gpu_blocks: int,
|
|
||||||
num_cpu_blocks: int) -> None:
|
|
||||||
"""As there is no cache need to handle, just pass this function"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def get_cache_block_size_bytes(self):
|
|
||||||
"""Return the size of a cache block in bytes."""
|
|
||||||
return 0
|
|
||||||
|
|
||||||
def sampler_output(
|
def sampler_output(
|
||||||
self,
|
self,
|
||||||
execute_model_req: ExecuteModelRequest,
|
execute_model_req: ExecuteModelRequest,
|
||||||
@ -97,7 +75,6 @@ class NGramWorker(LoraNotSupportedWorkerBase):
|
|||||||
-1,
|
-1,
|
||||||
):
|
):
|
||||||
ngram_tensor = input_ids[-ngram_size:]
|
ngram_tensor = input_ids[-ngram_size:]
|
||||||
proposal_start_idx = None
|
|
||||||
if ngram_size == 1:
|
if ngram_size == 1:
|
||||||
# Do not match itself and do not use unfold and all
|
# Do not match itself and do not use unfold and all
|
||||||
matches = (input_ids[:-1] == ngram_tensor)
|
matches = (input_ids[:-1] == ngram_tensor)
|
||||||
@ -161,7 +138,7 @@ class NGramWorker(LoraNotSupportedWorkerBase):
|
|||||||
speculative tokens per sequence is determined by max_proposal_len.
|
speculative tokens per sequence is determined by max_proposal_len.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return self._proposer.get_proposals(execute_model_req)
|
return self._proposer.get_spec_proposals(execute_model_req)
|
||||||
|
|
||||||
def _raise_if_unsupported(
|
def _raise_if_unsupported(
|
||||||
self,
|
self,
|
||||||
|
44
vllm/spec_decode/proposer_worker_base.py
Normal file
44
vllm/spec_decode/proposer_worker_base.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||||
|
from vllm.spec_decode.interfaces import SpeculativeProposer
|
||||||
|
from vllm.worker.worker_base import WorkerBase
|
||||||
|
|
||||||
|
|
||||||
|
class ProposerWorkerBase(WorkerBase, SpeculativeProposer):
|
||||||
|
"""Interface for proposer workers"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def sampler_output(
|
||||||
|
self,
|
||||||
|
execute_model_req: ExecuteModelRequest,
|
||||||
|
sample_len: int,
|
||||||
|
) -> Tuple[Optional[List[SamplerOutput]], bool]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def set_include_gpu_probs_tensor(self):
|
||||||
|
"""Implementation optional"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class NonLLMProposerWorkerBase(ProposerWorkerBase, ABC):
|
||||||
|
"""Proposer worker which does not use a model with kvcache"""
|
||||||
|
|
||||||
|
def execute_model(
|
||||||
|
self,
|
||||||
|
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||||
|
) -> List[SamplerOutput]:
|
||||||
|
"""get_spec_proposals is used to get the proposals"""
|
||||||
|
return []
|
||||||
|
|
||||||
|
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||||
|
"""This is never called on the proposer, only the target model"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def initialize_cache(self, num_gpu_blocks: int,
|
||||||
|
num_cpu_blocks: int) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_cache_block_size_bytes(self) -> int:
|
||||||
|
return 0
|
@ -14,6 +14,7 @@ from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
|||||||
from vllm.spec_decode.metrics import AsyncMetricsCollector
|
from vllm.spec_decode.metrics import AsyncMetricsCollector
|
||||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||||
from vllm.spec_decode.ngram_worker import NGramWorker
|
from vllm.spec_decode.ngram_worker import NGramWorker
|
||||||
|
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
|
||||||
from vllm.spec_decode.util import (create_sequence_group_output,
|
from vllm.spec_decode.util import (create_sequence_group_output,
|
||||||
get_all_num_logprobs, get_all_seq_ids,
|
get_all_num_logprobs, get_all_seq_ids,
|
||||||
get_sampled_token_logprobs, nvtx_range,
|
get_sampled_token_logprobs, nvtx_range,
|
||||||
@ -117,7 +118,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
proposer_worker: WorkerBase,
|
proposer_worker: ProposerWorkerBase,
|
||||||
scorer_worker: WorkerBase,
|
scorer_worker: WorkerBase,
|
||||||
rejection_sampler: RejectionSampler,
|
rejection_sampler: RejectionSampler,
|
||||||
metrics_collector: Optional[AsyncMetricsCollector] = None,
|
metrics_collector: Optional[AsyncMetricsCollector] = None,
|
||||||
@ -260,7 +261,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
|
|
||||||
# This is required as if the number of draft model runs changes
|
# This is required as if the number of draft model runs changes
|
||||||
# dynamically, the non-driver workers won't know unless we perform a
|
# dynamically, the non-driver workers won't know unless we perform a
|
||||||
# communication to inform then.
|
# communication to inform them.
|
||||||
broadcast_dict = dict(
|
broadcast_dict = dict(
|
||||||
num_lookahead_slots=num_lookahead_slots,
|
num_lookahead_slots=num_lookahead_slots,
|
||||||
disable_all_speculation=disable_all_speculation,
|
disable_all_speculation=disable_all_speculation,
|
||||||
|
@ -6,8 +6,8 @@ from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
|
|||||||
SequenceGroupMetadata)
|
SequenceGroupMetadata)
|
||||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||||
SpeculativeProposer)
|
SpeculativeProposer)
|
||||||
|
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
|
||||||
from vllm.spec_decode.util import sampler_output_to_torch
|
from vllm.spec_decode.util import sampler_output_to_torch
|
||||||
from vllm.worker.worker_base import WorkerBase
|
|
||||||
|
|
||||||
|
|
||||||
class Top1Proposer(SpeculativeProposer):
|
class Top1Proposer(SpeculativeProposer):
|
||||||
@ -29,7 +29,7 @@ class Top1Proposer(SpeculativeProposer):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
worker: WorkerBase,
|
worker: ProposerWorkerBase,
|
||||||
device: str,
|
device: str,
|
||||||
vocab_size: int,
|
vocab_size: int,
|
||||||
max_proposal_len: Optional[int] = None,
|
max_proposal_len: Optional[int] = None,
|
||||||
@ -39,7 +39,7 @@ class Top1Proposer(SpeculativeProposer):
|
|||||||
self.max_proposal_len = max_proposal_len
|
self.max_proposal_len = max_proposal_len
|
||||||
self._vocab_size = vocab_size
|
self._vocab_size = vocab_size
|
||||||
|
|
||||||
def get_proposals(
|
def get_spec_proposals(
|
||||||
self,
|
self,
|
||||||
execute_model_req: ExecuteModelRequest,
|
execute_model_req: ExecuteModelRequest,
|
||||||
) -> SpeculativeProposals:
|
) -> SpeculativeProposals:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user