[Doc] Explicitly state that PP isn't compatible with speculative decoding yet (#10975)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
39e227c7ae
commit
c889d5888b
@ -8,6 +8,9 @@ Speculative decoding
|
||||
not usually yield inter-token latency reductions for all prompt datasets or sampling parameters. The work
|
||||
to optimize it is ongoing and can be followed in `this issue. <https://github.com/vllm-project/vllm/issues/4630>`_
|
||||
|
||||
.. warning::
|
||||
Currently, speculative decoding in vLLM is not compatible with pipeline parallelism.
|
||||
|
||||
This document shows how to use `Speculative Decoding <https://x.com/karpathy/status/1697318534555336961>`_ with vLLM.
|
||||
Speculative decoding is a technique which improves inter-token latency in memory-bound LLM inference.
|
||||
|
||||
|
@ -247,9 +247,19 @@ def _compare_tp(
|
||||
*,
|
||||
method: Literal["generate", "encode"],
|
||||
):
|
||||
tp_size, pp_size, eager_mode, chunked_prefill = parallel_setup
|
||||
multi_node_only, trust_remote_code, tokenizer_mode, \
|
||||
load_format, hf_overrides = test_options
|
||||
(
|
||||
tp_size,
|
||||
pp_size,
|
||||
eager_mode,
|
||||
chunked_prefill,
|
||||
) = parallel_setup
|
||||
(
|
||||
multi_node_only,
|
||||
trust_remote_code,
|
||||
tokenizer_mode,
|
||||
load_format,
|
||||
hf_overrides,
|
||||
) = test_options
|
||||
|
||||
if num_gpus_available < tp_size * pp_size:
|
||||
pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
|
||||
|
@ -473,10 +473,11 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size,
|
||||
logit_scale)
|
||||
self.sampler = get_sampler()
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
|
||||
self.sampler = get_sampler()
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.transformer.make_empty_intermediate_tensors)
|
||||
|
||||
|
@ -400,16 +400,17 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
|
||||
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||
|
||||
if hasattr(config, "logits_scaling"):
|
||||
logit_scale /= config.logits_scaling
|
||||
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size,
|
||||
scale=logit_scale)
|
||||
self.sampler = get_sampler()
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
|
||||
self.sampler = get_sampler()
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
|
@ -540,10 +540,11 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size,
|
||||
logit_scale)
|
||||
self.sampler = get_sampler()
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
|
||||
self.sampler = get_sampler()
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
|
@ -435,9 +435,11 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size,
|
||||
logit_scale)
|
||||
self.sampler = get_sampler()
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
|
||||
self.sampler = get_sampler()
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
|
@ -443,10 +443,11 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size,
|
||||
logit_scale)
|
||||
self.sampler = get_sampler()
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
|
||||
self.sampler = get_sampler()
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
|
@ -54,6 +54,10 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
|
||||
speculative_config: SpeculativeConfig = vllm_config.speculative_config
|
||||
assert speculative_config is not None
|
||||
|
||||
if vllm_config.parallel_config.pipeline_parallel_size > 1:
|
||||
raise NotImplementedError("Speculative decoding is currently "
|
||||
"incompatible with pipeline parallelism")
|
||||
|
||||
draft_worker_kwargs = kwargs.copy()
|
||||
|
||||
kwargs["model_runner_cls"] = TargetModelRunner
|
||||
|
Loading…
x
Reference in New Issue
Block a user