[Spec Decode] Disable Log Prob serialization to CPU for spec decoding for both draft and target models. (#6485)
This commit is contained in:
parent
d7f4178dd9
commit
14f91fe67c
@ -22,10 +22,12 @@ from .conftest import get_logprobs_from_llm_generator
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_model": "JackFram/llama-160m",
|
||||
"num_speculative_tokens": 3,
|
||||
}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs",
|
||||
[{
|
||||
"speculative_model": "JackFram/llama-160m",
|
||||
"num_speculative_tokens": 3,
|
||||
"disable_logprobs_during_spec_decoding": False,
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
@ -59,10 +61,12 @@ def test_logprobs_equality(baseline_llm_generator, test_llm_generator,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_model": "JackFram/llama-160m",
|
||||
"num_speculative_tokens": 3,
|
||||
}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs",
|
||||
[{
|
||||
"speculative_model": "JackFram/llama-160m",
|
||||
"num_speculative_tokens": 3,
|
||||
"disable_logprobs_during_spec_decoding": False,
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1])
|
||||
@pytest.mark.parametrize("num_logprobs", [6])
|
||||
@pytest.mark.parametrize(
|
||||
@ -99,13 +103,16 @@ def test_diff_num_logprobs(baseline_llm_generator, test_llm_generator,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_model": "JackFram/llama-160m",
|
||||
"num_speculative_tokens": 3,
|
||||
}, {
|
||||
"speculative_model": "JackFram/llama-160m",
|
||||
"num_speculative_tokens": 6,
|
||||
}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs",
|
||||
[{
|
||||
"speculative_model": "JackFram/llama-160m",
|
||||
"num_speculative_tokens": 3,
|
||||
"disable_logprobs_during_spec_decoding": False,
|
||||
}, {
|
||||
"speculative_model": "JackFram/llama-160m",
|
||||
"num_speculative_tokens": 6,
|
||||
"disable_logprobs_during_spec_decoding": False,
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
@ -143,6 +150,7 @@ def test_logprobs_different_k(baseline_llm_generator, test_llm_generator,
|
||||
[{
|
||||
"speculative_model": "JackFram/llama-160m",
|
||||
"num_speculative_tokens": 3,
|
||||
"disable_logprobs_during_spec_decoding": False,
|
||||
|
||||
# Artificially limit the draft model max model len; this forces vLLM
|
||||
# to skip speculation once the sequences grow beyond 32-k tokens.
|
||||
@ -181,10 +189,12 @@ def test_logprobs_when_skip_speculation(baseline_llm_generator,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_model": "JackFram/llama-160m",
|
||||
"num_speculative_tokens": 3,
|
||||
}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs",
|
||||
[{
|
||||
"speculative_model": "JackFram/llama-160m",
|
||||
"num_speculative_tokens": 3,
|
||||
"disable_logprobs_during_spec_decoding": False,
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
|
@ -32,6 +32,7 @@ def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int,
|
||||
scorer_worker=target_worker,
|
||||
spec_decode_sampler=mock_spec_decode_sampler(
|
||||
acceptance_sampler_method),
|
||||
disable_logprobs=False,
|
||||
metrics_collector=metrics_collector,
|
||||
disable_by_batch_size=disable_by_batch_size)
|
||||
|
||||
|
@ -381,6 +381,7 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool,
|
||||
worker = SpecDecodeWorker(draft_worker,
|
||||
target_worker,
|
||||
spec_decode_sampler,
|
||||
disable_logprobs=False,
|
||||
metrics_collector=metrics_collector)
|
||||
worker.init_device()
|
||||
|
||||
@ -479,7 +480,8 @@ def test_k_equals_zero(k: int, batch_size: int,
|
||||
|
||||
worker = SpecDecodeWorker(
|
||||
draft_worker, target_worker,
|
||||
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
|
||||
mock_spec_decode_sampler(acceptance_sampler_method), False,
|
||||
metrics_collector)
|
||||
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size,
|
||||
k,
|
||||
@ -490,9 +492,10 @@ def test_k_equals_zero(k: int, batch_size: int,
|
||||
out = worker.execute_model(execute_model_req=execute_model_req)
|
||||
|
||||
assert len(out) == 1, f"expected only one token output when {k=}"
|
||||
assert out[0].probs is None, "expect gpu tensor references to be None"
|
||||
assert out[0].sampled_token_probs is None, (
|
||||
"expect gpu tensor references to be None")
|
||||
assert out[
|
||||
0].sampled_tokens is None, "expect gpu tensor references to be None"
|
||||
0].sampled_token_ids is None, "expect gpu tensor references to be None"
|
||||
|
||||
draft_worker.execute_model.assert_called_once_with(execute_model_req)
|
||||
target_worker.execute_model.assert_called_once_with(execute_model_req)
|
||||
@ -524,7 +527,8 @@ def test_empty_input_batch(k: int, batch_size: int,
|
||||
|
||||
worker = SpecDecodeWorker(
|
||||
draft_worker, target_worker,
|
||||
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
|
||||
mock_spec_decode_sampler(acceptance_sampler_method), False,
|
||||
metrics_collector)
|
||||
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size,
|
||||
k,
|
||||
@ -535,9 +539,10 @@ def test_empty_input_batch(k: int, batch_size: int,
|
||||
out = worker.execute_model(execute_model_req=execute_model_req)
|
||||
|
||||
assert len(out) == 1, f"expected only one token output when {k=}"
|
||||
assert out[0].probs is None, "expect gpu tensor references to be None"
|
||||
assert out[0].sampled_token_probs is None, (
|
||||
"expect gpu tensor references to be None")
|
||||
assert out[
|
||||
0].sampled_tokens is None, "expect gpu tensor references to be None"
|
||||
0].sampled_token_ids is None, "expect gpu tensor references to be None"
|
||||
|
||||
draft_worker.execute_model.assert_called_once_with(execute_model_req)
|
||||
target_worker.execute_model.assert_called_once_with(execute_model_req)
|
||||
@ -556,7 +561,7 @@ def test_init_device(acceptance_sampler_method: str):
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
|
||||
worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler,
|
||||
metrics_collector)
|
||||
False, metrics_collector)
|
||||
worker.init_device()
|
||||
|
||||
draft_worker.init_device.assert_called_once()
|
||||
@ -707,6 +712,7 @@ def test_populate_seq_ids_with_bonus_tokens():
|
||||
worker = SpecDecodeWorker(draft_worker,
|
||||
target_worker,
|
||||
mock_spec_decode_sampler("rejection_sampler"),
|
||||
disable_logprobs=False,
|
||||
metrics_collector=metrics_collector)
|
||||
# Initialize _seq_with_bonus_token_in_last_step with a set of sequence IDs.
|
||||
# This set includes all sequence IDs in the batch as well as an additional
|
||||
|
@ -894,6 +894,7 @@ class SpeculativeConfig:
|
||||
draft_token_acceptance_method: str,
|
||||
typical_acceptance_sampler_posterior_threshold: Optional[float],
|
||||
typical_acceptance_sampler_posterior_alpha: Optional[float],
|
||||
disable_logprobs: Optional[bool],
|
||||
) -> Optional["SpeculativeConfig"]:
|
||||
"""Create a SpeculativeConfig if possible, else return None.
|
||||
|
||||
@ -943,6 +944,11 @@ class SpeculativeConfig:
|
||||
typical_acceptance_sampler_posterior_alpha (Optional[float]):
|
||||
A scaling factor for the entropy-based threshold in the
|
||||
TypicalAcceptanceSampler.
|
||||
disable_logprobs (Optional[bool]): If set to True, token log
|
||||
probabilities are not returned during speculative decoding.
|
||||
If set to False, token log probabilities are returned
|
||||
according to the log probability settings in SamplingParams.
|
||||
If not specified, it defaults to True.
|
||||
|
||||
Returns:
|
||||
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
|
||||
@ -1055,6 +1061,8 @@ class SpeculativeConfig:
|
||||
typical_acceptance_sampler_posterior_threshold = 0.09
|
||||
if typical_acceptance_sampler_posterior_alpha is None:
|
||||
typical_acceptance_sampler_posterior_alpha = 0.3
|
||||
if disable_logprobs is None:
|
||||
disable_logprobs = True
|
||||
|
||||
return SpeculativeConfig(
|
||||
draft_model_config,
|
||||
@ -1068,6 +1076,7 @@ class SpeculativeConfig:
|
||||
typical_acceptance_sampler_posterior_threshold,
|
||||
typical_acceptance_sampler_posterior_alpha=\
|
||||
typical_acceptance_sampler_posterior_alpha,
|
||||
disable_logprobs=disable_logprobs
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -1152,6 +1161,7 @@ class SpeculativeConfig:
|
||||
draft_token_acceptance_method: str,
|
||||
typical_acceptance_sampler_posterior_threshold: float,
|
||||
typical_acceptance_sampler_posterior_alpha: float,
|
||||
disable_logprobs: bool,
|
||||
):
|
||||
"""Create a SpeculativeConfig object.
|
||||
|
||||
@ -1178,6 +1188,12 @@ class SpeculativeConfig:
|
||||
typical_acceptance_sampler_posterior_alpha (Optional[float]):
|
||||
A scaling factor for the entropy-based threshold in the
|
||||
TypicalAcceptanceSampler.
|
||||
disable_logprobs: If set to True, token log probabilities will not
|
||||
be returned even if requested by sampling parameters. This
|
||||
reduces latency by skipping logprob calculation in proposal
|
||||
sampling, target sampling, and after accepted tokens are
|
||||
determined. If set to False, log probabilities will be
|
||||
returned.
|
||||
"""
|
||||
self.draft_model_config = draft_model_config
|
||||
self.draft_parallel_config = draft_parallel_config
|
||||
@ -1191,6 +1207,7 @@ class SpeculativeConfig:
|
||||
typical_acceptance_sampler_posterior_threshold
|
||||
self.typical_acceptance_sampler_posterior_alpha = \
|
||||
typical_acceptance_sampler_posterior_alpha
|
||||
self.disable_logprobs = disable_logprobs
|
||||
|
||||
self._verify_args()
|
||||
|
||||
|
@ -110,6 +110,7 @@ class EngineArgs:
|
||||
typical_acceptance_sampler_posterior_threshold: Optional[float] = None
|
||||
typical_acceptance_sampler_posterior_alpha: Optional[float] = None
|
||||
qlora_adapter_name_or_path: Optional[str] = None
|
||||
disable_logprobs_during_spec_decoding: Optional[bool] = None
|
||||
|
||||
otlp_traces_endpoint: Optional[str] = None
|
||||
|
||||
@ -592,6 +593,18 @@ class EngineArgs:
|
||||
'to sqrt of --typical-acceptance-sampler-posterior-threshold '
|
||||
'i.e. 0.3')
|
||||
|
||||
parser.add_argument(
|
||||
'--disable-logprobs-during-spec-decoding',
|
||||
type=bool,
|
||||
default=EngineArgs.disable_logprobs_during_spec_decoding,
|
||||
help='If set to True, token log probabilities are not returned '
|
||||
'during speculative decoding. If set to False, log probabilities '
|
||||
'are returned according to the settings in SamplingParams. If '
|
||||
'not specified, it defaults to True. Disabling log probabilities '
|
||||
'during speculative decoding reduces latency by skipping logprob '
|
||||
'calculation in proposal sampling, target sampling, and after '
|
||||
'accepted tokens are determined.')
|
||||
|
||||
parser.add_argument('--model-loader-extra-config',
|
||||
type=nullable_str,
|
||||
default=EngineArgs.model_loader_extra_config,
|
||||
@ -736,6 +749,7 @@ class EngineArgs:
|
||||
typical_acceptance_sampler_posterior_threshold,
|
||||
typical_acceptance_sampler_posterior_alpha=self.
|
||||
typical_acceptance_sampler_posterior_alpha,
|
||||
disable_logprobs=self.disable_logprobs_during_spec_decoding,
|
||||
)
|
||||
|
||||
scheduler_config = SchedulerConfig(
|
||||
|
@ -14,7 +14,7 @@ from vllm.model_executor.layers.typical_acceptance_sampler import (
|
||||
TypicalAcceptanceSampler)
|
||||
from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest,
|
||||
HiddenStates, SamplerOutput, SequenceGroupMetadata,
|
||||
get_all_seq_ids_and_request_ids)
|
||||
get_all_seq_ids, get_all_seq_ids_and_request_ids)
|
||||
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
|
||||
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
@ -26,6 +26,7 @@ from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
from vllm.spec_decode.ngram_worker import NGramWorker
|
||||
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
|
||||
from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker
|
||||
from vllm.spec_decode.target_model_runner import TargetModelRunner
|
||||
from vllm.spec_decode.util import (create_sequence_group_output,
|
||||
get_all_num_logprobs,
|
||||
get_sampled_token_logprobs, nvtx_range,
|
||||
@ -44,9 +45,15 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
|
||||
speculative_config: SpeculativeConfig = kwargs.get("speculative_config")
|
||||
assert speculative_config is not None
|
||||
|
||||
target_worker = Worker(*args, **kwargs)
|
||||
|
||||
draft_worker_kwargs = kwargs.copy()
|
||||
|
||||
kwargs["model_runner_cls"] = TargetModelRunner
|
||||
target_worker = Worker(*args, **kwargs)
|
||||
# Set the disable_logprobs variable in the TargetModelRunner instance
|
||||
# as per its value specified in the SpeculativeConfig.
|
||||
target_worker.model_runner.disable_logprobs =\
|
||||
speculative_config.disable_logprobs
|
||||
|
||||
# Override draft-model specific worker args.
|
||||
draft_worker_kwargs.update(
|
||||
model_config=speculative_config.draft_model_config,
|
||||
@ -67,7 +74,8 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
|
||||
typical_acceptance_sampler_posterior_threshold=speculative_config.
|
||||
typical_acceptance_sampler_posterior_threshold,
|
||||
typical_acceptance_sampler_posterior_alpha=speculative_config.
|
||||
typical_acceptance_sampler_posterior_alpha)
|
||||
typical_acceptance_sampler_posterior_alpha,
|
||||
disable_logprobs=speculative_config.disable_logprobs)
|
||||
|
||||
return spec_decode_worker
|
||||
|
||||
@ -107,6 +115,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
draft_token_acceptance_method: str,
|
||||
typical_acceptance_sampler_posterior_threshold: float,
|
||||
typical_acceptance_sampler_posterior_alpha: float,
|
||||
disable_logprobs: bool,
|
||||
) -> "SpecDecodeWorker":
|
||||
|
||||
allow_zero_draft_token_step = True
|
||||
@ -161,6 +170,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
return SpecDecodeWorker(
|
||||
proposer_worker,
|
||||
scorer_worker,
|
||||
disable_logprobs=disable_logprobs,
|
||||
disable_by_batch_size=disable_by_batch_size,
|
||||
spec_decode_sampler=spec_decode_sampler,
|
||||
allow_zero_draft_token_step=allow_zero_draft_token_step)
|
||||
@ -170,6 +180,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
proposer_worker: ProposerWorkerBase,
|
||||
scorer_worker: WorkerBase,
|
||||
spec_decode_sampler: SpecDecodeBaseSampler,
|
||||
disable_logprobs: bool,
|
||||
metrics_collector: Optional[AsyncMetricsCollector] = None,
|
||||
disable_by_batch_size: Optional[int] = None,
|
||||
allow_zero_draft_token_step: Optional[bool] = True,
|
||||
@ -189,6 +200,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
types of sampler namely RejectionSampler and
|
||||
TypicalAcceptanceSampler. 'spec_decode_sampler' is either an
|
||||
instance of RejectionSampler or TypicalAcceptanceSampler.
|
||||
disable_logprobs: If set to True, token log probabilities will
|
||||
not be output in both the draft worker and the target worker.
|
||||
If set to False, log probabilities will be output by both.
|
||||
disable_by_batch_size: If the batch size is larger than this,
|
||||
disable speculative decoding for new incoming requests.
|
||||
metrics_collector: Helper class for collecting metrics; can be set
|
||||
@ -222,6 +236,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
# Hidden states from target model to pass to proposer
|
||||
# in the subsequent step.
|
||||
self.previous_hidden_states: Optional[HiddenStates] = None
|
||||
self._disable_logprobs = disable_logprobs
|
||||
|
||||
def init_device(self) -> None:
|
||||
"""Initialize both scorer and proposer models.
|
||||
@ -357,7 +372,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
) == 0 or disable_all_speculation:
|
||||
return self._run_no_spec(execute_model_req,
|
||||
skip_proposer=disable_all_speculation)
|
||||
|
||||
return self._run_speculative_decoding_step(execute_model_req,
|
||||
num_lookahead_slots)
|
||||
|
||||
@ -391,6 +405,42 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
# this state within spec decode worker.
|
||||
seq_group_metadata.num_speculative_tokens = 0
|
||||
|
||||
def _serialize_sampler_output_no_logprobs(
|
||||
self, execute_model_req: ExecuteModelRequest,
|
||||
sampler_output: SamplerOutput) -> SamplerOutput:
|
||||
"""
|
||||
Creates and returns a `SamplerOutput` with only the sampled token IDs
|
||||
being serialized to CPU & populated in `CompletionSequenceGroupOutput`.
|
||||
All other parameters in `CompletionSequenceGroupOutput` related to log
|
||||
probabilities are skipped.
|
||||
|
||||
Args:
|
||||
execute_model_req (ExecuteModelRequest): The model request that
|
||||
was executed.
|
||||
sampler_output (SamplerOutput): The output from the sampler with
|
||||
only GPU tensors populated.
|
||||
|
||||
Returns:
|
||||
SamplerOutput: A new `SamplerOutput` instance containing a list of
|
||||
`CompletionSequenceGroupOutput` objects with only sampled token
|
||||
IDs populated.
|
||||
"""
|
||||
seq_ids = get_all_seq_ids(execute_model_req.seq_group_metadata_list)
|
||||
sampled_token_ids_list = sampler_output.sampled_token_ids.tolist()
|
||||
completion_seq_group_output_list: List[
|
||||
CompletionSequenceGroupOutput] = []
|
||||
for index, seq_id in enumerate(seq_ids):
|
||||
completion_seq_group_output_list.append(
|
||||
create_sequence_group_output(
|
||||
token_id=sampled_token_ids_list[index][0],
|
||||
token_id_logprob_rank=-1,
|
||||
token_id_logprob=0.0,
|
||||
seq_id=seq_id,
|
||||
topk_token_ids=[],
|
||||
topk_logprobs=[],
|
||||
))
|
||||
return SamplerOutput(outputs=completion_seq_group_output_list)
|
||||
|
||||
@nvtx_range("spec_decode_worker._run_no_spec")
|
||||
def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
|
||||
skip_proposer: bool) -> List[SamplerOutput]:
|
||||
@ -417,12 +467,17 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
self.previous_hidden_states.update(
|
||||
execute_model_req.seq_group_metadata_list, hidden_states)
|
||||
|
||||
sampler_output_to_return = (self._serialize_sampler_output_no_logprobs(
|
||||
execute_model_req=execute_model_req, sampler_output=sampler_output)
|
||||
if self._disable_logprobs else
|
||||
sampler_output)
|
||||
|
||||
# Clear device tensors from sampler output. This reduces communication
|
||||
# overhead when the engine runs in a different process than the workers.
|
||||
sampler_output.probs = None
|
||||
sampler_output.sampled_tokens = None
|
||||
sampler_output.sampled_token_probs = None
|
||||
sampler_output.sampled_token_ids = None
|
||||
sampler_output.logprobs = None
|
||||
return [sampler_output]
|
||||
return [sampler_output_to_return]
|
||||
|
||||
def _run_non_driver_rank(self) -> bool:
|
||||
"""Run proposer and verifier model in non-driver workers. This is used
|
||||
@ -480,7 +535,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
execute_model_req,
|
||||
proposals,
|
||||
)
|
||||
|
||||
accepted_token_ids, target_logprobs = self._verify_tokens(
|
||||
execute_model_req.seq_group_metadata_list, proposal_scores,
|
||||
proposals, execute_model_req.num_lookahead_slots)
|
||||
@ -601,25 +655,27 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
the same number of outputs.
|
||||
"""
|
||||
batch_size, num_steps = accepted_token_ids.shape
|
||||
|
||||
# Organize input tensors by step instead of by sequence.
|
||||
target_logprobs_by_step = target_logprobs.transpose(0, 1)
|
||||
accepted_token_ids_by_step = accepted_token_ids.transpose(0, 1)
|
||||
|
||||
# Get the logprobs/rank of the accepted tokens.
|
||||
(accepted_token_id_ranks_by_step,
|
||||
accepted_token_id_logprobs_by_step) = get_sampled_token_logprobs(
|
||||
logprob_tensor=target_logprobs_by_step,
|
||||
sampled_token_ids=accepted_token_ids_by_step,
|
||||
)
|
||||
|
||||
# Get the top-k logprobs (which may or may not include the logprob of
|
||||
# the accepted token).
|
||||
(topk_logprobs_by_step,
|
||||
topk_indices_by_step) = target_logprobs_by_step.topk(
|
||||
k=self.scorer_worker.model_config.max_logprobs,
|
||||
dim=-1,
|
||||
)
|
||||
if self._disable_logprobs:
|
||||
# We are skipping the logprobs. Hence don't serialize the
|
||||
# logprobs related tensors from the GPU. Instead create
|
||||
# empty/dummy lists.
|
||||
(accepted_token_id_ranks_by_step,
|
||||
accepted_token_id_logprobs_by_step,
|
||||
topk_logprobs_by_step, topk_indices_by_step) =\
|
||||
self._create_dummy_logprob_lists(
|
||||
batch_size, num_steps,
|
||||
self.scorer_worker.model_config.max_logprobs)
|
||||
else:
|
||||
# Organize input tensors by step instead of by sequence.
|
||||
target_logprobs_by_step = target_logprobs.transpose(0, 1)
|
||||
# Serialize all tensors into Python lists.
|
||||
(accepted_token_id_ranks_by_step,
|
||||
accepted_token_id_logprobs_by_step,
|
||||
topk_logprobs_by_step, topk_indices_by_step) =\
|
||||
self._create_logprob_lists_from_tensors(
|
||||
target_logprobs_by_step, accepted_token_ids_by_step,
|
||||
self.scorer_worker.model_config.max_logprobs)
|
||||
|
||||
# Get the sequence ids and num_logprobs (sampling parameter) in the
|
||||
# batch.
|
||||
@ -628,14 +684,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
|
||||
num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list)
|
||||
|
||||
# Serialize all tensors to CPU Python lists.
|
||||
# Serialize tensor to CPU Python list.
|
||||
accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()
|
||||
accepted_token_id_ranks_by_step = (
|
||||
accepted_token_id_ranks_by_step.tolist())
|
||||
accepted_token_id_logprobs_by_step = (
|
||||
accepted_token_id_logprobs_by_step.tolist())
|
||||
topk_logprobs_by_step = topk_logprobs_by_step.tolist()
|
||||
topk_indices_by_step = topk_indices_by_step.tolist()
|
||||
|
||||
# Construct the output on a per-step, per-sequence basis.
|
||||
sampler_output_list: List[SamplerOutput] = []
|
||||
@ -677,6 +727,108 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
0].spec_decode_worker_metrics = maybe_rejsample_metrics
|
||||
return sampler_output_list
|
||||
|
||||
def _create_dummy_logprob_lists(
|
||||
self,
|
||||
batch_size: int,
|
||||
num_steps: int,
|
||||
num_top_k: int,
|
||||
) -> Tuple[List[List[int]], List[List[float]],
|
||||
List[List[List[Optional[float]]]],
|
||||
List[List[List[Optional[int]]]]]:
|
||||
"""
|
||||
Creates and returns four dummy lists representing token probabilities
|
||||
and their ranks.
|
||||
|
||||
This method initializes and returns:
|
||||
- The ranks of the accepted tokens, shaped (num_steps, batch_size)
|
||||
- The log probabilities of the accepted tokens,
|
||||
shaped (num_steps, batch_size)
|
||||
- The log probabilities of the top k tokens,
|
||||
shaped (num_steps, batch_size, num_top_k)
|
||||
- The token IDs of the top k tokens,
|
||||
shaped (num_steps, batch_size, num_top_k)
|
||||
|
||||
Args:
|
||||
batch_size (int): The size of the batch.
|
||||
num_steps (int): The number of steps in the sequence.
|
||||
num_top_k (int): The number of top-k token log probabilities to
|
||||
return.
|
||||
|
||||
Returns:
|
||||
A tuple containing four dummy lists as described above.
|
||||
"""
|
||||
accepted_token_id_ranks_by_step = [[-1] * batch_size
|
||||
for _ in range(num_steps)]
|
||||
accepted_token_id_logprobs_by_step = [[0.0] * batch_size
|
||||
for _ in range(num_steps)]
|
||||
topk_logprobs_by_step: List[List[List[Optional[float]]]] = [[
|
||||
[None] * num_top_k for _ in range(batch_size)
|
||||
] for _ in range(num_steps)]
|
||||
topk_indices_by_step: List[List[List[Optional[int]]]] = [[
|
||||
[None] * num_top_k for _ in range(batch_size)
|
||||
] for _ in range(num_steps)]
|
||||
return (accepted_token_id_ranks_by_step,
|
||||
accepted_token_id_logprobs_by_step, topk_logprobs_by_step,
|
||||
topk_indices_by_step)
|
||||
|
||||
def _create_logprob_lists_from_tensors(
|
||||
self,
|
||||
target_logprobs_by_step: torch.Tensor,
|
||||
accepted_token_ids_by_step: torch.Tensor,
|
||||
num_top_k: int,
|
||||
) -> Tuple[List[List[int]], List[List[float]],
|
||||
List[List[List[Optional[float]]]],
|
||||
List[List[List[Optional[int]]]]]:
|
||||
"""
|
||||
Creates and returns four lists representing token probabilities and
|
||||
their ranks.
|
||||
|
||||
This method initializes and returns four lists containing:
|
||||
- The ranks of the accepted tokens, shaped (num_steps, batch_size)
|
||||
- The log probabilities of the accepted tokens,
|
||||
shaped (num_steps, batch_size)
|
||||
- The log probabilities of the top k tokens,
|
||||
shaped (num_steps, batch_size, num_top_k)
|
||||
- The token IDs of the top k tokens,
|
||||
shaped (num_steps, batch_size, num_top_k)
|
||||
|
||||
Args:
|
||||
target_logprobs_by_step (torch.Tensor): Tensor representing the
|
||||
log probabilities of the target model,
|
||||
shaped (num_steps, batch_size, vocab_size)
|
||||
accepted_token_ids_by_step (torch.Tensor): Tensor representing
|
||||
the accepted token_ids, shaped (num_steps, batch_size)
|
||||
num_top_k (int): The number of top-k token log probabilities to
|
||||
return.
|
||||
|
||||
Returns:
|
||||
A tuple containing the lists as described above.
|
||||
"""
|
||||
# Serialize all tensors to CPU Python lists.
|
||||
# Get the logprobs/rank of the accepted tokens.
|
||||
(accepted_token_id_ranks_by_step_tensor,
|
||||
accepted_token_id_logprobs_by_step_tensor
|
||||
) = get_sampled_token_logprobs(
|
||||
logprob_tensor=target_logprobs_by_step,
|
||||
sampled_token_ids=accepted_token_ids_by_step,
|
||||
)
|
||||
# Get the top-k logprobs (which may or may not include the
|
||||
# logprob of the accepted token).
|
||||
(topk_logprobs_by_step_tensor,
|
||||
topk_indices_by_step_tensor) = target_logprobs_by_step.topk(
|
||||
k=num_top_k,
|
||||
dim=-1,
|
||||
)
|
||||
accepted_token_id_ranks_by_step = (
|
||||
accepted_token_id_ranks_by_step_tensor.tolist())
|
||||
accepted_token_id_logprobs_by_step = (
|
||||
accepted_token_id_logprobs_by_step_tensor.tolist())
|
||||
topk_logprobs_by_step = topk_logprobs_by_step_tensor.tolist()
|
||||
topk_indices_by_step = topk_indices_by_step_tensor.tolist()
|
||||
return (accepted_token_id_ranks_by_step,
|
||||
accepted_token_id_logprobs_by_step, topk_logprobs_by_step,
|
||||
topk_indices_by_step)
|
||||
|
||||
def _track_finished_requests(self, execute_model_req: ExecuteModelRequest):
|
||||
"""
|
||||
Removes the finished requests and their associated sequence ids from
|
||||
|
69
vllm/spec_decode/target_model_runner.py
Normal file
69
vllm/spec_decode/target_model_runner.py
Normal file
@ -0,0 +1,69 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, MultiModalConfig, ParallelConfig,
|
||||
PromptAdapterConfig, SchedulerConfig)
|
||||
from vllm.sequence import SequenceGroupMetadata
|
||||
from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata,
|
||||
ModelRunner)
|
||||
|
||||
|
||||
class TargetModelRunner(ModelRunner):
|
||||
"""Specialized model runner for speculative decoding target model.
|
||||
In speculative decoding, the log probabilities selected finally may not
|
||||
be the same ones as selected by the target model sampling. This means
|
||||
that the time spent in the log probability calculation of the target model
|
||||
is time wasted, since we calculate log probabilities after deciding which
|
||||
tokens are accepted. For this reason disabling log probabilities in the
|
||||
target model will make decode faster. The model runner sets the
|
||||
SamplingMetadata parameters according to whether log probabilities are
|
||||
requested or not.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
device_config: DeviceConfig,
|
||||
cache_config: CacheConfig,
|
||||
load_config: LoadConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
kv_cache_dtype: Optional[str] = "auto",
|
||||
is_driver_worker: bool = False,
|
||||
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
|
||||
multimodal_config: Optional[MultiModalConfig] = None,
|
||||
return_hidden_states: bool = False):
|
||||
# An internal boolean member variable to indicate if token log
|
||||
# probabilities are needed or not.
|
||||
self.disable_logprobs = True
|
||||
super().__init__(
|
||||
model_config=model_config,
|
||||
parallel_config=parallel_config,
|
||||
scheduler_config=scheduler_config,
|
||||
device_config=device_config,
|
||||
cache_config=cache_config,
|
||||
load_config=load_config,
|
||||
lora_config=lora_config,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
is_driver_worker=is_driver_worker,
|
||||
multimodal_config=multimodal_config,
|
||||
prompt_adapter_config=prompt_adapter_config,
|
||||
return_hidden_states=return_hidden_states,
|
||||
)
|
||||
|
||||
def prepare_model_input(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
virtual_engine: int = 0,
|
||||
finished_requests_ids: Optional[List[str]] = None
|
||||
) -> ModelInputForGPUWithSamplingMetadata:
|
||||
model_input: ModelInputForGPUWithSamplingMetadata = super(
|
||||
).prepare_model_input(seq_group_metadata_list, virtual_engine,
|
||||
finished_requests_ids)
|
||||
# If token log probabilities is disabled then skip generating sampler
|
||||
# CPU output. We directly serialize the GPU sampled_token_id tensors
|
||||
# as needed. If log probabilities is enabled then synchronize all the
|
||||
# sampling related tensors which includes the logprobs tensors.
|
||||
model_input.sampling_metadata.skip_sampler_cpu_output = (
|
||||
self.disable_logprobs)
|
||||
return model_input
|
@ -1,5 +1,5 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@ -53,8 +53,8 @@ def create_sequence_group_output(
|
||||
token_id_logprob_rank: int,
|
||||
token_id_logprob: float,
|
||||
seq_id: SeqId,
|
||||
topk_token_ids: List[int],
|
||||
topk_logprobs: List[float],
|
||||
topk_token_ids: List[Optional[int]],
|
||||
topk_logprobs: List[Optional[float]],
|
||||
) -> CompletionSequenceGroupOutput:
|
||||
"""Create a SequenceGroupOutput given the sampling results.
|
||||
|
||||
@ -68,7 +68,7 @@ def create_sequence_group_output(
|
||||
"""
|
||||
# vLLM logprobs always include the sampled token. In addition, the user may
|
||||
# request topk-logprobs (where top-k varies per user up to max_logprobs).
|
||||
logprobs: Dict[int, Logprob] = {
|
||||
logprobs: Dict[Optional[int], Logprob] = {
|
||||
token_id: Logprob(
|
||||
logprob=token_id_logprob,
|
||||
rank=token_id_logprob_rank,
|
||||
|
Loading…
x
Reference in New Issue
Block a user