2024-07-10 16:02:47 -07:00
|
|
|
from collections import defaultdict
|
2024-03-08 23:32:46 -08:00
|
|
|
from functools import cached_property
|
2024-07-10 16:02:47 -07:00
|
|
|
from typing import Any, Dict, List, Optional, Set, Tuple
|
2024-03-08 23:32:46 -08:00
|
|
|
|
|
|
|
import torch
|
|
|
|
|
2024-06-25 18:56:06 +09:00
|
|
|
from vllm.config import ParallelConfig, SpeculativeConfig
|
2024-05-16 00:53:51 -07:00
|
|
|
from vllm.distributed.communication_op import broadcast_tensor_dict
|
2024-04-16 13:09:21 -07:00
|
|
|
from vllm.logger import init_logger
|
2024-03-25 23:59:47 +09:00
|
|
|
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
2024-07-01 00:33:05 -07:00
|
|
|
from vllm.model_executor.layers.spec_decode_base_sampler import (
|
|
|
|
SpecDecodeBaseSampler)
|
|
|
|
from vllm.model_executor.layers.typical_acceptance_sampler import (
|
|
|
|
TypicalAcceptanceSampler)
|
2024-06-15 12:45:31 +08:00
|
|
|
from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest,
|
2024-06-20 20:23:12 -04:00
|
|
|
HiddenStates, SamplerOutput, SequenceGroupMetadata,
|
2024-07-10 16:02:47 -07:00
|
|
|
get_all_seq_ids_and_request_ids)
|
2024-03-25 23:59:47 +09:00
|
|
|
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
|
2024-06-28 09:17:51 -07:00
|
|
|
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
|
2024-03-25 23:59:47 +09:00
|
|
|
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
|
|
|
SpeculativeScorer, SpeculativeScores)
|
2024-07-10 07:04:02 +05:30
|
|
|
from vllm.spec_decode.medusa_worker import MedusaWorker
|
2024-03-25 23:59:47 +09:00
|
|
|
from vllm.spec_decode.metrics import AsyncMetricsCollector
|
2024-06-20 20:23:12 -04:00
|
|
|
from vllm.spec_decode.mlp_speculator_worker import MLPSpeculatorWorker
|
2024-03-08 23:32:46 -08:00
|
|
|
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
2024-05-02 02:13:03 +08:00
|
|
|
from vllm.spec_decode.ngram_worker import NGramWorker
|
2024-06-05 14:53:05 -07:00
|
|
|
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
|
2024-06-25 18:56:06 +09:00
|
|
|
from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker
|
2024-05-03 15:52:01 -07:00
|
|
|
from vllm.spec_decode.util import (create_sequence_group_output,
|
2024-06-20 20:23:12 -04:00
|
|
|
get_all_num_logprobs,
|
2024-05-03 15:52:01 -07:00
|
|
|
get_sampled_token_logprobs, nvtx_range,
|
2024-03-10 19:49:14 -07:00
|
|
|
split_batch_by_proposal_len)
|
2024-05-16 00:53:51 -07:00
|
|
|
from vllm.worker.worker import Worker
|
2024-04-16 13:09:21 -07:00
|
|
|
from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase
|
|
|
|
|
|
|
|
logger = init_logger(__name__)
|
2024-03-08 23:32:46 -08:00
|
|
|
|
|
|
|
|
2024-05-16 00:53:51 -07:00
|
|
|
def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
|
|
|
|
"""Helper method that is the entrypoint for Executors which use
|
|
|
|
WorkerWrapper. It constructs a SpecDecodeWorker from the speculative config.
|
|
|
|
"""
|
|
|
|
assert "speculative_config" in kwargs
|
2024-06-10 19:29:02 -07:00
|
|
|
speculative_config: SpeculativeConfig = kwargs.get("speculative_config")
|
2024-05-16 00:53:51 -07:00
|
|
|
assert speculative_config is not None
|
|
|
|
|
|
|
|
target_worker = Worker(*args, **kwargs)
|
|
|
|
|
|
|
|
draft_worker_kwargs = kwargs.copy()
|
|
|
|
# Override draft-model specific worker args.
|
|
|
|
draft_worker_kwargs.update(
|
|
|
|
model_config=speculative_config.draft_model_config,
|
|
|
|
parallel_config=speculative_config.draft_parallel_config,
|
|
|
|
ngram_prompt_lookup_max=speculative_config.ngram_prompt_lookup_max,
|
|
|
|
ngram_prompt_lookup_min=speculative_config.ngram_prompt_lookup_min,
|
|
|
|
# TODO allow draft-model specific load config.
|
|
|
|
#load_config=load_config,
|
|
|
|
)
|
|
|
|
|
|
|
|
spec_decode_worker = SpecDecodeWorker.create_worker(
|
|
|
|
scorer_worker=target_worker,
|
|
|
|
draft_worker_kwargs=draft_worker_kwargs,
|
|
|
|
disable_by_batch_size=speculative_config.
|
|
|
|
speculative_disable_by_batch_size,
|
2024-07-01 00:33:05 -07:00
|
|
|
draft_token_acceptance_method=speculative_config.
|
|
|
|
draft_token_acceptance_method,
|
|
|
|
typical_acceptance_sampler_posterior_threshold=speculative_config.
|
|
|
|
typical_acceptance_sampler_posterior_threshold,
|
|
|
|
typical_acceptance_sampler_posterior_alpha=speculative_config.
|
|
|
|
typical_acceptance_sampler_posterior_alpha)
|
2024-05-16 00:53:51 -07:00
|
|
|
|
|
|
|
return spec_decode_worker
|
|
|
|
|
|
|
|
|
2024-04-09 11:44:15 -07:00
|
|
|
class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
2024-03-08 23:32:46 -08:00
|
|
|
"""Worker which implements speculative decoding.
|
|
|
|
|
|
|
|
Speculative decoding reduces decoding per-token latency by using a proposal
|
|
|
|
method, such as a small draft model, to speculate ahead of a larger LLM. The
|
|
|
|
probabilities of the speculative tokens are then determined by the larger
|
|
|
|
LLM, after which some verification routine determines which (if any) of the
|
|
|
|
speculative tokens are accepted by the larger LLM.
|
|
|
|
|
2024-03-10 19:49:14 -07:00
|
|
|
See https://github.com/vllm-project/vllm/pull/2188 and
|
2024-03-08 23:32:46 -08:00
|
|
|
https://github.com/vllm-project/vllm/pull/3103 for more info.
|
|
|
|
|
|
|
|
The current implementation has the following limitations:
|
|
|
|
* Only draft-model proposal is implemented (contributions for more forms are
|
|
|
|
welcome!).
|
|
|
|
* Only top-1 proposal and scoring are implemented. Tree-attention is left as
|
|
|
|
future work.
|
|
|
|
* All sequences in a batch must have the same proposal length, or zero. This
|
|
|
|
can be improved by having per-sequence speculation in the future.
|
|
|
|
* The scoring forward pass is done without an MQA kernel, which is
|
|
|
|
suboptimal especially as the batch size, proposal length, and sequence
|
|
|
|
lengths grow. Contributions to add a MQA scoring are welcome once
|
|
|
|
correctness tests pass.
|
|
|
|
More info here https://docs.google.com/document/d/1T-JaS2T1NRfdP51qzqpyakoCXxSXTtORppiwaj5asxA/edit.
|
|
|
|
"""
|
|
|
|
|
2024-04-16 13:09:21 -07:00
|
|
|
@classmethod
|
2024-05-02 02:13:03 +08:00
|
|
|
def create_worker(
|
|
|
|
cls,
|
2024-06-25 18:56:06 +09:00
|
|
|
scorer_worker: Worker,
|
2024-05-08 14:44:00 -07:00
|
|
|
draft_worker_kwargs: Dict[str, Any],
|
|
|
|
disable_by_batch_size: Optional[int],
|
2024-07-01 00:33:05 -07:00
|
|
|
draft_token_acceptance_method: str,
|
|
|
|
typical_acceptance_sampler_posterior_threshold: float,
|
|
|
|
typical_acceptance_sampler_posterior_alpha: float,
|
2024-05-02 02:13:03 +08:00
|
|
|
) -> "SpecDecodeWorker":
|
|
|
|
|
2024-05-08 02:40:18 +08:00
|
|
|
ngram_prompt_lookup_max = (
|
|
|
|
draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
|
|
|
|
ngram_prompt_lookup_min = (
|
|
|
|
draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
|
2024-05-02 02:13:03 +08:00
|
|
|
if ngram_prompt_lookup_max > 0:
|
|
|
|
proposer_worker = NGramWorker(**draft_worker_kwargs)
|
|
|
|
proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
|
|
|
|
ngram_prompt_lookup_max)
|
|
|
|
else:
|
2024-06-25 18:56:06 +09:00
|
|
|
draft_parallel_config: ParallelConfig = draft_worker_kwargs[
|
|
|
|
'parallel_config']
|
|
|
|
draft_tp = draft_parallel_config.tensor_parallel_size
|
|
|
|
target_tp = scorer_worker.parallel_config.tensor_parallel_size
|
|
|
|
|
2024-07-02 07:20:29 -07:00
|
|
|
if draft_worker_kwargs[
|
|
|
|
"model_config"].hf_config.model_type == "mlp_speculator":
|
|
|
|
proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs)
|
2024-07-10 07:04:02 +05:30
|
|
|
elif draft_worker_kwargs[
|
|
|
|
"model_config"].hf_config.model_type == "medusa":
|
|
|
|
proposer_worker = MedusaWorker(**draft_worker_kwargs)
|
2024-07-02 07:20:29 -07:00
|
|
|
else:
|
|
|
|
if draft_tp == 1:
|
|
|
|
draft_worker_kwargs[
|
|
|
|
"model_runner_cls"] = TP1DraftModelRunner
|
|
|
|
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
|
|
|
|
|
2024-06-25 18:56:06 +09:00
|
|
|
proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
|
|
|
|
proposer_worker, draft_tp, target_tp)
|
2024-05-02 02:13:03 +08:00
|
|
|
|
2024-05-08 02:40:18 +08:00
|
|
|
logger.info("Configuring SpecDecodeWorker with proposer=%s",
|
|
|
|
type(proposer_worker))
|
|
|
|
|
2024-07-01 00:33:05 -07:00
|
|
|
spec_decode_sampler: SpecDecodeBaseSampler = None
|
|
|
|
if draft_token_acceptance_method == "rejection_sampler":
|
|
|
|
spec_decode_sampler = RejectionSampler(
|
2024-07-10 16:02:47 -07:00
|
|
|
disable_bonus_tokens=False, )
|
2024-07-01 00:33:05 -07:00
|
|
|
elif draft_token_acceptance_method == "typical_acceptance_sampler":
|
|
|
|
spec_decode_sampler = TypicalAcceptanceSampler(
|
2024-07-10 16:02:47 -07:00
|
|
|
disable_bonus_tokens=False,
|
2024-07-01 00:33:05 -07:00
|
|
|
posterior_threshold=\
|
|
|
|
typical_acceptance_sampler_posterior_threshold,
|
|
|
|
posterior_alpha=typical_acceptance_sampler_posterior_alpha,
|
|
|
|
)
|
|
|
|
logger.info("Configuring SpecDecodeWorker with sampler=%s",
|
|
|
|
type(spec_decode_sampler))
|
|
|
|
|
2024-06-10 19:29:02 -07:00
|
|
|
return SpecDecodeWorker(proposer_worker,
|
|
|
|
scorer_worker,
|
|
|
|
disable_by_batch_size=disable_by_batch_size,
|
2024-07-01 00:33:05 -07:00
|
|
|
spec_decode_sampler=spec_decode_sampler)
|
2024-04-16 13:09:21 -07:00
|
|
|
|
2024-03-08 23:32:46 -08:00
|
|
|
def __init__(
|
|
|
|
self,
|
2024-06-05 14:53:05 -07:00
|
|
|
proposer_worker: ProposerWorkerBase,
|
2024-04-16 13:09:21 -07:00
|
|
|
scorer_worker: WorkerBase,
|
2024-07-01 00:33:05 -07:00
|
|
|
spec_decode_sampler: SpecDecodeBaseSampler,
|
2024-03-08 23:32:46 -08:00
|
|
|
metrics_collector: Optional[AsyncMetricsCollector] = None,
|
2024-05-08 14:44:00 -07:00
|
|
|
disable_by_batch_size: Optional[int] = None,
|
2024-03-08 23:32:46 -08:00
|
|
|
):
|
|
|
|
"""
|
|
|
|
Create a SpecDecodeWorker.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
proposer_worker: A worker that can produce speculative tokens for
|
|
|
|
sequences.
|
|
|
|
scorer_worker: A worker that produces probabilities of speculative
|
|
|
|
tokens according to some base model. Typically a vanilla vLLM
|
|
|
|
Worker.
|
2024-07-01 00:33:05 -07:00
|
|
|
spec_decode_sampler: A Torch module used to perform acceptance
|
|
|
|
sampling of the draft tokens in the verification step of
|
|
|
|
speculative decoding. Currently we support two different
|
|
|
|
types of sampler namely RejectionSampler and
|
|
|
|
TypicalAcceptanceSampler. 'spec_decode_sampler' is either an
|
|
|
|
instance of RejectionSampler or TypicalAcceptanceSampler.
|
2024-05-08 14:44:00 -07:00
|
|
|
disable_by_batch_size: If the batch size is larger than this,
|
|
|
|
disable speculative decoding for new incoming requests.
|
2024-03-08 23:32:46 -08:00
|
|
|
metrics_collector: Helper class for collecting metrics; can be set
|
|
|
|
for testing purposes.
|
|
|
|
"""
|
|
|
|
self.proposer_worker = proposer_worker
|
|
|
|
self.scorer_worker = scorer_worker
|
2024-05-08 14:44:00 -07:00
|
|
|
self.disable_by_batch_size = disable_by_batch_size or float("inf")
|
2024-07-01 00:33:05 -07:00
|
|
|
self.spec_decode_sampler = spec_decode_sampler
|
2024-03-08 23:32:46 -08:00
|
|
|
self._metrics = AsyncMetricsCollector(
|
2024-07-01 00:33:05 -07:00
|
|
|
self.spec_decode_sampler
|
2024-03-08 23:32:46 -08:00
|
|
|
) if metrics_collector is None else metrics_collector
|
2024-07-10 16:02:47 -07:00
|
|
|
# Tracks the sequence IDs that received a bonus token ID in
|
|
|
|
# their last forward pass. Needed only if KV cache is being
|
|
|
|
# used for token generation such as in the case of MultiStepWorker.
|
|
|
|
self._seq_with_bonus_token_in_last_step: Set[int] = set()
|
|
|
|
# Tracks the currently active request ids and the sequence IDs
|
|
|
|
# corresponding to them
|
|
|
|
self._request_id_seq_id_mapping: Dict[str, Set[int]] = defaultdict(set)
|
|
|
|
# Tracks if the proposer worker uses the KV cache or not.
|
|
|
|
|
2024-07-01 00:33:05 -07:00
|
|
|
self.probs_dtype = self.spec_decode_sampler.probs_dtype
|
|
|
|
self.token_id_dtype = self.spec_decode_sampler.token_id_dtype
|
2024-04-18 09:28:43 +09:00
|
|
|
# Lazy initiazliation.
|
|
|
|
self.scorer: SpeculativeScorer
|
2024-03-08 23:32:46 -08:00
|
|
|
|
2024-06-20 20:23:12 -04:00
|
|
|
# Hidden states from target model to pass to proposer
|
|
|
|
# in the subsequent step.
|
|
|
|
self.previous_hidden_states: Optional[HiddenStates] = None
|
|
|
|
|
2024-03-21 18:22:17 -07:00
|
|
|
def init_device(self) -> None:
|
2024-03-08 23:32:46 -08:00
|
|
|
"""Initialize both scorer and proposer models.
|
|
|
|
"""
|
|
|
|
# The scorer worker model is initialized first in case the proposer
|
|
|
|
# model has a smaller TP degree than the target worker.
|
2024-03-21 18:22:17 -07:00
|
|
|
self.scorer_worker.init_device()
|
|
|
|
self.proposer_worker.init_device()
|
2024-03-08 23:32:46 -08:00
|
|
|
|
2024-04-16 13:09:21 -07:00
|
|
|
# NOTE(cade): load_model is not part of the WorkerBase interface.
|
|
|
|
self.scorer_worker.load_model()
|
|
|
|
self.proposer_worker.load_model()
|
|
|
|
|
2024-03-08 23:32:46 -08:00
|
|
|
self._metrics.init_gpu_tensors(self.rank)
|
2024-07-01 00:33:05 -07:00
|
|
|
self.spec_decode_sampler.init_gpu_tensors(self.rank)
|
|
|
|
|
2024-03-08 23:32:46 -08:00
|
|
|
self.scorer = BatchExpansionTop1Scorer(
|
|
|
|
scorer_worker=self.scorer_worker,
|
|
|
|
device=self.device,
|
|
|
|
vocab_size=self._vocab_size)
|
|
|
|
|
2024-04-23 01:02:36 -07:00
|
|
|
self._configure_model_sampler_for_spec_decode()
|
|
|
|
|
2024-05-16 00:53:51 -07:00
|
|
|
def load_model(self, *args, **kwargs):
|
|
|
|
pass
|
|
|
|
|
2024-04-23 01:02:36 -07:00
|
|
|
def _configure_model_sampler_for_spec_decode(self):
|
|
|
|
"""Configure model sampler to emit GPU tensors. This allows spec decode
|
|
|
|
to keep data on device without transferring to CPU and serializing,
|
2024-07-01 00:33:05 -07:00
|
|
|
which significantly reduces overhead of sampling during verification.
|
2024-04-23 01:02:36 -07:00
|
|
|
|
|
|
|
NOTE(cade): This breaks abstraction boundaries pretty badly. The better
|
|
|
|
design is to have the "move to CPU and serialize" sampling decision be
|
|
|
|
done outside of the model/sampler; this way the "last-mile" worker
|
|
|
|
object which interfaces with the scheduler can serialize and incur the
|
|
|
|
performance hit as necessary. This allows us to run the worker several
|
|
|
|
iterations in a row without incurring the "move to CPU and serialize"
|
|
|
|
performance penalty.
|
|
|
|
|
|
|
|
Since this requires a large change to vLLM, we defer it to later and
|
|
|
|
temporarily accept this broken abstraction boundary.
|
|
|
|
|
|
|
|
NOTE(cade): This will require a special check if the proposer worker
|
|
|
|
does not have a sampler (e.g. ngram speculation).
|
|
|
|
"""
|
|
|
|
(self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor
|
|
|
|
) = True
|
2024-05-02 02:13:03 +08:00
|
|
|
self.proposer_worker.set_include_gpu_probs_tensor()
|
2024-04-23 01:02:36 -07:00
|
|
|
|
2024-04-09 11:44:15 -07:00
|
|
|
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
2024-03-08 23:32:46 -08:00
|
|
|
"""Determine the number of cache blocks to use.
|
|
|
|
|
|
|
|
This is done by profiling the scorer model (which is typically the
|
|
|
|
larger of the two). Then the total memory which would be used by the
|
|
|
|
scorer cache is divided evenly between the proposer and scorer model KV,
|
|
|
|
such that the number of blocks is equal in both KV caches.
|
|
|
|
"""
|
|
|
|
num_gpu_blocks, num_cpu_blocks = (
|
2024-04-09 11:44:15 -07:00
|
|
|
self.scorer_worker.determine_num_available_blocks())
|
2024-03-08 23:32:46 -08:00
|
|
|
|
2024-03-10 19:49:14 -07:00
|
|
|
scorer_cache_block_size_bytes = (
|
2024-04-09 11:44:15 -07:00
|
|
|
self.scorer_worker.get_cache_block_size_bytes())
|
2024-03-10 19:49:14 -07:00
|
|
|
proposer_cache_block_size_bytes = (
|
2024-04-09 11:44:15 -07:00
|
|
|
self.proposer_worker.get_cache_block_size_bytes())
|
2024-03-08 23:32:46 -08:00
|
|
|
|
|
|
|
new_num_gpu_blocks = split_num_cache_blocks_evenly(
|
|
|
|
scorer_cache_block_size_bytes, proposer_cache_block_size_bytes,
|
|
|
|
num_gpu_blocks)
|
|
|
|
return new_num_gpu_blocks, num_cpu_blocks
|
|
|
|
|
2024-04-09 11:44:15 -07:00
|
|
|
def initialize_cache(self, num_gpu_blocks: int,
|
|
|
|
num_cpu_blocks: int) -> None:
|
2024-03-08 23:32:46 -08:00
|
|
|
"""Initialize the cache engine of the scorer and proposer workers.
|
|
|
|
"""
|
2024-04-09 11:44:15 -07:00
|
|
|
self.scorer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
|
|
|
|
num_cpu_blocks=num_cpu_blocks)
|
|
|
|
self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
|
|
|
|
num_cpu_blocks=num_cpu_blocks)
|
2024-03-08 23:32:46 -08:00
|
|
|
|
|
|
|
@torch.inference_mode()
|
|
|
|
def execute_model(
|
2024-05-16 00:53:51 -07:00
|
|
|
self,
|
|
|
|
execute_model_req: Optional[ExecuteModelRequest] = None
|
|
|
|
) -> List[SamplerOutput]:
|
2024-03-08 23:32:46 -08:00
|
|
|
"""Perform speculative decoding on the input batch.
|
|
|
|
"""
|
2024-05-22 14:17:27 -07:00
|
|
|
if self.rank != self._driver_rank:
|
|
|
|
self._run_non_driver_rank()
|
|
|
|
return []
|
2024-03-08 23:32:46 -08:00
|
|
|
|
2024-05-22 14:17:27 -07:00
|
|
|
if execute_model_req is None:
|
|
|
|
# This signals that there's no more requests to process for now.
|
|
|
|
# All workers are running infinite loop with broadcast_tensor_dict,
|
|
|
|
# and it stops the loop when the driver broadcasts an empty input.
|
|
|
|
# Send an empty input to notify all other workers to stop their
|
|
|
|
# execution loop.
|
|
|
|
broadcast_tensor_dict({}, src=0)
|
2024-05-16 00:53:51 -07:00
|
|
|
return []
|
2024-03-08 23:32:46 -08:00
|
|
|
|
2024-07-10 16:02:47 -07:00
|
|
|
self._track_finished_requests(execute_model_req)
|
2024-05-22 14:17:27 -07:00
|
|
|
disable_all_speculation = self._should_disable_all_speculation(
|
|
|
|
execute_model_req)
|
|
|
|
num_lookahead_slots = execute_model_req.num_lookahead_slots
|
|
|
|
|
|
|
|
# Broadcast how many lookahead slots are scheduled for this step, and
|
|
|
|
# whether all speculation is disabled, to all non-driver workers.
|
|
|
|
|
|
|
|
# This is required as if the number of draft model runs changes
|
|
|
|
# dynamically, the non-driver workers won't know unless we perform a
|
2024-06-05 14:53:05 -07:00
|
|
|
# communication to inform them.
|
2024-05-22 14:17:27 -07:00
|
|
|
broadcast_dict = dict(
|
|
|
|
num_lookahead_slots=num_lookahead_slots,
|
|
|
|
disable_all_speculation=disable_all_speculation,
|
|
|
|
)
|
|
|
|
broadcast_tensor_dict(broadcast_dict, src=self._driver_rank)
|
|
|
|
|
|
|
|
assert execute_model_req.seq_group_metadata_list is not None, (
|
|
|
|
"speculative decoding requires non-None seq_group_metadata_list")
|
|
|
|
|
|
|
|
self._maybe_disable_speculative_tokens(
|
|
|
|
disable_all_speculation, execute_model_req.seq_group_metadata_list)
|
|
|
|
|
2024-05-25 10:00:14 -07:00
|
|
|
# Speculative decoding is disabled in the following cases:
|
|
|
|
# 1. Prefill phase: Speculative decoding is not
|
|
|
|
# used during the prefill phase.
|
|
|
|
# 2. Auto-disable enabled: The running queue size exceeds
|
|
|
|
# the specified threshold.
|
|
|
|
# 3. No request: There are no requests in the batch.
|
|
|
|
# In any of these cases, the proposer and scorer workers
|
|
|
|
# are called normally.
|
2024-05-22 14:17:27 -07:00
|
|
|
if num_lookahead_slots == 0 or len(
|
2024-05-25 10:00:14 -07:00
|
|
|
execute_model_req.seq_group_metadata_list
|
|
|
|
) == 0 or disable_all_speculation:
|
2024-05-22 14:17:27 -07:00
|
|
|
return self._run_no_spec(execute_model_req,
|
|
|
|
skip_proposer=disable_all_speculation)
|
|
|
|
|
|
|
|
return self._run_speculative_decoding_step(execute_model_req,
|
|
|
|
num_lookahead_slots)
|
|
|
|
|
|
|
|
@torch.inference_mode()
|
|
|
|
def start_worker_execution_loop(self) -> None:
|
|
|
|
"""Execute model loop to perform speculative decoding
|
|
|
|
in parallel worker."""
|
|
|
|
while self._run_non_driver_rank():
|
|
|
|
pass
|
|
|
|
|
2024-05-16 00:53:51 -07:00
|
|
|
def _should_disable_all_speculation(
|
|
|
|
self, execute_model_req: ExecuteModelRequest) -> bool:
|
2024-05-08 14:44:00 -07:00
|
|
|
# When the batch size is too large, disable speculative decoding
|
|
|
|
# to stop trading off throughput for latency.
|
2024-05-16 00:53:51 -07:00
|
|
|
disable_all_speculation = (execute_model_req.running_queue_size >=
|
|
|
|
self.disable_by_batch_size)
|
|
|
|
|
|
|
|
return disable_all_speculation
|
|
|
|
|
|
|
|
def _maybe_disable_speculative_tokens(
|
|
|
|
self, disable_all_speculation: bool,
|
|
|
|
seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
|
|
|
|
if not disable_all_speculation:
|
|
|
|
return
|
|
|
|
|
|
|
|
for seq_group_metadata in seq_group_metadata_list:
|
|
|
|
# Once num_speculative_tokens is set to 0, the spec decode
|
|
|
|
# of this request will be disabled forever.
|
|
|
|
# TODO(comaniac): We currently store spec decoding specific
|
|
|
|
# state in the global data structure, but we should maintain
|
|
|
|
# this state within spec decode worker.
|
|
|
|
seq_group_metadata.num_speculative_tokens = 0
|
2024-03-08 23:32:46 -08:00
|
|
|
|
|
|
|
@nvtx_range("spec_decode_worker._run_no_spec")
|
2024-05-08 14:44:00 -07:00
|
|
|
def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
|
|
|
|
skip_proposer: bool) -> List[SamplerOutput]:
|
2024-05-25 10:00:14 -07:00
|
|
|
"""Run a single generation step without any speculation. The input is
|
|
|
|
sent to the proposer and scorer model so that the KV cache is consistent
|
2024-05-08 14:44:00 -07:00
|
|
|
between the two. When skip_proposer is True, the proposer model is
|
|
|
|
not called, meaning that the kv-cache in proposer for requests is not
|
|
|
|
updated, so they cannot enable spec decode in the rest decoding.
|
2024-03-08 23:32:46 -08:00
|
|
|
"""
|
2024-05-08 14:44:00 -07:00
|
|
|
if not skip_proposer:
|
|
|
|
self.proposer_worker.execute_model(execute_model_req)
|
2024-03-08 23:32:46 -08:00
|
|
|
|
2024-05-03 17:47:07 -07:00
|
|
|
sampler_output = self.scorer_worker.execute_model(execute_model_req)
|
2024-04-16 13:09:21 -07:00
|
|
|
assert len(sampler_output) == 1
|
|
|
|
sampler_output = sampler_output[0]
|
2024-03-08 23:32:46 -08:00
|
|
|
|
2024-06-20 20:23:12 -04:00
|
|
|
# Store hidden states from target model execution.
|
|
|
|
hidden_states = sampler_output.hidden_states
|
|
|
|
if hidden_states is not None:
|
|
|
|
if self.previous_hidden_states is None:
|
|
|
|
self.previous_hidden_states = HiddenStates(
|
|
|
|
execute_model_req.seq_group_metadata_list, hidden_states)
|
|
|
|
else:
|
|
|
|
self.previous_hidden_states.update(
|
|
|
|
execute_model_req.seq_group_metadata_list, hidden_states)
|
|
|
|
|
2024-03-08 23:32:46 -08:00
|
|
|
# Clear device tensors from sampler output. This reduces communication
|
|
|
|
# overhead when the engine runs in a different process than the workers.
|
|
|
|
sampler_output.probs = None
|
|
|
|
sampler_output.sampled_tokens = None
|
2024-05-03 15:52:01 -07:00
|
|
|
sampler_output.logprobs = None
|
2024-03-08 23:32:46 -08:00
|
|
|
return [sampler_output]
|
|
|
|
|
2024-05-22 14:17:27 -07:00
|
|
|
def _run_non_driver_rank(self) -> bool:
|
2024-05-16 00:53:51 -07:00
|
|
|
"""Run proposer and verifier model in non-driver workers. This is used
|
|
|
|
for both speculation cases (num_lookahead_slots>0) and non-speculation
|
|
|
|
cases (e.g. prefill).
|
2024-05-22 14:17:27 -07:00
|
|
|
|
|
|
|
Returns True iff there are remaining sequences to process.
|
2024-05-16 00:53:51 -07:00
|
|
|
"""
|
2024-05-22 14:17:27 -07:00
|
|
|
assert self.rank != self._driver_rank
|
|
|
|
|
|
|
|
data = broadcast_tensor_dict(src=self._driver_rank)
|
|
|
|
if not data:
|
|
|
|
return False
|
|
|
|
num_lookahead_slots = data["num_lookahead_slots"]
|
2024-05-16 00:53:51 -07:00
|
|
|
|
|
|
|
# Even if num_lookahead_slots is zero, we want to run the proposer model
|
|
|
|
# as it may have KV.
|
|
|
|
#
|
|
|
|
# We run the proposer once per lookahead slot. In the future we should
|
|
|
|
# delegate how many times it runs to the proposer.
|
|
|
|
for _ in range(max(num_lookahead_slots, 1)):
|
2024-05-22 14:17:27 -07:00
|
|
|
self.proposer_worker.execute_model()
|
2024-05-16 00:53:51 -07:00
|
|
|
|
2024-05-22 14:17:27 -07:00
|
|
|
self.scorer_worker.execute_model()
|
|
|
|
return True
|
2024-05-16 00:53:51 -07:00
|
|
|
|
2024-03-08 23:32:46 -08:00
|
|
|
@nvtx_range("spec_decode_worker._run_speculative_decoding_step")
|
|
|
|
def _run_speculative_decoding_step(
|
2024-05-16 00:53:51 -07:00
|
|
|
self, execute_model_req: ExecuteModelRequest,
|
|
|
|
num_lookahead_slots: int) -> List[SamplerOutput]:
|
2024-03-08 23:32:46 -08:00
|
|
|
"""Execute a single step of speculative decoding.
|
|
|
|
|
|
|
|
This invokes the proposer worker to get k speculative tokens for each
|
|
|
|
sequence, then scores each speculative token using the scoring worker.
|
|
|
|
|
|
|
|
Returns a list of SamplerOutput, each containing a single token per
|
|
|
|
sequence.
|
|
|
|
"""
|
2024-05-16 00:53:51 -07:00
|
|
|
assert num_lookahead_slots == execute_model_req.num_lookahead_slots
|
2024-03-08 23:32:46 -08:00
|
|
|
|
2024-06-20 20:23:12 -04:00
|
|
|
# Pass last hidden states from target model to proposer
|
|
|
|
execute_model_req.previous_hidden_states = self.previous_hidden_states
|
|
|
|
self.previous_hidden_states = None
|
|
|
|
|
2024-03-08 23:32:46 -08:00
|
|
|
# Generate proposals using draft worker.
|
2024-07-10 16:02:47 -07:00
|
|
|
proposals = self.proposer_worker.get_spec_proposals(
|
|
|
|
execute_model_req, self._seq_with_bonus_token_in_last_step)
|
2024-03-08 23:32:46 -08:00
|
|
|
|
|
|
|
proposal_scores = self.scorer.score_proposals(
|
2024-05-03 17:47:07 -07:00
|
|
|
execute_model_req,
|
2024-03-08 23:32:46 -08:00
|
|
|
proposals,
|
|
|
|
)
|
|
|
|
|
2024-05-03 15:52:01 -07:00
|
|
|
accepted_token_ids, target_logprobs = self._verify_tokens(
|
2024-05-03 17:47:07 -07:00
|
|
|
execute_model_req.seq_group_metadata_list, proposal_scores,
|
|
|
|
proposals, execute_model_req.num_lookahead_slots)
|
2024-03-08 23:32:46 -08:00
|
|
|
|
2024-05-03 15:52:01 -07:00
|
|
|
return self._create_output_sampler_list(
|
2024-05-03 17:47:07 -07:00
|
|
|
execute_model_req.seq_group_metadata_list,
|
2024-05-03 15:52:01 -07:00
|
|
|
accepted_token_ids,
|
|
|
|
target_logprobs=target_logprobs,
|
2024-05-03 17:47:07 -07:00
|
|
|
k=execute_model_req.num_lookahead_slots)
|
2024-03-08 23:32:46 -08:00
|
|
|
|
|
|
|
@nvtx_range("spec_decode_worker._verify_tokens")
|
|
|
|
def _verify_tokens(
|
|
|
|
self,
|
|
|
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
|
|
|
proposal_scores: SpeculativeScores,
|
|
|
|
proposals: SpeculativeProposals,
|
|
|
|
max_proposal_len: int,
|
2024-05-03 15:52:01 -07:00
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
2024-03-08 23:32:46 -08:00
|
|
|
"""Determine which speculative tokens are accepted using the
|
|
|
|
probabilities of each token according to the proposer and scorer models.
|
2024-05-03 15:52:01 -07:00
|
|
|
|
|
|
|
Returns a tuple of Tensors, one for the accepted token ids and one for
|
|
|
|
the logprobs according to the scoring model.
|
2024-03-08 23:32:46 -08:00
|
|
|
"""
|
|
|
|
proposal_lens_list = proposals.proposal_lens.tolist()
|
|
|
|
|
|
|
|
# vLLM currently only supports proposal lens equal to zero or the batch
|
|
|
|
# proposal len. This adds some complexity (splitting the batch into spec
|
|
|
|
# and non spec sequences) and should be removed in the future. It can be
|
|
|
|
# done by supporting per-sequence proposal lens.
|
|
|
|
_, spec_indices = split_batch_by_proposal_len(
|
|
|
|
seq_group_metadata_list,
|
|
|
|
proposal_lens_list,
|
|
|
|
select_proposal_len_zero=False)
|
|
|
|
_, non_spec_indices = split_batch_by_proposal_len(
|
|
|
|
seq_group_metadata_list,
|
|
|
|
proposal_lens_list,
|
|
|
|
select_proposal_len_zero=True)
|
|
|
|
original_indices = spec_indices + non_spec_indices
|
|
|
|
|
2024-04-23 01:02:36 -07:00
|
|
|
# Get probabilities of target model, excluding bonus token.
|
|
|
|
proposal_verifier_probs = proposal_scores.probs[spec_indices, :-1]
|
|
|
|
|
|
|
|
# Get non-speculative sampled tokens from target model.
|
2024-03-08 23:32:46 -08:00
|
|
|
non_spec_token_ids = proposal_scores.token_ids[non_spec_indices]
|
|
|
|
|
2024-04-23 01:02:36 -07:00
|
|
|
# Get bonus tokens from target model.
|
|
|
|
bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:]
|
|
|
|
|
|
|
|
# Get probabilities according to proposal method.
|
|
|
|
proposal_probs = proposals.proposal_probs[spec_indices]
|
|
|
|
|
|
|
|
# Get proposed tokens.
|
|
|
|
proposal_token_ids = proposals.proposal_token_ids[spec_indices]
|
|
|
|
|
2024-07-01 00:33:05 -07:00
|
|
|
accepted_token_ids = self.spec_decode_sampler(
|
2024-04-23 01:02:36 -07:00
|
|
|
target_probs=proposal_verifier_probs,
|
|
|
|
bonus_token_ids=bonus_token_ids,
|
|
|
|
draft_probs=proposal_probs,
|
|
|
|
draft_token_ids=proposal_token_ids,
|
2024-03-08 23:32:46 -08:00
|
|
|
)
|
|
|
|
|
|
|
|
# Append output tokens from non-speculative sequences to
|
|
|
|
# the accepted token ids tensor.
|
|
|
|
non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len +
|
|
|
|
1).clone()
|
|
|
|
non_spec_token_ids[:, 1:] = -1
|
|
|
|
accepted_token_ids = torch.cat(
|
|
|
|
[accepted_token_ids, non_spec_token_ids])
|
2024-05-03 15:52:01 -07:00
|
|
|
logprobs = proposal_scores.logprobs
|
2024-03-08 23:32:46 -08:00
|
|
|
# Rearrange so that results are in the order of the original seq group
|
|
|
|
# metadata.
|
|
|
|
accepted_token_ids[original_indices] = accepted_token_ids.clone()
|
|
|
|
|
2024-06-20 20:23:12 -04:00
|
|
|
hidden_states = proposal_scores.hidden_states
|
|
|
|
if hidden_states is not None:
|
|
|
|
# Contract hidden states based on accepted tokens
|
|
|
|
hs_size = hidden_states.shape[1]
|
|
|
|
hidden_states = hidden_states.reshape(-1, max_proposal_len + 1,
|
|
|
|
hs_size)
|
|
|
|
accepted_index = accepted_token_ids + 1 # Convert -1 to 0
|
|
|
|
accepted_index = accepted_index.count_nonzero(dim=1).add_(-1)
|
|
|
|
index = accepted_index[:, None, None].expand(-1, 1, hs_size)
|
|
|
|
hidden_states = hidden_states.gather(1, index).squeeze(1) # b x d
|
|
|
|
# Store hidden states from target model for subsequent decode step
|
|
|
|
self.previous_hidden_states = HiddenStates(seq_group_metadata_list,
|
|
|
|
hidden_states)
|
|
|
|
|
2024-05-03 15:52:01 -07:00
|
|
|
return accepted_token_ids, logprobs
|
2024-03-08 23:32:46 -08:00
|
|
|
|
|
|
|
def _create_output_sampler_list(
|
|
|
|
self,
|
|
|
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
|
|
|
accepted_token_ids: torch.Tensor, # shape: [batch_size, k+1]
|
2024-05-03 15:52:01 -07:00
|
|
|
target_logprobs: torch.Tensor, # shape: [batch_size, k+1, vocab_size]
|
2024-03-08 23:32:46 -08:00
|
|
|
k: int,
|
|
|
|
) -> List[SamplerOutput]:
|
|
|
|
"""Given the accepted token ids, create a list of SamplerOutput.
|
|
|
|
|
|
|
|
The output is padded with -1 tokens such that each sequence has
|
|
|
|
the same number of outputs.
|
|
|
|
"""
|
2024-05-03 15:52:01 -07:00
|
|
|
batch_size, num_steps = accepted_token_ids.shape
|
|
|
|
|
|
|
|
# Organize input tensors by step instead of by sequence.
|
|
|
|
target_logprobs_by_step = target_logprobs.transpose(0, 1)
|
|
|
|
accepted_token_ids_by_step = accepted_token_ids.transpose(0, 1)
|
|
|
|
|
|
|
|
# Get the logprobs/rank of the accepted tokens.
|
|
|
|
(accepted_token_id_ranks_by_step,
|
|
|
|
accepted_token_id_logprobs_by_step) = get_sampled_token_logprobs(
|
|
|
|
logprob_tensor=target_logprobs_by_step,
|
|
|
|
sampled_token_ids=accepted_token_ids_by_step,
|
|
|
|
)
|
|
|
|
|
|
|
|
# Get the top-k logprobs (which may or may not include the logprob of
|
|
|
|
# the accepted token).
|
|
|
|
(topk_logprobs_by_step,
|
|
|
|
topk_indices_by_step) = target_logprobs_by_step.topk(
|
|
|
|
k=self.scorer_worker.model_config.max_logprobs,
|
|
|
|
dim=-1,
|
|
|
|
)
|
|
|
|
|
|
|
|
# Get the sequence ids and num_logprobs (sampling parameter) in the
|
|
|
|
# batch.
|
2024-07-10 16:02:47 -07:00
|
|
|
seq_ids, request_ids_seq_ids_mapping = get_all_seq_ids_and_request_ids(
|
|
|
|
seq_group_metadata_list)
|
|
|
|
|
2024-05-03 15:52:01 -07:00
|
|
|
num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list)
|
|
|
|
|
|
|
|
# Serialize all tensors to CPU Python lists.
|
|
|
|
accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()
|
|
|
|
accepted_token_id_ranks_by_step = (
|
|
|
|
accepted_token_id_ranks_by_step.tolist())
|
|
|
|
accepted_token_id_logprobs_by_step = (
|
|
|
|
accepted_token_id_logprobs_by_step.tolist())
|
|
|
|
topk_logprobs_by_step = topk_logprobs_by_step.tolist()
|
|
|
|
topk_indices_by_step = topk_indices_by_step.tolist()
|
|
|
|
|
|
|
|
# Construct the output on a per-step, per-sequence basis.
|
2024-06-15 12:45:31 +08:00
|
|
|
sampler_output_list: List[SamplerOutput] = []
|
2024-05-03 15:52:01 -07:00
|
|
|
for step_index in range(num_steps):
|
|
|
|
if all(token_id == -1
|
|
|
|
for token_id in accepted_token_ids_by_step[step_index]):
|
2024-03-08 23:32:46 -08:00
|
|
|
break
|
|
|
|
|
2024-06-15 12:45:31 +08:00
|
|
|
step_output_token_ids: List[CompletionSequenceGroupOutput] = []
|
2024-05-03 15:52:01 -07:00
|
|
|
for sequence_index in range(batch_size):
|
|
|
|
# Each sequence may have a different num_logprobs; retrieve it.
|
|
|
|
num_logprobs = num_logprobs_per_seq[sequence_index]
|
2024-03-08 23:32:46 -08:00
|
|
|
step_output_token_ids.append(
|
2024-05-03 15:52:01 -07:00
|
|
|
create_sequence_group_output(
|
|
|
|
token_id=accepted_token_ids_by_step[step_index]
|
|
|
|
[sequence_index],
|
|
|
|
token_id_logprob_rank=accepted_token_id_ranks_by_step[
|
|
|
|
step_index][sequence_index],
|
|
|
|
token_id_logprob=accepted_token_id_logprobs_by_step[
|
|
|
|
step_index][sequence_index],
|
|
|
|
seq_id=seq_ids[sequence_index],
|
|
|
|
topk_token_ids=topk_indices_by_step[step_index]
|
|
|
|
[sequence_index][:num_logprobs],
|
|
|
|
topk_logprobs=topk_logprobs_by_step[step_index]
|
|
|
|
[sequence_index][:num_logprobs],
|
2024-03-08 23:32:46 -08:00
|
|
|
))
|
|
|
|
sampler_output_list.append(
|
|
|
|
SamplerOutput(outputs=step_output_token_ids))
|
|
|
|
|
2024-07-10 16:02:47 -07:00
|
|
|
# Populate the data structures needed to keep track of sequences with
|
|
|
|
# bonus tokens.
|
|
|
|
self._track_sequences_with_bonus_tokens(seq_ids,
|
|
|
|
request_ids_seq_ids_mapping,
|
|
|
|
accepted_token_ids_by_step)
|
2024-03-10 19:49:14 -07:00
|
|
|
maybe_rejsample_metrics = (
|
|
|
|
self._metrics.maybe_collect_rejsample_metrics(k))
|
2024-03-08 23:32:46 -08:00
|
|
|
if maybe_rejsample_metrics is not None:
|
|
|
|
sampler_output_list[
|
|
|
|
0].spec_decode_worker_metrics = maybe_rejsample_metrics
|
|
|
|
return sampler_output_list
|
|
|
|
|
2024-07-10 16:02:47 -07:00
|
|
|
def _track_finished_requests(self, execute_model_req: ExecuteModelRequest):
|
|
|
|
"""
|
|
|
|
Removes the finished requests and their associated sequence ids from
|
|
|
|
internal book keeping data structures.
|
|
|
|
"""
|
|
|
|
for finished_request in execute_model_req.finished_requests_ids:
|
|
|
|
for seq_id in self._request_id_seq_id_mapping[finished_request]:
|
|
|
|
self._seq_with_bonus_token_in_last_step.discard(seq_id)
|
|
|
|
del self._request_id_seq_id_mapping[finished_request]
|
|
|
|
|
|
|
|
def _track_sequences_with_bonus_tokens(
|
|
|
|
self, seq_ids: List[int],
|
|
|
|
request_ids_seq_ids_mapping: Dict[str, Set[int]],
|
|
|
|
accepted_token_ids_by_step: List[List[int]]):
|
|
|
|
"""
|
|
|
|
Updates the internal data structures which keep track of sequences
|
|
|
|
which have been assigned bonus tokens in their last forward pass.
|
|
|
|
"""
|
|
|
|
for seq_index, seq_id in enumerate(seq_ids):
|
|
|
|
last_token_id = accepted_token_ids_by_step[-1][seq_index]
|
|
|
|
if last_token_id == -1:
|
|
|
|
self._seq_with_bonus_token_in_last_step.discard(seq_id)
|
|
|
|
else:
|
|
|
|
self._seq_with_bonus_token_in_last_step.add(seq_id)
|
|
|
|
for request_id, sequences in request_ids_seq_ids_mapping.items():
|
|
|
|
self._request_id_seq_id_mapping[request_id].update(sequences)
|
|
|
|
|
2024-03-08 23:32:46 -08:00
|
|
|
@cached_property
|
|
|
|
def _vocab_size(self) -> int:
|
|
|
|
"""Get the vocab size of the model and make sure it's consistent between
|
|
|
|
draft and target workers.
|
|
|
|
"""
|
|
|
|
vocab_sizes = [
|
|
|
|
worker.vocab_size
|
|
|
|
for worker in [self.proposer_worker, self.scorer_worker]
|
|
|
|
]
|
|
|
|
assert all(vocab_sizes[0] == vocab_size for vocab_size in vocab_sizes)
|
|
|
|
return vocab_sizes[0]
|
|
|
|
|
|
|
|
@property
|
|
|
|
def rank(self):
|
|
|
|
return self.scorer_worker.rank
|
|
|
|
|
|
|
|
@property
|
|
|
|
def device(self):
|
|
|
|
return self.scorer_worker.device
|
|
|
|
|
2024-05-16 00:53:51 -07:00
|
|
|
@property
|
|
|
|
def _driver_rank(self) -> int:
|
|
|
|
return 0
|
|
|
|
|
2024-04-09 11:44:15 -07:00
|
|
|
def get_cache_block_size_bytes(self):
|
|
|
|
"""Return the size of a cache block in bytes.
|
|
|
|
|
|
|
|
This function is only used to compose workers within a SpecDecodeWorker.
|
|
|
|
We leave composing a SpecDecodeWorker within a SpecDecodeWorker
|
|
|
|
undefined for now, although it could be implemented in the future.
|
|
|
|
See https://arxiv.org/abs/2308.04623.
|
|
|
|
"""
|
|
|
|
raise NotImplementedError
|
|
|
|
|
2024-03-08 23:32:46 -08:00
|
|
|
|
|
|
|
def split_num_cache_blocks_evenly(scorer_cache_block_size_bytes: int,
|
|
|
|
proposer_cache_block_size_bytes: int,
|
|
|
|
total_num_gpu_blocks: int) -> int:
|
|
|
|
"""Given total_num_gpu_blocks, the number of GPU blocks that could be
|
|
|
|
allocate to the target model, this function calculates how many blocks
|
|
|
|
should be given to the draft and target model.
|
|
|
|
|
|
|
|
Note that usually the block size, in bytes, of each model is different,
|
|
|
|
as it's a function of number of KV/layer, number of heads, and hidden
|
|
|
|
dimension size.
|
|
|
|
|
|
|
|
Since the target and draft models allocate the same number of blocks, we
|
|
|
|
simply calculate the number of blocks where if allocated by both models,
|
|
|
|
the total memory usage from KV cache is no larger than the number of
|
|
|
|
blocks allocatable by the target model alone.
|
|
|
|
"""
|
|
|
|
new_num_gpu_blocks = int(
|
|
|
|
total_num_gpu_blocks * scorer_cache_block_size_bytes /
|
|
|
|
(proposer_cache_block_size_bytes + scorer_cache_block_size_bytes))
|
|
|
|
|
|
|
|
return new_num_gpu_blocks
|