1032 lines
46 KiB
Python
1032 lines
46 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
import bisect
|
|
import time
|
|
from typing import TYPE_CHECKING, Optional, cast
|
|
from unittest.mock import patch
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.distributed
|
|
import torch.nn as nn
|
|
# TPU XLA related
|
|
import torch_xla.core.xla_model as xm
|
|
import torch_xla.runtime as xr
|
|
|
|
import vllm.envs as envs
|
|
from vllm.attention.backends.abstract import AttentionType
|
|
from vllm.attention.layer import Attention
|
|
from vllm.config import VllmConfig
|
|
from vllm.forward_context import set_forward_context
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.model_loader import get_model
|
|
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
|
from vllm.multimodal.utils import group_mm_inputs_by_modality
|
|
from vllm.sampling_params import SamplingType
|
|
from vllm.sequence import IntermediateTensors
|
|
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
|
|
from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK,
|
|
PallasAttentionBackend,
|
|
PallasMetadata)
|
|
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
|
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
|
KVCacheSpec)
|
|
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
|
|
ModelRunnerOutput, SamplerOutput)
|
|
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
|
|
from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler
|
|
from vllm.v1.utils import bind_kv_cache
|
|
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
|
|
|
from .utils import sanity_check_mm_encoder_outputs
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm.v1.core.sched.output import SchedulerOutput
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
# Here we utilize the behavior that out-of-bound index is ignored.
|
|
# FIXME(woosuk): Find a more reliable way to prevent possible bugs.
|
|
_PAD_SLOT_ID = 1_000_000_000
|
|
INVALID_TOKEN_ID = -1
|
|
# Smallest output size
|
|
MIN_NUM_SEQS = 8
|
|
|
|
|
|
class TPUModelRunner:
|
|
|
|
def __init__(
|
|
self,
|
|
vllm_config: VllmConfig,
|
|
device: torch.device,
|
|
):
|
|
self.vllm_config = vllm_config
|
|
self.model_config = vllm_config.model_config
|
|
self.cache_config = vllm_config.cache_config
|
|
self.lora_config = vllm_config.lora_config
|
|
self.load_config = vllm_config.load_config
|
|
self.parallel_config = vllm_config.parallel_config
|
|
self.scheduler_config = vllm_config.scheduler_config
|
|
self.speculative_config = vllm_config.speculative_config
|
|
self.prompt_adapter_config = vllm_config.prompt_adapter_config
|
|
self.observability_config = vllm_config.observability_config
|
|
self.device_config = vllm_config.device_config
|
|
|
|
model_config = self.model_config
|
|
cache_config = self.cache_config
|
|
scheduler_config = self.scheduler_config
|
|
parallel_config = self.parallel_config
|
|
self.device = device
|
|
self.check_recompilation = envs.VLLM_XLA_CHECK_RECOMPILATION
|
|
|
|
self.enforce_eager = model_config.enforce_eager
|
|
|
|
self.num_xla_graphs = 0
|
|
self._update_num_xla_graphs("init")
|
|
|
|
self.pin_memory = is_pin_memory_available()
|
|
self.dtype = self.model_config.dtype
|
|
self._hidden_states_dtype = self.dtype
|
|
|
|
self.is_multimodal_model = model_config.is_multimodal_model
|
|
self.sliding_window = model_config.get_sliding_window()
|
|
self.block_size = cache_config.block_size
|
|
self.max_model_len = model_config.max_model_len
|
|
self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
|
|
self.max_num_tokens = scheduler_config.max_num_batched_tokens
|
|
# InputBatch needs to work with sampling tensors greater than padding
|
|
# to avoid dynamic shapes. Also, avoid suboptimal alignment.
|
|
self.max_num_reqs = max(scheduler_config.max_num_seqs, MIN_NUM_SEQS)
|
|
|
|
# Model-related.
|
|
self.num_attn_layers = model_config.get_num_layers_by_block_type(
|
|
parallel_config, LayerBlockType.attention)
|
|
self.num_query_heads = model_config.get_num_attention_heads(
|
|
parallel_config)
|
|
self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
|
|
self.head_size = model_config.get_head_size()
|
|
self.hidden_size = model_config.get_hidden_size()
|
|
|
|
# Multi-modal data support
|
|
self.mm_registry = MULTIMODAL_REGISTRY
|
|
self.uses_mrope = model_config.uses_mrope
|
|
# TODO: Support M-RoPE (e.g, Qwen2-VL)
|
|
assert not self.uses_mrope, "TPU does not support M-RoPE yet."
|
|
|
|
encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
|
|
model_config=model_config,
|
|
scheduler_config=scheduler_config,
|
|
mm_registry=self.mm_registry,
|
|
)
|
|
self.max_num_encoder_input_tokens = encoder_compute_budget
|
|
self.encoder_cache_size = encoder_cache_size
|
|
|
|
# Lazy initialization
|
|
# self.model: nn.Module # Set after load_model
|
|
self.kv_caches: list[torch.Tensor] = []
|
|
# req_id -> (input_id -> encoder_output)
|
|
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
|
|
|
|
# Request states.
|
|
self.requests: dict[str, CachedRequestState] = {}
|
|
# Persistent batch.
|
|
self.input_batch = InputBatch(
|
|
max_num_reqs=self.max_num_reqs,
|
|
max_model_len=self.max_model_len,
|
|
max_num_blocks_per_req=self.max_num_blocks_per_req,
|
|
device=self.device,
|
|
pin_memory=self.pin_memory,
|
|
vocab_size=model_config.get_vocab_size(),
|
|
)
|
|
|
|
# Cached torch/numpy tensor
|
|
# The pytorch tensor and numpy array share the same buffer.
|
|
# Sometimes the numpy op is faster so we create both.
|
|
self.input_ids_cpu = torch.zeros(self.max_num_tokens,
|
|
dtype=torch.int32,
|
|
device="cpu")
|
|
self.input_ids_np = self.input_ids_cpu.numpy()
|
|
|
|
self.positions_cpu = torch.zeros(self.max_num_tokens,
|
|
dtype=torch.int32,
|
|
device="cpu")
|
|
self.positions_np = self.positions_cpu.numpy()
|
|
|
|
self.slot_mapping_cpu = torch.zeros(self.max_num_tokens,
|
|
dtype=torch.int64,
|
|
device="cpu")
|
|
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
|
|
|
|
padded_max_num_blocks_per_req = _get_padded_number(
|
|
self.max_num_blocks_per_req, NUM_KV_PAGES_PER_BLOCK)
|
|
self.block_table_cpu = torch.zeros(
|
|
(self.max_num_tokens, padded_max_num_blocks_per_req),
|
|
dtype=self.input_batch.block_table.get_cpu_tensor().dtype,
|
|
device="cpu")
|
|
|
|
self.query_start_loc_cpu = torch.zeros(self.max_num_tokens + 1,
|
|
dtype=torch.int32,
|
|
device="cpu",
|
|
pin_memory=self.pin_memory)
|
|
self.query_start_loc_np = self.query_start_loc_cpu.numpy()
|
|
|
|
self.seq_lens_cpu = torch.zeros(self.max_num_tokens,
|
|
dtype=torch.int32,
|
|
device="cpu",
|
|
pin_memory=self.pin_memory)
|
|
self.seq_lens_np = self.seq_lens_cpu.numpy()
|
|
|
|
# Range tensor with values [0 .. self.max_num_tokens - 1].
|
|
# Used to initialize positions / context_lens / seq_lens
|
|
self.arange_np = np.arange(self.max_num_tokens, dtype=np.int32)
|
|
self.num_tokens_paddings = _get_paddings(
|
|
min_token_size=16,
|
|
max_token_size=self.max_num_tokens,
|
|
padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP)
|
|
|
|
def _update_num_xla_graphs(self, case_str):
|
|
check_comp = self.check_recompilation and not self.enforce_eager
|
|
if not check_comp:
|
|
return
|
|
|
|
total_cached_graphs = xr.get_num_cached_compilation_graph()
|
|
new_compiled_graphs = total_cached_graphs - self.num_xla_graphs
|
|
if new_compiled_graphs == 0:
|
|
return
|
|
|
|
logger.info("Add new %d compiled XLA graphs due to %s",
|
|
new_compiled_graphs, case_str)
|
|
self.num_xla_graphs += new_compiled_graphs
|
|
|
|
def _verify_num_xla_graphs(self, case_str):
|
|
check_comp = self.check_recompilation and not self.enforce_eager
|
|
if not check_comp:
|
|
return
|
|
|
|
curr_cached_graph = xr.get_num_cached_compilation_graph()
|
|
assert self.num_xla_graphs == curr_cached_graph, (
|
|
"Recompilation after warm up is detected during {}."
|
|
" num_xla_graphs = {} curr_cached_graph = {}".format(
|
|
case_str, self.num_xla_graphs, curr_cached_graph))
|
|
|
|
def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
|
|
"""Update the cached states and the persistent batch with the scheduler
|
|
output.
|
|
|
|
The updated states are used by the `_prepare_inputs` function to create
|
|
the input GPU tensors for the model.
|
|
|
|
Returns:
|
|
True if there is a new/resumed/paused/finished request.
|
|
If False, we can skip copying SamplingMetadata to the GPU.
|
|
"""
|
|
# Remove finished requests from the cached states.
|
|
for req_id in scheduler_output.finished_req_ids:
|
|
self.requests.pop(req_id, None)
|
|
self.encoder_cache.pop(req_id, None)
|
|
|
|
# Remove the finished requests from the persistent batch.
|
|
# NOTE(woosuk): There could be an edge case where finished_req_ids and
|
|
# scheduled_req_ids overlap. This happens when a request is aborted and
|
|
# then resubmitted with the same ID. In this case, we treat them as two
|
|
# distinct requests - clearing the cached states for the first request
|
|
# and handling the second as a new request.
|
|
removed_req_indices: list[int] = []
|
|
for req_id in scheduler_output.finished_req_ids:
|
|
req_index = self.input_batch.remove_request(req_id)
|
|
if req_index is not None:
|
|
removed_req_indices.append(req_index)
|
|
|
|
# Free the cached encoder outputs.
|
|
for req_id, input_id in scheduler_output.free_encoder_input_ids:
|
|
encoder_outputs = self.encoder_cache.get(req_id)
|
|
if encoder_outputs is not None:
|
|
encoder_outputs.pop(input_id, None)
|
|
if not encoder_outputs:
|
|
self.encoder_cache.pop(req_id, None)
|
|
|
|
# Remove the unscheduled requests from the persistent batch.
|
|
# NOTE(woosuk): The unscheduled requests are either preempted requests
|
|
# or running requests that are not scheduled in this step. We remove
|
|
# them from the persistent batch but keep their cached states since
|
|
# they will be scheduled again sometime in the future.
|
|
scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys()
|
|
cached_req_ids = self.input_batch.req_id_to_index.keys()
|
|
unscheduled_req_ids = cached_req_ids - scheduled_req_ids
|
|
# NOTE(woosuk): The persistent batch optimization assumes that
|
|
# consecutive batches contain mostly the same requests. If batches
|
|
# have low request overlap (e.g., alternating between two distinct
|
|
# sets of requests), this optimization becomes very inefficient.
|
|
for req_id in unscheduled_req_ids:
|
|
req_index = self.input_batch.remove_request(req_id)
|
|
assert req_index is not None
|
|
removed_req_indices.append(req_index)
|
|
|
|
req_ids_to_add: list[str] = []
|
|
# Add new requests to the cached states.
|
|
for new_req_data in scheduler_output.scheduled_new_reqs:
|
|
req_id = new_req_data.req_id
|
|
sampling_params = new_req_data.sampling_params
|
|
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
|
|
generator = torch.Generator(device=self.device)
|
|
generator.manual_seed(sampling_params.seed)
|
|
else:
|
|
generator = None
|
|
|
|
self.requests[req_id] = CachedRequestState(
|
|
req_id=req_id,
|
|
prompt_token_ids=new_req_data.prompt_token_ids,
|
|
prompt=new_req_data.prompt,
|
|
mm_inputs=new_req_data.mm_inputs,
|
|
mm_positions=new_req_data.mm_positions,
|
|
sampling_params=sampling_params,
|
|
generator=generator,
|
|
block_ids=new_req_data.block_ids,
|
|
num_computed_tokens=new_req_data.num_computed_tokens,
|
|
output_token_ids=[],
|
|
lora_request=new_req_data.lora_request,
|
|
)
|
|
|
|
req_ids_to_add.append(req_id)
|
|
|
|
# Update the states of the running/resumed requests.
|
|
for req_data in scheduler_output.scheduled_cached_reqs:
|
|
req_id = req_data.req_id
|
|
req_state = self.requests[req_id]
|
|
|
|
# Update the cached states.
|
|
req_state.num_computed_tokens = req_data.num_computed_tokens
|
|
if not req_data.resumed_from_preemption:
|
|
# Append the new blocks to the existing block IDs.
|
|
req_state.block_ids.extend(req_data.new_block_ids)
|
|
else:
|
|
# The request is resumed from preemption.
|
|
# Replace the existing block IDs with the new ones.
|
|
req_state.block_ids = req_data.new_block_ids
|
|
|
|
req_index = self.input_batch.req_id_to_index.get(req_id)
|
|
if req_index is None:
|
|
# The request is not in the persistent batch.
|
|
# The request was either preempted and resumed later, or was not
|
|
# scheduled in the previous step and needs to be added again.
|
|
req_ids_to_add.append(req_id)
|
|
continue
|
|
|
|
# Update the persistent batch.
|
|
self.input_batch.num_computed_tokens_cpu[req_index] = (
|
|
req_data.num_computed_tokens)
|
|
self.input_batch.block_table.append_row(req_data.new_block_ids,
|
|
req_index)
|
|
|
|
# Add the new or resumed requests to the persistent batch.
|
|
# The smaller empty indices are filled first.
|
|
removed_req_indices = sorted(removed_req_indices, reverse=True)
|
|
for req_id in req_ids_to_add:
|
|
req_state = self.requests[req_id]
|
|
if removed_req_indices:
|
|
# Fill the empty index.
|
|
req_index = removed_req_indices.pop()
|
|
else:
|
|
# Append to the end.
|
|
req_index = None
|
|
self.input_batch.add_request(req_state, req_index)
|
|
|
|
# Condense the batched states if there are empty indices.
|
|
if removed_req_indices:
|
|
self.input_batch.condense(removed_req_indices)
|
|
|
|
return len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0
|
|
|
|
def get_model(self) -> nn.Module:
|
|
assert self.model is not None
|
|
return self.model
|
|
|
|
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
|
"""
|
|
Generates the KVCacheSpec by parsing the kv cache format from each
|
|
Attention module in the static forward context.
|
|
Returns:
|
|
KVCacheSpec: A dictionary mapping layer names to their KV cache
|
|
format. Layers that do not need KV cache are not included.
|
|
"""
|
|
|
|
forward_ctx = self.vllm_config.compilation_config.static_forward_context
|
|
block_size = self.vllm_config.cache_config.block_size
|
|
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
|
for layer_name, attn_module in forward_ctx.items():
|
|
# TODO: Support other attention modules, e.g., sliding window,
|
|
# cross-attention, MLA.
|
|
assert isinstance(attn_module, Attention)
|
|
if attn_module.attn_type == AttentionType.DECODER:
|
|
kv_cache_spec[layer_name] = FullAttentionSpec(
|
|
block_size=block_size,
|
|
num_kv_heads=attn_module.num_kv_heads,
|
|
head_size=attn_module.head_size,
|
|
dtype=attn_module.dtype,
|
|
use_mla=False,
|
|
)
|
|
elif attn_module.attn_type in (AttentionType.ENCODER,
|
|
AttentionType.ENCODER_ONLY):
|
|
# encoder-only attention does not need KV cache.
|
|
continue
|
|
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
|
|
raise NotImplementedError
|
|
else:
|
|
raise ValueError(
|
|
f"Unknown attention type: {attn_module.attn_type}")
|
|
|
|
return kv_cache_spec
|
|
|
|
def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
|
|
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
|
assert total_num_scheduled_tokens > 0
|
|
num_reqs = self.input_batch.num_reqs
|
|
assert num_reqs > 0
|
|
|
|
# Get the number of scheduled tokens for each request.
|
|
num_scheduled_tokens_per_req = []
|
|
max_num_scheduled_tokens_all_reqs = 0
|
|
for req_id in self.input_batch.req_ids[:num_reqs]:
|
|
assert req_id is not None
|
|
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
|
num_scheduled_tokens_per_req.append(num_tokens)
|
|
max_num_scheduled_tokens_all_reqs = max(
|
|
max_num_scheduled_tokens_all_reqs, num_tokens)
|
|
num_scheduled_tokens_per_req = np.array(num_scheduled_tokens_per_req,
|
|
dtype=np.int32)
|
|
assert max_num_scheduled_tokens_all_reqs > 0
|
|
|
|
# Get request indices.
|
|
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
|
|
# For each scheduled token, what are the corresponding req index.
|
|
req_indices = np.repeat(self.arange_np[:num_reqs],
|
|
num_scheduled_tokens_per_req)
|
|
|
|
# Get batched arange.
|
|
# E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
|
# For each scheduled token, what is its position in corresponding req.
|
|
arange = np.concatenate(
|
|
[self.arange_np[:n] for n in num_scheduled_tokens_per_req])
|
|
|
|
# Get positions.
|
|
positions_np = self.positions_np[:total_num_scheduled_tokens]
|
|
np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
|
|
arange,
|
|
out=positions_np)
|
|
|
|
# Get token indices.
|
|
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
|
# -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
|
|
# where M is the max_model_len.
|
|
token_indices = (positions_np +
|
|
req_indices * self.input_batch.token_ids_cpu.shape[1])
|
|
|
|
# NOTE(woosuk): We use torch.index_select instead of np.take here
|
|
# because torch.index_select is much faster than np.take for large
|
|
# tensors.
|
|
torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
|
|
0,
|
|
torch.from_numpy(token_indices),
|
|
out=self.input_ids_cpu[:total_num_scheduled_tokens])
|
|
|
|
# Calculate the slot mapping.
|
|
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
|
# -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
|
|
# where K is the max_num_blocks_per_req and the block size is 2.
|
|
# NOTE(woosuk): We can't simply use `token_indices // block_size` here
|
|
# because M (max_model_len) is not necessarily divisible by block_size.
|
|
# req_indices: # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
|
|
block_table_indices = (req_indices * self.max_num_blocks_per_req +
|
|
positions_np // self.block_size)
|
|
# NOTE(woosuk): We use torch.index_select instead of np.take here
|
|
# because torch.index_select is much faster than np.take for large
|
|
# tensors.
|
|
block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
|
|
block_numbers = block_table_cpu.flatten()[block_table_indices].numpy()
|
|
block_offsets = positions_np % self.block_size
|
|
np.add(block_numbers * self.block_size,
|
|
block_offsets,
|
|
out=self.slot_mapping_np[:total_num_scheduled_tokens])
|
|
|
|
# Prepare the attention metadata.
|
|
self.query_start_loc_np[0] = 0
|
|
np.cumsum(num_scheduled_tokens_per_req,
|
|
out=self.query_start_loc_np[1:num_reqs + 1])
|
|
self.query_start_loc_np[num_reqs + 1:] = 1
|
|
|
|
self.seq_lens_np[:num_reqs] = (
|
|
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
|
|
num_scheduled_tokens_per_req)
|
|
|
|
# Do the padding and copy the tensors to the TPU.
|
|
padded_total_num_scheduled_tokens = _get_padded_token_len(
|
|
self.num_tokens_paddings, total_num_scheduled_tokens)
|
|
# Zero out to avoid spurious values from prev iteration (last cp chunk)
|
|
self.input_ids_cpu[
|
|
total_num_scheduled_tokens:padded_total_num_scheduled_tokens] = 0
|
|
self.input_ids = self.input_ids_cpu[:
|
|
padded_total_num_scheduled_tokens].to(
|
|
self.device)
|
|
self.position_ids = self.positions_cpu[:
|
|
padded_total_num_scheduled_tokens].to(
|
|
self.device)
|
|
self.slot_mapping_cpu[total_num_scheduled_tokens:] = _PAD_SLOT_ID
|
|
slot_mapping = self.slot_mapping_cpu[:
|
|
padded_total_num_scheduled_tokens].to(
|
|
self.device)
|
|
block_tables = self.block_table_cpu[:self.max_num_reqs]
|
|
block_tables[:num_reqs, :self.max_num_blocks_per_req] = (
|
|
self.input_batch.block_table.get_cpu_tensor()[:num_reqs])
|
|
block_tables = block_tables.to(self.device)
|
|
query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs + 1].to(
|
|
self.device)
|
|
seq_lens = self.seq_lens_cpu[:self.max_num_reqs].to(self.device)
|
|
|
|
attn_metadata = PallasMetadata(
|
|
slot_mapping=slot_mapping,
|
|
block_tables=block_tables,
|
|
context_lens=seq_lens,
|
|
query_start_loc=query_start_loc,
|
|
num_seqs=torch.tensor([num_reqs],
|
|
dtype=torch.int32,
|
|
device=self.device),
|
|
)
|
|
# NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial
|
|
# request in the batch. While we should not sample any token from this
|
|
# partial request, we do so for simplicity. We will ignore the sampled
|
|
# token from the partial request.
|
|
# TODO: Support prompt logprobs.
|
|
padded_num_reqs = _get_padded_num_reqs_with_upper_limit(
|
|
num_reqs, self.max_num_reqs)
|
|
# Indices at which we sample (positions of last token in the sequence).
|
|
# Padded to avoid recompiling when `num_reqs` varies.
|
|
logits_indices = self.query_start_loc_cpu[1:padded_num_reqs + 1] - 1
|
|
logits_indices = logits_indices.to(self.device)
|
|
return attn_metadata, logits_indices
|
|
|
|
def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
|
|
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
|
|
if not scheduled_encoder_inputs:
|
|
return
|
|
|
|
# Batch the multi-modal inputs.
|
|
mm_inputs: list[MultiModalKwargs] = []
|
|
req_input_ids: list[tuple[str, int]] = []
|
|
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
|
|
req_state = self.requests[req_id]
|
|
for input_id in encoder_input_ids:
|
|
mm_inputs.append(req_state.mm_inputs[input_id])
|
|
req_input_ids.append((req_id, input_id))
|
|
|
|
# Batch mm inputs as much as we can: if a request in the batch has
|
|
# multiple modalities or a different modality than the previous one,
|
|
# we process it separately to preserve item order.
|
|
# FIXME(ywang96): This is a hacky way to deal with multiple modalities
|
|
# in the same batch while still being able to benefit from batching
|
|
# multimodal inputs. The proper solution should be reordering the
|
|
# encoder outputs.
|
|
grouped_mm_inputs_list = group_mm_inputs_by_modality(mm_inputs)
|
|
|
|
encoder_outputs = []
|
|
for grouped_mm_inputs in grouped_mm_inputs_list:
|
|
batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
|
|
batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs,
|
|
device=self.device)
|
|
|
|
# Run the encoder.
|
|
# `curr_group_outputs` is either of the following:
|
|
# 1. A tensor of shape (num_items, feature_size, hidden_size)
|
|
# in case feature_size is fixed across all multimodal items.
|
|
# 2. A list or tuple (length: num_items) of tensors, each of shape
|
|
# (feature_size, hidden_size) in case the feature size is dynamic
|
|
# depending on the input multimodal items.
|
|
curr_group_outputs = self.model.get_multimodal_embeddings(
|
|
**batched_mm_inputs)
|
|
|
|
sanity_check_mm_encoder_outputs(
|
|
curr_group_outputs,
|
|
expected_num_items=len(grouped_mm_inputs),
|
|
)
|
|
|
|
for output in curr_group_outputs:
|
|
encoder_outputs.append(output)
|
|
|
|
# Cache the encoder outputs.
|
|
for (req_id, input_id), output in zip(req_input_ids, encoder_outputs):
|
|
if req_id not in self.encoder_cache:
|
|
self.encoder_cache[req_id] = {}
|
|
self.encoder_cache[req_id][input_id] = output
|
|
|
|
def _gather_encoder_outputs(
|
|
self,
|
|
scheduler_output: "SchedulerOutput",
|
|
) -> list[torch.Tensor]:
|
|
encoder_outputs: list[torch.Tensor] = []
|
|
for req_id in self.input_batch.req_ids:
|
|
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
|
|
req_id]
|
|
req_state = self.requests[req_id]
|
|
num_computed_tokens = req_state.num_computed_tokens
|
|
mm_positions = req_state.mm_positions
|
|
for i, pos_info in enumerate(mm_positions):
|
|
start_pos = pos_info["offset"]
|
|
num_encoder_tokens = pos_info["length"]
|
|
|
|
# The encoder output is needed if the two ranges overlap:
|
|
# [num_computed_tokens,
|
|
# num_computed_tokens + num_scheduled_tokens) and
|
|
# [start_pos, start_pos + num_encoder_tokens)
|
|
if start_pos >= num_computed_tokens + num_scheduled_tokens:
|
|
# The encoder output is not needed in this step.
|
|
break
|
|
if start_pos + num_encoder_tokens <= num_computed_tokens:
|
|
# The encoder output is already processed and stored
|
|
# in the decoder's KV cache.
|
|
continue
|
|
|
|
start_idx = max(num_computed_tokens - start_pos, 0)
|
|
end_idx = min(
|
|
num_computed_tokens - start_pos + num_scheduled_tokens,
|
|
num_encoder_tokens)
|
|
assert start_idx < end_idx
|
|
assert req_id in self.encoder_cache
|
|
assert i in self.encoder_cache[req_id]
|
|
encoder_output = self.encoder_cache[req_id][i]
|
|
encoder_outputs.append(encoder_output[start_idx:end_idx])
|
|
return encoder_outputs
|
|
|
|
@torch.no_grad()
|
|
def execute_model(
|
|
self,
|
|
scheduler_output: "SchedulerOutput",
|
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
) -> ModelRunnerOutput:
|
|
# Update cached state
|
|
self._update_states(scheduler_output)
|
|
if not scheduler_output.total_num_scheduled_tokens:
|
|
# Return empty ModelRunnerOuptut if there's no work to do.
|
|
return EMPTY_MODEL_RUNNER_OUTPUT
|
|
|
|
if self.is_multimodal_model:
|
|
# Run the multimodal encoder if any.
|
|
self._execute_encoder(scheduler_output)
|
|
encoder_outputs = self._gather_encoder_outputs(scheduler_output)
|
|
else:
|
|
encoder_outputs = []
|
|
|
|
# Prepare inputs
|
|
attn_metadata, logits_indices = self._prepare_inputs(scheduler_output)
|
|
if self.is_multimodal_model:
|
|
# NOTE(woosuk): To unify token ids and soft tokens (vision
|
|
# embeddings), we always use embeddings (rather than token ids)
|
|
# as input to the multimodal model, even when the input is text.
|
|
if encoder_outputs:
|
|
inputs_embeds = self.model.get_input_embeddings(
|
|
self.input_ids, encoder_outputs)
|
|
else:
|
|
inputs_embeds = self.model.get_input_embeddings(self.input_ids)
|
|
input_ids = None
|
|
else:
|
|
# For text-only models, we use token ids as input.
|
|
# While it is possible to use embeddings as input just like the
|
|
# multimodal models, it is not desirable for performance since
|
|
# then the embedding layer is not included in the CUDA graph.
|
|
input_ids = self.input_ids
|
|
inputs_embeds = None
|
|
num_reqs = self.input_batch.num_reqs
|
|
# NOTE (NickLucche) here we sync with TPU: sampling params tensors
|
|
# are copied to device in chunks of pre-compiled padded shape to
|
|
# avoid recompilations.
|
|
tpu_sampling_metadata = TPUSupportedSamplingMetadata.\
|
|
from_input_batch(self.input_batch, logits_indices)
|
|
# Run the decoder
|
|
with set_forward_context(attn_metadata, self.vllm_config):
|
|
hidden_states = self.model(
|
|
input_ids=input_ids,
|
|
positions=self.position_ids,
|
|
kv_caches=self.kv_caches,
|
|
inputs_embeds=inputs_embeds,
|
|
)
|
|
selected_token_ids = self.model.sample_from_hidden(
|
|
hidden_states, tpu_sampling_metadata)
|
|
# Remove padding on cpu and keep dynamic op outside of xla graph.
|
|
selected_token_ids = selected_token_ids.cpu()[:num_reqs]
|
|
|
|
# Update the cache state concurrently. Code above will not block until
|
|
# we use `selected_token_ids`. Add mark_step if post-processing changes
|
|
request_seq_lens: list[tuple[int, CachedRequestState, int]] = []
|
|
discard_sampled_tokens_req_indices = []
|
|
for i, req_id in zip(range(num_reqs), self.input_batch.req_ids):
|
|
assert req_id is not None
|
|
req_state = self.requests[req_id]
|
|
seq_len = (req_state.num_computed_tokens +
|
|
scheduler_output.num_scheduled_tokens[req_id])
|
|
if seq_len >= req_state.num_tokens:
|
|
request_seq_lens.append((i, req_state, seq_len))
|
|
else:
|
|
# Ignore the sampled token from the partial request.
|
|
# Rewind the generator state as if the token was not sampled.
|
|
generator = self.input_batch.generators.get(i)
|
|
if generator is not None:
|
|
# This relies on cuda-specific torch-internal impl details
|
|
generator.set_offset(generator.get_offset() - 4)
|
|
|
|
# Record the index of the request that should not be sampled,
|
|
# so that we could clear the sampled tokens before returning.
|
|
discard_sampled_tokens_req_indices.append(i)
|
|
|
|
assert all(
|
|
req_id is not None for req_id in
|
|
self.input_batch.req_ids[:num_reqs]), "req_ids contains None"
|
|
req_ids = cast(list[str], self.input_batch.req_ids[:num_reqs])
|
|
|
|
prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {}
|
|
for req_id in self.input_batch.req_ids[:num_reqs]:
|
|
prompt_logprobs_dict[req_id] = None
|
|
|
|
max_gen_len = selected_token_ids.shape[-1]
|
|
if max_gen_len == 1:
|
|
valid_sampled_token_ids = selected_token_ids.tolist()
|
|
|
|
# Mask out the sampled tokens that should not be sampled.
|
|
# TODO: Keep in sync with gpu_model_runner.py, in particular
|
|
# the "else" case here
|
|
for i in discard_sampled_tokens_req_indices:
|
|
valid_sampled_token_ids[i].clear()
|
|
|
|
# Append sampled tokens
|
|
for i, req_state, seq_len in request_seq_lens:
|
|
token_id = valid_sampled_token_ids[i][0]
|
|
self.input_batch.token_ids_cpu[i, seq_len] = token_id
|
|
req_state.output_token_ids.append(token_id)
|
|
self.input_batch.num_tokens[i] += 1
|
|
|
|
else:
|
|
valid_mask = selected_token_ids != INVALID_TOKEN_ID
|
|
gen_lens = valid_mask.sum(dim=1).tolist()
|
|
valid_sampled_token_ids = [
|
|
seq.tolist()
|
|
for seq in selected_token_ids[valid_mask].split(gen_lens)
|
|
]
|
|
self.input_batch.num_tokens[:num_reqs] += gen_lens
|
|
for i, req_state, seq_len in request_seq_lens:
|
|
target_slice = slice(seq_len - gen_lens[i] + 1, seq_len + 1)
|
|
self.input_batch.token_ids_cpu[
|
|
i, target_slice] = valid_sampled_token_ids[i]
|
|
req_state.output_token_ids.extend(valid_sampled_token_ids[i])
|
|
|
|
model_runner_output = ModelRunnerOutput(
|
|
req_ids=req_ids,
|
|
req_id_to_index=self.input_batch.req_id_to_index,
|
|
sampled_token_ids=valid_sampled_token_ids,
|
|
spec_token_ids=None,
|
|
logprobs=None,
|
|
prompt_logprobs_dict=prompt_logprobs_dict,
|
|
)
|
|
|
|
# Check there are no new graphs compiled - all the graphs should be
|
|
# captured and compiled during warm up.
|
|
self._verify_num_xla_graphs("execute_model")
|
|
|
|
return model_runner_output
|
|
|
|
def load_model(self) -> None:
|
|
self.device = self.device_config.device
|
|
|
|
# NOTE(woosuk): While the executor assigns the TP ranks to the worker
|
|
# process, the ranks can be different from the ranks internally assigned
|
|
# by the xm runtime. Therefore, there is a mismatch in the rank
|
|
# assignment between the gloo (cpu) runtime and the xm (tpu) runtime.
|
|
# This is not a problem in linear layers because all-reduce is
|
|
# rank-agnostic. However, it matters for all-gather as the ranks
|
|
# determine the order of concatenating the output tensors.
|
|
# As a workaround, we use the xm's rank assignment only when loading
|
|
# the embedding weights.
|
|
xm_tp_rank = xr.global_ordinal()
|
|
with patch(
|
|
"vllm.model_executor.layers.vocab_parallel_embedding."
|
|
"get_tensor_model_parallel_rank",
|
|
return_value=xm_tp_rank):
|
|
model = get_model(vllm_config=self.vllm_config)
|
|
model = model.eval()
|
|
xm.mark_step()
|
|
xm.wait_device_ops()
|
|
model = ModelWrapperV1(model)
|
|
self.model = torch.compile(model,
|
|
backend="openxla",
|
|
fullgraph=True,
|
|
dynamic=False)
|
|
|
|
@torch.no_grad()
|
|
def _dummy_run(self, kv_caches, num_tokens: int) -> None:
|
|
if self.is_multimodal_model:
|
|
input_ids = None
|
|
inputs_embeds = torch.zeros((num_tokens, self.hidden_size),
|
|
dtype=self.dtype,
|
|
device=self.device)
|
|
else:
|
|
input_ids = torch.zeros((num_tokens),
|
|
dtype=torch.int32,
|
|
device=self.device)
|
|
inputs_embeds = None
|
|
actual_num_reqs = min(num_tokens, self.max_num_reqs)
|
|
position_ids = torch.zeros(num_tokens,
|
|
dtype=torch.int32,
|
|
device=self.device)
|
|
slot_mapping = torch.zeros(num_tokens,
|
|
dtype=torch.int64,
|
|
device=self.device)
|
|
block_tables = torch.zeros(
|
|
(self.max_num_reqs, self.block_table_cpu.shape[1]),
|
|
dtype=torch.int32,
|
|
device=self.device)
|
|
query_lens = [1] * self.max_num_reqs
|
|
query_start_loc = torch.cumsum(torch.tensor([0] + query_lens,
|
|
dtype=torch.int32),
|
|
dim=0,
|
|
dtype=torch.int32).to(self.device)
|
|
context_lens = torch.ones((self.max_num_reqs, ),
|
|
dtype=torch.int32,
|
|
device=self.device)
|
|
num_seqs = torch.tensor([actual_num_reqs],
|
|
dtype=torch.int32,
|
|
device=self.device)
|
|
attn_metadata = PallasMetadata(
|
|
slot_mapping=slot_mapping,
|
|
block_tables=block_tables,
|
|
context_lens=context_lens,
|
|
query_start_loc=query_start_loc,
|
|
num_seqs=num_seqs,
|
|
)
|
|
|
|
if self.is_multimodal_model:
|
|
torch._dynamo.mark_dynamic(inputs_embeds, 0)
|
|
else:
|
|
torch._dynamo.mark_dynamic(input_ids, 0)
|
|
torch._dynamo.mark_dynamic(position_ids, 0)
|
|
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
|
|
|
|
with set_forward_context(attn_metadata, self.vllm_config, 0):
|
|
out = self.model(input_ids=input_ids,
|
|
positions=position_ids,
|
|
kv_caches=kv_caches,
|
|
inputs_embeds=inputs_embeds)
|
|
self._hidden_states_dtype = out.dtype
|
|
|
|
def capture_model(self) -> None:
|
|
"""Compile the model."""
|
|
|
|
logger.info("Compiling the model with different input shapes.")
|
|
|
|
start = time.perf_counter()
|
|
for num_tokens in self.num_tokens_paddings:
|
|
logger.info(" -- num_tokens: %d", num_tokens)
|
|
self._dummy_run(self.kv_caches, num_tokens)
|
|
xm.mark_step()
|
|
xm.wait_device_ops()
|
|
end = time.perf_counter()
|
|
|
|
logger.info("Compilation finished in in %.2f [secs].", end - start)
|
|
self._update_num_xla_graphs("model")
|
|
|
|
logger.info("Compiling sampling with different input shapes.")
|
|
start = time.perf_counter()
|
|
hsize = self.model_config.get_hidden_size()
|
|
device = self.device
|
|
# Compile sampling step for different model+sampler outputs in bucketed
|
|
# n_tokens x max_num_reqs. Graph is really small so this is fine.
|
|
for num_tokens in self.num_tokens_paddings:
|
|
num_reqs_to_sample = MIN_NUM_SEQS
|
|
dummy_hidden = torch.randn((num_tokens, hsize),
|
|
device=device,
|
|
dtype=self._hidden_states_dtype)
|
|
# Compile for [8, 16, .., 128,.., `self.max_num_reqs`]
|
|
while True:
|
|
indices = torch.zeros(
|
|
num_reqs_to_sample,
|
|
dtype=torch.int32,
|
|
device=device,
|
|
)
|
|
xm.mark_step()
|
|
sampling_meta = TPUSupportedSamplingMetadata.\
|
|
from_input_batch(self.input_batch, indices)
|
|
logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens,
|
|
num_reqs_to_sample)
|
|
out = self.model.sample_from_hidden(dummy_hidden,
|
|
sampling_meta)
|
|
out = out.cpu()
|
|
if num_reqs_to_sample >= self.max_num_reqs:
|
|
break
|
|
# Make sure to compile the `max_num_reqs` upper-limit case
|
|
num_reqs_to_sample = _get_padded_num_reqs_with_upper_limit(
|
|
num_reqs_to_sample + 1, self.max_num_reqs)
|
|
xm.wait_device_ops()
|
|
end = time.perf_counter()
|
|
|
|
logger.info("Compilation finished in in %.2f [secs].", end - start)
|
|
self._update_num_xla_graphs("sampling")
|
|
|
|
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
|
"""
|
|
Initialize KV cache based on `kv_cache_config`.
|
|
Args:
|
|
kv_cache_config: Configuration for the KV cache, including the KV
|
|
cache size of each layer
|
|
"""
|
|
if len(kv_cache_config.kv_cache_groups) > 1:
|
|
raise NotImplementedError(
|
|
"Hybrid models with more than one KV cache type are not "
|
|
"supported yet.")
|
|
|
|
kv_caches: dict[str, torch.Tensor] = {}
|
|
|
|
for kv_cache_group in kv_cache_config.kv_cache_groups:
|
|
kv_cache_spec = kv_cache_group.kv_cache_spec
|
|
for layer_name in kv_cache_group.layer_names:
|
|
tensor_config = kv_cache_config.tensors[layer_name]
|
|
assert tensor_config.size % kv_cache_spec.page_size_bytes == 0
|
|
num_blocks = tensor_config.size // kv_cache_spec.page_size_bytes
|
|
if isinstance(kv_cache_spec, FullAttentionSpec):
|
|
kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape(
|
|
num_blocks, kv_cache_spec.block_size,
|
|
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
|
dtype = kv_cache_spec.dtype
|
|
|
|
tpu_kv_cache = torch.zeros(kv_cache_shape,
|
|
dtype=dtype,
|
|
device=self.device)
|
|
|
|
kv_caches[layer_name] = tpu_kv_cache
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
bind_kv_cache(
|
|
kv_caches,
|
|
self.vllm_config.compilation_config.static_forward_context,
|
|
self.kv_caches)
|
|
|
|
|
|
class ModelWrapperV1(nn.Module):
|
|
|
|
def __init__(self, model: nn.Module):
|
|
super().__init__()
|
|
self.model = model
|
|
self.sampler = TPUSampler()
|
|
|
|
def sample(
|
|
self, logits: torch.Tensor,
|
|
sampling_metadata: TPUSupportedSamplingMetadata) -> SamplerOutput:
|
|
sampler_out = self.sampler(logits, sampling_metadata)
|
|
return sampler_out
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
kv_caches: list[torch.Tensor],
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
"""Executes the forward pass of the model.
|
|
|
|
Args:
|
|
input_ids: The input token IDs of shape [num_tokens].
|
|
positions: The input position IDs of shape [num_tokens].
|
|
kv_caches: The key and value caches. They can be None during the
|
|
memory profiling at initialization.
|
|
inputs_embeds: The input embeddings of shape [num_tokens,
|
|
hidden_size]. It is used for multimodal models.
|
|
"""
|
|
|
|
hidden_states = self.model(
|
|
input_ids=input_ids,
|
|
positions=positions,
|
|
inputs_embeds=inputs_embeds,
|
|
)
|
|
|
|
return hidden_states
|
|
|
|
def sample_from_hidden(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
sampling_metadata: TPUSupportedSamplingMetadata,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Sample with xla-friendly function. This function is to be traced
|
|
separately from `forward` for lighter compilation overhead.
|
|
"""
|
|
# Tensor `sample_hidden_states` is of fixed pre-compiled size.
|
|
sample_hidden_states = \
|
|
hidden_states[sampling_metadata.indices_do_sample]
|
|
logits = self.compute_logits(sample_hidden_states)
|
|
# Optimized greedy sampling branch, tracing both paths in a single pass
|
|
# NOTE all_greedy is a scalar, this is just an optimized if/else.
|
|
out_tokens = torch.where(sampling_metadata.all_greedy,
|
|
torch.argmax(logits, dim=-1, keepdim=True),
|
|
self.sample(logits, sampling_metadata)\
|
|
.sampled_token_ids)
|
|
return out_tokens
|
|
|
|
def compute_logits(self,
|
|
hidden_states: torch.Tensor) -> Optional[torch.Tensor]:
|
|
# SamplingMetadata here for pruning output in LogitsProcessor, disabled
|
|
logits = self.model.compute_logits(hidden_states, None)
|
|
return logits
|
|
|
|
def get_multimodal_embeddings(self, *args, **kwargs):
|
|
return self.model.get_multimodal_embeddings(*args, **kwargs)
|
|
|
|
def get_input_embeddings(self, *args, **kwargs):
|
|
return self.model.get_input_embeddings(*args, **kwargs)
|
|
|
|
|
|
def _get_padded_number(n: int, multiple: int) -> int:
|
|
return ((n + multiple - 1) // multiple) * multiple
|
|
|
|
|
|
def _get_padded_num_reqs_with_upper_limit(x, upper_limit) -> int:
|
|
res = MIN_NUM_SEQS if x <= MIN_NUM_SEQS else 1 << (x - 1).bit_length()
|
|
return min(res, upper_limit)
|
|
|
|
|
|
def _get_paddings(min_token_size: int, max_token_size: int,
|
|
padding_gap: int) -> list[int]:
|
|
"""Generate a list of padding size, starting from min_token_size,
|
|
ending with a number that can cover max_token_size
|
|
|
|
If padding_gap == 0 then:
|
|
increase 2X each time (exponential)
|
|
else:
|
|
first increase the size to twice,
|
|
then increase the padding size by padding_gap.
|
|
"""
|
|
paddings = []
|
|
num = min_token_size
|
|
|
|
if padding_gap == 0:
|
|
logger.info("Using exponential paddings:")
|
|
while num <= max_token_size:
|
|
logger.info(" %d", num)
|
|
paddings.append(num)
|
|
num *= 2
|
|
|
|
else:
|
|
logger.info("Using incremental paddings:")
|
|
while num <= padding_gap:
|
|
logger.info(" %d", num)
|
|
paddings.append(num)
|
|
num *= 2
|
|
num //= 2
|
|
while num < max_token_size:
|
|
num += padding_gap
|
|
logger.info(" %d", num)
|
|
paddings.append(num)
|
|
|
|
return paddings
|
|
|
|
|
|
def _get_padded_token_len(paddings: list[int], x: int) -> int:
|
|
"""Return the first element in paddings list greater or equal to x.
|
|
"""
|
|
index = bisect.bisect_left(paddings, x)
|
|
assert index < len(paddings)
|
|
return paddings[index]
|