[Spec Decode] Disable Log Prob serialization to CPU for spec decoding for both draft and target models. (#6485)

This commit is contained in:
sroy745 2024-07-20 23:58:58 -07:00 committed by GitHub
parent d7f4178dd9
commit 14f91fe67c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 333 additions and 64 deletions

View File

@ -22,10 +22,12 @@ from .conftest import get_logprobs_from_llm_generator
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_model": "JackFram/llama-160m",
"num_speculative_tokens": 3,
}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": "JackFram/llama-160m",
"num_speculative_tokens": 3,
"disable_logprobs_during_spec_decoding": False,
}])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize(
"output_len",
@ -59,10 +61,12 @@ def test_logprobs_equality(baseline_llm_generator, test_llm_generator,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_model": "JackFram/llama-160m",
"num_speculative_tokens": 3,
}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": "JackFram/llama-160m",
"num_speculative_tokens": 3,
"disable_logprobs_during_spec_decoding": False,
}])
@pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize("num_logprobs", [6])
@pytest.mark.parametrize(
@ -99,13 +103,16 @@ def test_diff_num_logprobs(baseline_llm_generator, test_llm_generator,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_model": "JackFram/llama-160m",
"num_speculative_tokens": 3,
}, {
"speculative_model": "JackFram/llama-160m",
"num_speculative_tokens": 6,
}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": "JackFram/llama-160m",
"num_speculative_tokens": 3,
"disable_logprobs_during_spec_decoding": False,
}, {
"speculative_model": "JackFram/llama-160m",
"num_speculative_tokens": 6,
"disable_logprobs_during_spec_decoding": False,
}])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize(
"output_len",
@ -143,6 +150,7 @@ def test_logprobs_different_k(baseline_llm_generator, test_llm_generator,
[{
"speculative_model": "JackFram/llama-160m",
"num_speculative_tokens": 3,
"disable_logprobs_during_spec_decoding": False,
# Artificially limit the draft model max model len; this forces vLLM
# to skip speculation once the sequences grow beyond 32-k tokens.
@ -181,10 +189,12 @@ def test_logprobs_when_skip_speculation(baseline_llm_generator,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_model": "JackFram/llama-160m",
"num_speculative_tokens": 3,
}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": "JackFram/llama-160m",
"num_speculative_tokens": 3,
"disable_logprobs_during_spec_decoding": False,
}])
@pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize(
"output_len",

View File

@ -32,6 +32,7 @@ def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int,
scorer_worker=target_worker,
spec_decode_sampler=mock_spec_decode_sampler(
acceptance_sampler_method),
disable_logprobs=False,
metrics_collector=metrics_collector,
disable_by_batch_size=disable_by_batch_size)

View File

@ -381,6 +381,7 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool,
worker = SpecDecodeWorker(draft_worker,
target_worker,
spec_decode_sampler,
disable_logprobs=False,
metrics_collector=metrics_collector)
worker.init_device()
@ -479,7 +480,8 @@ def test_k_equals_zero(k: int, batch_size: int,
worker = SpecDecodeWorker(
draft_worker, target_worker,
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
mock_spec_decode_sampler(acceptance_sampler_method), False,
metrics_collector)
seq_group_metadata_list, _, _ = create_batch(batch_size,
k,
@ -490,9 +492,10 @@ def test_k_equals_zero(k: int, batch_size: int,
out = worker.execute_model(execute_model_req=execute_model_req)
assert len(out) == 1, f"expected only one token output when {k=}"
assert out[0].probs is None, "expect gpu tensor references to be None"
assert out[0].sampled_token_probs is None, (
"expect gpu tensor references to be None")
assert out[
0].sampled_tokens is None, "expect gpu tensor references to be None"
0].sampled_token_ids is None, "expect gpu tensor references to be None"
draft_worker.execute_model.assert_called_once_with(execute_model_req)
target_worker.execute_model.assert_called_once_with(execute_model_req)
@ -524,7 +527,8 @@ def test_empty_input_batch(k: int, batch_size: int,
worker = SpecDecodeWorker(
draft_worker, target_worker,
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
mock_spec_decode_sampler(acceptance_sampler_method), False,
metrics_collector)
seq_group_metadata_list, _, _ = create_batch(batch_size,
k,
@ -535,9 +539,10 @@ def test_empty_input_batch(k: int, batch_size: int,
out = worker.execute_model(execute_model_req=execute_model_req)
assert len(out) == 1, f"expected only one token output when {k=}"
assert out[0].probs is None, "expect gpu tensor references to be None"
assert out[0].sampled_token_probs is None, (
"expect gpu tensor references to be None")
assert out[
0].sampled_tokens is None, "expect gpu tensor references to be None"
0].sampled_token_ids is None, "expect gpu tensor references to be None"
draft_worker.execute_model.assert_called_once_with(execute_model_req)
target_worker.execute_model.assert_called_once_with(execute_model_req)
@ -556,7 +561,7 @@ def test_init_device(acceptance_sampler_method: str):
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler,
metrics_collector)
False, metrics_collector)
worker.init_device()
draft_worker.init_device.assert_called_once()
@ -707,6 +712,7 @@ def test_populate_seq_ids_with_bonus_tokens():
worker = SpecDecodeWorker(draft_worker,
target_worker,
mock_spec_decode_sampler("rejection_sampler"),
disable_logprobs=False,
metrics_collector=metrics_collector)
# Initialize _seq_with_bonus_token_in_last_step with a set of sequence IDs.
# This set includes all sequence IDs in the batch as well as an additional

View File

@ -894,6 +894,7 @@ class SpeculativeConfig:
draft_token_acceptance_method: str,
typical_acceptance_sampler_posterior_threshold: Optional[float],
typical_acceptance_sampler_posterior_alpha: Optional[float],
disable_logprobs: Optional[bool],
) -> Optional["SpeculativeConfig"]:
"""Create a SpeculativeConfig if possible, else return None.
@ -943,6 +944,11 @@ class SpeculativeConfig:
typical_acceptance_sampler_posterior_alpha (Optional[float]):
A scaling factor for the entropy-based threshold in the
TypicalAcceptanceSampler.
disable_logprobs (Optional[bool]): If set to True, token log
probabilities are not returned during speculative decoding.
If set to False, token log probabilities are returned
according to the log probability settings in SamplingParams.
If not specified, it defaults to True.
Returns:
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
@ -1055,6 +1061,8 @@ class SpeculativeConfig:
typical_acceptance_sampler_posterior_threshold = 0.09
if typical_acceptance_sampler_posterior_alpha is None:
typical_acceptance_sampler_posterior_alpha = 0.3
if disable_logprobs is None:
disable_logprobs = True
return SpeculativeConfig(
draft_model_config,
@ -1068,6 +1076,7 @@ class SpeculativeConfig:
typical_acceptance_sampler_posterior_threshold,
typical_acceptance_sampler_posterior_alpha=\
typical_acceptance_sampler_posterior_alpha,
disable_logprobs=disable_logprobs
)
@staticmethod
@ -1152,6 +1161,7 @@ class SpeculativeConfig:
draft_token_acceptance_method: str,
typical_acceptance_sampler_posterior_threshold: float,
typical_acceptance_sampler_posterior_alpha: float,
disable_logprobs: bool,
):
"""Create a SpeculativeConfig object.
@ -1178,6 +1188,12 @@ class SpeculativeConfig:
typical_acceptance_sampler_posterior_alpha (Optional[float]):
A scaling factor for the entropy-based threshold in the
TypicalAcceptanceSampler.
disable_logprobs: If set to True, token log probabilities will not
be returned even if requested by sampling parameters. This
reduces latency by skipping logprob calculation in proposal
sampling, target sampling, and after accepted tokens are
determined. If set to False, log probabilities will be
returned.
"""
self.draft_model_config = draft_model_config
self.draft_parallel_config = draft_parallel_config
@ -1191,6 +1207,7 @@ class SpeculativeConfig:
typical_acceptance_sampler_posterior_threshold
self.typical_acceptance_sampler_posterior_alpha = \
typical_acceptance_sampler_posterior_alpha
self.disable_logprobs = disable_logprobs
self._verify_args()

View File

@ -110,6 +110,7 @@ class EngineArgs:
typical_acceptance_sampler_posterior_threshold: Optional[float] = None
typical_acceptance_sampler_posterior_alpha: Optional[float] = None
qlora_adapter_name_or_path: Optional[str] = None
disable_logprobs_during_spec_decoding: Optional[bool] = None
otlp_traces_endpoint: Optional[str] = None
@ -592,6 +593,18 @@ class EngineArgs:
'to sqrt of --typical-acceptance-sampler-posterior-threshold '
'i.e. 0.3')
parser.add_argument(
'--disable-logprobs-during-spec-decoding',
type=bool,
default=EngineArgs.disable_logprobs_during_spec_decoding,
help='If set to True, token log probabilities are not returned '
'during speculative decoding. If set to False, log probabilities '
'are returned according to the settings in SamplingParams. If '
'not specified, it defaults to True. Disabling log probabilities '
'during speculative decoding reduces latency by skipping logprob '
'calculation in proposal sampling, target sampling, and after '
'accepted tokens are determined.')
parser.add_argument('--model-loader-extra-config',
type=nullable_str,
default=EngineArgs.model_loader_extra_config,
@ -736,6 +749,7 @@ class EngineArgs:
typical_acceptance_sampler_posterior_threshold,
typical_acceptance_sampler_posterior_alpha=self.
typical_acceptance_sampler_posterior_alpha,
disable_logprobs=self.disable_logprobs_during_spec_decoding,
)
scheduler_config = SchedulerConfig(

View File

@ -14,7 +14,7 @@ from vllm.model_executor.layers.typical_acceptance_sampler import (
TypicalAcceptanceSampler)
from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest,
HiddenStates, SamplerOutput, SequenceGroupMetadata,
get_all_seq_ids_and_request_ids)
get_all_seq_ids, get_all_seq_ids_and_request_ids)
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
from vllm.spec_decode.interfaces import (SpeculativeProposals,
@ -26,6 +26,7 @@ from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.ngram_worker import NGramWorker
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker
from vllm.spec_decode.target_model_runner import TargetModelRunner
from vllm.spec_decode.util import (create_sequence_group_output,
get_all_num_logprobs,
get_sampled_token_logprobs, nvtx_range,
@ -44,9 +45,15 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
speculative_config: SpeculativeConfig = kwargs.get("speculative_config")
assert speculative_config is not None
target_worker = Worker(*args, **kwargs)
draft_worker_kwargs = kwargs.copy()
kwargs["model_runner_cls"] = TargetModelRunner
target_worker = Worker(*args, **kwargs)
# Set the disable_logprobs variable in the TargetModelRunner instance
# as per its value specified in the SpeculativeConfig.
target_worker.model_runner.disable_logprobs =\
speculative_config.disable_logprobs
# Override draft-model specific worker args.
draft_worker_kwargs.update(
model_config=speculative_config.draft_model_config,
@ -67,7 +74,8 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
typical_acceptance_sampler_posterior_threshold=speculative_config.
typical_acceptance_sampler_posterior_threshold,
typical_acceptance_sampler_posterior_alpha=speculative_config.
typical_acceptance_sampler_posterior_alpha)
typical_acceptance_sampler_posterior_alpha,
disable_logprobs=speculative_config.disable_logprobs)
return spec_decode_worker
@ -107,6 +115,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
draft_token_acceptance_method: str,
typical_acceptance_sampler_posterior_threshold: float,
typical_acceptance_sampler_posterior_alpha: float,
disable_logprobs: bool,
) -> "SpecDecodeWorker":
allow_zero_draft_token_step = True
@ -161,6 +170,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
return SpecDecodeWorker(
proposer_worker,
scorer_worker,
disable_logprobs=disable_logprobs,
disable_by_batch_size=disable_by_batch_size,
spec_decode_sampler=spec_decode_sampler,
allow_zero_draft_token_step=allow_zero_draft_token_step)
@ -170,6 +180,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
proposer_worker: ProposerWorkerBase,
scorer_worker: WorkerBase,
spec_decode_sampler: SpecDecodeBaseSampler,
disable_logprobs: bool,
metrics_collector: Optional[AsyncMetricsCollector] = None,
disable_by_batch_size: Optional[int] = None,
allow_zero_draft_token_step: Optional[bool] = True,
@ -189,6 +200,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
types of sampler namely RejectionSampler and
TypicalAcceptanceSampler. 'spec_decode_sampler' is either an
instance of RejectionSampler or TypicalAcceptanceSampler.
disable_logprobs: If set to True, token log probabilities will
not be output in both the draft worker and the target worker.
If set to False, log probabilities will be output by both.
disable_by_batch_size: If the batch size is larger than this,
disable speculative decoding for new incoming requests.
metrics_collector: Helper class for collecting metrics; can be set
@ -222,6 +236,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# Hidden states from target model to pass to proposer
# in the subsequent step.
self.previous_hidden_states: Optional[HiddenStates] = None
self._disable_logprobs = disable_logprobs
def init_device(self) -> None:
"""Initialize both scorer and proposer models.
@ -357,7 +372,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
) == 0 or disable_all_speculation:
return self._run_no_spec(execute_model_req,
skip_proposer=disable_all_speculation)
return self._run_speculative_decoding_step(execute_model_req,
num_lookahead_slots)
@ -391,6 +405,42 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# this state within spec decode worker.
seq_group_metadata.num_speculative_tokens = 0
def _serialize_sampler_output_no_logprobs(
self, execute_model_req: ExecuteModelRequest,
sampler_output: SamplerOutput) -> SamplerOutput:
"""
Creates and returns a `SamplerOutput` with only the sampled token IDs
being serialized to CPU & populated in `CompletionSequenceGroupOutput`.
All other parameters in `CompletionSequenceGroupOutput` related to log
probabilities are skipped.
Args:
execute_model_req (ExecuteModelRequest): The model request that
was executed.
sampler_output (SamplerOutput): The output from the sampler with
only GPU tensors populated.
Returns:
SamplerOutput: A new `SamplerOutput` instance containing a list of
`CompletionSequenceGroupOutput` objects with only sampled token
IDs populated.
"""
seq_ids = get_all_seq_ids(execute_model_req.seq_group_metadata_list)
sampled_token_ids_list = sampler_output.sampled_token_ids.tolist()
completion_seq_group_output_list: List[
CompletionSequenceGroupOutput] = []
for index, seq_id in enumerate(seq_ids):
completion_seq_group_output_list.append(
create_sequence_group_output(
token_id=sampled_token_ids_list[index][0],
token_id_logprob_rank=-1,
token_id_logprob=0.0,
seq_id=seq_id,
topk_token_ids=[],
topk_logprobs=[],
))
return SamplerOutput(outputs=completion_seq_group_output_list)
@nvtx_range("spec_decode_worker._run_no_spec")
def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
skip_proposer: bool) -> List[SamplerOutput]:
@ -417,12 +467,17 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self.previous_hidden_states.update(
execute_model_req.seq_group_metadata_list, hidden_states)
sampler_output_to_return = (self._serialize_sampler_output_no_logprobs(
execute_model_req=execute_model_req, sampler_output=sampler_output)
if self._disable_logprobs else
sampler_output)
# Clear device tensors from sampler output. This reduces communication
# overhead when the engine runs in a different process than the workers.
sampler_output.probs = None
sampler_output.sampled_tokens = None
sampler_output.sampled_token_probs = None
sampler_output.sampled_token_ids = None
sampler_output.logprobs = None
return [sampler_output]
return [sampler_output_to_return]
def _run_non_driver_rank(self) -> bool:
"""Run proposer and verifier model in non-driver workers. This is used
@ -480,7 +535,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
execute_model_req,
proposals,
)
accepted_token_ids, target_logprobs = self._verify_tokens(
execute_model_req.seq_group_metadata_list, proposal_scores,
proposals, execute_model_req.num_lookahead_slots)
@ -601,25 +655,27 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
the same number of outputs.
"""
batch_size, num_steps = accepted_token_ids.shape
# Organize input tensors by step instead of by sequence.
target_logprobs_by_step = target_logprobs.transpose(0, 1)
accepted_token_ids_by_step = accepted_token_ids.transpose(0, 1)
# Get the logprobs/rank of the accepted tokens.
(accepted_token_id_ranks_by_step,
accepted_token_id_logprobs_by_step) = get_sampled_token_logprobs(
logprob_tensor=target_logprobs_by_step,
sampled_token_ids=accepted_token_ids_by_step,
)
# Get the top-k logprobs (which may or may not include the logprob of
# the accepted token).
(topk_logprobs_by_step,
topk_indices_by_step) = target_logprobs_by_step.topk(
k=self.scorer_worker.model_config.max_logprobs,
dim=-1,
)
if self._disable_logprobs:
# We are skipping the logprobs. Hence don't serialize the
# logprobs related tensors from the GPU. Instead create
# empty/dummy lists.
(accepted_token_id_ranks_by_step,
accepted_token_id_logprobs_by_step,
topk_logprobs_by_step, topk_indices_by_step) =\
self._create_dummy_logprob_lists(
batch_size, num_steps,
self.scorer_worker.model_config.max_logprobs)
else:
# Organize input tensors by step instead of by sequence.
target_logprobs_by_step = target_logprobs.transpose(0, 1)
# Serialize all tensors into Python lists.
(accepted_token_id_ranks_by_step,
accepted_token_id_logprobs_by_step,
topk_logprobs_by_step, topk_indices_by_step) =\
self._create_logprob_lists_from_tensors(
target_logprobs_by_step, accepted_token_ids_by_step,
self.scorer_worker.model_config.max_logprobs)
# Get the sequence ids and num_logprobs (sampling parameter) in the
# batch.
@ -628,14 +684,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list)
# Serialize all tensors to CPU Python lists.
# Serialize tensor to CPU Python list.
accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()
accepted_token_id_ranks_by_step = (
accepted_token_id_ranks_by_step.tolist())
accepted_token_id_logprobs_by_step = (
accepted_token_id_logprobs_by_step.tolist())
topk_logprobs_by_step = topk_logprobs_by_step.tolist()
topk_indices_by_step = topk_indices_by_step.tolist()
# Construct the output on a per-step, per-sequence basis.
sampler_output_list: List[SamplerOutput] = []
@ -677,6 +727,108 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
0].spec_decode_worker_metrics = maybe_rejsample_metrics
return sampler_output_list
def _create_dummy_logprob_lists(
self,
batch_size: int,
num_steps: int,
num_top_k: int,
) -> Tuple[List[List[int]], List[List[float]],
List[List[List[Optional[float]]]],
List[List[List[Optional[int]]]]]:
"""
Creates and returns four dummy lists representing token probabilities
and their ranks.
This method initializes and returns:
- The ranks of the accepted tokens, shaped (num_steps, batch_size)
- The log probabilities of the accepted tokens,
shaped (num_steps, batch_size)
- The log probabilities of the top k tokens,
shaped (num_steps, batch_size, num_top_k)
- The token IDs of the top k tokens,
shaped (num_steps, batch_size, num_top_k)
Args:
batch_size (int): The size of the batch.
num_steps (int): The number of steps in the sequence.
num_top_k (int): The number of top-k token log probabilities to
return.
Returns:
A tuple containing four dummy lists as described above.
"""
accepted_token_id_ranks_by_step = [[-1] * batch_size
for _ in range(num_steps)]
accepted_token_id_logprobs_by_step = [[0.0] * batch_size
for _ in range(num_steps)]
topk_logprobs_by_step: List[List[List[Optional[float]]]] = [[
[None] * num_top_k for _ in range(batch_size)
] for _ in range(num_steps)]
topk_indices_by_step: List[List[List[Optional[int]]]] = [[
[None] * num_top_k for _ in range(batch_size)
] for _ in range(num_steps)]
return (accepted_token_id_ranks_by_step,
accepted_token_id_logprobs_by_step, topk_logprobs_by_step,
topk_indices_by_step)
def _create_logprob_lists_from_tensors(
self,
target_logprobs_by_step: torch.Tensor,
accepted_token_ids_by_step: torch.Tensor,
num_top_k: int,
) -> Tuple[List[List[int]], List[List[float]],
List[List[List[Optional[float]]]],
List[List[List[Optional[int]]]]]:
"""
Creates and returns four lists representing token probabilities and
their ranks.
This method initializes and returns four lists containing:
- The ranks of the accepted tokens, shaped (num_steps, batch_size)
- The log probabilities of the accepted tokens,
shaped (num_steps, batch_size)
- The log probabilities of the top k tokens,
shaped (num_steps, batch_size, num_top_k)
- The token IDs of the top k tokens,
shaped (num_steps, batch_size, num_top_k)
Args:
target_logprobs_by_step (torch.Tensor): Tensor representing the
log probabilities of the target model,
shaped (num_steps, batch_size, vocab_size)
accepted_token_ids_by_step (torch.Tensor): Tensor representing
the accepted token_ids, shaped (num_steps, batch_size)
num_top_k (int): The number of top-k token log probabilities to
return.
Returns:
A tuple containing the lists as described above.
"""
# Serialize all tensors to CPU Python lists.
# Get the logprobs/rank of the accepted tokens.
(accepted_token_id_ranks_by_step_tensor,
accepted_token_id_logprobs_by_step_tensor
) = get_sampled_token_logprobs(
logprob_tensor=target_logprobs_by_step,
sampled_token_ids=accepted_token_ids_by_step,
)
# Get the top-k logprobs (which may or may not include the
# logprob of the accepted token).
(topk_logprobs_by_step_tensor,
topk_indices_by_step_tensor) = target_logprobs_by_step.topk(
k=num_top_k,
dim=-1,
)
accepted_token_id_ranks_by_step = (
accepted_token_id_ranks_by_step_tensor.tolist())
accepted_token_id_logprobs_by_step = (
accepted_token_id_logprobs_by_step_tensor.tolist())
topk_logprobs_by_step = topk_logprobs_by_step_tensor.tolist()
topk_indices_by_step = topk_indices_by_step_tensor.tolist()
return (accepted_token_id_ranks_by_step,
accepted_token_id_logprobs_by_step, topk_logprobs_by_step,
topk_indices_by_step)
def _track_finished_requests(self, execute_model_req: ExecuteModelRequest):
"""
Removes the finished requests and their associated sequence ids from

View File

@ -0,0 +1,69 @@
from typing import List, Optional
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig)
from vllm.sequence import SequenceGroupMetadata
from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata,
ModelRunner)
class TargetModelRunner(ModelRunner):
"""Specialized model runner for speculative decoding target model.
In speculative decoding, the log probabilities selected finally may not
be the same ones as selected by the target model sampling. This means
that the time spent in the log probability calculation of the target model
is time wasted, since we calculate log probabilities after deciding which
tokens are accepted. For this reason disabling log probabilities in the
target model will make decode faster. The model runner sets the
SamplingMetadata parameters according to whether log probabilities are
requested or not.
"""
def __init__(self,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
multimodal_config: Optional[MultiModalConfig] = None,
return_hidden_states: bool = False):
# An internal boolean member variable to indicate if token log
# probabilities are needed or not.
self.disable_logprobs = True
super().__init__(
model_config=model_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
device_config=device_config,
cache_config=cache_config,
load_config=load_config,
lora_config=lora_config,
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker,
multimodal_config=multimodal_config,
prompt_adapter_config=prompt_adapter_config,
return_hidden_states=return_hidden_states,
)
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> ModelInputForGPUWithSamplingMetadata:
model_input: ModelInputForGPUWithSamplingMetadata = super(
).prepare_model_input(seq_group_metadata_list, virtual_engine,
finished_requests_ids)
# If token log probabilities is disabled then skip generating sampler
# CPU output. We directly serialize the GPU sampled_token_id tensors
# as needed. If log probabilities is enabled then synchronize all the
# sampling related tensors which includes the logprobs tensors.
model_input.sampling_metadata.skip_sampler_cpu_output = (
self.disable_logprobs)
return model_input

View File

@ -1,5 +1,5 @@
from contextlib import contextmanager
from typing import Dict, List, Tuple
from typing import Dict, List, Optional, Tuple
import torch
@ -53,8 +53,8 @@ def create_sequence_group_output(
token_id_logprob_rank: int,
token_id_logprob: float,
seq_id: SeqId,
topk_token_ids: List[int],
topk_logprobs: List[float],
topk_token_ids: List[Optional[int]],
topk_logprobs: List[Optional[float]],
) -> CompletionSequenceGroupOutput:
"""Create a SequenceGroupOutput given the sampling results.
@ -68,7 +68,7 @@ def create_sequence_group_output(
"""
# vLLM logprobs always include the sampled token. In addition, the user may
# request topk-logprobs (where top-k varies per user up to max_logprobs).
logprobs: Dict[int, Logprob] = {
logprobs: Dict[Optional[int], Logprob] = {
token_id: Logprob(
logprob=token_id_logprob,
rank=token_id_logprob_rank,