2024-11-21 12:53:39 -08:00
|
|
|
import gc
|
2024-11-05 22:16:04 -08:00
|
|
|
import time
|
2024-12-14 17:54:04 +00:00
|
|
|
from typing import TYPE_CHECKING, Dict, List, Tuple, cast
|
2024-10-22 01:24:07 -07:00
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
|
|
import torch.distributed
|
|
|
|
import torch.nn as nn
|
|
|
|
|
2025-01-17 15:39:35 +08:00
|
|
|
from vllm.attention.backends.abstract import AttentionType
|
|
|
|
from vllm.attention.layer import Attention
|
2024-11-19 10:09:03 -08:00
|
|
|
from vllm.config import CompilationLevel, VllmConfig
|
2024-11-26 00:00:16 -06:00
|
|
|
from vllm.distributed.parallel_state import graph_capture
|
2024-10-22 01:24:07 -07:00
|
|
|
from vllm.forward_context import set_forward_context
|
2024-12-16 22:10:57 -08:00
|
|
|
from vllm.inputs import INPUT_REGISTRY
|
2024-10-22 01:24:07 -07:00
|
|
|
from vllm.logger import init_logger
|
|
|
|
from vllm.model_executor.model_loader import get_model
|
2024-12-16 22:10:57 -08:00
|
|
|
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
2024-12-09 12:33:41 -05:00
|
|
|
from vllm.sampling_params import SamplingType
|
2024-12-11 04:53:37 +02:00
|
|
|
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
2025-01-17 15:39:35 +08:00
|
|
|
LayerBlockType, cdiv, is_pin_memory_available)
|
2024-10-22 01:24:07 -07:00
|
|
|
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
|
|
|
|
FlashAttentionMetadata)
|
2025-01-15 11:29:00 -08:00
|
|
|
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
2025-01-06 11:58:16 -08:00
|
|
|
from vllm.v1.engine.mm_input_mapper import MMInputMapperClient
|
2025-01-17 15:39:35 +08:00
|
|
|
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
|
|
|
KVCacheSpec)
|
2024-10-22 01:24:07 -07:00
|
|
|
from vllm.v1.outputs import ModelRunnerOutput
|
|
|
|
from vllm.v1.sample.metadata import SamplingMetadata
|
2025-01-17 15:39:35 +08:00
|
|
|
from vllm.v1.utils import bind_kv_cache
|
2024-12-09 12:33:41 -05:00
|
|
|
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
2024-10-22 01:24:07 -07:00
|
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
from vllm.v1.core.scheduler import SchedulerOutput
|
|
|
|
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
class GPUModelRunner:
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
2024-11-02 07:35:05 -07:00
|
|
|
vllm_config: VllmConfig,
|
2024-12-10 01:28:14 -05:00
|
|
|
device: torch.device,
|
2024-10-22 01:24:07 -07:00
|
|
|
):
|
2024-11-02 07:35:05 -07:00
|
|
|
self.vllm_config = vllm_config
|
|
|
|
self.model_config = vllm_config.model_config
|
|
|
|
self.cache_config = vllm_config.cache_config
|
|
|
|
self.lora_config = vllm_config.lora_config
|
|
|
|
self.load_config = vllm_config.load_config
|
|
|
|
self.parallel_config = vllm_config.parallel_config
|
|
|
|
self.scheduler_config = vllm_config.scheduler_config
|
|
|
|
self.speculative_config = vllm_config.speculative_config
|
|
|
|
self.prompt_adapter_config = vllm_config.prompt_adapter_config
|
|
|
|
self.observability_config = vllm_config.observability_config
|
2024-10-22 01:24:07 -07:00
|
|
|
|
2024-11-02 07:35:05 -07:00
|
|
|
model_config = self.model_config
|
|
|
|
cache_config = self.cache_config
|
|
|
|
scheduler_config = self.scheduler_config
|
|
|
|
parallel_config = self.parallel_config
|
2024-12-10 01:28:14 -05:00
|
|
|
self.device = device
|
2024-10-22 01:24:07 -07:00
|
|
|
self.pin_memory = is_pin_memory_available()
|
|
|
|
self.dtype = self.model_config.dtype
|
|
|
|
if cache_config.cache_dtype == "auto":
|
|
|
|
self.kv_cache_dtype = self.dtype
|
|
|
|
else:
|
|
|
|
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
|
|
|
|
cache_config.cache_dtype]
|
|
|
|
|
2024-12-11 10:49:23 -08:00
|
|
|
self.is_multimodal_model = model_config.is_multimodal_model
|
2024-10-22 01:24:07 -07:00
|
|
|
self.sliding_window = model_config.get_sliding_window()
|
|
|
|
self.block_size = cache_config.block_size
|
|
|
|
self.max_model_len = model_config.max_model_len
|
|
|
|
self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
|
|
|
|
self.max_num_tokens = scheduler_config.max_num_batched_tokens
|
2024-12-11 23:14:20 -08:00
|
|
|
self.max_num_reqs = scheduler_config.max_num_seqs
|
2024-10-22 01:24:07 -07:00
|
|
|
|
|
|
|
# Model-related.
|
2024-12-11 04:53:37 +02:00
|
|
|
self.num_attn_layers = model_config.get_num_layers_by_block_type(
|
|
|
|
parallel_config, LayerBlockType.attention)
|
2025-01-01 21:56:46 +09:00
|
|
|
self.num_query_heads = model_config.get_num_attention_heads(
|
|
|
|
parallel_config)
|
2024-10-22 01:24:07 -07:00
|
|
|
self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
|
|
|
|
self.head_size = model_config.get_head_size()
|
2024-11-12 20:53:13 -08:00
|
|
|
self.hidden_size = model_config.get_hidden_size()
|
|
|
|
|
|
|
|
# Multi-modal data support
|
2024-12-16 22:10:57 -08:00
|
|
|
self.input_registry = INPUT_REGISTRY
|
|
|
|
self.mm_registry = MULTIMODAL_REGISTRY
|
2024-12-18 18:54:46 -05:00
|
|
|
|
2025-01-06 11:58:16 -08:00
|
|
|
# NOTE: Initialized input mapper is only used for processing dummy
|
|
|
|
# multimodal data into multimodal kwargs for GPU memory profiling.
|
|
|
|
self.mm_input_mapper_profiling = MMInputMapperClient(self.model_config)
|
|
|
|
self.mm_input_mapper_profiling.use_cache = False
|
2024-12-18 18:54:46 -05:00
|
|
|
|
2025-01-15 11:29:00 -08:00
|
|
|
encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
|
|
|
|
model_config=model_config,
|
|
|
|
scheduler_config=scheduler_config,
|
|
|
|
)
|
|
|
|
self.max_num_encoder_input_tokens = encoder_compute_budget
|
|
|
|
self.encoder_cache_size = encoder_cache_size
|
2024-10-22 01:24:07 -07:00
|
|
|
|
|
|
|
# Lazy initialization
|
|
|
|
# self.model: nn.Module # Set after load_model
|
|
|
|
self.kv_caches: List[torch.Tensor] = []
|
2024-11-12 20:53:13 -08:00
|
|
|
# req_id -> (input_id -> encoder_output)
|
|
|
|
self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {}
|
2024-10-22 01:24:07 -07:00
|
|
|
|
|
|
|
# Request states.
|
|
|
|
self.requests: Dict[str, CachedRequestState] = {}
|
|
|
|
# Persistent batch.
|
|
|
|
self.input_batch = InputBatch(
|
2024-12-11 23:14:20 -08:00
|
|
|
max_num_reqs=self.max_num_reqs,
|
2024-10-22 01:24:07 -07:00
|
|
|
max_model_len=self.max_model_len,
|
|
|
|
max_num_blocks_per_req=self.max_num_blocks_per_req,
|
|
|
|
device=self.device,
|
|
|
|
pin_memory=self.pin_memory,
|
2024-12-26 02:02:58 -08:00
|
|
|
vocab_size=model_config.get_vocab_size(),
|
2024-10-22 01:24:07 -07:00
|
|
|
)
|
|
|
|
|
2024-11-16 18:02:14 -08:00
|
|
|
self.use_cuda_graph = (self.vllm_config.compilation_config.level
|
2024-11-05 22:16:04 -08:00
|
|
|
== CompilationLevel.PIECEWISE
|
|
|
|
and not self.model_config.enforce_eager)
|
|
|
|
# TODO(woosuk): Provide an option to tune the max cudagraph batch size.
|
2024-12-02 22:19:02 -08:00
|
|
|
# The convention is different.
|
|
|
|
# self.cudagraph_batch_sizes sorts in ascending order.
|
|
|
|
# The batch sizes in the config are in descending order.
|
|
|
|
self.cudagraph_batch_sizes = list(
|
|
|
|
reversed(self.vllm_config.compilation_config.capture_sizes))
|
2024-12-11 10:49:23 -08:00
|
|
|
|
2025-01-01 21:56:46 +09:00
|
|
|
# Cache the device properties.
|
|
|
|
self.device_properties = torch.cuda.get_device_properties(self.device)
|
|
|
|
self.num_sms = self.device_properties.multi_processor_count
|
|
|
|
|
2024-12-11 10:49:23 -08:00
|
|
|
# Persistent buffers for CUDA graphs.
|
|
|
|
self.input_ids = torch.zeros(self.max_num_tokens,
|
|
|
|
dtype=torch.int32,
|
|
|
|
device=self.device)
|
2024-11-05 22:16:04 -08:00
|
|
|
self.positions = torch.zeros(self.max_num_tokens,
|
|
|
|
dtype=torch.int64,
|
|
|
|
device=self.device)
|
2024-11-12 20:53:13 -08:00
|
|
|
self.inputs_embeds = torch.zeros(
|
|
|
|
(self.max_num_tokens, self.hidden_size),
|
|
|
|
dtype=self.dtype,
|
|
|
|
device=self.device)
|
2024-11-05 22:16:04 -08:00
|
|
|
|
2024-12-15 13:33:00 -08:00
|
|
|
# OPTIMIZATION: Cache the tensors rather than creating them every step.
|
2025-01-01 21:56:46 +09:00
|
|
|
self.arange_np = np.arange(max(self.max_num_reqs + 1,
|
|
|
|
self.max_model_len),
|
2024-12-15 13:33:00 -08:00
|
|
|
dtype=np.int32)
|
|
|
|
# NOTE(woosuk): These tensors are "stateless", i.e., they are literally
|
|
|
|
# a faster version of creating a new tensor every time. Thus, we should
|
|
|
|
# not make any assumptions about the values in these tensors.
|
2024-12-11 23:14:20 -08:00
|
|
|
self.input_ids_cpu = torch.zeros(self.max_num_tokens,
|
|
|
|
dtype=torch.int32,
|
|
|
|
device="cpu",
|
|
|
|
pin_memory=self.pin_memory)
|
|
|
|
self.input_ids_np = self.input_ids_cpu.numpy()
|
|
|
|
self.positions_cpu = torch.zeros(self.max_num_tokens,
|
|
|
|
dtype=torch.int64,
|
|
|
|
device="cpu",
|
|
|
|
pin_memory=self.pin_memory)
|
|
|
|
self.positions_np = self.positions_cpu.numpy()
|
|
|
|
self.slot_mapping_cpu = torch.zeros(self.max_num_tokens,
|
|
|
|
dtype=torch.int32,
|
|
|
|
device="cpu",
|
|
|
|
pin_memory=self.pin_memory)
|
|
|
|
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
|
|
|
|
self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1,
|
|
|
|
dtype=torch.int32,
|
|
|
|
device="cpu",
|
|
|
|
pin_memory=self.pin_memory)
|
|
|
|
self.query_start_loc_np = self.query_start_loc_cpu.numpy()
|
|
|
|
self.seq_start_loc_cpu = torch.zeros(self.max_num_reqs + 1,
|
|
|
|
dtype=torch.int32,
|
|
|
|
device="cpu",
|
|
|
|
pin_memory=self.pin_memory)
|
|
|
|
self.seq_start_loc_np = self.seq_start_loc_cpu.numpy()
|
|
|
|
|
2024-10-22 01:24:07 -07:00
|
|
|
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
|
|
|
# Remove stopped requests from the cached states.
|
|
|
|
# Keep the states of the pre-empted requests.
|
|
|
|
for req_id in scheduler_output.finished_req_ids:
|
|
|
|
self.requests.pop(req_id, None)
|
2024-11-12 20:53:13 -08:00
|
|
|
self.encoder_cache.pop(req_id, None)
|
|
|
|
|
|
|
|
# Free the cached encoder outputs.
|
|
|
|
for req_id, input_id in scheduler_output.free_encoder_input_ids:
|
|
|
|
encoder_outputs = self.encoder_cache.get(req_id)
|
|
|
|
if encoder_outputs is not None:
|
|
|
|
encoder_outputs.pop(input_id, None)
|
|
|
|
if not encoder_outputs:
|
|
|
|
self.encoder_cache.pop(req_id, None)
|
2024-10-22 01:24:07 -07:00
|
|
|
|
|
|
|
# Remove the requests from the persistent batch.
|
|
|
|
stopped_req_ids = set().union(
|
|
|
|
scheduler_output.preempted_req_ids,
|
|
|
|
scheduler_output.finished_req_ids,
|
|
|
|
)
|
|
|
|
removed_req_indices: List[int] = []
|
|
|
|
for req_id in stopped_req_ids:
|
|
|
|
req_index = self.input_batch.remove_request(req_id)
|
|
|
|
if req_index is not None:
|
|
|
|
removed_req_indices.append(req_index)
|
|
|
|
|
|
|
|
# Update the states of the running requests.
|
|
|
|
for req_data in scheduler_output.scheduled_running_reqs:
|
|
|
|
req_id = req_data.req_id
|
|
|
|
req_state = self.requests[req_id]
|
|
|
|
req_index = self.input_batch.req_id_to_index[req_id]
|
|
|
|
|
|
|
|
# Update the num_computed_tokens.
|
|
|
|
req_state.num_computed_tokens = req_data.num_computed_tokens
|
|
|
|
self.input_batch.num_computed_tokens_cpu[req_index] = (
|
|
|
|
req_data.num_computed_tokens)
|
|
|
|
|
|
|
|
# Update the block table.
|
|
|
|
num_new_blocks = len(req_data.new_block_ids)
|
|
|
|
if num_new_blocks == 0:
|
|
|
|
continue
|
|
|
|
start_index = len(req_state.block_ids)
|
|
|
|
req_state.block_ids.extend(req_data.new_block_ids)
|
2025-01-06 14:24:42 +09:00
|
|
|
self.input_batch.block_table.append_row(req_index, start_index,
|
|
|
|
req_data.new_block_ids)
|
2024-10-22 01:24:07 -07:00
|
|
|
|
|
|
|
req_ids_to_add: List[str] = []
|
|
|
|
# Add new requests to the cached states.
|
2024-12-14 17:54:04 +00:00
|
|
|
for new_req_data in scheduler_output.scheduled_new_reqs:
|
|
|
|
req_id = new_req_data.req_id
|
|
|
|
sampling_params = new_req_data.sampling_params
|
2024-11-07 05:06:57 +00:00
|
|
|
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
|
2024-11-03 17:14:17 +00:00
|
|
|
generator = torch.Generator(device=self.device)
|
|
|
|
generator.manual_seed(sampling_params.seed)
|
|
|
|
else:
|
|
|
|
generator = None
|
|
|
|
|
2024-10-22 01:24:07 -07:00
|
|
|
self.requests[req_id] = CachedRequestState(
|
|
|
|
req_id=req_id,
|
2024-12-14 17:54:04 +00:00
|
|
|
prompt_token_ids=new_req_data.prompt_token_ids,
|
|
|
|
prompt=new_req_data.prompt,
|
|
|
|
mm_inputs=new_req_data.mm_inputs,
|
|
|
|
mm_positions=new_req_data.mm_positions,
|
2024-11-03 17:14:17 +00:00
|
|
|
sampling_params=sampling_params,
|
|
|
|
generator=generator,
|
2024-12-14 17:54:04 +00:00
|
|
|
block_ids=new_req_data.block_ids,
|
|
|
|
num_computed_tokens=new_req_data.num_computed_tokens,
|
2024-10-22 01:24:07 -07:00
|
|
|
output_token_ids=[],
|
|
|
|
)
|
|
|
|
req_ids_to_add.append(req_id)
|
|
|
|
|
|
|
|
# Update the cached states of the resumed requests.
|
2024-12-14 17:54:04 +00:00
|
|
|
for res_req_data in scheduler_output.scheduled_resumed_reqs:
|
|
|
|
req_id = res_req_data.req_id
|
2024-10-22 01:24:07 -07:00
|
|
|
req_state = self.requests[req_id]
|
|
|
|
|
2024-12-14 17:54:04 +00:00
|
|
|
req_state.block_ids = res_req_data.block_ids
|
|
|
|
req_state.num_computed_tokens = res_req_data.num_computed_tokens
|
2024-10-22 01:24:07 -07:00
|
|
|
req_ids_to_add.append(req_id)
|
|
|
|
|
|
|
|
# Add the new or resumed requests to the persistent batch.
|
|
|
|
# The smaller empty indices are filled first.
|
|
|
|
removed_req_indices = sorted(removed_req_indices, reverse=True)
|
|
|
|
for req_id in req_ids_to_add:
|
|
|
|
req_state = self.requests[req_id]
|
|
|
|
if removed_req_indices:
|
|
|
|
# Fill the empty index.
|
|
|
|
req_index = removed_req_indices.pop()
|
|
|
|
else:
|
|
|
|
# Append to the end.
|
|
|
|
req_index = None
|
|
|
|
self.input_batch.add_request(req_state, req_index)
|
|
|
|
|
|
|
|
# Condense the batched states if there are empty indices.
|
|
|
|
if removed_req_indices:
|
|
|
|
self.input_batch.condense(removed_req_indices)
|
|
|
|
|
|
|
|
def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
|
|
|
|
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
|
|
|
assert total_num_scheduled_tokens > 0
|
|
|
|
num_reqs = self.input_batch.num_reqs
|
|
|
|
assert num_reqs > 0
|
|
|
|
|
|
|
|
# OPTIMIZATION: Start copying the block table first.
|
|
|
|
# This way, we can overlap the copy with the following CPU operations.
|
2025-01-06 14:24:42 +09:00
|
|
|
self.input_batch.block_table.commit(num_reqs)
|
2024-10-22 01:24:07 -07:00
|
|
|
|
|
|
|
# Get the number of scheduled tokens for each request.
|
|
|
|
# TODO: The Python loop can be slow. Optimize.
|
|
|
|
num_scheduled_tokens = []
|
|
|
|
max_num_scheduled_tokens = 0
|
|
|
|
for req_id in self.input_batch.req_ids[:num_reqs]:
|
2024-12-14 17:54:04 +00:00
|
|
|
assert req_id is not None
|
2024-10-22 01:24:07 -07:00
|
|
|
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
|
|
|
num_scheduled_tokens.append(num_tokens)
|
|
|
|
max_num_scheduled_tokens = max(max_num_scheduled_tokens,
|
|
|
|
num_tokens)
|
|
|
|
num_scheduled_tokens = np.array(num_scheduled_tokens, dtype=np.int32)
|
|
|
|
assert max_num_scheduled_tokens > 0
|
|
|
|
|
|
|
|
# Get request indices.
|
|
|
|
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
|
2024-12-15 13:33:00 -08:00
|
|
|
req_indices = np.repeat(self.arange_np[:num_reqs],
|
|
|
|
num_scheduled_tokens)
|
2024-10-22 01:24:07 -07:00
|
|
|
|
|
|
|
# Get batched arange.
|
|
|
|
# E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
2024-12-15 13:33:00 -08:00
|
|
|
arange = np.concatenate(
|
|
|
|
[self.arange_np[:n] for n in num_scheduled_tokens])
|
2024-10-22 01:24:07 -07:00
|
|
|
|
|
|
|
# Get positions.
|
2024-12-11 23:14:20 -08:00
|
|
|
positions_np = self.positions_np[:total_num_scheduled_tokens]
|
2024-10-22 01:24:07 -07:00
|
|
|
np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
|
|
|
|
arange,
|
|
|
|
out=positions_np)
|
|
|
|
|
|
|
|
# Get token indices.
|
|
|
|
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
|
|
|
# -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
|
|
|
|
# where M is the max_model_len.
|
2024-12-04 16:54:05 -08:00
|
|
|
token_indices = (positions_np +
|
|
|
|
req_indices * self.input_batch.token_ids_cpu.shape[1])
|
2024-12-11 23:14:20 -08:00
|
|
|
# NOTE(woosuk): We use torch.index_select instead of np.take here
|
|
|
|
# because torch.index_select is much faster than np.take for large
|
|
|
|
# tensors.
|
|
|
|
torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
|
2024-10-22 01:24:07 -07:00
|
|
|
0,
|
2024-12-11 23:14:20 -08:00
|
|
|
torch.from_numpy(token_indices),
|
|
|
|
out=self.input_ids_cpu[:total_num_scheduled_tokens])
|
2024-10-22 01:24:07 -07:00
|
|
|
|
|
|
|
# Calculate the slot mapping.
|
2024-12-04 16:54:05 -08:00
|
|
|
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
|
|
|
# -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
|
|
|
|
# where K is the max_num_blocks_per_req and the block size is 2.
|
|
|
|
# NOTE(woosuk): We can't simply use `token_indices // block_size` here
|
|
|
|
# because M (max_model_len) is not necessarily divisible by block_size.
|
2024-12-11 23:14:20 -08:00
|
|
|
block_table_indices = (req_indices * self.max_num_blocks_per_req +
|
|
|
|
positions_np // self.block_size)
|
|
|
|
# NOTE(woosuk): We use torch.index_select instead of np.take here
|
|
|
|
# because torch.index_select is much faster than np.take for large
|
|
|
|
# tensors.
|
2025-01-06 14:24:42 +09:00
|
|
|
block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
|
|
|
|
block_numbers = block_table_cpu.flatten()[block_table_indices].numpy()
|
2024-12-11 23:14:20 -08:00
|
|
|
block_offsets = positions_np % self.block_size
|
|
|
|
np.add(block_numbers * self.block_size,
|
|
|
|
block_offsets,
|
|
|
|
out=self.slot_mapping_np[:total_num_scheduled_tokens])
|
2024-10-22 01:24:07 -07:00
|
|
|
|
|
|
|
# Prepare the attention metadata.
|
2024-12-11 23:14:20 -08:00
|
|
|
self.query_start_loc_np[0] = 0
|
|
|
|
np.cumsum(num_scheduled_tokens,
|
|
|
|
out=self.query_start_loc_np[1:num_reqs + 1])
|
2024-10-22 01:24:07 -07:00
|
|
|
|
|
|
|
seq_lens = (self.input_batch.num_computed_tokens_cpu[:num_reqs] +
|
|
|
|
num_scheduled_tokens)
|
|
|
|
max_seq_len = seq_lens.max()
|
2024-12-11 23:14:20 -08:00
|
|
|
self.seq_start_loc_np[0] = 0
|
|
|
|
np.cumsum(seq_lens, out=self.seq_start_loc_np[1:num_reqs + 1])
|
|
|
|
|
|
|
|
# Copy the tensors to the GPU.
|
|
|
|
self.input_ids[:total_num_scheduled_tokens].copy_(
|
|
|
|
self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
|
|
|
|
self.positions[:total_num_scheduled_tokens].copy_(
|
|
|
|
self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True)
|
|
|
|
query_start_loc = self.query_start_loc_cpu[:num_reqs + 1].to(
|
|
|
|
self.device, non_blocking=True)
|
|
|
|
seq_start_loc = self.seq_start_loc_cpu[:num_reqs + 1].to(
|
|
|
|
self.device, non_blocking=True)
|
|
|
|
slot_mapping = self.slot_mapping_cpu[:total_num_scheduled_tokens].to(
|
|
|
|
self.device, non_blocking=True).long()
|
2025-01-01 21:56:46 +09:00
|
|
|
|
|
|
|
# Prepare for cascade attention if needed.
|
|
|
|
common_prefix_len = (scheduler_output.num_common_prefix_blocks *
|
|
|
|
self.block_size)
|
|
|
|
if common_prefix_len == 0:
|
|
|
|
# Common case.
|
|
|
|
use_cascade = False
|
|
|
|
else:
|
|
|
|
# NOTE(woosuk): Cascade attention uses two attention kernels: one
|
|
|
|
# for the common prefix and the other for the rest. For the first
|
|
|
|
# kernel, we concatenate all the query tokens (possibly from
|
|
|
|
# different requests) and treat them as if they are from the same
|
|
|
|
# request. Then, we use bi-directional attention to process the
|
|
|
|
# common prefix in the KV cache. Importantly, this means that the
|
|
|
|
# first kernel does not do any masking.
|
|
|
|
|
|
|
|
# Consider the following example:
|
|
|
|
# Request 1's input query: [D, E, X]
|
|
|
|
# Request 1's kv cache: [A, B, C, D, E, X]
|
|
|
|
# Request 1's num_computed_tokens: 3 (i.e., [A, B, C])
|
|
|
|
# Request 2's input query: [E, Y]
|
|
|
|
# Request 2's kv cache: [A, B, C, D, E, Y]
|
|
|
|
# Request 2's num_computed_tokens: 4 (i.e., [A, B, C, D])
|
|
|
|
|
|
|
|
# If we use [A, B, C, D, E] as the common prefix, then the
|
|
|
|
# first kernel will compute the bi-directional attention between
|
|
|
|
# input query [D, E, X, E, Y] and common prefix [A, B, C, D, E].
|
|
|
|
# However, this is wrong because D in Request 1 should not attend to
|
|
|
|
# E in the common prefix (i.e., we need masking).
|
|
|
|
# To avoid this, [A, B, C, D] should be the common prefix.
|
|
|
|
# That is, the common prefix should be capped by the minimum
|
|
|
|
# num_computed_tokens among the requests, and plus one to include
|
|
|
|
# the first token of the query.
|
|
|
|
|
|
|
|
# In practice, we use [A, B, C] as the common prefix, instead of
|
|
|
|
# [A, B, C, D] (i.e., the common prefix is capped by the minimum
|
|
|
|
# num_computed_tokens, without plus one).
|
|
|
|
# This is because of an implementation detail: We want to always
|
|
|
|
# use two kernels for cascade attention. Let's imagine:
|
|
|
|
# Request 3's input query: [D]
|
|
|
|
# Request 3's kv cache: [A, B, C, D]
|
|
|
|
# Request 3's num_computed_tokens: 4 (i.e., [A, B, C, D])
|
|
|
|
# If we use [A, B, C, D] as the common prefix for Request 1-3,
|
|
|
|
# then Request 3 will be processed only by the first kernel,
|
|
|
|
# and the second kernel will get an empty input. While this is not
|
|
|
|
# a fundamental problem, our current implementation does not support
|
|
|
|
# this case.
|
|
|
|
common_prefix_len = min(
|
|
|
|
common_prefix_len,
|
|
|
|
self.input_batch.num_computed_tokens_cpu[:num_reqs].min())
|
|
|
|
# common_prefix_len should be a multiple of the block size.
|
|
|
|
common_prefix_len = (common_prefix_len // self.block_size *
|
|
|
|
self.block_size)
|
|
|
|
use_cascade = FlashAttentionBackend.use_cascade_attention(
|
|
|
|
common_prefix_len=common_prefix_len,
|
|
|
|
query_lens=num_scheduled_tokens,
|
|
|
|
num_query_heads=self.num_query_heads,
|
|
|
|
num_kv_heads=self.num_kv_heads,
|
|
|
|
use_alibi=False, # FIXME
|
|
|
|
use_sliding_window=self.sliding_window is not None,
|
|
|
|
num_sms=self.num_sms,
|
|
|
|
)
|
|
|
|
|
|
|
|
if use_cascade:
|
|
|
|
# TODO: Optimize.
|
|
|
|
cu_prefix_query_lens = torch.tensor(
|
|
|
|
[0, total_num_scheduled_tokens],
|
|
|
|
dtype=torch.int32,
|
|
|
|
device=self.device)
|
|
|
|
cu_prefix_kv_lens = torch.tensor([0, common_prefix_len],
|
|
|
|
dtype=torch.int32,
|
|
|
|
device=self.device)
|
|
|
|
cu_suffix_kv_lens = (
|
|
|
|
self.seq_start_loc_np[:num_reqs + 1] -
|
|
|
|
self.arange_np[:num_reqs + 1] * common_prefix_len)
|
|
|
|
cu_suffix_kv_lens = torch.from_numpy(cu_suffix_kv_lens).to(
|
|
|
|
self.device)
|
|
|
|
else:
|
|
|
|
cu_prefix_query_lens = None
|
|
|
|
cu_prefix_kv_lens = None
|
|
|
|
cu_suffix_kv_lens = None
|
|
|
|
|
2024-10-22 01:24:07 -07:00
|
|
|
attn_metadata = FlashAttentionMetadata(
|
2024-11-05 22:16:04 -08:00
|
|
|
num_actual_tokens=total_num_scheduled_tokens,
|
2024-10-22 01:24:07 -07:00
|
|
|
max_query_len=max_num_scheduled_tokens,
|
|
|
|
query_start_loc=query_start_loc,
|
|
|
|
max_seq_len=max_seq_len,
|
|
|
|
seq_start_loc=seq_start_loc,
|
2025-01-06 14:24:42 +09:00
|
|
|
block_table=(
|
|
|
|
self.input_batch.block_table.get_device_tensor()[:num_reqs]),
|
2024-10-22 01:24:07 -07:00
|
|
|
slot_mapping=slot_mapping,
|
2025-01-01 21:56:46 +09:00
|
|
|
use_cascade=use_cascade,
|
|
|
|
common_prefix_len=common_prefix_len,
|
|
|
|
cu_prefix_query_lens=cu_prefix_query_lens,
|
|
|
|
cu_prefix_kv_lens=cu_prefix_kv_lens,
|
|
|
|
cu_suffix_kv_lens=cu_suffix_kv_lens,
|
2024-10-22 01:24:07 -07:00
|
|
|
)
|
|
|
|
# NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial
|
|
|
|
# request in the batch. While we should not sample any token from this
|
|
|
|
# partial request, we do so for simplicity. We will ignore the sampled
|
|
|
|
# token from the partial request.
|
|
|
|
# TODO: Support prompt logprobs.
|
|
|
|
logits_indices = query_start_loc[1:] - 1
|
2024-12-11 10:49:23 -08:00
|
|
|
return attn_metadata, logits_indices
|
2024-10-22 01:24:07 -07:00
|
|
|
|
|
|
|
def _prepare_sampling(
|
|
|
|
self,
|
|
|
|
scheduler_output: "SchedulerOutput",
|
|
|
|
) -> SamplingMetadata:
|
|
|
|
skip_copy = True
|
|
|
|
if (scheduler_output.finished_req_ids
|
|
|
|
or scheduler_output.preempted_req_ids):
|
|
|
|
skip_copy = False
|
|
|
|
if (scheduler_output.scheduled_new_reqs
|
|
|
|
or scheduler_output.scheduled_resumed_reqs):
|
|
|
|
skip_copy = False
|
|
|
|
# Create the sampling metadata.
|
2024-12-26 02:02:58 -08:00
|
|
|
req_id_output_token_ids: Dict[str, List[int]] = \
|
|
|
|
{req_id: req.output_token_ids \
|
|
|
|
for req_id, req in self.requests.items()}
|
|
|
|
|
|
|
|
sampling_metadata = self.input_batch.make_sampling_metadata(
|
|
|
|
req_id_output_token_ids, skip_copy)
|
2024-10-22 01:24:07 -07:00
|
|
|
return sampling_metadata
|
|
|
|
|
2024-11-12 20:53:13 -08:00
|
|
|
def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
|
|
|
|
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
|
|
|
|
if not scheduled_encoder_inputs:
|
|
|
|
return
|
|
|
|
|
|
|
|
# Batch the multi-modal inputs.
|
|
|
|
mm_inputs: List[MultiModalKwargs] = []
|
2024-12-14 17:54:04 +00:00
|
|
|
req_input_ids: List[Tuple[str, int]] = []
|
2024-11-12 20:53:13 -08:00
|
|
|
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
|
|
|
|
req_state = self.requests[req_id]
|
|
|
|
for input_id in encoder_input_ids:
|
|
|
|
mm_inputs.append(req_state.mm_inputs[input_id])
|
|
|
|
req_input_ids.append((req_id, input_id))
|
|
|
|
batched_mm_inputs = MultiModalKwargs.batch(mm_inputs)
|
|
|
|
batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs,
|
|
|
|
device=self.device)
|
|
|
|
|
|
|
|
# Run the encoder.
|
|
|
|
# `encoder_outputs` is either of the following:
|
|
|
|
# 1. A tensor of shape [num_images, feature_size, hidden_size]
|
|
|
|
# in case when feature_size is fixed across all images.
|
|
|
|
# 2. A list (length: num_images) of tensors, each of shape
|
|
|
|
# [feature_size, hidden_size] in case when the feature size is
|
|
|
|
# dynamic depending on input images.
|
2024-11-26 12:46:11 -08:00
|
|
|
encoder_outputs = self.model.get_multimodal_embeddings(
|
|
|
|
**batched_mm_inputs)
|
2024-11-12 20:53:13 -08:00
|
|
|
|
|
|
|
# Cache the encoder outputs.
|
|
|
|
for (req_id, input_id), output in zip(req_input_ids, encoder_outputs):
|
|
|
|
if req_id not in self.encoder_cache:
|
|
|
|
self.encoder_cache[req_id] = {}
|
|
|
|
self.encoder_cache[req_id][input_id] = output
|
|
|
|
|
|
|
|
def _gather_encoder_outputs(
|
|
|
|
self,
|
|
|
|
scheduler_output: "SchedulerOutput",
|
|
|
|
) -> List[torch.Tensor]:
|
|
|
|
encoder_outputs: List[torch.Tensor] = []
|
|
|
|
num_reqs = self.input_batch.num_reqs
|
|
|
|
for req_id in self.input_batch.req_ids[:num_reqs]:
|
2024-12-14 17:54:04 +00:00
|
|
|
assert req_id is not None
|
2024-11-12 20:53:13 -08:00
|
|
|
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
|
|
|
|
req_id]
|
|
|
|
req_state = self.requests[req_id]
|
|
|
|
num_computed_tokens = req_state.num_computed_tokens
|
|
|
|
mm_positions = req_state.mm_positions
|
|
|
|
for i, pos_info in enumerate(mm_positions):
|
|
|
|
start_pos = pos_info["offset"]
|
|
|
|
num_encoder_tokens = pos_info["length"]
|
|
|
|
|
|
|
|
# The encoder output is needed if the two ranges overlap:
|
|
|
|
# [num_computed_tokens,
|
|
|
|
# num_computed_tokens + num_scheduled_tokens) and
|
|
|
|
# [start_pos, start_pos + num_encoder_tokens)
|
|
|
|
if start_pos >= num_computed_tokens + num_scheduled_tokens:
|
|
|
|
# The encoder output is not needed in this step.
|
|
|
|
break
|
|
|
|
if start_pos + num_encoder_tokens <= num_computed_tokens:
|
|
|
|
# The encoder output is already processed and stored
|
|
|
|
# in the decoder's KV cache.
|
|
|
|
continue
|
|
|
|
|
|
|
|
start_idx = max(num_computed_tokens - start_pos, 0)
|
|
|
|
end_idx = min(
|
|
|
|
num_computed_tokens - start_pos + num_scheduled_tokens,
|
|
|
|
num_encoder_tokens)
|
|
|
|
assert start_idx < end_idx
|
|
|
|
assert req_id in self.encoder_cache
|
|
|
|
assert i in self.encoder_cache[req_id]
|
|
|
|
encoder_output = self.encoder_cache[req_id][i]
|
|
|
|
encoder_outputs.append(encoder_output[start_idx:end_idx])
|
|
|
|
return encoder_outputs
|
|
|
|
|
2024-10-22 01:24:07 -07:00
|
|
|
@torch.inference_mode()
|
|
|
|
def execute_model(
|
|
|
|
self,
|
|
|
|
scheduler_output: "SchedulerOutput",
|
|
|
|
) -> ModelRunnerOutput:
|
|
|
|
self._update_states(scheduler_output)
|
2024-11-12 20:53:13 -08:00
|
|
|
|
2024-12-11 10:49:23 -08:00
|
|
|
if self.is_multimodal_model:
|
|
|
|
# Run the multimodal encoder if any.
|
|
|
|
self._execute_encoder(scheduler_output)
|
|
|
|
encoder_outputs = self._gather_encoder_outputs(scheduler_output)
|
|
|
|
else:
|
|
|
|
encoder_outputs = []
|
2024-11-12 20:53:13 -08:00
|
|
|
|
|
|
|
# Prepare the decoder inputs.
|
2024-12-11 10:49:23 -08:00
|
|
|
attn_metadata, logits_indices = self._prepare_inputs(scheduler_output)
|
2024-11-05 22:16:04 -08:00
|
|
|
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
|
|
|
if (self.use_cuda_graph
|
|
|
|
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
|
|
|
|
# Use piecewise CUDA graphs.
|
|
|
|
# Add padding to the batch size.
|
2024-12-12 22:57:50 -08:00
|
|
|
num_input_tokens = self.vllm_config.pad_for_cudagraph(
|
2024-11-05 22:16:04 -08:00
|
|
|
num_scheduled_tokens)
|
|
|
|
else:
|
|
|
|
# Eager mode.
|
|
|
|
num_input_tokens = num_scheduled_tokens
|
2024-12-10 12:40:52 -08:00
|
|
|
attn_metadata.num_input_tokens = num_input_tokens
|
|
|
|
|
2024-12-11 10:49:23 -08:00
|
|
|
if self.is_multimodal_model:
|
|
|
|
# NOTE(woosuk): To unify token ids and soft tokens (vision
|
|
|
|
# embeddings), we always use embeddings (rather than token ids)
|
|
|
|
# as input to the multimodal model, even when the input is text.
|
|
|
|
input_ids = self.input_ids[:num_scheduled_tokens]
|
|
|
|
if encoder_outputs:
|
|
|
|
inputs_embeds = self.model.get_input_embeddings(
|
|
|
|
input_ids, encoder_outputs)
|
|
|
|
else:
|
|
|
|
inputs_embeds = self.model.get_input_embeddings(input_ids)
|
|
|
|
# TODO(woosuk): Avoid the copy. Optimize.
|
|
|
|
self.inputs_embeds[:num_scheduled_tokens].copy_(inputs_embeds)
|
|
|
|
inputs_embeds = self.inputs_embeds[:num_input_tokens]
|
|
|
|
input_ids = None
|
2024-11-12 20:53:13 -08:00
|
|
|
else:
|
2024-12-11 10:49:23 -08:00
|
|
|
# For text-only models, we use token ids as input.
|
|
|
|
# While it is possible to use embeddings as input just like the
|
|
|
|
# multimodal models, it is not desirable for performance since
|
|
|
|
# then the embedding layer is not included in the CUDA graph.
|
|
|
|
input_ids = self.input_ids[:num_input_tokens]
|
|
|
|
inputs_embeds = None
|
2024-11-12 20:53:13 -08:00
|
|
|
|
|
|
|
# Run the decoder.
|
|
|
|
# Use persistent buffers for CUDA graphs.
|
2024-11-22 14:04:42 -08:00
|
|
|
with set_forward_context(attn_metadata, self.vllm_config):
|
2024-10-22 01:24:07 -07:00
|
|
|
hidden_states = self.model(
|
2024-12-11 10:49:23 -08:00
|
|
|
input_ids=input_ids,
|
2024-11-05 22:16:04 -08:00
|
|
|
positions=self.positions[:num_input_tokens],
|
2024-10-22 01:24:07 -07:00
|
|
|
kv_caches=self.kv_caches,
|
2024-11-05 22:16:04 -08:00
|
|
|
attn_metadata=None,
|
2024-12-11 10:49:23 -08:00
|
|
|
inputs_embeds=inputs_embeds,
|
2024-10-22 01:24:07 -07:00
|
|
|
)
|
2024-11-05 22:16:04 -08:00
|
|
|
hidden_states = hidden_states[:num_scheduled_tokens]
|
2024-10-22 01:24:07 -07:00
|
|
|
hidden_states = hidden_states[logits_indices]
|
|
|
|
logits = self.model.compute_logits(hidden_states, None)
|
|
|
|
|
|
|
|
# Sample the next token and get logprobs if needed.
|
|
|
|
sampling_metadata = self._prepare_sampling(scheduler_output)
|
|
|
|
sampler_output = self.model.sample(
|
|
|
|
logits=logits,
|
|
|
|
sampling_metadata=sampling_metadata,
|
|
|
|
)
|
|
|
|
|
2024-12-10 01:28:14 -05:00
|
|
|
sampled_token_ids = sampler_output.sampled_token_ids
|
2024-10-22 01:24:07 -07:00
|
|
|
# TODO(woosuk): The following loop can be slow since it iterates over
|
|
|
|
# the requests one by one. Optimize.
|
|
|
|
num_reqs = self.input_batch.num_reqs
|
|
|
|
for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]):
|
2024-12-14 17:54:04 +00:00
|
|
|
assert req_id is not None
|
2024-10-22 01:24:07 -07:00
|
|
|
req_state = self.requests[req_id]
|
|
|
|
seq_len = (req_state.num_computed_tokens +
|
|
|
|
scheduler_output.num_scheduled_tokens[req_id])
|
|
|
|
assert seq_len <= req_state.num_tokens
|
|
|
|
if seq_len == req_state.num_tokens:
|
|
|
|
# Append the sampled token to the output token ids.
|
2024-12-10 01:28:14 -05:00
|
|
|
token_id = sampled_token_ids[i]
|
2024-10-22 01:24:07 -07:00
|
|
|
self.input_batch.token_ids_cpu[i, seq_len] = token_id
|
2025-01-03 04:04:58 +09:00
|
|
|
self.input_batch.num_tokens[i] += 1
|
2024-10-22 01:24:07 -07:00
|
|
|
req_state.output_token_ids.append(token_id)
|
|
|
|
else:
|
|
|
|
# Ignore the sampled token from the partial request.
|
|
|
|
# Rewind the generator state as if the token was not sampled.
|
2024-11-03 17:14:17 +00:00
|
|
|
generator = self.input_batch.generators.get(i)
|
2024-10-22 01:24:07 -07:00
|
|
|
if generator is not None:
|
2024-11-07 05:06:57 +00:00
|
|
|
# This relies on cuda-specific torch-internal impl details
|
|
|
|
generator.set_offset(generator.get_offset() - 4)
|
2024-10-22 01:24:07 -07:00
|
|
|
|
|
|
|
if sampler_output.logprob_token_ids is None:
|
|
|
|
logprob_token_ids = None
|
|
|
|
else:
|
|
|
|
logprob_token_ids = sampler_output.logprob_token_ids.cpu()
|
|
|
|
if sampler_output.logprobs is None:
|
|
|
|
logprobs = None
|
|
|
|
else:
|
|
|
|
logprobs = sampler_output.logprobs.cpu()
|
2024-12-14 17:54:04 +00:00
|
|
|
|
|
|
|
# num_reqs entries should be non-None
|
|
|
|
assert all(
|
|
|
|
req_id is not None for req_id in
|
|
|
|
self.input_batch.req_ids[:num_reqs]), "req_ids contains None"
|
|
|
|
req_ids = cast(List[str], self.input_batch.req_ids[:num_reqs])
|
|
|
|
|
2024-10-22 01:24:07 -07:00
|
|
|
model_runner_output = ModelRunnerOutput(
|
2024-12-14 17:54:04 +00:00
|
|
|
req_ids=req_ids,
|
2024-10-22 01:24:07 -07:00
|
|
|
req_id_to_index=self.input_batch.req_id_to_index,
|
2024-12-10 01:28:14 -05:00
|
|
|
sampled_token_ids=sampled_token_ids,
|
2024-10-22 01:24:07 -07:00
|
|
|
logprob_token_ids_cpu=logprob_token_ids,
|
|
|
|
logprobs_cpu=logprobs,
|
|
|
|
)
|
|
|
|
return model_runner_output
|
|
|
|
|
|
|
|
def load_model(self) -> None:
|
|
|
|
logger.info("Starting to load model %s...", self.model_config.model)
|
|
|
|
with DeviceMemoryProfiler() as m: # noqa: SIM117
|
2024-11-06 12:57:35 -07:00
|
|
|
self.model = get_model(vllm_config=self.vllm_config)
|
2024-10-22 01:24:07 -07:00
|
|
|
|
|
|
|
self.model_memory_usage = m.consumed_memory
|
|
|
|
logger.info("Loading model weights took %.4f GB",
|
|
|
|
self.model_memory_usage / float(2**30))
|
|
|
|
|
2024-11-21 12:53:39 -08:00
|
|
|
@torch.inference_mode()
|
|
|
|
def _dummy_run(
|
|
|
|
self,
|
|
|
|
model: nn.Module,
|
|
|
|
num_tokens: int,
|
|
|
|
kv_caches: List[torch.Tensor],
|
|
|
|
) -> torch.Tensor:
|
2024-12-11 10:49:23 -08:00
|
|
|
if self.is_multimodal_model:
|
|
|
|
input_ids = None
|
|
|
|
inputs_embeds = self.inputs_embeds[:num_tokens]
|
|
|
|
else:
|
|
|
|
input_ids = self.input_ids[:num_tokens]
|
|
|
|
inputs_embeds = None
|
2024-11-22 14:04:42 -08:00
|
|
|
with set_forward_context(None, self.vllm_config):
|
2024-11-21 12:53:39 -08:00
|
|
|
hidden_states = model(
|
2024-12-11 10:49:23 -08:00
|
|
|
input_ids=input_ids,
|
2024-11-21 12:53:39 -08:00
|
|
|
positions=self.positions[:num_tokens],
|
|
|
|
kv_caches=kv_caches,
|
|
|
|
attn_metadata=None,
|
2024-12-11 10:49:23 -08:00
|
|
|
inputs_embeds=inputs_embeds,
|
|
|
|
)
|
2024-11-21 12:53:39 -08:00
|
|
|
return hidden_states
|
|
|
|
|
|
|
|
def profile_run(self) -> None:
|
2024-11-05 22:16:04 -08:00
|
|
|
# use an empty tensor instead of `None`` to force Dynamo to pass
|
|
|
|
# it by reference, rather by specializing on the value `None`.
|
|
|
|
# the `dtype` argument does not matter, and we use `float32` as
|
|
|
|
# a placeholder (it has wide hardware support).
|
|
|
|
# it is important to create tensors inside the loop, rather than
|
|
|
|
# multiplying the list, to avoid Dynamo from treating them as
|
|
|
|
# tensor aliasing.
|
|
|
|
dummy_kv_caches = [
|
|
|
|
torch.tensor([], dtype=torch.float32, device=self.device)
|
|
|
|
for _ in range(self.num_attn_layers)
|
|
|
|
]
|
2024-12-16 22:10:57 -08:00
|
|
|
|
|
|
|
# Profile with multimodal encoder & encoder cache.
|
2025-01-15 11:29:00 -08:00
|
|
|
# TODO: handle encoder-decoder models once we support them.
|
|
|
|
if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0
|
|
|
|
and self.encoder_cache_size > 0):
|
2024-12-18 18:54:46 -05:00
|
|
|
|
2024-12-16 22:10:57 -08:00
|
|
|
# NOTE: Currently model is profiled with a single non-text
|
2025-01-06 11:58:16 -08:00
|
|
|
# modality with the max possible input tokens even when
|
|
|
|
# it supports multiple.
|
2025-01-15 11:29:00 -08:00
|
|
|
max_tokens_by_modality_dict = MULTIMODAL_REGISTRY.get_max_tokens_per_item_by_nonzero_modality( # noqa: E501
|
2025-01-06 11:58:16 -08:00
|
|
|
self.model_config)
|
|
|
|
dummy_data_modality, max_tokens_per_mm_item = max(
|
|
|
|
max_tokens_by_modality_dict.items(), key=lambda item: item[1])
|
|
|
|
|
|
|
|
# Check how many items of this modality can be supported by
|
2025-01-15 11:29:00 -08:00
|
|
|
# the encoder budget.
|
|
|
|
encoder_budget = min(self.max_num_encoder_input_tokens,
|
|
|
|
self.encoder_cache_size)
|
|
|
|
|
|
|
|
max_num_mm_items_encoder_budget = cdiv(encoder_budget,
|
|
|
|
max_tokens_per_mm_item)
|
2025-01-06 11:58:16 -08:00
|
|
|
|
|
|
|
# Check how many items of this modality can be supported by
|
|
|
|
# the decoder budget.
|
2025-01-15 11:29:00 -08:00
|
|
|
max_mm_items_per_req = self.mm_registry.get_mm_limits_per_prompt(
|
|
|
|
self.model_config)[dummy_data_modality]
|
2024-12-31 13:17:22 -08:00
|
|
|
|
|
|
|
# NOTE: We do not consider max_num_batched_tokens on purpose
|
|
|
|
# because the multimodal embeddings can be generated in advance
|
|
|
|
# and chunked prefilled.
|
|
|
|
max_num_mm_items_decoder_budget = self.max_num_reqs * \
|
|
|
|
max_mm_items_per_req
|
|
|
|
|
|
|
|
max_num_mm_items = min(max_num_mm_items_encoder_budget,
|
|
|
|
max_num_mm_items_decoder_budget)
|
|
|
|
|
2025-01-15 11:29:00 -08:00
|
|
|
logger.info(
|
|
|
|
"Encoder cache will be initialized with a budget of %s tokens,"
|
|
|
|
" and profiled with %s %s items of the maximum feature size.",
|
|
|
|
encoder_budget, max_num_mm_items, dummy_data_modality)
|
|
|
|
|
|
|
|
# Create dummy batch of multimodal inputs.
|
|
|
|
dummy_request_data = self.input_registry.dummy_data_for_profiling(
|
|
|
|
model_config=self.model_config,
|
|
|
|
seq_len=self.max_num_tokens,
|
|
|
|
mm_registry=self.mm_registry,
|
|
|
|
)
|
|
|
|
dummy_mm_data = dummy_request_data.multi_modal_data
|
|
|
|
|
2024-12-16 22:10:57 -08:00
|
|
|
# Dummy data definition in V0 may contain multiple multimodal items
|
|
|
|
# (e.g, multiple images) for a single request, therefore here we
|
|
|
|
# always replicate first item by max_num_mm_items times since in V1
|
|
|
|
# they are scheduled to be processed separately.
|
2024-12-20 04:04:21 -08:00
|
|
|
|
|
|
|
# Case when models have a merged processor, their dummy data is
|
2025-01-06 11:58:16 -08:00
|
|
|
# already batched `MultiModalKwargs`, therefore we take the first
|
|
|
|
# `MultiModalKwargsItem` from the desired modality to profile on.
|
2024-12-20 04:04:21 -08:00
|
|
|
if isinstance(dummy_mm_data, MultiModalKwargs):
|
2025-01-06 11:58:16 -08:00
|
|
|
dummy_mm_item = dummy_mm_data.get_item(
|
|
|
|
modality=dummy_data_modality, item_index=0)
|
|
|
|
dummy_mm_kwargs = MultiModalKwargs.from_items([dummy_mm_item])
|
2024-12-20 04:04:21 -08:00
|
|
|
|
|
|
|
# Case when models have dummy data explicitly defined as
|
|
|
|
# `MultiModalDataDict`, so they need to be processed through input
|
|
|
|
# mapper.
|
2025-01-06 11:58:16 -08:00
|
|
|
# TODO (ywang96): deprecate this path once merged processor is
|
|
|
|
# supported on all models.
|
2024-12-20 04:04:21 -08:00
|
|
|
else:
|
2025-01-06 11:58:16 -08:00
|
|
|
mm_kwargs_list = self.mm_input_mapper_profiling.process_inputs(
|
2024-12-20 04:04:21 -08:00
|
|
|
mm_data=dummy_mm_data,
|
2025-01-06 11:58:16 -08:00
|
|
|
mm_hashes=None,
|
2024-12-20 04:04:21 -08:00
|
|
|
mm_processor_kwargs=None,
|
|
|
|
precomputed_mm_inputs=None)
|
|
|
|
dummy_mm_kwargs = mm_kwargs_list[0]
|
|
|
|
|
2024-12-16 22:10:57 -08:00
|
|
|
batched_dummy_mm_inputs = MultiModalKwargs.batch(
|
2024-12-20 04:04:21 -08:00
|
|
|
[dummy_mm_kwargs] * max_num_mm_items)
|
2024-12-16 22:10:57 -08:00
|
|
|
batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs(
|
|
|
|
batched_dummy_mm_inputs, device=self.device)
|
|
|
|
|
|
|
|
# Run multimodal encoder.
|
|
|
|
dummy_encoder_outputs = self.model.get_multimodal_embeddings(
|
|
|
|
**batched_dummy_mm_inputs)
|
|
|
|
assert len(dummy_encoder_outputs) == max_num_mm_items, (
|
|
|
|
"Expected dimension 0 of encoder outputs to match the number "
|
|
|
|
f"of multimodal data items: {max_num_mm_items}, got "
|
|
|
|
f"{len(dummy_encoder_outputs)=} instead. This is most likely "
|
|
|
|
"due to the 'get_multimodal_embeddings' method of the model "
|
|
|
|
"not implemented correctly.")
|
|
|
|
|
|
|
|
# Cache the dummy encoder outputs.
|
|
|
|
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
|
|
|
|
|
2024-12-02 22:19:02 -08:00
|
|
|
# Trigger compilation for general shape.
|
|
|
|
hidden_states = self._dummy_run(self.model, self.max_num_tokens,
|
|
|
|
dummy_kv_caches)
|
2024-11-21 12:53:39 -08:00
|
|
|
logits = self.model.compute_logits(hidden_states, None)
|
|
|
|
logits = logits[:self.max_num_tokens]
|
|
|
|
# TODO(woosuk): Consider the memory usage of the sampler.
|
2024-10-22 01:24:07 -07:00
|
|
|
torch.cuda.synchronize()
|
2024-11-21 12:53:39 -08:00
|
|
|
del hidden_states, logits
|
2024-12-16 22:10:57 -08:00
|
|
|
self.encoder_cache.clear()
|
2024-11-21 12:53:39 -08:00
|
|
|
gc.collect()
|
2024-10-22 01:24:07 -07:00
|
|
|
|
|
|
|
def capture_model(self) -> None:
|
2024-11-05 22:16:04 -08:00
|
|
|
if not self.use_cuda_graph:
|
|
|
|
logger.warning(
|
2024-11-19 10:09:03 -08:00
|
|
|
"Skipping CUDA graph capture. Please add "
|
2024-11-21 00:44:57 -05:00
|
|
|
"-O %s to use CUDA graphs.", CompilationLevel.PIECEWISE)
|
2024-11-05 22:16:04 -08:00
|
|
|
return
|
|
|
|
|
|
|
|
start_time = time.perf_counter()
|
|
|
|
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
|
|
|
|
|
2024-11-21 12:53:39 -08:00
|
|
|
# Trigger CUDA graph capture for specific shapes.
|
|
|
|
# Capture the large shapes first so that the smaller shapes
|
|
|
|
# can reuse the memory pool allocated for the large shapes.
|
2025-01-04 08:50:16 +02:00
|
|
|
with graph_capture(device=self.device):
|
2024-11-26 00:00:16 -06:00
|
|
|
for num_tokens in reversed(self.cudagraph_batch_sizes):
|
2024-12-09 13:47:24 -08:00
|
|
|
for _ in range(self.vllm_config.compilation_config.
|
|
|
|
cudagraph_num_of_warmups):
|
|
|
|
self._dummy_run(self.model, num_tokens, self.kv_caches)
|
2024-11-26 00:00:16 -06:00
|
|
|
self._dummy_run(self.model, num_tokens, self.kv_caches)
|
2024-11-05 22:16:04 -08:00
|
|
|
|
|
|
|
end_time = time.perf_counter()
|
|
|
|
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
|
|
|
|
elapsed_time = end_time - start_time
|
|
|
|
cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory
|
|
|
|
# This usually takes 5~20 seconds.
|
|
|
|
logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
|
|
|
|
elapsed_time, cuda_graph_size / (1 << 30))
|
2024-10-22 01:24:07 -07:00
|
|
|
|
2025-01-17 15:39:35 +08:00
|
|
|
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
|
|
|
"""
|
|
|
|
Initialize KV cache based on `kv_cache_config`.
|
|
|
|
Args:
|
|
|
|
kv_cache_config: Configuration for the KV cache, including the KV
|
|
|
|
cache size of each layer
|
|
|
|
"""
|
|
|
|
if len(kv_cache_config.groups) > 1:
|
|
|
|
raise NotImplementedError(
|
|
|
|
"Hybrid models with more than one KV cache type are not "
|
|
|
|
"supported yet.")
|
|
|
|
|
|
|
|
kv_caches: Dict[str, torch.Tensor] = {}
|
|
|
|
|
|
|
|
for layer_name, layer_spec in kv_cache_config.kv_cache_spec.items():
|
|
|
|
tensor_config = kv_cache_config.tensors[layer_name]
|
|
|
|
assert tensor_config.size % layer_spec.page_size_bytes == 0
|
|
|
|
num_blocks = tensor_config.size // layer_spec.page_size_bytes
|
|
|
|
if isinstance(layer_spec, FullAttentionSpec):
|
|
|
|
kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(
|
|
|
|
num_blocks, layer_spec.block_size, layer_spec.num_kv_heads,
|
|
|
|
layer_spec.head_size)
|
|
|
|
dtype = layer_spec.dtype
|
|
|
|
kv_caches[layer_name] = torch.zeros(kv_cache_shape,
|
|
|
|
dtype=dtype,
|
|
|
|
device=self.device)
|
|
|
|
else:
|
|
|
|
raise NotImplementedError
|
|
|
|
|
2025-01-10 13:14:42 +08:00
|
|
|
bind_kv_cache(
|
2025-01-17 15:39:35 +08:00
|
|
|
kv_caches,
|
2025-01-10 13:14:42 +08:00
|
|
|
self.vllm_config.compilation_config.static_forward_context,
|
2025-01-17 15:39:35 +08:00
|
|
|
self.kv_caches)
|
|
|
|
|
|
|
|
def get_kv_cache_spec(self) -> KVCacheSpec:
|
|
|
|
"""
|
|
|
|
Generates the KVCacheSpec by parsing the kv cache format from each
|
|
|
|
Attention module in the static forward context.
|
|
|
|
Returns:
|
|
|
|
KVCacheSpec: A dictionary mapping layer names to their KV cache
|
|
|
|
format. Layers that do not need KV cache are not included.
|
|
|
|
"""
|
|
|
|
|
|
|
|
forward_ctx = self.vllm_config.compilation_config.static_forward_context
|
|
|
|
block_size = self.vllm_config.cache_config.block_size
|
|
|
|
kv_cache_spec: KVCacheSpec = {}
|
|
|
|
for layer_name, attn_module in forward_ctx.items():
|
|
|
|
# TODO: Support other attention modules, e.g., sliding window,
|
|
|
|
# cross-attention, MLA.
|
|
|
|
assert isinstance(attn_module, Attention)
|
|
|
|
if attn_module.attn_type == AttentionType.DECODER:
|
|
|
|
kv_cache_spec[layer_name] = FullAttentionSpec(
|
|
|
|
block_size=block_size,
|
|
|
|
num_kv_heads=attn_module.num_kv_heads,
|
|
|
|
head_size=attn_module.head_size,
|
|
|
|
dtype=attn_module.dtype,
|
|
|
|
)
|
|
|
|
elif attn_module.attn_type in (AttentionType.ENCODER,
|
|
|
|
AttentionType.ENCODER_ONLY):
|
|
|
|
# encoder-only attention does not need KV cache.
|
|
|
|
continue
|
|
|
|
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
|
|
|
|
raise NotImplementedError
|
|
|
|
else:
|
|
|
|
raise ValueError(
|
|
|
|
f"Unknown attention type: {attn_module.attn_type}")
|
|
|
|
|
|
|
|
return kv_cache_spec
|