[Spec Decode] Introduce DraftModelRunner (#5799)
This commit is contained in:
parent
b90d8cd832
commit
b2c620230a
@ -7,6 +7,7 @@ import torch
|
|||||||
|
|
||||||
from vllm.model_executor.utils import set_random_seed
|
from vllm.model_executor.utils import set_random_seed
|
||||||
from vllm.sequence import ExecuteModelRequest, Logprob, SamplerOutput
|
from vllm.sequence import ExecuteModelRequest, Logprob, SamplerOutput
|
||||||
|
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
|
||||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||||
from vllm.worker.worker import Worker
|
from vllm.worker.worker import Worker
|
||||||
@ -85,6 +86,7 @@ def test_same_output_for_single_step():
|
|||||||
block_size,
|
block_size,
|
||||||
num_gpu_blocks,
|
num_gpu_blocks,
|
||||||
seed,
|
seed,
|
||||||
|
model_runner_cls=TP1DraftModelRunner,
|
||||||
)
|
)
|
||||||
worker = create_worker(
|
worker = create_worker(
|
||||||
Worker,
|
Worker,
|
||||||
@ -168,6 +170,7 @@ def test_same_output_for_multi_step():
|
|||||||
block_size,
|
block_size,
|
||||||
num_gpu_blocks,
|
num_gpu_blocks,
|
||||||
seed,
|
seed,
|
||||||
|
model_runner_cls=TP1DraftModelRunner,
|
||||||
)
|
)
|
||||||
|
|
||||||
worker = create_worker(
|
worker = create_worker(
|
||||||
|
@ -14,6 +14,7 @@ from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
|
|||||||
SequenceOutput)
|
SequenceOutput)
|
||||||
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
|
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
|
||||||
from vllm.worker.cache_engine import CacheEngine
|
from vllm.worker.cache_engine import CacheEngine
|
||||||
|
from vllm.worker.model_runner import ModelRunner
|
||||||
from vllm.worker.worker import Worker
|
from vllm.worker.worker import Worker
|
||||||
|
|
||||||
T = TypeVar("T", bound=Worker)
|
T = TypeVar("T", bound=Worker)
|
||||||
@ -66,7 +67,8 @@ def create_worker(cls: Callable[..., T],
|
|||||||
num_gpu_blocks: int,
|
num_gpu_blocks: int,
|
||||||
seed: int,
|
seed: int,
|
||||||
is_driver_worker: bool = True,
|
is_driver_worker: bool = True,
|
||||||
enforce_eager: bool = True) -> T:
|
enforce_eager: bool = True,
|
||||||
|
model_runner_cls: Optional[ModelRunner] = None) -> T:
|
||||||
engine_args = EngineArgs(
|
engine_args = EngineArgs(
|
||||||
model=model_name,
|
model=model_name,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
@ -89,6 +91,7 @@ def create_worker(cls: Callable[..., T],
|
|||||||
rank=0,
|
rank=0,
|
||||||
distributed_init_method=distributed_init_method,
|
distributed_init_method=distributed_init_method,
|
||||||
is_driver_worker=is_driver_worker,
|
is_driver_worker=is_driver_worker,
|
||||||
|
model_runner_cls=model_runner_cls,
|
||||||
)
|
)
|
||||||
|
|
||||||
worker.init_device()
|
worker.init_device()
|
||||||
|
@ -880,6 +880,8 @@ class ExecuteModelRequest:
|
|||||||
running_queue_size: int = 0
|
running_queue_size: int = 0
|
||||||
# Optional hidden states from prior step.
|
# Optional hidden states from prior step.
|
||||||
previous_hidden_states: Optional[HiddenStates] = None
|
previous_hidden_states: Optional[HiddenStates] = None
|
||||||
|
# The number of forward steps to run.
|
||||||
|
num_steps: int = 1
|
||||||
|
|
||||||
def clone(
|
def clone(
|
||||||
self, seq_group_metadata_list: List[SequenceGroupMetadata]
|
self, seq_group_metadata_list: List[SequenceGroupMetadata]
|
||||||
@ -893,4 +895,5 @@ class ExecuteModelRequest:
|
|||||||
num_lookahead_slots=self.num_lookahead_slots,
|
num_lookahead_slots=self.num_lookahead_slots,
|
||||||
running_queue_size=self.running_queue_size,
|
running_queue_size=self.running_queue_size,
|
||||||
previous_hidden_states=self.previous_hidden_states,
|
previous_hidden_states=self.previous_hidden_states,
|
||||||
|
num_steps=self.num_steps,
|
||||||
)
|
)
|
||||||
|
170
vllm/spec_decode/draft_model_runner.py
Normal file
170
vllm/spec_decode/draft_model_runner.py
Normal file
@ -0,0 +1,170 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||||
|
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||||
|
VisionLanguageConfig)
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||||
|
from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata,
|
||||||
|
ModelRunner)
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TP1DraftModelRunner(ModelRunner):
|
||||||
|
"""Specialized model runner for speculative decoding draft model.
|
||||||
|
Since the draft model always execute k forward passes consecutively to
|
||||||
|
generate k speculative tokens in a single speculative decoding step,
|
||||||
|
we could get rid of most CPU-GPU synchronization and data transfer
|
||||||
|
overheads by keeping model input and output tensors on GPU all the time.
|
||||||
|
|
||||||
|
This runner is still under development so there's no performance gain
|
||||||
|
at this moment. Currently we adopt a temporary solution that caches the
|
||||||
|
seq_group_metadata_list for multi-step execution, so that we can
|
||||||
|
leverage existing prepare_model_input to be compatible with the current
|
||||||
|
execution flow, but we plan to remove this cache and avoid calling
|
||||||
|
prepare_model_input in execute_model at all.
|
||||||
|
|
||||||
|
The detail development plan includes:
|
||||||
|
1. Use "update_model_input" to update existing model_input without
|
||||||
|
creating a new one.
|
||||||
|
2. Improve the performance of "update_model_input" with a GPU kernel.
|
||||||
|
3. Support TP > 1 (this requires some designs because we do not expect
|
||||||
|
any broadcasting inside execute_model).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
parallel_config: ParallelConfig,
|
||||||
|
scheduler_config: SchedulerConfig,
|
||||||
|
device_config: DeviceConfig,
|
||||||
|
cache_config: CacheConfig,
|
||||||
|
load_config: LoadConfig,
|
||||||
|
lora_config: Optional[LoRAConfig],
|
||||||
|
kv_cache_dtype: Optional[str] = "auto",
|
||||||
|
is_driver_worker: bool = False,
|
||||||
|
vision_language_config: Optional[VisionLanguageConfig] = None,
|
||||||
|
return_hidden_states: bool = False,
|
||||||
|
):
|
||||||
|
if return_hidden_states:
|
||||||
|
raise ValueError(
|
||||||
|
"return_hidden_states is not supported for TP1DraftModelRunner."
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
model_config=model_config,
|
||||||
|
parallel_config=parallel_config,
|
||||||
|
scheduler_config=scheduler_config,
|
||||||
|
device_config=device_config,
|
||||||
|
cache_config=cache_config,
|
||||||
|
load_config=load_config,
|
||||||
|
lora_config=lora_config,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
is_driver_worker=is_driver_worker,
|
||||||
|
vision_language_config=vision_language_config,
|
||||||
|
return_hidden_states=return_hidden_states,
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: Remove this cache when we are able to update model_input
|
||||||
|
# directly in advance_step.
|
||||||
|
self.cached_seq_group_metadata_list: Optional[
|
||||||
|
List[SequenceGroupMetadata]] = None
|
||||||
|
|
||||||
|
def prepare_model_input(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
) -> ModelInputForGPUWithSamplingMetadata:
|
||||||
|
"""A temporary solution that caches the seq_group_metadata_list
|
||||||
|
for multi-step execution.
|
||||||
|
TODO: In-place update model_input and remove this function.
|
||||||
|
"""
|
||||||
|
self.cached_seq_group_metadata_list = seq_group_metadata_list
|
||||||
|
return super().prepare_model_input(seq_group_metadata_list)
|
||||||
|
|
||||||
|
def update_model_input(
|
||||||
|
self, model_input: ModelInputForGPUWithSamplingMetadata,
|
||||||
|
last_output: SamplerOutput
|
||||||
|
) -> ModelInputForGPUWithSamplingMetadata:
|
||||||
|
"""Prepare the model inputs for the next step.
|
||||||
|
TODO: In-place update model_input instead of calling
|
||||||
|
prepare_model_input.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Append the output token to the sequence data.
|
||||||
|
assert self.cached_seq_group_metadata_list is not None
|
||||||
|
for seq_group_metadata, sequence_group_outputs in zip(
|
||||||
|
self.cached_seq_group_metadata_list, last_output.outputs):
|
||||||
|
seq_group_metadata.is_prompt = False
|
||||||
|
|
||||||
|
for seq_output in sequence_group_outputs.samples:
|
||||||
|
seq = seq_group_metadata.seq_data[seq_output.parent_seq_id]
|
||||||
|
|
||||||
|
token_id = seq_output.output_token
|
||||||
|
token_logprob = seq_output.logprobs[token_id]
|
||||||
|
|
||||||
|
seq.append_token_id(token_id, token_logprob.logprob)
|
||||||
|
seq.update_num_computed_tokens(1)
|
||||||
|
|
||||||
|
return self.prepare_model_input(self.cached_seq_group_metadata_list)
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def execute_model(
|
||||||
|
self,
|
||||||
|
model_input: ModelInputForGPUWithSamplingMetadata,
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
num_steps: int = 1,
|
||||||
|
) -> Optional[List[SamplerOutput]]:
|
||||||
|
# Since we do not broadcast data inside execute_model anymore,
|
||||||
|
# we need to figure out the best way to support TP > 1 in this
|
||||||
|
# case, because we will at least need to broadcast the sampled
|
||||||
|
# tokens to all workers.
|
||||||
|
if not self.is_driver_worker:
|
||||||
|
raise ValueError("TP1DraftModelRunner only supports TP=1.")
|
||||||
|
|
||||||
|
if self.lora_config:
|
||||||
|
assert model_input.lora_requests is not None
|
||||||
|
assert model_input.lora_mapping is not None
|
||||||
|
self.set_active_loras(model_input.lora_requests,
|
||||||
|
model_input.lora_mapping)
|
||||||
|
|
||||||
|
outputs: List[SamplerOutput] = []
|
||||||
|
for step in range(num_steps):
|
||||||
|
# Currently cuda graph is only supported by the decode phase.
|
||||||
|
assert model_input.attn_metadata is not None
|
||||||
|
prefill_meta = model_input.attn_metadata.prefill_metadata
|
||||||
|
decode_meta = model_input.attn_metadata.decode_metadata
|
||||||
|
if prefill_meta is None and decode_meta.use_cuda_graph:
|
||||||
|
assert model_input.input_tokens is not None
|
||||||
|
graph_batch_size = model_input.input_tokens.shape[0]
|
||||||
|
model_executable = self.graph_runners[graph_batch_size]
|
||||||
|
else:
|
||||||
|
model_executable = self.model
|
||||||
|
|
||||||
|
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
|
||||||
|
hidden_states = model_executable(
|
||||||
|
input_ids=model_input.input_tokens,
|
||||||
|
positions=model_input.input_positions,
|
||||||
|
kv_caches=kv_caches,
|
||||||
|
attn_metadata=model_input.attn_metadata,
|
||||||
|
**multi_modal_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute the logits.
|
||||||
|
logits = self.model.compute_logits(hidden_states,
|
||||||
|
model_input.sampling_metadata)
|
||||||
|
|
||||||
|
# Sample the next token.
|
||||||
|
outputs.append(
|
||||||
|
self.model.sample(
|
||||||
|
logits=logits,
|
||||||
|
sampling_metadata=model_input.sampling_metadata,
|
||||||
|
))
|
||||||
|
|
||||||
|
# Prepare the inputs for the next step.
|
||||||
|
if step != num_steps - 1:
|
||||||
|
model_input = self.update_model_input(model_input, outputs[-1])
|
||||||
|
|
||||||
|
return outputs
|
@ -6,6 +6,7 @@ import torch
|
|||||||
|
|
||||||
from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData,
|
from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData,
|
||||||
SequenceGroupMetadata)
|
SequenceGroupMetadata)
|
||||||
|
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
|
||||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||||
SpeculativeProposer)
|
SpeculativeProposer)
|
||||||
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
|
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
|
||||||
@ -67,22 +68,24 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
|
|||||||
copied_execute_model_req = execute_model_req.clone(
|
copied_execute_model_req = execute_model_req.clone(
|
||||||
copied_seq_group_metadata_list)
|
copied_seq_group_metadata_list)
|
||||||
|
|
||||||
# Assert enough KV space for sample_len tokens per sequence.
|
|
||||||
self._assert_enough_kv_space(execute_model_req.seq_group_metadata_list,
|
|
||||||
sample_len)
|
|
||||||
|
|
||||||
# Run model sample_len times.
|
# Run model sample_len times.
|
||||||
model_outputs: List[SamplerOutput] = []
|
model_outputs: List[SamplerOutput] = []
|
||||||
for _ in range(sample_len):
|
if isinstance(self.model_runner, TP1DraftModelRunner):
|
||||||
model_output: List[SamplerOutput] = super().execute_model(
|
copied_execute_model_req.num_steps = sample_len
|
||||||
|
model_outputs = self.execute_model(
|
||||||
execute_model_req=copied_execute_model_req)
|
execute_model_req=copied_execute_model_req)
|
||||||
assert (len(model_output) == 1
|
else:
|
||||||
), "composing multistep workers not supported"
|
# TODO: Remove this branch once DraftModelRunner supports TP>1.
|
||||||
model_output = model_output[0]
|
for _ in range(sample_len):
|
||||||
|
model_output: List[SamplerOutput] = super().execute_model(
|
||||||
|
execute_model_req=copied_execute_model_req)
|
||||||
|
assert (len(model_output) == 1
|
||||||
|
), "composing multistep workers not supported"
|
||||||
|
model_output = model_output[0]
|
||||||
|
|
||||||
self._append_new_tokens(model_output,
|
self._append_new_tokens(model_output,
|
||||||
copied_seq_group_metadata_list)
|
copied_seq_group_metadata_list)
|
||||||
model_outputs.append(model_output)
|
model_outputs.append(model_output)
|
||||||
|
|
||||||
return model_outputs, True
|
return model_outputs, True
|
||||||
|
|
||||||
|
@ -11,6 +11,7 @@ from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest,
|
|||||||
HiddenStates, SamplerOutput, SequenceGroupMetadata,
|
HiddenStates, SamplerOutput, SequenceGroupMetadata,
|
||||||
get_all_seq_ids)
|
get_all_seq_ids)
|
||||||
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
|
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
|
||||||
|
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
|
||||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||||
SpeculativeScorer, SpeculativeScores)
|
SpeculativeScorer, SpeculativeScores)
|
||||||
from vllm.spec_decode.metrics import AsyncMetricsCollector
|
from vllm.spec_decode.metrics import AsyncMetricsCollector
|
||||||
@ -117,6 +118,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
draft_tp = draft_parallel_config.tensor_parallel_size
|
draft_tp = draft_parallel_config.tensor_parallel_size
|
||||||
target_tp = scorer_worker.parallel_config.tensor_parallel_size
|
target_tp = scorer_worker.parallel_config.tensor_parallel_size
|
||||||
|
|
||||||
|
if draft_tp == 1:
|
||||||
|
draft_worker_kwargs["model_runner_cls"] = TP1DraftModelRunner
|
||||||
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
|
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
|
||||||
proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
|
proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
|
||||||
proposer_worker, draft_tp, target_tp)
|
proposer_worker, draft_tp, target_tp)
|
||||||
|
@ -351,7 +351,12 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
|
|||||||
self,
|
self,
|
||||||
model_input: CPUModelInput,
|
model_input: CPUModelInput,
|
||||||
kv_caches: List[torch.Tensor],
|
kv_caches: List[torch.Tensor],
|
||||||
) -> Optional[SamplerOutput]:
|
num_steps: int = 1,
|
||||||
|
) -> Optional[List[SamplerOutput]]:
|
||||||
|
if num_steps > 1:
|
||||||
|
raise ValueError(
|
||||||
|
"CPU worker does not support multi-step execution.")
|
||||||
|
|
||||||
model_executable = self.model
|
model_executable = self.model
|
||||||
execute_model_kwargs = {
|
execute_model_kwargs = {
|
||||||
"input_ids": model_input.input_tokens,
|
"input_ids": model_input.input_tokens,
|
||||||
@ -371,11 +376,11 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
|
|||||||
|
|
||||||
# Only perform sampling in the driver worker.
|
# Only perform sampling in the driver worker.
|
||||||
if not self.is_driver_worker:
|
if not self.is_driver_worker:
|
||||||
return None
|
return []
|
||||||
|
|
||||||
# Sample the next token.
|
# Sample the next token.
|
||||||
output = self.model.sample(
|
output = self.model.sample(
|
||||||
logits=logits,
|
logits=logits,
|
||||||
sampling_metadata=model_input.sampling_metadata,
|
sampling_metadata=model_input.sampling_metadata,
|
||||||
)
|
)
|
||||||
return output
|
return [output]
|
||||||
|
@ -57,7 +57,12 @@ class EmbeddingModelRunner(
|
|||||||
self,
|
self,
|
||||||
model_input: ModelInputForGPUWithPoolingMetadata,
|
model_input: ModelInputForGPUWithPoolingMetadata,
|
||||||
kv_caches: List[torch.Tensor],
|
kv_caches: List[torch.Tensor],
|
||||||
) -> Optional[PoolerOutput]:
|
num_steps: int = 1,
|
||||||
|
) -> Optional[List[PoolerOutput]]:
|
||||||
|
if num_steps > 1:
|
||||||
|
raise ValueError(
|
||||||
|
"EmbeddingModelRunner does not support multi-step execution.")
|
||||||
|
|
||||||
if self.lora_config:
|
if self.lora_config:
|
||||||
assert model_input.lora_requests is not None
|
assert model_input.lora_requests is not None
|
||||||
assert model_input.lora_mapping is not None
|
assert model_input.lora_mapping is not None
|
||||||
@ -91,10 +96,12 @@ class EmbeddingModelRunner(
|
|||||||
|
|
||||||
# Only perform pooling in the driver worker.
|
# Only perform pooling in the driver worker.
|
||||||
if not self.is_driver_worker:
|
if not self.is_driver_worker:
|
||||||
return None
|
return []
|
||||||
|
|
||||||
return self.model.pooler(hidden_states=hidden_states,
|
return [
|
||||||
pooling_metadata=model_input.pooling_metadata)
|
self.model.pooler(hidden_states=hidden_states,
|
||||||
|
pooling_metadata=model_input.pooling_metadata)
|
||||||
|
]
|
||||||
|
|
||||||
def make_model_input_from_broadcasted_tensor_dict(
|
def make_model_input_from_broadcasted_tensor_dict(
|
||||||
self,
|
self,
|
||||||
|
@ -959,7 +959,11 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
|||||||
self,
|
self,
|
||||||
model_input: ModelInputForGPUWithSamplingMetadata,
|
model_input: ModelInputForGPUWithSamplingMetadata,
|
||||||
kv_caches: List[torch.Tensor],
|
kv_caches: List[torch.Tensor],
|
||||||
) -> SamplerOutput:
|
num_steps: int = 1,
|
||||||
|
) -> Optional[List[SamplerOutput]]:
|
||||||
|
if num_steps > 1:
|
||||||
|
raise ValueError("num_steps > 1 is not supported in ModelRunner")
|
||||||
|
|
||||||
if self.lora_config:
|
if self.lora_config:
|
||||||
assert model_input.lora_requests is not None
|
assert model_input.lora_requests is not None
|
||||||
assert model_input.lora_mapping is not None
|
assert model_input.lora_mapping is not None
|
||||||
@ -992,7 +996,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
|||||||
|
|
||||||
# Only perform sampling in the driver worker.
|
# Only perform sampling in the driver worker.
|
||||||
if not self.is_driver_worker:
|
if not self.is_driver_worker:
|
||||||
return None
|
return []
|
||||||
|
|
||||||
# Sample the next token.
|
# Sample the next token.
|
||||||
output: SamplerOutput = self.model.sample(
|
output: SamplerOutput = self.model.sample(
|
||||||
@ -1011,7 +1015,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
|||||||
|
|
||||||
output.hidden_states = hidden_states
|
output.hidden_states = hidden_states
|
||||||
|
|
||||||
return output
|
return [output]
|
||||||
|
|
||||||
|
|
||||||
class CUDAGraphRunner:
|
class CUDAGraphRunner:
|
||||||
|
@ -150,7 +150,8 @@ class ModelRunnerBase(ABC, Generic[T]):
|
|||||||
self,
|
self,
|
||||||
model_input: T,
|
model_input: T,
|
||||||
kv_caches: Optional[List[torch.Tensor]],
|
kv_caches: Optional[List[torch.Tensor]],
|
||||||
) -> Optional[SamplerOutput]:
|
num_steps: int = 1,
|
||||||
|
) -> Optional[List[SamplerOutput]]:
|
||||||
"""
|
"""
|
||||||
Execute the model on the given input.
|
Execute the model on the given input.
|
||||||
"""
|
"""
|
||||||
|
@ -207,7 +207,12 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
|
|||||||
self,
|
self,
|
||||||
model_input: ModelInputForNeuron,
|
model_input: ModelInputForNeuron,
|
||||||
kv_caches: Optional[List[torch.Tensor]] = None,
|
kv_caches: Optional[List[torch.Tensor]] = None,
|
||||||
) -> Optional[SamplerOutput]:
|
num_steps: int = 1,
|
||||||
|
) -> Optional[List[SamplerOutput]]:
|
||||||
|
if num_steps > 1:
|
||||||
|
raise ValueError(
|
||||||
|
"NeuronModelRunner does not support multi-step execution.")
|
||||||
|
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
input_ids=model_input.input_tokens,
|
input_ids=model_input.input_tokens,
|
||||||
positions=model_input.input_positions,
|
positions=model_input.input_positions,
|
||||||
@ -223,7 +228,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
|
|||||||
logits=logits,
|
logits=logits,
|
||||||
sampling_metadata=model_input.sampling_metadata,
|
sampling_metadata=model_input.sampling_metadata,
|
||||||
)
|
)
|
||||||
return output
|
return [output]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def vocab_size(self) -> int:
|
def vocab_size(self) -> int:
|
||||||
|
@ -444,7 +444,12 @@ class TPUModelRunner:
|
|||||||
self,
|
self,
|
||||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||||
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
) -> SamplerOutput:
|
num_steps: int = 1,
|
||||||
|
) -> List[SamplerOutput]:
|
||||||
|
if num_steps > 1:
|
||||||
|
raise ValueError(
|
||||||
|
"TPUModelRunner does not support multi-step execution.")
|
||||||
|
|
||||||
assert seq_group_metadata_list is not None
|
assert seq_group_metadata_list is not None
|
||||||
assert len(seq_group_metadata_list) > 0
|
assert len(seq_group_metadata_list) > 0
|
||||||
if seq_group_metadata_list[0].is_prompt:
|
if seq_group_metadata_list[0].is_prompt:
|
||||||
@ -462,7 +467,7 @@ class TPUModelRunner:
|
|||||||
else:
|
else:
|
||||||
sampler_outputs = self._execute_model(seq_group_metadata_list,
|
sampler_outputs = self._execute_model(seq_group_metadata_list,
|
||||||
kv_caches)
|
kv_caches)
|
||||||
return SamplerOutput(sampler_outputs)
|
return [SamplerOutput(sampler_outputs)]
|
||||||
|
|
||||||
|
|
||||||
class ModelWrapper(nn.Module):
|
class ModelWrapper(nn.Module):
|
||||||
|
@ -45,6 +45,7 @@ class Worker(LocalOrDistributedWorkerBase):
|
|||||||
vision_language_config: Optional[VisionLanguageConfig] = None,
|
vision_language_config: Optional[VisionLanguageConfig] = None,
|
||||||
speculative_config: Optional[SpeculativeConfig] = None,
|
speculative_config: Optional[SpeculativeConfig] = None,
|
||||||
is_driver_worker: bool = False,
|
is_driver_worker: bool = False,
|
||||||
|
model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.parallel_config = parallel_config
|
self.parallel_config = parallel_config
|
||||||
@ -78,7 +79,9 @@ class Worker(LocalOrDistributedWorkerBase):
|
|||||||
"mlp_speculator") else {"return_hidden_states": True}
|
"mlp_speculator") else {"return_hidden_states": True}
|
||||||
|
|
||||||
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
|
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
|
||||||
if self.model_config.embedding_mode:
|
if model_runner_cls is not None:
|
||||||
|
ModelRunnerClass = model_runner_cls
|
||||||
|
elif self.model_config.embedding_mode:
|
||||||
ModelRunnerClass = EmbeddingModelRunner
|
ModelRunnerClass = EmbeddingModelRunner
|
||||||
self.model_runner: GPUModelRunnerBase = ModelRunnerClass(
|
self.model_runner: GPUModelRunnerBase = ModelRunnerClass(
|
||||||
model_config,
|
model_config,
|
||||||
|
@ -228,11 +228,13 @@ class LocalOrDistributedWorkerBase(WorkerBase):
|
|||||||
model_input: ModelRunnerInputBase = (
|
model_input: ModelRunnerInputBase = (
|
||||||
self.model_runner.prepare_model_input(
|
self.model_runner.prepare_model_input(
|
||||||
execute_model_req.seq_group_metadata_list))
|
execute_model_req.seq_group_metadata_list))
|
||||||
|
num_steps = execute_model_req.num_steps
|
||||||
|
|
||||||
if self.do_metadata_broadcast:
|
if self.do_metadata_broadcast:
|
||||||
broadcast_data = worker_input.as_broadcastable_tensor_dict()
|
broadcast_data = worker_input.as_broadcastable_tensor_dict()
|
||||||
broadcast_data.update(
|
broadcast_data.update(
|
||||||
model_input.as_broadcastable_tensor_dict())
|
model_input.as_broadcastable_tensor_dict())
|
||||||
|
broadcast_data["num_steps"] = num_steps
|
||||||
broadcast_tensor_dict(broadcast_data, src=0)
|
broadcast_tensor_dict(broadcast_data, src=0)
|
||||||
else:
|
else:
|
||||||
assert self.do_metadata_broadcast
|
assert self.do_metadata_broadcast
|
||||||
@ -240,6 +242,7 @@ class LocalOrDistributedWorkerBase(WorkerBase):
|
|||||||
if not broadcast_data:
|
if not broadcast_data:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
num_steps = broadcast_data.pop("num_steps")
|
||||||
worker_input = WorkerInput.from_broadcasted_tensor_dict(
|
worker_input = WorkerInput.from_broadcasted_tensor_dict(
|
||||||
broadcast_data)
|
broadcast_data)
|
||||||
model_input = (
|
model_input = (
|
||||||
@ -252,10 +255,8 @@ class LocalOrDistributedWorkerBase(WorkerBase):
|
|||||||
if worker_input.num_seq_groups == 0:
|
if worker_input.num_seq_groups == 0:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
output = self.model_runner.execute_model(model_input, self.kv_cache)
|
return self.model_runner.execute_model(model_input, self.kv_cache,
|
||||||
# Worker only supports single-step execution. Wrap the output in a
|
num_steps)
|
||||||
# list to conform to interface.
|
|
||||||
return [output]
|
|
||||||
|
|
||||||
|
|
||||||
class WorkerWrapperBase:
|
class WorkerWrapperBase:
|
||||||
|
@ -334,7 +334,12 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
|
|||||||
self,
|
self,
|
||||||
model_input: ModelInputForXPU,
|
model_input: ModelInputForXPU,
|
||||||
kv_caches: List[torch.Tensor],
|
kv_caches: List[torch.Tensor],
|
||||||
) -> Optional[SamplerOutput]:
|
num_steps: int = 1,
|
||||||
|
) -> Optional[List[SamplerOutput]]:
|
||||||
|
if num_steps > 1:
|
||||||
|
raise ValueError(
|
||||||
|
"XPUModelRunner does not support multi-step execution.")
|
||||||
|
|
||||||
model_executable = self.model
|
model_executable = self.model
|
||||||
execute_model_kwargs = {
|
execute_model_kwargs = {
|
||||||
"input_ids": model_input.input_tokens,
|
"input_ids": model_input.input_tokens,
|
||||||
@ -354,14 +359,14 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
|
|||||||
|
|
||||||
# Only perform sampling in the driver worker.
|
# Only perform sampling in the driver worker.
|
||||||
if not self.is_driver_worker:
|
if not self.is_driver_worker:
|
||||||
return None
|
return []
|
||||||
|
|
||||||
# Sample the next token.
|
# Sample the next token.
|
||||||
output = self.model.sample(
|
output = self.model.sample(
|
||||||
logits=logits,
|
logits=logits,
|
||||||
sampling_metadata=model_input.sampling_metadata,
|
sampling_metadata=model_input.sampling_metadata,
|
||||||
)
|
)
|
||||||
return output
|
return [output]
|
||||||
|
|
||||||
def _prepare_prompt(
|
def _prepare_prompt(
|
||||||
self,
|
self,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user