[Spec Decode] Introduce DraftModelRunner (#5799)
This commit is contained in:
parent
b90d8cd832
commit
b2c620230a
@ -7,6 +7,7 @@ import torch
|
||||
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sequence import ExecuteModelRequest, Logprob, SamplerOutput
|
||||
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
|
||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||
from vllm.worker.worker import Worker
|
||||
@ -85,6 +86,7 @@ def test_same_output_for_single_step():
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
model_runner_cls=TP1DraftModelRunner,
|
||||
)
|
||||
worker = create_worker(
|
||||
Worker,
|
||||
@ -168,6 +170,7 @@ def test_same_output_for_multi_step():
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
model_runner_cls=TP1DraftModelRunner,
|
||||
)
|
||||
|
||||
worker = create_worker(
|
||||
|
@ -14,6 +14,7 @@ from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
|
||||
SequenceOutput)
|
||||
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
|
||||
from vllm.worker.cache_engine import CacheEngine
|
||||
from vllm.worker.model_runner import ModelRunner
|
||||
from vllm.worker.worker import Worker
|
||||
|
||||
T = TypeVar("T", bound=Worker)
|
||||
@ -66,7 +67,8 @@ def create_worker(cls: Callable[..., T],
|
||||
num_gpu_blocks: int,
|
||||
seed: int,
|
||||
is_driver_worker: bool = True,
|
||||
enforce_eager: bool = True) -> T:
|
||||
enforce_eager: bool = True,
|
||||
model_runner_cls: Optional[ModelRunner] = None) -> T:
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
seed=seed,
|
||||
@ -89,6 +91,7 @@ def create_worker(cls: Callable[..., T],
|
||||
rank=0,
|
||||
distributed_init_method=distributed_init_method,
|
||||
is_driver_worker=is_driver_worker,
|
||||
model_runner_cls=model_runner_cls,
|
||||
)
|
||||
|
||||
worker.init_device()
|
||||
|
@ -880,6 +880,8 @@ class ExecuteModelRequest:
|
||||
running_queue_size: int = 0
|
||||
# Optional hidden states from prior step.
|
||||
previous_hidden_states: Optional[HiddenStates] = None
|
||||
# The number of forward steps to run.
|
||||
num_steps: int = 1
|
||||
|
||||
def clone(
|
||||
self, seq_group_metadata_list: List[SequenceGroupMetadata]
|
||||
@ -893,4 +895,5 @@ class ExecuteModelRequest:
|
||||
num_lookahead_slots=self.num_lookahead_slots,
|
||||
running_queue_size=self.running_queue_size,
|
||||
previous_hidden_states=self.previous_hidden_states,
|
||||
num_steps=self.num_steps,
|
||||
)
|
||||
|
170
vllm/spec_decode/draft_model_runner.py
Normal file
170
vllm/spec_decode/draft_model_runner.py
Normal file
@ -0,0 +1,170 @@
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||
VisionLanguageConfig)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata,
|
||||
ModelRunner)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TP1DraftModelRunner(ModelRunner):
|
||||
"""Specialized model runner for speculative decoding draft model.
|
||||
Since the draft model always execute k forward passes consecutively to
|
||||
generate k speculative tokens in a single speculative decoding step,
|
||||
we could get rid of most CPU-GPU synchronization and data transfer
|
||||
overheads by keeping model input and output tensors on GPU all the time.
|
||||
|
||||
This runner is still under development so there's no performance gain
|
||||
at this moment. Currently we adopt a temporary solution that caches the
|
||||
seq_group_metadata_list for multi-step execution, so that we can
|
||||
leverage existing prepare_model_input to be compatible with the current
|
||||
execution flow, but we plan to remove this cache and avoid calling
|
||||
prepare_model_input in execute_model at all.
|
||||
|
||||
The detail development plan includes:
|
||||
1. Use "update_model_input" to update existing model_input without
|
||||
creating a new one.
|
||||
2. Improve the performance of "update_model_input" with a GPU kernel.
|
||||
3. Support TP > 1 (this requires some designs because we do not expect
|
||||
any broadcasting inside execute_model).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
device_config: DeviceConfig,
|
||||
cache_config: CacheConfig,
|
||||
load_config: LoadConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
kv_cache_dtype: Optional[str] = "auto",
|
||||
is_driver_worker: bool = False,
|
||||
vision_language_config: Optional[VisionLanguageConfig] = None,
|
||||
return_hidden_states: bool = False,
|
||||
):
|
||||
if return_hidden_states:
|
||||
raise ValueError(
|
||||
"return_hidden_states is not supported for TP1DraftModelRunner."
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
model_config=model_config,
|
||||
parallel_config=parallel_config,
|
||||
scheduler_config=scheduler_config,
|
||||
device_config=device_config,
|
||||
cache_config=cache_config,
|
||||
load_config=load_config,
|
||||
lora_config=lora_config,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
is_driver_worker=is_driver_worker,
|
||||
vision_language_config=vision_language_config,
|
||||
return_hidden_states=return_hidden_states,
|
||||
)
|
||||
|
||||
# TODO: Remove this cache when we are able to update model_input
|
||||
# directly in advance_step.
|
||||
self.cached_seq_group_metadata_list: Optional[
|
||||
List[SequenceGroupMetadata]] = None
|
||||
|
||||
def prepare_model_input(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> ModelInputForGPUWithSamplingMetadata:
|
||||
"""A temporary solution that caches the seq_group_metadata_list
|
||||
for multi-step execution.
|
||||
TODO: In-place update model_input and remove this function.
|
||||
"""
|
||||
self.cached_seq_group_metadata_list = seq_group_metadata_list
|
||||
return super().prepare_model_input(seq_group_metadata_list)
|
||||
|
||||
def update_model_input(
|
||||
self, model_input: ModelInputForGPUWithSamplingMetadata,
|
||||
last_output: SamplerOutput
|
||||
) -> ModelInputForGPUWithSamplingMetadata:
|
||||
"""Prepare the model inputs for the next step.
|
||||
TODO: In-place update model_input instead of calling
|
||||
prepare_model_input.
|
||||
"""
|
||||
|
||||
# Append the output token to the sequence data.
|
||||
assert self.cached_seq_group_metadata_list is not None
|
||||
for seq_group_metadata, sequence_group_outputs in zip(
|
||||
self.cached_seq_group_metadata_list, last_output.outputs):
|
||||
seq_group_metadata.is_prompt = False
|
||||
|
||||
for seq_output in sequence_group_outputs.samples:
|
||||
seq = seq_group_metadata.seq_data[seq_output.parent_seq_id]
|
||||
|
||||
token_id = seq_output.output_token
|
||||
token_logprob = seq_output.logprobs[token_id]
|
||||
|
||||
seq.append_token_id(token_id, token_logprob.logprob)
|
||||
seq.update_num_computed_tokens(1)
|
||||
|
||||
return self.prepare_model_input(self.cached_seq_group_metadata_list)
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
model_input: ModelInputForGPUWithSamplingMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
num_steps: int = 1,
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
# Since we do not broadcast data inside execute_model anymore,
|
||||
# we need to figure out the best way to support TP > 1 in this
|
||||
# case, because we will at least need to broadcast the sampled
|
||||
# tokens to all workers.
|
||||
if not self.is_driver_worker:
|
||||
raise ValueError("TP1DraftModelRunner only supports TP=1.")
|
||||
|
||||
if self.lora_config:
|
||||
assert model_input.lora_requests is not None
|
||||
assert model_input.lora_mapping is not None
|
||||
self.set_active_loras(model_input.lora_requests,
|
||||
model_input.lora_mapping)
|
||||
|
||||
outputs: List[SamplerOutput] = []
|
||||
for step in range(num_steps):
|
||||
# Currently cuda graph is only supported by the decode phase.
|
||||
assert model_input.attn_metadata is not None
|
||||
prefill_meta = model_input.attn_metadata.prefill_metadata
|
||||
decode_meta = model_input.attn_metadata.decode_metadata
|
||||
if prefill_meta is None and decode_meta.use_cuda_graph:
|
||||
assert model_input.input_tokens is not None
|
||||
graph_batch_size = model_input.input_tokens.shape[0]
|
||||
model_executable = self.graph_runners[graph_batch_size]
|
||||
else:
|
||||
model_executable = self.model
|
||||
|
||||
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
|
||||
hidden_states = model_executable(
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=model_input.attn_metadata,
|
||||
**multi_modal_kwargs,
|
||||
)
|
||||
|
||||
# Compute the logits.
|
||||
logits = self.model.compute_logits(hidden_states,
|
||||
model_input.sampling_metadata)
|
||||
|
||||
# Sample the next token.
|
||||
outputs.append(
|
||||
self.model.sample(
|
||||
logits=logits,
|
||||
sampling_metadata=model_input.sampling_metadata,
|
||||
))
|
||||
|
||||
# Prepare the inputs for the next step.
|
||||
if step != num_steps - 1:
|
||||
model_input = self.update_model_input(model_input, outputs[-1])
|
||||
|
||||
return outputs
|
@ -6,6 +6,7 @@ import torch
|
||||
|
||||
from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeProposer)
|
||||
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
|
||||
@ -67,12 +68,14 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
|
||||
copied_execute_model_req = execute_model_req.clone(
|
||||
copied_seq_group_metadata_list)
|
||||
|
||||
# Assert enough KV space for sample_len tokens per sequence.
|
||||
self._assert_enough_kv_space(execute_model_req.seq_group_metadata_list,
|
||||
sample_len)
|
||||
|
||||
# Run model sample_len times.
|
||||
model_outputs: List[SamplerOutput] = []
|
||||
if isinstance(self.model_runner, TP1DraftModelRunner):
|
||||
copied_execute_model_req.num_steps = sample_len
|
||||
model_outputs = self.execute_model(
|
||||
execute_model_req=copied_execute_model_req)
|
||||
else:
|
||||
# TODO: Remove this branch once DraftModelRunner supports TP>1.
|
||||
for _ in range(sample_len):
|
||||
model_output: List[SamplerOutput] = super().execute_model(
|
||||
execute_model_req=copied_execute_model_req)
|
||||
|
@ -11,6 +11,7 @@ from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest,
|
||||
HiddenStates, SamplerOutput, SequenceGroupMetadata,
|
||||
get_all_seq_ids)
|
||||
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
|
||||
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeScorer, SpeculativeScores)
|
||||
from vllm.spec_decode.metrics import AsyncMetricsCollector
|
||||
@ -117,6 +118,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
draft_tp = draft_parallel_config.tensor_parallel_size
|
||||
target_tp = scorer_worker.parallel_config.tensor_parallel_size
|
||||
|
||||
if draft_tp == 1:
|
||||
draft_worker_kwargs["model_runner_cls"] = TP1DraftModelRunner
|
||||
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
|
||||
proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
|
||||
proposer_worker, draft_tp, target_tp)
|
||||
|
@ -351,7 +351,12 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
|
||||
self,
|
||||
model_input: CPUModelInput,
|
||||
kv_caches: List[torch.Tensor],
|
||||
) -> Optional[SamplerOutput]:
|
||||
num_steps: int = 1,
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
if num_steps > 1:
|
||||
raise ValueError(
|
||||
"CPU worker does not support multi-step execution.")
|
||||
|
||||
model_executable = self.model
|
||||
execute_model_kwargs = {
|
||||
"input_ids": model_input.input_tokens,
|
||||
@ -371,11 +376,11 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
|
||||
|
||||
# Only perform sampling in the driver worker.
|
||||
if not self.is_driver_worker:
|
||||
return None
|
||||
return []
|
||||
|
||||
# Sample the next token.
|
||||
output = self.model.sample(
|
||||
logits=logits,
|
||||
sampling_metadata=model_input.sampling_metadata,
|
||||
)
|
||||
return output
|
||||
return [output]
|
||||
|
@ -57,7 +57,12 @@ class EmbeddingModelRunner(
|
||||
self,
|
||||
model_input: ModelInputForGPUWithPoolingMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
) -> Optional[PoolerOutput]:
|
||||
num_steps: int = 1,
|
||||
) -> Optional[List[PoolerOutput]]:
|
||||
if num_steps > 1:
|
||||
raise ValueError(
|
||||
"EmbeddingModelRunner does not support multi-step execution.")
|
||||
|
||||
if self.lora_config:
|
||||
assert model_input.lora_requests is not None
|
||||
assert model_input.lora_mapping is not None
|
||||
@ -91,10 +96,12 @@ class EmbeddingModelRunner(
|
||||
|
||||
# Only perform pooling in the driver worker.
|
||||
if not self.is_driver_worker:
|
||||
return None
|
||||
return []
|
||||
|
||||
return self.model.pooler(hidden_states=hidden_states,
|
||||
return [
|
||||
self.model.pooler(hidden_states=hidden_states,
|
||||
pooling_metadata=model_input.pooling_metadata)
|
||||
]
|
||||
|
||||
def make_model_input_from_broadcasted_tensor_dict(
|
||||
self,
|
||||
|
@ -959,7 +959,11 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
||||
self,
|
||||
model_input: ModelInputForGPUWithSamplingMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
) -> SamplerOutput:
|
||||
num_steps: int = 1,
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
if num_steps > 1:
|
||||
raise ValueError("num_steps > 1 is not supported in ModelRunner")
|
||||
|
||||
if self.lora_config:
|
||||
assert model_input.lora_requests is not None
|
||||
assert model_input.lora_mapping is not None
|
||||
@ -992,7 +996,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
||||
|
||||
# Only perform sampling in the driver worker.
|
||||
if not self.is_driver_worker:
|
||||
return None
|
||||
return []
|
||||
|
||||
# Sample the next token.
|
||||
output: SamplerOutput = self.model.sample(
|
||||
@ -1011,7 +1015,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
||||
|
||||
output.hidden_states = hidden_states
|
||||
|
||||
return output
|
||||
return [output]
|
||||
|
||||
|
||||
class CUDAGraphRunner:
|
||||
|
@ -150,7 +150,8 @@ class ModelRunnerBase(ABC, Generic[T]):
|
||||
self,
|
||||
model_input: T,
|
||||
kv_caches: Optional[List[torch.Tensor]],
|
||||
) -> Optional[SamplerOutput]:
|
||||
num_steps: int = 1,
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
"""
|
||||
Execute the model on the given input.
|
||||
"""
|
||||
|
@ -207,7 +207,12 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
|
||||
self,
|
||||
model_input: ModelInputForNeuron,
|
||||
kv_caches: Optional[List[torch.Tensor]] = None,
|
||||
) -> Optional[SamplerOutput]:
|
||||
num_steps: int = 1,
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
if num_steps > 1:
|
||||
raise ValueError(
|
||||
"NeuronModelRunner does not support multi-step execution.")
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
@ -223,7 +228,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
|
||||
logits=logits,
|
||||
sampling_metadata=model_input.sampling_metadata,
|
||||
)
|
||||
return output
|
||||
return [output]
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
|
@ -444,7 +444,12 @@ class TPUModelRunner:
|
||||
self,
|
||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
) -> SamplerOutput:
|
||||
num_steps: int = 1,
|
||||
) -> List[SamplerOutput]:
|
||||
if num_steps > 1:
|
||||
raise ValueError(
|
||||
"TPUModelRunner does not support multi-step execution.")
|
||||
|
||||
assert seq_group_metadata_list is not None
|
||||
assert len(seq_group_metadata_list) > 0
|
||||
if seq_group_metadata_list[0].is_prompt:
|
||||
@ -462,7 +467,7 @@ class TPUModelRunner:
|
||||
else:
|
||||
sampler_outputs = self._execute_model(seq_group_metadata_list,
|
||||
kv_caches)
|
||||
return SamplerOutput(sampler_outputs)
|
||||
return [SamplerOutput(sampler_outputs)]
|
||||
|
||||
|
||||
class ModelWrapper(nn.Module):
|
||||
|
@ -45,6 +45,7 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
vision_language_config: Optional[VisionLanguageConfig] = None,
|
||||
speculative_config: Optional[SpeculativeConfig] = None,
|
||||
is_driver_worker: bool = False,
|
||||
model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None,
|
||||
) -> None:
|
||||
self.model_config = model_config
|
||||
self.parallel_config = parallel_config
|
||||
@ -78,7 +79,9 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
"mlp_speculator") else {"return_hidden_states": True}
|
||||
|
||||
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
|
||||
if self.model_config.embedding_mode:
|
||||
if model_runner_cls is not None:
|
||||
ModelRunnerClass = model_runner_cls
|
||||
elif self.model_config.embedding_mode:
|
||||
ModelRunnerClass = EmbeddingModelRunner
|
||||
self.model_runner: GPUModelRunnerBase = ModelRunnerClass(
|
||||
model_config,
|
||||
|
@ -228,11 +228,13 @@ class LocalOrDistributedWorkerBase(WorkerBase):
|
||||
model_input: ModelRunnerInputBase = (
|
||||
self.model_runner.prepare_model_input(
|
||||
execute_model_req.seq_group_metadata_list))
|
||||
num_steps = execute_model_req.num_steps
|
||||
|
||||
if self.do_metadata_broadcast:
|
||||
broadcast_data = worker_input.as_broadcastable_tensor_dict()
|
||||
broadcast_data.update(
|
||||
model_input.as_broadcastable_tensor_dict())
|
||||
broadcast_data["num_steps"] = num_steps
|
||||
broadcast_tensor_dict(broadcast_data, src=0)
|
||||
else:
|
||||
assert self.do_metadata_broadcast
|
||||
@ -240,6 +242,7 @@ class LocalOrDistributedWorkerBase(WorkerBase):
|
||||
if not broadcast_data:
|
||||
return None
|
||||
|
||||
num_steps = broadcast_data.pop("num_steps")
|
||||
worker_input = WorkerInput.from_broadcasted_tensor_dict(
|
||||
broadcast_data)
|
||||
model_input = (
|
||||
@ -252,10 +255,8 @@ class LocalOrDistributedWorkerBase(WorkerBase):
|
||||
if worker_input.num_seq_groups == 0:
|
||||
return []
|
||||
|
||||
output = self.model_runner.execute_model(model_input, self.kv_cache)
|
||||
# Worker only supports single-step execution. Wrap the output in a
|
||||
# list to conform to interface.
|
||||
return [output]
|
||||
return self.model_runner.execute_model(model_input, self.kv_cache,
|
||||
num_steps)
|
||||
|
||||
|
||||
class WorkerWrapperBase:
|
||||
|
@ -334,7 +334,12 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
|
||||
self,
|
||||
model_input: ModelInputForXPU,
|
||||
kv_caches: List[torch.Tensor],
|
||||
) -> Optional[SamplerOutput]:
|
||||
num_steps: int = 1,
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
if num_steps > 1:
|
||||
raise ValueError(
|
||||
"XPUModelRunner does not support multi-step execution.")
|
||||
|
||||
model_executable = self.model
|
||||
execute_model_kwargs = {
|
||||
"input_ids": model_input.input_tokens,
|
||||
@ -354,14 +359,14 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
|
||||
|
||||
# Only perform sampling in the driver worker.
|
||||
if not self.is_driver_worker:
|
||||
return None
|
||||
return []
|
||||
|
||||
# Sample the next token.
|
||||
output = self.model.sample(
|
||||
logits=logits,
|
||||
sampling_metadata=model_input.sampling_metadata,
|
||||
)
|
||||
return output
|
||||
return [output]
|
||||
|
||||
def _prepare_prompt(
|
||||
self,
|
||||
|
Loading…
x
Reference in New Issue
Block a user