2024-06-25 20:30:03 -07:00
|
|
|
|
import dataclasses
|
2024-06-08 19:14:43 -07:00
|
|
|
|
import gc
|
2023-12-16 21:12:08 -08:00
|
|
|
|
import time
|
2024-05-22 13:28:20 -07:00
|
|
|
|
import warnings
|
2024-06-03 13:56:41 +08:00
|
|
|
|
from collections import defaultdict
|
2024-07-03 11:34:00 +08:00
|
|
|
|
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set,
|
|
|
|
|
Tuple, Type, TypeVar, Union)
|
2023-11-29 22:16:37 -08:00
|
|
|
|
|
2023-12-16 21:12:08 -08:00
|
|
|
|
import numpy as np
|
2023-11-29 22:16:37 -08:00
|
|
|
|
import torch
|
2024-07-02 10:58:08 -07:00
|
|
|
|
import torch.distributed
|
2023-12-16 21:12:08 -08:00
|
|
|
|
import torch.nn as nn
|
2023-11-29 22:16:37 -08:00
|
|
|
|
|
2024-06-28 15:28:49 -07:00
|
|
|
|
try:
|
|
|
|
|
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
|
|
|
|
|
from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper
|
|
|
|
|
from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
|
|
|
|
|
FLASHINFER_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024
|
|
|
|
|
except ImportError:
|
|
|
|
|
BatchDecodeWithPagedKVCacheWrapper = None
|
|
|
|
|
CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
|
|
|
|
|
BatchPrefillWithPagedKVCacheWrapper = None
|
|
|
|
|
FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
|
|
|
|
|
|
2024-05-15 14:00:10 +09:00
|
|
|
|
from vllm.attention import AttentionMetadata, get_attn_backend
|
2024-05-09 09:04:59 -07:00
|
|
|
|
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
2024-07-03 15:14:16 -07:00
|
|
|
|
ModelConfig, MultiModalConfig, ParallelConfig,
|
|
|
|
|
SchedulerConfig)
|
2024-07-02 10:58:08 -07:00
|
|
|
|
from vllm.distributed import get_pp_group
|
2024-06-12 17:27:08 -07:00
|
|
|
|
from vllm.distributed.parallel_state import graph_capture
|
2024-06-28 20:09:56 +08:00
|
|
|
|
from vllm.inputs import INPUT_REGISTRY
|
2023-11-29 22:16:37 -08:00
|
|
|
|
from vllm.logger import init_logger
|
2024-03-25 23:59:47 +09:00
|
|
|
|
from vllm.lora.layers import LoRAMapping
|
|
|
|
|
from vllm.lora.request import LoRARequest
|
|
|
|
|
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
|
2024-03-24 21:39:33 -07:00
|
|
|
|
from vllm.model_executor import SamplingMetadata
|
2024-03-21 18:22:17 -07:00
|
|
|
|
from vllm.model_executor.model_loader import get_model
|
2024-06-12 15:13:52 -06:00
|
|
|
|
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
2024-07-03 15:14:16 -07:00
|
|
|
|
from vllm.model_executor.models.interfaces import (supports_lora,
|
|
|
|
|
supports_vision)
|
2024-07-03 11:34:00 +08:00
|
|
|
|
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
|
|
|
|
|
MultiModalInputs)
|
2024-04-26 22:02:02 +09:00
|
|
|
|
from vllm.sampling_params import SamplingParams
|
2024-07-02 10:58:08 -07:00
|
|
|
|
from vllm.sequence import (IntermediateTensors, SamplerOutput,
|
|
|
|
|
SequenceGroupMetadata)
|
2024-05-03 15:51:27 -07:00
|
|
|
|
from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip,
|
|
|
|
|
is_pin_memory_available, make_tensor_with_pad)
|
2024-06-25 20:30:03 -07:00
|
|
|
|
from vllm.worker.model_runner_base import (
|
|
|
|
|
ModelRunnerBase, ModelRunnerInputBase,
|
|
|
|
|
_add_attn_metadata_broadcastable_dict,
|
|
|
|
|
_add_sampling_metadata_broadcastable_dict,
|
|
|
|
|
_init_attn_metadata_from_tensor_dict,
|
|
|
|
|
_init_sampling_metadata_from_tensor_dict)
|
|
|
|
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
|
from vllm.attention.backends.abstract import AttentionBackend
|
2023-11-29 22:16:37 -08:00
|
|
|
|
|
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
|
|
|
|
_PAD_SLOT_ID = -1
|
2024-01-24 00:26:37 +01:00
|
|
|
|
LORA_WARMUP_RANK = 8
|
2024-03-21 06:46:05 +09:00
|
|
|
|
_BATCH_SIZE_ALIGNMENT = 8
|
|
|
|
|
# Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
|
2023-12-16 21:12:08 -08:00
|
|
|
|
# NOTE: _get_graph_batch_size needs to be updated if this list is changed.
|
2024-03-21 06:46:05 +09:00
|
|
|
|
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
|
|
|
|
|
_BATCH_SIZE_ALIGNMENT * i for i in range(1, 33)
|
|
|
|
|
]
|
2024-06-06 19:07:57 -07:00
|
|
|
|
_NUM_WARMUP_ITERS = 2
|
2023-11-29 22:16:37 -08:00
|
|
|
|
|
2024-06-25 20:30:03 -07:00
|
|
|
|
TModelInputForGPU = TypeVar('TModelInputForGPU', bound="ModelInputForGPU")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
|
|
|
class ModelInputForGPU(ModelRunnerInputBase):
|
|
|
|
|
"""
|
|
|
|
|
This base class contains metadata needed for the base model forward pass
|
|
|
|
|
but not metadata for possible additional steps, e.g., sampling. Model
|
|
|
|
|
runners that run additional steps should subclass this method to add
|
|
|
|
|
additional fields.
|
|
|
|
|
"""
|
|
|
|
|
input_tokens: Optional[torch.Tensor] = None
|
|
|
|
|
input_positions: Optional[torch.Tensor] = None
|
|
|
|
|
seq_lens: Optional[List[int]] = None
|
|
|
|
|
query_lens: Optional[List[int]] = None
|
|
|
|
|
lora_mapping: Optional["LoRAMapping"] = None
|
|
|
|
|
lora_requests: Optional[Set[LoRARequest]] = None
|
|
|
|
|
attn_metadata: Optional["AttentionMetadata"] = None
|
2024-07-03 11:34:00 +08:00
|
|
|
|
multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]] = None
|
2024-07-03 02:11:29 +03:00
|
|
|
|
request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None
|
|
|
|
|
finished_requests_ids: Optional[List[str]] = None
|
2024-07-02 10:58:08 -07:00
|
|
|
|
virtual_engine: int = 0
|
2024-06-25 20:30:03 -07:00
|
|
|
|
|
|
|
|
|
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
|
|
|
|
tensor_dict = {
|
|
|
|
|
"input_tokens": self.input_tokens,
|
|
|
|
|
"input_positions": self.input_positions,
|
|
|
|
|
"lora_requests": self.lora_requests,
|
|
|
|
|
"lora_mapping": self.lora_mapping,
|
|
|
|
|
"multi_modal_kwargs": self.multi_modal_kwargs,
|
2024-07-02 10:58:08 -07:00
|
|
|
|
"virtual_engine": self.virtual_engine,
|
2024-07-03 02:11:29 +03:00
|
|
|
|
"request_ids_to_seq_ids": self.request_ids_to_seq_ids,
|
|
|
|
|
"finished_requests_ids": self.finished_requests_ids,
|
2024-06-25 20:30:03 -07:00
|
|
|
|
}
|
|
|
|
|
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
|
|
|
|
|
return tensor_dict
|
2024-04-11 09:56:48 +09:00
|
|
|
|
|
|
|
|
|
@classmethod
|
2024-06-25 20:30:03 -07:00
|
|
|
|
def from_broadcasted_tensor_dict(
|
|
|
|
|
cls: Type[TModelInputForGPU],
|
|
|
|
|
tensor_dict: Dict[str, Any],
|
|
|
|
|
attn_backend: Optional["AttentionBackend"] = None,
|
|
|
|
|
) -> TModelInputForGPU:
|
|
|
|
|
if attn_backend is not None:
|
|
|
|
|
tensor_dict = _init_attn_metadata_from_tensor_dict(
|
|
|
|
|
attn_backend, tensor_dict)
|
|
|
|
|
return cls(**tensor_dict)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
|
|
|
class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
|
|
|
|
|
"""
|
|
|
|
|
Used by the ModelRunner.
|
|
|
|
|
"""
|
|
|
|
|
sampling_metadata: Optional["SamplingMetadata"] = None
|
|
|
|
|
# Used for speculative decoding. We do not broadcast it because it is only
|
|
|
|
|
# used by the driver worker.
|
|
|
|
|
is_prompt: Optional[bool] = None
|
|
|
|
|
|
|
|
|
|
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
|
|
|
|
tensor_dict = {
|
|
|
|
|
"input_tokens": self.input_tokens,
|
|
|
|
|
"input_positions": self.input_positions,
|
|
|
|
|
"lora_requests": self.lora_requests,
|
|
|
|
|
"lora_mapping": self.lora_mapping,
|
|
|
|
|
"multi_modal_kwargs": self.multi_modal_kwargs,
|
2024-07-02 10:58:08 -07:00
|
|
|
|
"virtual_engine": self.virtual_engine,
|
2024-07-03 02:11:29 +03:00
|
|
|
|
"request_ids_to_seq_ids": self.request_ids_to_seq_ids,
|
|
|
|
|
"finished_requests_ids": self.finished_requests_ids,
|
2024-06-25 20:30:03 -07:00
|
|
|
|
}
|
|
|
|
|
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
|
|
|
|
|
_add_sampling_metadata_broadcastable_dict(tensor_dict,
|
|
|
|
|
self.sampling_metadata)
|
|
|
|
|
return tensor_dict
|
2024-04-11 09:56:48 +09:00
|
|
|
|
|
2024-06-25 20:30:03 -07:00
|
|
|
|
@classmethod
|
|
|
|
|
def from_broadcasted_tensor_dict(
|
|
|
|
|
cls,
|
|
|
|
|
tensor_dict: Dict[str, Any],
|
|
|
|
|
attn_backend: Optional["AttentionBackend"] = None,
|
|
|
|
|
) -> "ModelInputForGPUWithSamplingMetadata":
|
|
|
|
|
tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
|
|
|
|
|
if attn_backend is not None:
|
|
|
|
|
tensor_dict = _init_attn_metadata_from_tensor_dict(
|
|
|
|
|
attn_backend, tensor_dict)
|
|
|
|
|
return cls(**tensor_dict)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|
|
|
|
"""
|
|
|
|
|
Helper class for shared methods between GPU model runners.
|
|
|
|
|
"""
|
|
|
|
|
_model_input_cls: Type[TModelInputForGPU]
|
2023-11-29 22:16:37 -08:00
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
model_config: ModelConfig,
|
|
|
|
|
parallel_config: ParallelConfig,
|
|
|
|
|
scheduler_config: SchedulerConfig,
|
2024-02-02 07:46:39 +08:00
|
|
|
|
device_config: DeviceConfig,
|
2024-05-09 09:04:59 -07:00
|
|
|
|
cache_config: CacheConfig,
|
2024-04-16 11:34:39 -07:00
|
|
|
|
load_config: LoadConfig,
|
2024-01-24 00:26:37 +01:00
|
|
|
|
lora_config: Optional[LoRAConfig],
|
2024-01-29 08:43:54 +08:00
|
|
|
|
kv_cache_dtype: Optional[str] = "auto",
|
2024-01-04 03:30:22 +08:00
|
|
|
|
is_driver_worker: bool = False,
|
2024-07-03 15:14:16 -07:00
|
|
|
|
multimodal_config: Optional[MultiModalConfig] = None,
|
2024-06-20 20:23:12 -04:00
|
|
|
|
return_hidden_states: bool = False,
|
2023-11-29 22:16:37 -08:00
|
|
|
|
):
|
|
|
|
|
self.model_config = model_config
|
|
|
|
|
self.parallel_config = parallel_config
|
|
|
|
|
self.scheduler_config = scheduler_config
|
2024-05-09 09:04:59 -07:00
|
|
|
|
self.device_config = device_config
|
|
|
|
|
self.cache_config = cache_config
|
2024-01-24 00:26:37 +01:00
|
|
|
|
self.lora_config = lora_config
|
2024-04-16 11:34:39 -07:00
|
|
|
|
self.load_config = load_config
|
2024-01-04 03:30:22 +08:00
|
|
|
|
self.is_driver_worker = is_driver_worker
|
2024-07-03 15:14:16 -07:00
|
|
|
|
self.multimodal_config = multimodal_config
|
2024-06-20 20:23:12 -04:00
|
|
|
|
self.return_hidden_states = return_hidden_states
|
2023-11-29 22:16:37 -08:00
|
|
|
|
|
2024-02-02 07:46:39 +08:00
|
|
|
|
self.device = self.device_config.device
|
2024-05-09 09:04:59 -07:00
|
|
|
|
self.pin_memory = is_pin_memory_available()
|
2024-02-02 07:46:39 +08:00
|
|
|
|
|
2024-05-09 09:04:59 -07:00
|
|
|
|
self.kv_cache_dtype = kv_cache_dtype
|
|
|
|
|
self.sliding_window = model_config.get_sliding_window()
|
|
|
|
|
self.block_size = cache_config.block_size
|
|
|
|
|
self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture
|
2024-07-02 10:58:08 -07:00
|
|
|
|
|
|
|
|
|
self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [
|
|
|
|
|
{} for _ in range(self.parallel_config.pipeline_parallel_size)
|
|
|
|
|
]
|
2024-04-18 09:28:43 +09:00
|
|
|
|
self.graph_memory_pool: Optional[Tuple[
|
|
|
|
|
int, int]] = None # Set during graph capture.
|
2024-07-03 02:11:29 +03:00
|
|
|
|
|
|
|
|
|
self.has_seqlen_agnostic = model_config.contains_seqlen_agnostic_layers(
|
|
|
|
|
parallel_config)
|
|
|
|
|
|
2024-04-18 09:28:43 +09:00
|
|
|
|
# When using CUDA graph, the input block tables must be padded to
|
2024-05-04 02:20:12 +09:00
|
|
|
|
# max_seq_len_to_capture. However, creating the block table in
|
2024-04-18 09:28:43 +09:00
|
|
|
|
# Python can be expensive. To optimize this, we cache the block table
|
|
|
|
|
# in numpy and only copy the actual input content at every iteration.
|
|
|
|
|
# The shape of the cached block table will be
|
|
|
|
|
# (max batch size to capture, max context len to capture / block size).
|
2024-05-09 09:04:59 -07:00
|
|
|
|
self.graph_block_tables = np.zeros(
|
|
|
|
|
(max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()),
|
|
|
|
|
dtype=np.int32)
|
2024-06-20 20:23:12 -04:00
|
|
|
|
num_attn_heads = self.model_config.get_num_attention_heads(
|
|
|
|
|
self.parallel_config)
|
2024-05-13 10:47:25 -07:00
|
|
|
|
self.attn_backend = get_attn_backend(
|
2024-06-20 20:23:12 -04:00
|
|
|
|
num_attn_heads,
|
2024-05-13 10:47:25 -07:00
|
|
|
|
self.model_config.get_head_size(),
|
|
|
|
|
self.model_config.get_num_kv_heads(self.parallel_config),
|
|
|
|
|
self.model_config.get_sliding_window(),
|
|
|
|
|
self.model_config.dtype,
|
|
|
|
|
self.kv_cache_dtype,
|
|
|
|
|
self.block_size,
|
2024-06-20 20:23:12 -04:00
|
|
|
|
) if num_attn_heads else None
|
2024-04-18 09:28:43 +09:00
|
|
|
|
|
2024-06-28 20:09:56 +08:00
|
|
|
|
# Multi-modal data support
|
|
|
|
|
self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
|
|
|
|
|
.create_input_mapper(self.model_config)
|
2024-06-03 13:56:41 +08:00
|
|
|
|
|
2024-05-09 09:04:59 -07:00
|
|
|
|
# Lazy initialization
|
2024-05-13 10:47:25 -07:00
|
|
|
|
self.model: nn.Module # Set after load_model
|
2024-05-09 09:04:59 -07:00
|
|
|
|
# Set after load_model.
|
|
|
|
|
self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
|
2024-05-03 15:51:27 -07:00
|
|
|
|
|
2024-06-28 15:28:49 -07:00
|
|
|
|
self.flashinfer_decode_workspace_buffer = None
|
|
|
|
|
self.flashinfer_decode_wrapper = None
|
|
|
|
|
self.flashinfer_prefill_workspace_buffer = None
|
|
|
|
|
self.flashinfer_prefill_wrapper = None
|
|
|
|
|
|
2023-11-29 22:16:37 -08:00
|
|
|
|
def load_model(self) -> None:
|
2024-03-21 18:22:17 -07:00
|
|
|
|
with CudaMemoryProfiler() as m:
|
2024-03-25 14:16:30 -07:00
|
|
|
|
self.model = get_model(
|
2024-04-16 11:34:39 -07:00
|
|
|
|
model_config=self.model_config,
|
|
|
|
|
device_config=self.device_config,
|
|
|
|
|
load_config=self.load_config,
|
2024-03-25 14:16:30 -07:00
|
|
|
|
lora_config=self.lora_config,
|
2024-07-03 15:14:16 -07:00
|
|
|
|
multimodal_config=self.multimodal_config,
|
2024-03-25 14:16:30 -07:00
|
|
|
|
parallel_config=self.parallel_config,
|
2024-04-13 20:13:01 -04:00
|
|
|
|
scheduler_config=self.scheduler_config,
|
2024-05-13 10:47:25 -07:00
|
|
|
|
cache_config=self.cache_config,
|
2024-04-13 20:13:01 -04:00
|
|
|
|
)
|
2024-03-07 11:42:42 -08:00
|
|
|
|
|
|
|
|
|
self.model_memory_usage = m.consumed_memory
|
2024-04-26 16:16:58 +09:00
|
|
|
|
logger.info("Loading model weights took %.4f GB",
|
|
|
|
|
self.model_memory_usage / float(2**30))
|
2024-01-24 00:26:37 +01:00
|
|
|
|
|
|
|
|
|
if self.lora_config:
|
2024-06-27 16:03:04 +08:00
|
|
|
|
assert supports_lora(self.model), "Model does not support LoRA"
|
2024-07-03 15:14:16 -07:00
|
|
|
|
assert not supports_vision(
|
|
|
|
|
self.model
|
|
|
|
|
), "To be tested: vision language model with LoRA settings."
|
2024-06-27 16:03:04 +08:00
|
|
|
|
|
2024-01-24 00:26:37 +01:00
|
|
|
|
self.lora_manager = LRUCacheWorkerLoRAManager(
|
|
|
|
|
self.scheduler_config.max_num_seqs,
|
2024-05-18 16:05:23 +09:00
|
|
|
|
self.scheduler_config.max_num_batched_tokens,
|
|
|
|
|
self.vocab_size,
|
|
|
|
|
self.lora_config,
|
|
|
|
|
self.device,
|
|
|
|
|
self.model.embedding_modules,
|
|
|
|
|
self.model.embedding_padding_modules,
|
|
|
|
|
max_position_embeddings=self.model.config.
|
|
|
|
|
max_position_embeddings,
|
|
|
|
|
)
|
2024-01-24 00:26:37 +01:00
|
|
|
|
self.model = self.lora_manager.create_lora_manager(self.model)
|
2023-11-29 22:16:37 -08:00
|
|
|
|
|
2024-04-03 16:15:55 -05:00
|
|
|
|
if self.kv_cache_dtype == "fp8" and is_hip():
|
2024-05-22 13:28:20 -07:00
|
|
|
|
# Currently only ROCm accepts kv-cache scaling factors
|
|
|
|
|
# via quantization_param_path and this will be deprecated
|
|
|
|
|
# in the future.
|
2024-04-03 16:15:55 -05:00
|
|
|
|
if self.model_config.quantization_param_path is not None:
|
|
|
|
|
if callable(getattr(self.model, "load_kv_cache_scales", None)):
|
2024-05-22 13:28:20 -07:00
|
|
|
|
warnings.warn(
|
|
|
|
|
"Loading kv cache scaling factor from JSON is "
|
|
|
|
|
"deprecated and will be removed. Please include "
|
|
|
|
|
"kv cache scaling factors in the model checkpoint.",
|
|
|
|
|
FutureWarning,
|
|
|
|
|
stacklevel=2)
|
2024-04-03 16:15:55 -05:00
|
|
|
|
self.model.load_kv_cache_scales(
|
|
|
|
|
self.model_config.quantization_param_path)
|
2024-05-22 13:28:20 -07:00
|
|
|
|
logger.info("Loaded KV cache scaling factors from %s",
|
|
|
|
|
self.model_config.quantization_param_path)
|
2024-04-03 16:15:55 -05:00
|
|
|
|
else:
|
2024-04-26 16:16:58 +09:00
|
|
|
|
raise RuntimeError(
|
|
|
|
|
"Using FP8 KV cache and scaling factors provided but "
|
|
|
|
|
"model %s does not support loading scaling factors.",
|
|
|
|
|
self.model.__class__)
|
2024-04-03 16:15:55 -05:00
|
|
|
|
else:
|
2024-04-26 16:16:58 +09:00
|
|
|
|
logger.warning(
|
|
|
|
|
"Using FP8 KV cache but no scaling factors "
|
|
|
|
|
"provided. Defaulting to scaling factors of 1.0. "
|
|
|
|
|
"This may lead to less accurate results!")
|
2024-04-03 16:15:55 -05:00
|
|
|
|
|
2024-05-16 01:11:54 -04:00
|
|
|
|
def save_sharded_state(
|
|
|
|
|
self,
|
|
|
|
|
path: str,
|
|
|
|
|
pattern: Optional[str] = None,
|
|
|
|
|
max_size: Optional[int] = None,
|
|
|
|
|
) -> None:
|
|
|
|
|
from vllm.model_executor.model_loader.loader import ShardedStateLoader
|
|
|
|
|
ShardedStateLoader.save_model(
|
|
|
|
|
self.model,
|
|
|
|
|
path,
|
|
|
|
|
pattern=pattern,
|
|
|
|
|
max_size=max_size,
|
|
|
|
|
)
|
|
|
|
|
|
2024-06-12 15:13:52 -06:00
|
|
|
|
def save_tensorized_model(
|
|
|
|
|
self,
|
|
|
|
|
tensorizer_config: TensorizerConfig,
|
|
|
|
|
) -> None:
|
|
|
|
|
from vllm.model_executor.model_loader.loader import TensorizerLoader
|
|
|
|
|
TensorizerLoader.save_model(
|
|
|
|
|
self.model,
|
|
|
|
|
tensorizer_config=tensorizer_config,
|
|
|
|
|
)
|
|
|
|
|
|
2024-03-21 06:46:05 +09:00
|
|
|
|
def get_max_block_per_batch(self) -> int:
|
|
|
|
|
block_size = self.block_size
|
2024-05-04 02:20:12 +09:00
|
|
|
|
return (self.max_seq_len_to_capture + block_size - 1) // block_size
|
2023-12-16 21:12:08 -08:00
|
|
|
|
|
2024-06-25 20:30:03 -07:00
|
|
|
|
def _prepare_model_input_tensors(
|
2023-11-29 22:16:37 -08:00
|
|
|
|
self,
|
|
|
|
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
2024-07-03 02:11:29 +03:00
|
|
|
|
finished_requests_ids: Optional[List[str]] = None
|
2024-06-25 20:30:03 -07:00
|
|
|
|
) -> TModelInputForGPU:
|
|
|
|
|
"""Helper method to prepare the model input based on a given sequence
|
|
|
|
|
group. Prepares metadata needed for the base model forward pass but not
|
|
|
|
|
metadata for possible additional steps, e.g., sampling.
|
2024-05-15 14:00:10 +09:00
|
|
|
|
|
|
|
|
|
The API assumes seq_group_metadata_list is sorted by prefill -> decode.
|
|
|
|
|
|
|
|
|
|
The result tensors and data structure also batches input in prefill
|
|
|
|
|
-> decode order. For example,
|
|
|
|
|
|
|
|
|
|
- input_tokens[:num_prefill_tokens] contains prefill tokens.
|
|
|
|
|
- input_tokens[num_prefill_tokens:] contains decode tokens.
|
|
|
|
|
|
|
|
|
|
If cuda graph is required, this API automatically pads inputs.
|
|
|
|
|
"""
|
2024-03-21 06:46:05 +09:00
|
|
|
|
input_tokens: List[int] = []
|
|
|
|
|
input_positions: List[int] = []
|
|
|
|
|
slot_mapping: List[int] = []
|
2024-01-24 00:26:37 +01:00
|
|
|
|
lora_index_mapping: List[int] = []
|
|
|
|
|
lora_prompt_mapping: List[int] = []
|
|
|
|
|
lora_requests: Set[LoRARequest] = set()
|
2023-11-29 22:16:37 -08:00
|
|
|
|
|
2024-05-04 02:20:12 +09:00
|
|
|
|
seq_lens: List[int] = []
|
2024-05-15 14:00:10 +09:00
|
|
|
|
prefill_seq_lens: List[int] = []
|
|
|
|
|
decode_seq_lens: List[int] = []
|
2024-01-17 16:32:10 -08:00
|
|
|
|
context_lens: List[int] = []
|
2024-05-04 02:20:12 +09:00
|
|
|
|
query_lens: List[int] = []
|
2023-11-29 22:16:37 -08:00
|
|
|
|
block_tables: List[List[int]] = []
|
2024-07-03 11:34:00 +08:00
|
|
|
|
multi_modal_inputs_list: List[MultiModalInputs] = []
|
2024-07-03 02:11:29 +03:00
|
|
|
|
request_ids_to_seq_ids: Dict[str, List[int]] = defaultdict(list)
|
2024-05-15 14:00:10 +09:00
|
|
|
|
decode_only = True
|
|
|
|
|
num_prefills = 0
|
|
|
|
|
num_prefill_tokens = 0
|
|
|
|
|
num_decode_tokens = 0
|
2023-11-29 22:16:37 -08:00
|
|
|
|
|
2024-05-03 15:51:27 -07:00
|
|
|
|
# The following fields are only for flashinfer
|
|
|
|
|
# Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
|
|
|
|
|
# for the precise definition of the following fields.
|
|
|
|
|
# An example:
|
|
|
|
|
# request 1, page indices [0, 5, 8]
|
|
|
|
|
# request 2, page indices [1, 6, 7]
|
|
|
|
|
# request 3, page indices [3, 4]
|
|
|
|
|
# paged_kv_indices is a concatenation of page indices of all requests:
|
|
|
|
|
# [0, 5, 8, 1, 6, 7, 3, 4]
|
|
|
|
|
# paged_kv_indptr is used to index into paged_kv_indices:
|
|
|
|
|
# [0, 3, 6, 8]
|
|
|
|
|
paged_kv_indices: List[int] = []
|
|
|
|
|
# 0 at the beginning of paged_kv_indptr indicates the start of the
|
|
|
|
|
# first request’s page indices in the paged_kv_indices list.
|
|
|
|
|
paged_kv_indptr: List[int] = [0]
|
|
|
|
|
# paged_kv_last_page_len is the length of the last page of each request
|
|
|
|
|
paged_kv_last_page_len: List[int] = []
|
|
|
|
|
|
2024-04-11 09:56:48 +09:00
|
|
|
|
if len(seq_group_metadata_list) == 0:
|
2024-06-25 20:30:03 -07:00
|
|
|
|
return self._model_input_cls()
|
2024-04-11 09:56:48 +09:00
|
|
|
|
|
2024-05-27 19:07:07 -07:00
|
|
|
|
if self.sliding_window is not None:
|
|
|
|
|
sliding_window_blocks = (self.sliding_window + self.block_size -
|
|
|
|
|
1) // self.block_size
|
|
|
|
|
block_aligned_sliding_window = \
|
|
|
|
|
sliding_window_blocks * self.block_size
|
|
|
|
|
|
2023-11-29 22:16:37 -08:00
|
|
|
|
for seq_group_metadata in seq_group_metadata_list:
|
|
|
|
|
seq_ids = list(seq_group_metadata.seq_data.keys())
|
2024-05-15 14:00:10 +09:00
|
|
|
|
is_prompt = seq_group_metadata.is_prompt
|
2024-01-24 00:26:37 +01:00
|
|
|
|
|
2023-11-29 22:16:37 -08:00
|
|
|
|
for seq_id in seq_ids:
|
2024-05-15 14:00:10 +09:00
|
|
|
|
computed_block_nums = seq_group_metadata.computed_block_nums
|
|
|
|
|
if (self.scheduler_config is not None
|
|
|
|
|
and self.scheduler_config.chunked_prefill_enabled
|
|
|
|
|
and not (computed_block_nums is None
|
|
|
|
|
or computed_block_nums == [])):
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
"chunked prefill cannot be used with prefix caching "
|
|
|
|
|
"now.")
|
|
|
|
|
|
2023-11-29 22:16:37 -08:00
|
|
|
|
seq_data = seq_group_metadata.seq_data[seq_id]
|
2024-05-15 14:00:10 +09:00
|
|
|
|
if is_prompt:
|
|
|
|
|
context_len = seq_data.get_num_computed_tokens()
|
|
|
|
|
else:
|
|
|
|
|
# get_num_computed_tokens is incorrect for spec decoding.
|
|
|
|
|
# So, we should have a special logic here.
|
|
|
|
|
# TODO(sang): Fix it.
|
|
|
|
|
context_len = seq_data.get_len() - 1
|
|
|
|
|
|
|
|
|
|
seq_len = min(
|
|
|
|
|
seq_data.get_len(),
|
|
|
|
|
context_len + seq_group_metadata.token_chunk_size)
|
|
|
|
|
if is_prompt:
|
|
|
|
|
tokens = seq_data.get_token_ids()[context_len:seq_len]
|
|
|
|
|
else:
|
|
|
|
|
# Optimization. get_token_ids requires the entire copy of
|
|
|
|
|
# tokens.
|
|
|
|
|
tokens = [seq_data.get_last_token_id()]
|
2023-11-29 22:16:37 -08:00
|
|
|
|
|
2024-05-15 14:00:10 +09:00
|
|
|
|
# Prefix cache was hit.
|
|
|
|
|
# Prefix is not supported with sliding_window
|
|
|
|
|
prefix_cache_hit = (computed_block_nums is not None
|
|
|
|
|
and len(computed_block_nums) > 0
|
|
|
|
|
and self.sliding_window is None
|
|
|
|
|
and is_prompt)
|
|
|
|
|
|
2024-05-27 19:07:07 -07:00
|
|
|
|
# These are seq_len/context_len capped to the sliding window.
|
|
|
|
|
# They are passed to decode kernel.
|
|
|
|
|
# We still need original seq_len/context_len to compute slot
|
|
|
|
|
# mapping (and input position) below.
|
|
|
|
|
curr_sliding_window_blocks = None
|
|
|
|
|
sliding_seq_len = seq_len
|
|
|
|
|
sliding_context_len = context_len
|
|
|
|
|
|
|
|
|
|
# TODO(sang): This is a hack to make sliding window work with
|
|
|
|
|
# paged attn. We can remove it if we make paged attn kernel
|
|
|
|
|
# to properly handle slinding window attn.
|
|
|
|
|
if (self.sliding_window is not None and not is_prompt):
|
|
|
|
|
curr_sliding_window_blocks = sliding_window_blocks
|
|
|
|
|
if self.scheduler_config.use_v2_block_manager:
|
|
|
|
|
# number of elements in last block
|
|
|
|
|
suff_len = seq_len % self.block_size
|
|
|
|
|
sliding_seq_len = min(
|
|
|
|
|
seq_len, block_aligned_sliding_window + suff_len)
|
|
|
|
|
if suff_len > 0:
|
|
|
|
|
curr_sliding_window_blocks += 1
|
|
|
|
|
else:
|
|
|
|
|
sliding_seq_len = min(seq_len, self.sliding_window)
|
|
|
|
|
sliding_context_len = sliding_seq_len - 1
|
|
|
|
|
|
2024-05-15 14:00:10 +09:00
|
|
|
|
# TODO(sang): Combine chunked prefill and prefix caching by
|
|
|
|
|
# only allowing multiple of block_size chunk size.
|
|
|
|
|
# NOTE: This only works for oooooooxxx style attention.
|
|
|
|
|
if prefix_cache_hit:
|
|
|
|
|
assert computed_block_nums is not None
|
|
|
|
|
context_len = len(computed_block_nums) * self.block_size
|
|
|
|
|
tokens = tokens[context_len:]
|
2024-05-27 19:07:07 -07:00
|
|
|
|
|
|
|
|
|
# need to think what to set it to when we have both sliding
|
|
|
|
|
# window and prefix caching...
|
|
|
|
|
assert self.sliding_window is None, \
|
|
|
|
|
"Prefix caching is not supported with sliding window"
|
|
|
|
|
sliding_context_len = context_len
|
|
|
|
|
|
2024-05-15 14:00:10 +09:00
|
|
|
|
if self.attn_backend.get_name() == "flash-attn":
|
|
|
|
|
# NOTE(woosuk): For flash-attn, the block table should
|
|
|
|
|
# include the entries for the incoming prefill tokens.
|
|
|
|
|
# TODO(woosuk): This is a temporary fix. We should
|
|
|
|
|
# provide a unified interface for different backends.
|
|
|
|
|
block_table = seq_group_metadata.block_tables[seq_id]
|
|
|
|
|
else:
|
|
|
|
|
block_table = computed_block_nums
|
|
|
|
|
elif (self.scheduler_config.chunked_prefill_enabled
|
|
|
|
|
or not is_prompt):
|
|
|
|
|
if seq_group_metadata.block_tables is not None:
|
|
|
|
|
# chunked prefill or decode
|
|
|
|
|
block_table = seq_group_metadata.block_tables[seq_id]
|
2024-05-27 19:07:07 -07:00
|
|
|
|
if curr_sliding_window_blocks is not None:
|
|
|
|
|
block_table = block_table[
|
|
|
|
|
-curr_sliding_window_blocks:]
|
2024-05-15 14:00:10 +09:00
|
|
|
|
else:
|
|
|
|
|
# Only happens when memory profiling runs.
|
|
|
|
|
block_table = []
|
|
|
|
|
else:
|
|
|
|
|
# Prefill without chunked prefill or memory profiling.
|
|
|
|
|
block_table = []
|
|
|
|
|
block_tables.append(block_table)
|
|
|
|
|
|
2024-05-27 19:07:07 -07:00
|
|
|
|
seq_lens.append(sliding_seq_len)
|
|
|
|
|
context_lens.append(sliding_context_len)
|
|
|
|
|
query_len = sliding_seq_len - sliding_context_len
|
2024-05-15 14:00:10 +09:00
|
|
|
|
query_lens.append(query_len)
|
|
|
|
|
input_tokens.extend(tokens)
|
|
|
|
|
input_positions.extend(list(range(context_len, seq_len)))
|
|
|
|
|
lora_id = seq_group_metadata.lora_int_id
|
|
|
|
|
|
|
|
|
|
if is_prompt:
|
|
|
|
|
assert len(seq_ids) == 1
|
|
|
|
|
num_prefills += 1
|
|
|
|
|
num_prefill_tokens += len(tokens)
|
|
|
|
|
decode_only = False
|
|
|
|
|
prefill_seq_lens.append(seq_len)
|
|
|
|
|
else:
|
|
|
|
|
assert query_len == 1, (
|
|
|
|
|
"seq_len: {}, context_len: {}, query_len: {}".format(
|
|
|
|
|
seq_len, context_len, query_len))
|
|
|
|
|
num_decode_tokens += query_len
|
2024-05-27 19:07:07 -07:00
|
|
|
|
decode_seq_lens.append(sliding_seq_len)
|
2024-05-15 14:00:10 +09:00
|
|
|
|
|
|
|
|
|
if lora_id > 0:
|
|
|
|
|
lora_requests.add(seq_group_metadata.lora_request)
|
|
|
|
|
|
2024-05-27 19:07:07 -07:00
|
|
|
|
lora_index_mapping += [lora_id] * query_len
|
2024-05-15 14:00:10 +09:00
|
|
|
|
lora_prompt_mapping.extend(
|
|
|
|
|
[lora_id] *
|
2024-05-27 19:07:07 -07:00
|
|
|
|
(query_len if seq_group_metadata.sampling_params
|
2024-05-15 14:00:10 +09:00
|
|
|
|
and seq_group_metadata.sampling_params.prompt_logprobs
|
2024-06-04 09:59:30 +09:00
|
|
|
|
is not None else 1))
|
2024-05-15 14:00:10 +09:00
|
|
|
|
|
2024-06-03 13:56:41 +08:00
|
|
|
|
mm_data = seq_group_metadata.multi_modal_data
|
2024-07-02 00:57:09 -07:00
|
|
|
|
if mm_data:
|
2024-06-03 13:56:41 +08:00
|
|
|
|
# Process multi-modal data
|
2024-06-28 20:09:56 +08:00
|
|
|
|
mm_kwargs = self.multi_modal_input_mapper(mm_data)
|
2024-07-03 11:34:00 +08:00
|
|
|
|
multi_modal_inputs_list.append(mm_kwargs)
|
2024-05-15 14:00:10 +09:00
|
|
|
|
|
2024-06-28 15:28:49 -07:00
|
|
|
|
is_profile_run = _is_block_tables_empty(
|
|
|
|
|
seq_group_metadata.block_tables)
|
|
|
|
|
if is_profile_run:
|
2024-05-15 14:00:10 +09:00
|
|
|
|
# During memory profiling, the block tables are not
|
|
|
|
|
# initialized yet. In this case, we just use a dummy
|
|
|
|
|
# slot mapping.
|
|
|
|
|
# In embeddings, the block tables are {seq_id: None}.
|
|
|
|
|
slot_mapping.extend([_PAD_SLOT_ID] * seq_len)
|
|
|
|
|
continue
|
2023-12-13 12:28:13 -08:00
|
|
|
|
|
2024-05-15 14:00:10 +09:00
|
|
|
|
# Compute the slot mapping.
|
2023-11-29 22:16:37 -08:00
|
|
|
|
block_table = seq_group_metadata.block_tables[seq_id]
|
|
|
|
|
|
2024-05-15 14:00:10 +09:00
|
|
|
|
# Mask the [0, start_idx) tokens of the prompt with
|
|
|
|
|
# _PAD_SLOT_ID, where start_idx is max(0, seq_len -
|
|
|
|
|
# sliding_window). For example, if the prompt len is 10,
|
|
|
|
|
# sliding window is 8, and block size is 4, the first two
|
|
|
|
|
# tokens are masked and the slot mapping will be
|
|
|
|
|
# [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
|
|
|
|
|
start_idx = 0
|
2023-11-29 22:16:37 -08:00
|
|
|
|
if self.sliding_window is not None:
|
2024-05-15 14:00:10 +09:00
|
|
|
|
if is_prompt:
|
2024-05-27 19:07:07 -07:00
|
|
|
|
assert self.scheduler_config.use_v2_block_manager \
|
|
|
|
|
or context_len == 0, (
|
2024-05-15 14:00:10 +09:00
|
|
|
|
"Prefix caching is currently not supported with "
|
2024-05-27 19:07:07 -07:00
|
|
|
|
"sliding window attention in V1 block manager")
|
2024-05-15 14:00:10 +09:00
|
|
|
|
# It is an optimization. When it is decoding, it is always
|
|
|
|
|
# 0. When prefill, we use it to not write slots to kv cache
|
|
|
|
|
# to save memory.
|
|
|
|
|
start_idx = max(0, query_len - self.sliding_window)
|
|
|
|
|
|
|
|
|
|
for i in range(context_len, seq_len):
|
|
|
|
|
if i < start_idx:
|
|
|
|
|
slot_mapping.append(_PAD_SLOT_ID)
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
block_number = block_table[i // self.block_size]
|
|
|
|
|
block_offset = i % self.block_size
|
|
|
|
|
slot = block_number * self.block_size + block_offset
|
|
|
|
|
slot_mapping.append(slot)
|
2023-11-29 22:16:37 -08:00
|
|
|
|
|
2024-06-28 15:28:49 -07:00
|
|
|
|
# Prepare input tensors for flashinfer
|
|
|
|
|
if self.attn_backend.get_name() == "flashinfer":
|
|
|
|
|
seq_len = seq_data.get_len()
|
|
|
|
|
# Get the number of valid blocks based on sequence length.
|
|
|
|
|
# If seq_len = 16, block_size = 16,
|
|
|
|
|
# block_table_bound is 1 with 1 valid block.
|
|
|
|
|
# If seq_len = 15, block_size = 16,
|
|
|
|
|
# block_table_bound is 0 + 1 with 1 valid block.
|
|
|
|
|
block_table_bound = seq_len // self.block_size + 1 \
|
|
|
|
|
if seq_len % self.block_size != 0 \
|
|
|
|
|
else seq_len // self.block_size
|
|
|
|
|
|
|
|
|
|
paged_kv_indices.extend(block_table[:block_table_bound])
|
|
|
|
|
paged_kv_indptr.append(paged_kv_indptr[-1] +
|
|
|
|
|
block_table_bound)
|
|
|
|
|
|
|
|
|
|
last_page_len = seq_len % self.block_size
|
|
|
|
|
if last_page_len == 0:
|
|
|
|
|
last_page_len = self.block_size
|
|
|
|
|
paged_kv_last_page_len.append(last_page_len)
|
|
|
|
|
|
2024-05-15 14:00:10 +09:00
|
|
|
|
batch_size = len(input_tokens)
|
|
|
|
|
max_query_len = max(query_lens)
|
|
|
|
|
max_prefill_seq_len = max(prefill_seq_lens, default=0)
|
|
|
|
|
max_decode_seq_len = max(decode_seq_lens, default=0)
|
2024-05-03 15:51:27 -07:00
|
|
|
|
|
2024-05-15 14:00:10 +09:00
|
|
|
|
# If cuda graph can be used, pad tensors accordingly.
|
2024-03-21 06:46:05 +09:00
|
|
|
|
# See `capture_model` API for more details.
|
2024-05-15 14:00:10 +09:00
|
|
|
|
# vLLM uses cuda graph only for decoding requests.
|
|
|
|
|
use_captured_graph = (
|
|
|
|
|
decode_only and not self.model_config.enforce_eager
|
|
|
|
|
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
|
|
|
|
|
and max_decode_seq_len <= self.max_seq_len_to_capture)
|
2023-12-16 21:12:08 -08:00
|
|
|
|
if use_captured_graph:
|
|
|
|
|
graph_batch_size = _get_graph_batch_size(batch_size)
|
|
|
|
|
assert graph_batch_size >= batch_size
|
|
|
|
|
for _ in range(graph_batch_size - batch_size):
|
2024-03-21 06:46:05 +09:00
|
|
|
|
input_tokens.append(0)
|
|
|
|
|
input_positions.append(0)
|
|
|
|
|
slot_mapping.append(_PAD_SLOT_ID)
|
2024-05-04 02:20:12 +09:00
|
|
|
|
seq_lens.append(1)
|
2023-12-16 21:12:08 -08:00
|
|
|
|
block_tables.append([])
|
2024-03-21 06:46:05 +09:00
|
|
|
|
lora_index_mapping.append(0)
|
2024-06-28 15:28:49 -07:00
|
|
|
|
|
|
|
|
|
if self.attn_backend.get_name() == "flashinfer":
|
|
|
|
|
last_paged_kv_indptr = paged_kv_indptr[-1]
|
|
|
|
|
paged_kv_indptr.append(last_paged_kv_indptr)
|
|
|
|
|
paged_kv_last_page_len.append(0)
|
|
|
|
|
|
2023-12-16 21:12:08 -08:00
|
|
|
|
batch_size = graph_batch_size
|
2024-05-15 14:00:10 +09:00
|
|
|
|
num_decode_tokens = batch_size
|
2023-12-16 21:12:08 -08:00
|
|
|
|
|
|
|
|
|
if use_captured_graph:
|
|
|
|
|
# The shape of graph_block_tables is
|
|
|
|
|
# [max batch size, max context len // block size].
|
|
|
|
|
input_block_tables = self.graph_block_tables[:batch_size]
|
|
|
|
|
for i, block_table in enumerate(block_tables):
|
|
|
|
|
if block_table:
|
|
|
|
|
input_block_tables[i, :len(block_table)] = block_table
|
2024-02-02 07:46:39 +08:00
|
|
|
|
block_tables = torch.tensor(input_block_tables, device=self.device)
|
2023-12-16 21:12:08 -08:00
|
|
|
|
else:
|
2024-01-21 16:31:47 -08:00
|
|
|
|
max_block_table_len = max(
|
|
|
|
|
len(block_table) for block_table in block_tables)
|
2024-03-21 18:22:17 -07:00
|
|
|
|
block_tables = make_tensor_with_pad(
|
2023-12-16 21:12:08 -08:00
|
|
|
|
block_tables,
|
2024-01-08 10:11:06 -08:00
|
|
|
|
max_len=max_block_table_len,
|
2023-12-16 21:12:08 -08:00
|
|
|
|
pad=0,
|
|
|
|
|
dtype=torch.int,
|
2024-02-02 07:46:39 +08:00
|
|
|
|
device=self.device,
|
2023-12-16 21:12:08 -08:00
|
|
|
|
)
|
2024-05-15 14:00:10 +09:00
|
|
|
|
assert max_query_len > 0, ("query_lens: {}".format(query_lens))
|
|
|
|
|
|
2024-06-28 15:28:49 -07:00
|
|
|
|
context_lens_tensor = torch.tensor(context_lens,
|
|
|
|
|
dtype=torch.int,
|
|
|
|
|
device=self.device)
|
|
|
|
|
|
2024-05-15 14:00:10 +09:00
|
|
|
|
seq_lens_tensor = torch.tensor(seq_lens,
|
|
|
|
|
dtype=torch.int,
|
|
|
|
|
device=self.device)
|
2024-06-28 15:28:49 -07:00
|
|
|
|
query_lens_tensor = torch.tensor(query_lens,
|
|
|
|
|
dtype=torch.long,
|
|
|
|
|
device=self.device)
|
|
|
|
|
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
|
|
|
|
|
dtype=torch.int32,
|
|
|
|
|
device=self.device)
|
2024-05-15 14:00:10 +09:00
|
|
|
|
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
|
|
|
|
|
dtype=torch.int32,
|
|
|
|
|
device=self.device)
|
|
|
|
|
|
|
|
|
|
torch.cumsum(seq_lens_tensor,
|
|
|
|
|
dim=0,
|
|
|
|
|
dtype=seq_start_loc.dtype,
|
|
|
|
|
out=seq_start_loc[1:])
|
2024-06-28 15:28:49 -07:00
|
|
|
|
torch.cumsum(query_lens_tensor,
|
|
|
|
|
dim=0,
|
|
|
|
|
dtype=query_start_loc.dtype,
|
|
|
|
|
out=query_start_loc[1:])
|
2024-05-15 14:00:10 +09:00
|
|
|
|
|
|
|
|
|
input_tokens_tensor = torch.tensor(input_tokens,
|
|
|
|
|
dtype=torch.long,
|
|
|
|
|
device=self.device)
|
|
|
|
|
input_positions_tensor = torch.tensor(input_positions,
|
|
|
|
|
dtype=torch.long,
|
|
|
|
|
device=self.device)
|
|
|
|
|
slot_mapping_tensor = torch.tensor(slot_mapping,
|
|
|
|
|
dtype=torch.long,
|
|
|
|
|
device=self.device)
|
2023-11-29 22:16:37 -08:00
|
|
|
|
|
2024-05-08 09:59:31 -07:00
|
|
|
|
if self.attn_backend.get_name() == "flashinfer":
|
2024-06-28 15:28:49 -07:00
|
|
|
|
if len(paged_kv_indptr) > 0:
|
|
|
|
|
paged_kv_indices_tensor = torch.tensor(paged_kv_indices,
|
|
|
|
|
device='cpu',
|
|
|
|
|
dtype=torch.int)
|
|
|
|
|
paged_kv_indptr_tensor = torch.tensor(paged_kv_indptr,
|
|
|
|
|
device='cpu',
|
|
|
|
|
dtype=torch.int)
|
|
|
|
|
paged_kv_last_page_len_tensor = torch.tensor(
|
|
|
|
|
paged_kv_last_page_len, device='cpu', dtype=torch.int)
|
|
|
|
|
else:
|
|
|
|
|
paged_kv_indices_tensor = None
|
|
|
|
|
paged_kv_indptr_tensor = None
|
|
|
|
|
paged_kv_last_page_len_tensor = None
|
|
|
|
|
|
2024-05-03 15:51:27 -07:00
|
|
|
|
kv_cache_dtype = get_kv_cache_torch_dtype(self.kv_cache_dtype,
|
|
|
|
|
self.model_config.dtype)
|
2024-06-28 15:28:49 -07:00
|
|
|
|
|
2024-05-03 15:51:27 -07:00
|
|
|
|
attn_metadata = self.attn_backend.make_metadata(
|
2024-05-15 14:00:10 +09:00
|
|
|
|
num_prefills=num_prefills,
|
|
|
|
|
slot_mapping=slot_mapping_tensor,
|
|
|
|
|
num_prefill_tokens=num_prefill_tokens,
|
|
|
|
|
num_decode_tokens=num_decode_tokens,
|
|
|
|
|
max_prefill_seq_len=max_prefill_seq_len,
|
|
|
|
|
block_tables=block_tables,
|
|
|
|
|
paged_kv_indptr=paged_kv_indptr_tensor,
|
|
|
|
|
paged_kv_indices=paged_kv_indices_tensor,
|
|
|
|
|
paged_kv_last_page_len=paged_kv_last_page_len_tensor,
|
2024-05-03 15:51:27 -07:00
|
|
|
|
num_qo_heads=self.model_config.get_num_attention_heads(
|
|
|
|
|
self.parallel_config),
|
|
|
|
|
num_kv_heads=self.model_config.get_num_kv_heads(
|
|
|
|
|
self.parallel_config),
|
|
|
|
|
head_dim=self.model_config.get_head_size(),
|
2024-06-28 15:28:49 -07:00
|
|
|
|
page_size=self.block_size,
|
2024-05-15 14:00:10 +09:00
|
|
|
|
seq_start_loc=seq_start_loc,
|
2024-06-28 15:28:49 -07:00
|
|
|
|
query_start_loc=query_start_loc,
|
|
|
|
|
device=self.device,
|
|
|
|
|
data_type=kv_cache_dtype,
|
|
|
|
|
use_cuda_graph=use_captured_graph)
|
2024-06-10 19:29:02 -07:00
|
|
|
|
|
2024-06-28 15:28:49 -07:00
|
|
|
|
else:
|
2024-05-03 15:51:27 -07:00
|
|
|
|
attn_metadata = self.attn_backend.make_metadata(
|
2024-05-15 14:00:10 +09:00
|
|
|
|
num_prefills=num_prefills,
|
|
|
|
|
slot_mapping=slot_mapping_tensor,
|
|
|
|
|
num_prefill_tokens=num_prefill_tokens,
|
|
|
|
|
num_decode_tokens=num_decode_tokens,
|
|
|
|
|
seq_lens=seq_lens,
|
2024-05-03 15:51:27 -07:00
|
|
|
|
seq_lens_tensor=seq_lens_tensor,
|
2024-05-15 14:00:10 +09:00
|
|
|
|
max_query_len=max_query_len,
|
|
|
|
|
max_prefill_seq_len=max_prefill_seq_len,
|
|
|
|
|
max_decode_seq_len=max_decode_seq_len,
|
|
|
|
|
query_start_loc=query_start_loc,
|
|
|
|
|
seq_start_loc=seq_start_loc,
|
|
|
|
|
context_lens_tensor=context_lens_tensor,
|
2024-05-03 15:51:27 -07:00
|
|
|
|
block_tables=block_tables,
|
|
|
|
|
use_cuda_graph=use_captured_graph,
|
|
|
|
|
)
|
2024-05-15 14:00:10 +09:00
|
|
|
|
|
|
|
|
|
if self.lora_config:
|
|
|
|
|
lora_mapping = LoRAMapping(
|
|
|
|
|
lora_index_mapping,
|
|
|
|
|
lora_prompt_mapping,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
lora_mapping = None
|
|
|
|
|
|
2024-07-03 11:34:00 +08:00
|
|
|
|
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
|
|
|
|
|
device=self.device)
|
2024-07-03 02:11:29 +03:00
|
|
|
|
request_ids_to_seq_ids = {
|
|
|
|
|
seq_group_metadata.request_id:
|
|
|
|
|
list(seq_group_metadata.seq_data.keys())
|
|
|
|
|
for seq_group_metadata in seq_group_metadata_list
|
|
|
|
|
}
|
2024-06-25 20:30:03 -07:00
|
|
|
|
return self._model_input_cls(
|
2024-05-15 14:00:10 +09:00
|
|
|
|
input_tokens=input_tokens_tensor,
|
|
|
|
|
input_positions=input_positions_tensor,
|
2024-04-11 09:56:48 +09:00
|
|
|
|
attn_metadata=attn_metadata,
|
2024-05-15 14:00:10 +09:00
|
|
|
|
seq_lens=seq_lens,
|
|
|
|
|
query_lens=query_lens,
|
|
|
|
|
lora_mapping=lora_mapping,
|
2024-04-11 09:56:48 +09:00
|
|
|
|
lora_requests=lora_requests,
|
2024-06-03 13:56:41 +08:00
|
|
|
|
multi_modal_kwargs=multi_modal_kwargs,
|
2024-07-03 02:11:29 +03:00
|
|
|
|
request_ids_to_seq_ids=request_ids_to_seq_ids,
|
|
|
|
|
finished_requests_ids=finished_requests_ids)
|
2024-04-26 22:02:02 +09:00
|
|
|
|
|
2023-11-29 22:16:37 -08:00
|
|
|
|
@torch.inference_mode()
|
|
|
|
|
def profile_run(self) -> None:
|
|
|
|
|
# Enable top-k sampling to reflect the accurate memory usage.
|
2024-03-08 23:32:46 -08:00
|
|
|
|
sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
|
2023-11-29 22:16:37 -08:00
|
|
|
|
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
|
|
|
|
|
max_num_seqs = self.scheduler_config.max_num_seqs
|
2024-01-24 00:26:37 +01:00
|
|
|
|
# This represents the maximum number of different requests
|
|
|
|
|
# that will have unique loras, an therefore the max amount of memory
|
|
|
|
|
# consumption create dummy lora request copies from the lora request
|
|
|
|
|
# passed in, which contains a lora from the lora warmup path.
|
2024-06-15 12:45:31 +08:00
|
|
|
|
dummy_lora_requests: List[LoRARequest] = []
|
|
|
|
|
dummy_lora_requests_per_seq: List[LoRARequest] = []
|
2024-01-24 00:26:37 +01:00
|
|
|
|
if self.lora_config:
|
2024-05-09 09:04:59 -07:00
|
|
|
|
assert self.lora_manager is not None
|
2024-05-08 10:33:18 -07:00
|
|
|
|
with self.lora_manager.dummy_lora_cache():
|
|
|
|
|
for idx in range(self.lora_config.max_loras):
|
|
|
|
|
lora_id = idx + 1
|
|
|
|
|
dummy_lora_request = LoRARequest(
|
|
|
|
|
lora_name=f"warmup_{lora_id}",
|
|
|
|
|
lora_int_id=lora_id,
|
|
|
|
|
lora_local_path="/not/a/real/path",
|
|
|
|
|
)
|
|
|
|
|
self.lora_manager.add_dummy_lora(dummy_lora_request,
|
|
|
|
|
rank=LORA_WARMUP_RANK)
|
|
|
|
|
dummy_lora_requests.append(dummy_lora_request)
|
|
|
|
|
dummy_lora_requests_per_seq = [
|
|
|
|
|
dummy_lora_requests[idx % len(dummy_lora_requests)]
|
|
|
|
|
for idx in range(max_num_seqs)
|
|
|
|
|
]
|
2024-01-24 00:26:37 +01:00
|
|
|
|
|
2023-11-29 22:16:37 -08:00
|
|
|
|
# Profile memory usage with max_num_sequences sequences and the total
|
|
|
|
|
# number of tokens equal to max_num_batched_tokens.
|
|
|
|
|
seqs: List[SequenceGroupMetadata] = []
|
2024-03-25 14:16:30 -07:00
|
|
|
|
# Additional GPU memory may be needed for vision encoding, which needs
|
|
|
|
|
# to be accounted for when calculating the GPU blocks for
|
|
|
|
|
# vLLM blocker manager.
|
|
|
|
|
# To exercise the worst scenario for GPU memory consumption,
|
|
|
|
|
# the number of seqs (batch_size) is chosen to maximize the number
|
|
|
|
|
# of images processed.
|
2024-06-03 13:56:41 +08:00
|
|
|
|
model_config = self.model_config
|
|
|
|
|
|
2024-07-03 15:14:16 -07:00
|
|
|
|
if supports_vision(self.model):
|
|
|
|
|
max_num_seqs = max(
|
|
|
|
|
1,
|
|
|
|
|
min(
|
|
|
|
|
max_num_seqs,
|
|
|
|
|
int(max_num_batched_tokens /
|
|
|
|
|
MULTIMODAL_REGISTRY.get_num_input_tokens())))
|
2024-07-02 10:58:08 -07:00
|
|
|
|
batch_size = 0
|
2023-11-29 22:16:37 -08:00
|
|
|
|
for group_id in range(max_num_seqs):
|
|
|
|
|
seq_len = (max_num_batched_tokens // max_num_seqs +
|
|
|
|
|
(group_id < max_num_batched_tokens % max_num_seqs))
|
2024-07-02 10:58:08 -07:00
|
|
|
|
batch_size += seq_len
|
2024-06-03 13:56:41 +08:00
|
|
|
|
|
2024-06-28 20:09:56 +08:00
|
|
|
|
seq_data, dummy_multi_modal_data = INPUT_REGISTRY \
|
|
|
|
|
.dummy_data_for_profiling(model_config, seq_len)
|
2024-07-03 11:34:00 +08:00
|
|
|
|
|
|
|
|
|
# Having more tokens is over-conservative but otherwise fine
|
|
|
|
|
assert len(seq_data.prompt_token_ids) >= seq_len, (
|
|
|
|
|
f"Expected at least {seq_len} dummy tokens for profiling, "
|
|
|
|
|
f"but got: {len(seq_data.prompt_token_ids)}")
|
2024-06-03 13:56:41 +08:00
|
|
|
|
|
2023-11-29 22:16:37 -08:00
|
|
|
|
seq = SequenceGroupMetadata(
|
|
|
|
|
request_id=str(group_id),
|
|
|
|
|
is_prompt=True,
|
|
|
|
|
seq_data={group_id: seq_data},
|
|
|
|
|
sampling_params=sampling_params,
|
|
|
|
|
block_tables=None,
|
2024-01-24 00:26:37 +01:00
|
|
|
|
lora_request=dummy_lora_requests_per_seq[group_id]
|
|
|
|
|
if dummy_lora_requests_per_seq else None,
|
2024-06-03 13:56:41 +08:00
|
|
|
|
multi_modal_data=dummy_multi_modal_data,
|
2023-11-29 22:16:37 -08:00
|
|
|
|
)
|
|
|
|
|
seqs.append(seq)
|
|
|
|
|
|
|
|
|
|
# Run the model with the dummy inputs.
|
|
|
|
|
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
2024-03-24 21:39:33 -07:00
|
|
|
|
kv_caches = [None] * num_layers
|
2024-07-03 02:11:29 +03:00
|
|
|
|
finished_requests_ids = [seq.request_id for seq in seqs]
|
|
|
|
|
model_input = self.prepare_model_input(
|
|
|
|
|
seqs, finished_requests_ids=finished_requests_ids)
|
2024-07-02 10:58:08 -07:00
|
|
|
|
intermediate_tensors = None
|
|
|
|
|
if not get_pp_group().is_first_rank:
|
|
|
|
|
intermediate_tensors = self.model.make_empty_intermediate_tensors(
|
|
|
|
|
batch_size=batch_size,
|
|
|
|
|
dtype=self.model_config.dtype,
|
|
|
|
|
device=self.device)
|
|
|
|
|
self.execute_model(model_input, kv_caches, intermediate_tensors)
|
2023-12-16 21:12:08 -08:00
|
|
|
|
torch.cuda.synchronize()
|
2023-11-29 22:16:37 -08:00
|
|
|
|
return
|
|
|
|
|
|
2024-04-26 04:13:50 +09:00
|
|
|
|
def remove_all_loras(self):
|
2024-01-24 00:26:37 +01:00
|
|
|
|
if not self.lora_manager:
|
|
|
|
|
raise RuntimeError("LoRA is not enabled.")
|
2024-04-26 04:13:50 +09:00
|
|
|
|
self.lora_manager.remove_all_loras()
|
2024-01-24 00:26:37 +01:00
|
|
|
|
|
2024-04-18 09:28:43 +09:00
|
|
|
|
def set_active_loras(self, lora_requests: Set[LoRARequest],
|
2024-01-24 00:26:37 +01:00
|
|
|
|
lora_mapping: LoRAMapping) -> None:
|
|
|
|
|
if not self.lora_manager:
|
|
|
|
|
raise RuntimeError("LoRA is not enabled.")
|
|
|
|
|
self.lora_manager.set_active_loras(lora_requests, lora_mapping)
|
|
|
|
|
|
|
|
|
|
def add_lora(self, lora_request: LoRARequest) -> bool:
|
|
|
|
|
if not self.lora_manager:
|
|
|
|
|
raise RuntimeError("LoRA is not enabled.")
|
|
|
|
|
return self.lora_manager.add_lora(lora_request)
|
|
|
|
|
|
|
|
|
|
def remove_lora(self, lora_id: int) -> bool:
|
|
|
|
|
if not self.lora_manager:
|
|
|
|
|
raise RuntimeError("LoRA is not enabled.")
|
|
|
|
|
return self.lora_manager.remove_lora(lora_id)
|
2024-06-21 15:42:46 -07:00
|
|
|
|
|
|
|
|
|
def pin_lora(self, lora_id: int) -> bool:
|
|
|
|
|
if not self.lora_manager:
|
|
|
|
|
raise RuntimeError("LoRA is not enabled.")
|
|
|
|
|
return self.lora_manager.pin_lora(lora_id)
|
2024-01-24 00:26:37 +01:00
|
|
|
|
|
|
|
|
|
def list_loras(self) -> Set[int]:
|
|
|
|
|
if not self.lora_manager:
|
|
|
|
|
raise RuntimeError("LoRA is not enabled.")
|
|
|
|
|
return self.lora_manager.list_loras()
|
|
|
|
|
|
2023-12-16 21:12:08 -08:00
|
|
|
|
@torch.inference_mode()
|
2024-07-02 10:58:08 -07:00
|
|
|
|
def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
|
2024-03-21 06:46:05 +09:00
|
|
|
|
"""Cuda graph capture a model.
|
|
|
|
|
|
|
|
|
|
Note that CUDA graph's performance gain is negligible if number
|
|
|
|
|
of batched tokens are larger than 200. And since CUDA graph
|
|
|
|
|
requires fixed sized tensors, supporting large/variable batch
|
|
|
|
|
size requires high GPU memory overhead. Thus, vLLM only captures
|
|
|
|
|
decoding requests. Mixed batch (chunked prefill + decoding) or
|
|
|
|
|
prefill requests are not captured.
|
|
|
|
|
|
|
|
|
|
Since it is used for decoding-only, it assumes there's only 1 token
|
|
|
|
|
per sequence in the batch.
|
|
|
|
|
"""
|
2023-12-16 21:12:08 -08:00
|
|
|
|
assert not self.model_config.enforce_eager
|
|
|
|
|
logger.info("Capturing the model for CUDA graphs. This may lead to "
|
|
|
|
|
"unexpected consequences if the model is not static. To "
|
|
|
|
|
"run the model in eager mode, set 'enforce_eager=True' or "
|
|
|
|
|
"use '--enforce-eager' in the CLI.")
|
2023-12-18 18:16:17 -08:00
|
|
|
|
logger.info("CUDA graphs can take additional 1~3 GiB memory per GPU. "
|
|
|
|
|
"If you are running out of memory, consider decreasing "
|
2024-01-15 01:40:51 +08:00
|
|
|
|
"`gpu_memory_utilization` or enforcing eager mode. "
|
|
|
|
|
"You can also reduce the `max_num_seqs` as needed "
|
|
|
|
|
"to decrease memory usage.")
|
2023-12-16 21:12:08 -08:00
|
|
|
|
start_time = time.perf_counter()
|
|
|
|
|
|
|
|
|
|
# Prepare dummy inputs. These will be reused for all batch sizes.
|
|
|
|
|
max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
|
2024-03-21 06:46:05 +09:00
|
|
|
|
input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
|
|
|
|
|
input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
|
|
|
|
|
slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda()
|
2023-12-16 21:12:08 -08:00
|
|
|
|
slot_mapping.fill_(_PAD_SLOT_ID)
|
2024-05-04 02:20:12 +09:00
|
|
|
|
seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
|
2023-12-16 21:12:08 -08:00
|
|
|
|
block_tables = torch.from_numpy(self.graph_block_tables).cuda()
|
2024-07-02 10:58:08 -07:00
|
|
|
|
intermediate_inputs = None
|
|
|
|
|
if not get_pp_group().is_first_rank:
|
|
|
|
|
intermediate_inputs = self.model.make_empty_intermediate_tensors(
|
|
|
|
|
batch_size=max_batch_size,
|
|
|
|
|
dtype=self.model_config.dtype,
|
|
|
|
|
device=self.device)
|
2023-12-16 21:12:08 -08:00
|
|
|
|
|
2024-06-08 19:14:43 -07:00
|
|
|
|
# Prepare buffer for outputs. These will be reused for all batch sizes.
|
|
|
|
|
# It will be filled after the first graph capture.
|
2024-07-02 10:58:08 -07:00
|
|
|
|
hidden_or_intermediate_states: List[Optional[torch.Tensor]] = [
|
|
|
|
|
None
|
|
|
|
|
] * self.parallel_config.pipeline_parallel_size
|
2024-06-08 19:14:43 -07:00
|
|
|
|
|
2024-01-15 01:40:51 +08:00
|
|
|
|
graph_batch_size = _get_graph_batch_size(
|
|
|
|
|
self.scheduler_config.max_num_seqs)
|
|
|
|
|
batch_size_capture_list = [
|
|
|
|
|
bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
|
|
|
|
|
]
|
|
|
|
|
|
2024-06-28 15:28:49 -07:00
|
|
|
|
if self.attn_backend.get_name() == "flashinfer":
|
|
|
|
|
# For flashinfer, different batch sizes will share the
|
|
|
|
|
# same workspace buffer.
|
|
|
|
|
decode_workspace_buffer = \
|
|
|
|
|
torch.empty(FLASHINFER_WORKSPACE_BUFFER_SIZE,
|
|
|
|
|
dtype=torch.uint8,
|
|
|
|
|
device=self.device)
|
|
|
|
|
indices_buffer = torch.empty(max_batch_size *
|
|
|
|
|
self.cache_config.num_gpu_blocks,
|
|
|
|
|
dtype=torch.int32,
|
|
|
|
|
device=self.device)
|
|
|
|
|
indptr_buffer = torch.empty(max_batch_size + 1,
|
|
|
|
|
dtype=torch.int32,
|
|
|
|
|
device=self.device)
|
|
|
|
|
last_page_len_buffer = torch.empty(max_batch_size,
|
|
|
|
|
dtype=torch.int32,
|
|
|
|
|
device=self.device)
|
|
|
|
|
|
2024-05-16 10:59:52 -07:00
|
|
|
|
with graph_capture() as graph_capture_context:
|
2024-02-14 12:30:44 -08:00
|
|
|
|
# NOTE: Capturing the largest batch size first may help reduce the
|
|
|
|
|
# memory usage of CUDA graph.
|
2024-07-02 10:58:08 -07:00
|
|
|
|
for virtual_engine in range(
|
|
|
|
|
self.parallel_config.pipeline_parallel_size):
|
|
|
|
|
for batch_size in reversed(batch_size_capture_list):
|
|
|
|
|
if self.attn_backend.get_name() == "flashinfer":
|
|
|
|
|
indptr_buffer = indptr_buffer[:batch_size + 1]
|
|
|
|
|
last_page_len_buffer = last_page_len_buffer[:
|
|
|
|
|
batch_size]
|
|
|
|
|
|
|
|
|
|
num_qo_heads = (
|
|
|
|
|
self.model_config.get_num_attention_heads(
|
|
|
|
|
self.parallel_config))
|
|
|
|
|
num_kv_heads = self.model_config.get_num_kv_heads(
|
|
|
|
|
self.parallel_config)
|
|
|
|
|
if num_qo_heads // num_kv_heads >= 4:
|
|
|
|
|
use_tensor_cores = True
|
|
|
|
|
else:
|
|
|
|
|
use_tensor_cores = False
|
|
|
|
|
decode_wrapper = \
|
|
|
|
|
CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
|
|
|
|
|
decode_workspace_buffer, indptr_buffer,
|
|
|
|
|
indices_buffer, last_page_len_buffer, "NHD",
|
|
|
|
|
use_tensor_cores)
|
|
|
|
|
kv_cache_dtype = get_kv_cache_torch_dtype(
|
|
|
|
|
self.kv_cache_dtype, self.model_config.dtype)
|
|
|
|
|
|
|
|
|
|
paged_kv_indptr_tensor_host = torch.arange(
|
|
|
|
|
0, batch_size + 1, dtype=torch.int32)
|
|
|
|
|
paged_kv_indices_tensor_host = torch.arange(
|
|
|
|
|
0, batch_size, dtype=torch.int32)
|
|
|
|
|
paged_kv_last_page_len_tensor_host = torch.full(
|
|
|
|
|
(batch_size, ), self.block_size, dtype=torch.int32)
|
|
|
|
|
query_start_loc_host = torch.arange(0,
|
|
|
|
|
batch_size + 1,
|
|
|
|
|
dtype=torch.int32)
|
|
|
|
|
|
|
|
|
|
attn_metadata = self.attn_backend.make_metadata(
|
|
|
|
|
num_prefills=0,
|
|
|
|
|
slot_mapping=slot_mapping[:batch_size],
|
|
|
|
|
num_prefill_tokens=0,
|
|
|
|
|
num_decode_tokens=batch_size,
|
|
|
|
|
max_prefill_seq_len=0,
|
|
|
|
|
block_tables=block_tables,
|
|
|
|
|
paged_kv_indptr=paged_kv_indptr_tensor_host,
|
|
|
|
|
paged_kv_indices=paged_kv_indices_tensor_host,
|
|
|
|
|
paged_kv_last_page_len=
|
|
|
|
|
paged_kv_last_page_len_tensor_host,
|
|
|
|
|
num_qo_heads=num_qo_heads,
|
|
|
|
|
num_kv_heads=num_kv_heads,
|
|
|
|
|
head_dim=self.model_config.get_head_size(),
|
|
|
|
|
page_size=self.block_size,
|
|
|
|
|
seq_start_loc=None,
|
|
|
|
|
query_start_loc=query_start_loc_host,
|
|
|
|
|
device=self.device,
|
|
|
|
|
data_type=kv_cache_dtype,
|
|
|
|
|
use_cuda_graph=True,
|
|
|
|
|
decode_wrapper=decode_wrapper,
|
|
|
|
|
prefill_wrapper=None)
|
|
|
|
|
attn_metadata.begin_forward()
|
2024-06-28 15:28:49 -07:00
|
|
|
|
else:
|
2024-07-02 10:58:08 -07:00
|
|
|
|
attn_metadata = self.attn_backend.make_metadata(
|
|
|
|
|
num_prefills=0,
|
|
|
|
|
num_prefill_tokens=0,
|
|
|
|
|
num_decode_tokens=batch_size,
|
|
|
|
|
slot_mapping=slot_mapping[:batch_size],
|
|
|
|
|
seq_lens=None,
|
|
|
|
|
seq_lens_tensor=seq_lens[:batch_size],
|
|
|
|
|
max_query_len=None,
|
|
|
|
|
max_prefill_seq_len=0,
|
|
|
|
|
max_decode_seq_len=self.max_seq_len_to_capture,
|
|
|
|
|
query_start_loc=None,
|
|
|
|
|
seq_start_loc=None,
|
|
|
|
|
context_lens_tensor=None,
|
|
|
|
|
block_tables=block_tables[:batch_size],
|
|
|
|
|
use_cuda_graph=True,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if self.lora_config:
|
|
|
|
|
lora_mapping = LoRAMapping(
|
|
|
|
|
[0] * batch_size,
|
|
|
|
|
[0] * batch_size,
|
|
|
|
|
)
|
|
|
|
|
self.set_active_loras(set(), lora_mapping)
|
|
|
|
|
|
|
|
|
|
graph_runner = CUDAGraphRunner(
|
|
|
|
|
self.model, self.attn_backend.get_name())
|
|
|
|
|
|
|
|
|
|
if self.attn_backend.get_name() == "flashinfer":
|
|
|
|
|
graph_runner.flashinfer_indptr_buffer = indptr_buffer
|
|
|
|
|
graph_runner.flashinfer_indices_buffer = indices_buffer
|
|
|
|
|
graph_runner.flashinfer_last_page_len_buffer = \
|
|
|
|
|
last_page_len_buffer
|
|
|
|
|
graph_runner.flashinfer_decode_workspace_buffer = \
|
|
|
|
|
decode_workspace_buffer
|
|
|
|
|
graph_runner.flashinfer_decode_wrapper = \
|
|
|
|
|
decode_wrapper
|
|
|
|
|
|
2024-07-03 02:11:29 +03:00
|
|
|
|
capture_inputs = {
|
|
|
|
|
"input_ids":
|
2024-07-02 10:58:08 -07:00
|
|
|
|
input_tokens[:batch_size],
|
2024-07-03 02:11:29 +03:00
|
|
|
|
"positions":
|
2024-07-02 10:58:08 -07:00
|
|
|
|
input_positions[:batch_size],
|
2024-07-03 02:11:29 +03:00
|
|
|
|
"hidden_or_intermediate_states":
|
2024-07-02 10:58:08 -07:00
|
|
|
|
hidden_or_intermediate_states[
|
|
|
|
|
virtual_engine] # type: ignore
|
|
|
|
|
[:batch_size]
|
|
|
|
|
if hidden_or_intermediate_states[virtual_engine]
|
|
|
|
|
is not None else None,
|
2024-07-03 02:11:29 +03:00
|
|
|
|
"intermediate_inputs":
|
2024-07-02 10:58:08 -07:00
|
|
|
|
intermediate_inputs[:batch_size]
|
|
|
|
|
if intermediate_inputs is not None else None,
|
2024-07-03 02:11:29 +03:00
|
|
|
|
"kv_caches":
|
2024-07-02 10:58:08 -07:00
|
|
|
|
kv_caches[virtual_engine],
|
2024-07-03 02:11:29 +03:00
|
|
|
|
"attn_metadata":
|
2024-07-02 10:58:08 -07:00
|
|
|
|
attn_metadata,
|
2024-07-03 02:11:29 +03:00
|
|
|
|
"memory_pool":
|
|
|
|
|
self.graph_memory_pool,
|
|
|
|
|
"stream":
|
|
|
|
|
graph_capture_context.stream
|
|
|
|
|
}
|
|
|
|
|
if self.has_seqlen_agnostic:
|
|
|
|
|
# Only used by Mamba-based models CUDA graph atm (Jamba)
|
|
|
|
|
capture_inputs.update({
|
|
|
|
|
"seqlen_agnostic_capture_inputs":
|
|
|
|
|
self.model.get_seqlen_agnostic_capture_inputs(
|
|
|
|
|
batch_size)
|
|
|
|
|
})
|
|
|
|
|
graph_runner.capture(**capture_inputs)
|
2024-07-02 10:58:08 -07:00
|
|
|
|
self.graph_memory_pool = graph_runner.graph.pool()
|
|
|
|
|
self.graph_runners[virtual_engine][batch_size] = (
|
|
|
|
|
graph_runner)
|
2023-12-16 21:12:08 -08:00
|
|
|
|
|
|
|
|
|
end_time = time.perf_counter()
|
|
|
|
|
elapsed_time = end_time - start_time
|
|
|
|
|
# This usually takes < 10 seconds.
|
2024-04-26 16:16:58 +09:00
|
|
|
|
logger.info("Graph capturing finished in %.0f secs.", elapsed_time)
|
2023-12-16 21:12:08 -08:00
|
|
|
|
|
2024-03-08 23:32:46 -08:00
|
|
|
|
@property
|
|
|
|
|
def vocab_size(self) -> int:
|
|
|
|
|
return self.model_config.get_vocab_size()
|
|
|
|
|
|
2023-12-16 21:12:08 -08:00
|
|
|
|
|
2024-06-25 20:30:03 -07:00
|
|
|
|
class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
|
|
|
|
"""
|
|
|
|
|
GPU model runner with sampling step.
|
|
|
|
|
"""
|
|
|
|
|
_model_input_cls: Type[ModelInputForGPUWithSamplingMetadata] = (
|
|
|
|
|
ModelInputForGPUWithSamplingMetadata)
|
|
|
|
|
|
|
|
|
|
def make_model_input_from_broadcasted_tensor_dict(
|
|
|
|
|
self,
|
|
|
|
|
tensor_dict: Dict[str, Any],
|
|
|
|
|
) -> ModelInputForGPUWithSamplingMetadata:
|
2024-06-28 15:28:49 -07:00
|
|
|
|
model_input = \
|
2024-06-25 20:30:03 -07:00
|
|
|
|
ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict(
|
|
|
|
|
tensor_dict,
|
|
|
|
|
attn_backend=self.attn_backend,
|
2024-06-28 15:28:49 -07:00
|
|
|
|
)
|
|
|
|
|
return model_input
|
2024-06-25 20:30:03 -07:00
|
|
|
|
|
|
|
|
|
def prepare_model_input(
|
|
|
|
|
self,
|
|
|
|
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
2024-07-02 10:58:08 -07:00
|
|
|
|
virtual_engine: int = 0,
|
2024-07-03 02:11:29 +03:00
|
|
|
|
finished_requests_ids: Optional[List[str]] = None
|
2024-06-25 20:30:03 -07:00
|
|
|
|
) -> ModelInputForGPUWithSamplingMetadata:
|
|
|
|
|
"""Prepare the model input based on a given sequence group, including
|
|
|
|
|
metadata for the sampling step.
|
|
|
|
|
|
|
|
|
|
The API assumes seq_group_metadata_list is sorted by prefill -> decode.
|
|
|
|
|
|
|
|
|
|
The result tensors and data structure also batches input in prefill
|
|
|
|
|
-> decode order. For example,
|
|
|
|
|
|
|
|
|
|
- input_tokens[:num_prefill_tokens] contains prefill tokens.
|
|
|
|
|
- input_tokens[num_prefill_tokens:] contains decode tokens.
|
|
|
|
|
|
|
|
|
|
If cuda graph is required, this API automatically pads inputs.
|
|
|
|
|
"""
|
|
|
|
|
model_input = self._prepare_model_input_tensors(
|
2024-07-03 02:11:29 +03:00
|
|
|
|
seq_group_metadata_list, finished_requests_ids)
|
2024-06-25 20:30:03 -07:00
|
|
|
|
sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
|
|
|
|
|
model_input.seq_lens,
|
|
|
|
|
model_input.query_lens,
|
|
|
|
|
self.device,
|
|
|
|
|
self.pin_memory)
|
|
|
|
|
is_prompt = (seq_group_metadata_list[0].is_prompt
|
|
|
|
|
if seq_group_metadata_list else None)
|
|
|
|
|
return dataclasses.replace(model_input,
|
|
|
|
|
sampling_metadata=sampling_metadata,
|
2024-07-02 10:58:08 -07:00
|
|
|
|
is_prompt=is_prompt,
|
|
|
|
|
virtual_engine=virtual_engine)
|
2024-06-25 20:30:03 -07:00
|
|
|
|
|
|
|
|
|
@torch.inference_mode()
|
|
|
|
|
def execute_model(
|
|
|
|
|
self,
|
|
|
|
|
model_input: ModelInputForGPUWithSamplingMetadata,
|
|
|
|
|
kv_caches: List[torch.Tensor],
|
2024-07-02 10:58:08 -07:00
|
|
|
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
2024-06-28 09:17:51 -07:00
|
|
|
|
num_steps: int = 1,
|
2024-07-02 10:58:08 -07:00
|
|
|
|
) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
|
2024-06-28 09:17:51 -07:00
|
|
|
|
if num_steps > 1:
|
|
|
|
|
raise ValueError("num_steps > 1 is not supported in ModelRunner")
|
|
|
|
|
|
2024-06-25 20:30:03 -07:00
|
|
|
|
if self.lora_config:
|
|
|
|
|
assert model_input.lora_requests is not None
|
|
|
|
|
assert model_input.lora_mapping is not None
|
|
|
|
|
self.set_active_loras(model_input.lora_requests,
|
|
|
|
|
model_input.lora_mapping)
|
|
|
|
|
|
2024-06-28 15:28:49 -07:00
|
|
|
|
if self.attn_backend.get_name() == "flashinfer":
|
|
|
|
|
assert model_input.attn_metadata is not None
|
|
|
|
|
assert model_input.input_tokens is not None
|
|
|
|
|
if self.flashinfer_decode_workspace_buffer is None:
|
|
|
|
|
self.flashinfer_decode_workspace_buffer = torch.empty(
|
|
|
|
|
FLASHINFER_WORKSPACE_BUFFER_SIZE,
|
|
|
|
|
dtype=torch.uint8,
|
|
|
|
|
device=self.device)
|
|
|
|
|
self.flashinfer_decode_wrapper = \
|
|
|
|
|
BatchDecodeWithPagedKVCacheWrapper(
|
|
|
|
|
self.flashinfer_decode_workspace_buffer, "NHD")
|
|
|
|
|
self.flashinfer_prefill_workspace_buffer = torch.empty(
|
|
|
|
|
FLASHINFER_WORKSPACE_BUFFER_SIZE,
|
|
|
|
|
dtype=torch.uint8,
|
|
|
|
|
device=self.device)
|
|
|
|
|
self.flashinfer_prefill_wrapper = \
|
|
|
|
|
BatchPrefillWithPagedKVCacheWrapper(
|
|
|
|
|
self.flashinfer_prefill_workspace_buffer, "NHD")
|
|
|
|
|
|
|
|
|
|
model_input.attn_metadata.prefill_wrapper = \
|
|
|
|
|
self.flashinfer_prefill_wrapper
|
|
|
|
|
if model_input.attn_metadata.use_cuda_graph:
|
|
|
|
|
batch_size = model_input.input_tokens.shape[0]
|
|
|
|
|
model_input.attn_metadata.decode_wrapper = self.graph_runners[
|
|
|
|
|
batch_size].flashinfer_decode_wrapper
|
|
|
|
|
else:
|
|
|
|
|
model_input.attn_metadata.decode_wrapper = \
|
|
|
|
|
self.flashinfer_decode_wrapper
|
|
|
|
|
model_input.attn_metadata.begin_forward()
|
|
|
|
|
|
2024-06-25 20:30:03 -07:00
|
|
|
|
# Currently cuda graph is only supported by the decode phase.
|
|
|
|
|
assert model_input.attn_metadata is not None
|
|
|
|
|
prefill_meta = model_input.attn_metadata.prefill_metadata
|
|
|
|
|
decode_meta = model_input.attn_metadata.decode_metadata
|
2024-07-02 10:58:08 -07:00
|
|
|
|
# TODO(andoorve): We can remove this once all
|
|
|
|
|
# virtual engines share the same kv cache.
|
|
|
|
|
virtual_engine = model_input.virtual_engine
|
2024-06-25 20:30:03 -07:00
|
|
|
|
if prefill_meta is None and decode_meta.use_cuda_graph:
|
|
|
|
|
assert model_input.input_tokens is not None
|
|
|
|
|
graph_batch_size = model_input.input_tokens.shape[0]
|
2024-07-02 10:58:08 -07:00
|
|
|
|
model_executable = self.graph_runners[virtual_engine][
|
|
|
|
|
graph_batch_size]
|
2024-06-25 20:30:03 -07:00
|
|
|
|
else:
|
|
|
|
|
model_executable = self.model
|
|
|
|
|
|
|
|
|
|
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
|
2024-07-03 02:11:29 +03:00
|
|
|
|
seqlen_agnostic_kwargs = {
|
|
|
|
|
"finished_requests_ids": model_input.finished_requests_ids,
|
|
|
|
|
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
|
|
|
|
|
} if self.has_seqlen_agnostic else {}
|
2024-07-02 10:58:08 -07:00
|
|
|
|
hidden_or_intermediate_states = model_executable(
|
2024-06-25 20:30:03 -07:00
|
|
|
|
input_ids=model_input.input_tokens,
|
|
|
|
|
positions=model_input.input_positions,
|
|
|
|
|
kv_caches=kv_caches,
|
|
|
|
|
attn_metadata=model_input.attn_metadata,
|
2024-07-02 10:58:08 -07:00
|
|
|
|
intermediate_tensors=intermediate_tensors,
|
2024-06-25 20:30:03 -07:00
|
|
|
|
**multi_modal_kwargs,
|
2024-07-03 02:11:29 +03:00
|
|
|
|
**seqlen_agnostic_kwargs)
|
2024-06-25 20:30:03 -07:00
|
|
|
|
|
2024-07-02 10:58:08 -07:00
|
|
|
|
# Compute the logits in the last pipeline stage.
|
|
|
|
|
if not get_pp_group().is_last_rank:
|
|
|
|
|
return hidden_or_intermediate_states
|
|
|
|
|
|
|
|
|
|
logits = self.model.compute_logits(hidden_or_intermediate_states,
|
2024-06-25 20:30:03 -07:00
|
|
|
|
model_input.sampling_metadata)
|
|
|
|
|
|
|
|
|
|
if not self.is_driver_worker:
|
2024-06-28 09:17:51 -07:00
|
|
|
|
return []
|
2024-06-25 20:30:03 -07:00
|
|
|
|
|
|
|
|
|
# Sample the next token.
|
|
|
|
|
output: SamplerOutput = self.model.sample(
|
|
|
|
|
logits=logits,
|
|
|
|
|
sampling_metadata=model_input.sampling_metadata,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if self.return_hidden_states:
|
|
|
|
|
# we only need to pass hidden states of most recent token
|
2024-06-26 21:12:10 -07:00
|
|
|
|
assert model_input.sampling_metadata is not None
|
|
|
|
|
indices = model_input.sampling_metadata.selected_token_indices
|
2024-06-25 20:30:03 -07:00
|
|
|
|
if model_input.is_prompt:
|
2024-07-02 10:58:08 -07:00
|
|
|
|
hidden_states = hidden_or_intermediate_states.index_select(
|
|
|
|
|
0, indices)
|
2024-06-26 21:12:10 -07:00
|
|
|
|
elif decode_meta.use_cuda_graph:
|
2024-07-02 10:58:08 -07:00
|
|
|
|
hidden_states = hidden_or_intermediate_states[:len(indices)]
|
|
|
|
|
else:
|
|
|
|
|
hidden_states = hidden_or_intermediate_states
|
2024-06-26 21:12:10 -07:00
|
|
|
|
|
2024-06-25 20:30:03 -07:00
|
|
|
|
output.hidden_states = hidden_states
|
|
|
|
|
|
2024-06-28 09:17:51 -07:00
|
|
|
|
return [output]
|
2024-06-25 20:30:03 -07:00
|
|
|
|
|
|
|
|
|
|
2023-12-16 21:12:08 -08:00
|
|
|
|
class CUDAGraphRunner:
|
|
|
|
|
|
2024-06-28 15:28:49 -07:00
|
|
|
|
def __init__(self, model: nn.Module, backend_name: str):
|
2023-12-16 21:12:08 -08:00
|
|
|
|
self.model = model
|
2024-06-28 15:28:49 -07:00
|
|
|
|
self.backend_name = backend_name
|
|
|
|
|
|
2023-12-16 21:12:08 -08:00
|
|
|
|
self.input_buffers: Dict[str, torch.Tensor] = {}
|
|
|
|
|
self.output_buffers: Dict[str, torch.Tensor] = {}
|
|
|
|
|
|
2024-04-18 09:28:43 +09:00
|
|
|
|
self._graph: Optional[torch.cuda.CUDAGraph] = None
|
|
|
|
|
|
2024-06-28 15:28:49 -07:00
|
|
|
|
self.flashinfer_decode_workspace_buffer: Optional[torch.Tensor] = None
|
|
|
|
|
self.flashinfer_indptr_buffer: Optional[torch.Tensor] = None
|
|
|
|
|
self.flashinfer_indices_buffer: Optional[torch.Tensor] = None
|
|
|
|
|
self.flashinfer_last_page_len_buffer: Optional[torch.Tensor] = None
|
|
|
|
|
self.flashinfer_decode_wrapper: Optional[
|
|
|
|
|
CUDAGraphBatchDecodeWithPagedKVCacheWrapper] = None
|
|
|
|
|
|
2024-04-18 09:28:43 +09:00
|
|
|
|
@property
|
|
|
|
|
def graph(self):
|
|
|
|
|
assert self._graph is not None
|
|
|
|
|
return self._graph
|
|
|
|
|
|
2023-12-16 21:12:08 -08:00
|
|
|
|
def capture(
|
|
|
|
|
self,
|
|
|
|
|
input_ids: torch.Tensor,
|
|
|
|
|
positions: torch.Tensor,
|
2024-07-02 10:58:08 -07:00
|
|
|
|
hidden_or_intermediate_states: Optional[Union[IntermediateTensors,
|
|
|
|
|
torch.Tensor]],
|
|
|
|
|
intermediate_inputs: Optional[IntermediateTensors],
|
2024-03-24 21:39:33 -07:00
|
|
|
|
kv_caches: List[torch.Tensor],
|
|
|
|
|
attn_metadata: AttentionMetadata,
|
2024-05-16 10:59:52 -07:00
|
|
|
|
memory_pool: Optional[Tuple[int, int]],
|
|
|
|
|
stream: torch.cuda.Stream,
|
2024-03-25 14:16:30 -07:00
|
|
|
|
**kwargs,
|
2024-07-02 10:58:08 -07:00
|
|
|
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
2024-04-18 09:28:43 +09:00
|
|
|
|
assert self._graph is None
|
2024-06-06 19:07:57 -07:00
|
|
|
|
# Run the model a few times without capturing the graph.
|
2023-12-16 21:12:08 -08:00
|
|
|
|
# This is to make sure that the captured graph does not include the
|
|
|
|
|
# kernel launches for initial benchmarking (e.g., Triton autotune).
|
2024-06-06 19:07:57 -07:00
|
|
|
|
# Note one iteration is not enough for torch.jit.script
|
|
|
|
|
for _ in range(_NUM_WARMUP_ITERS):
|
|
|
|
|
self.model(
|
|
|
|
|
input_ids,
|
|
|
|
|
positions,
|
|
|
|
|
kv_caches,
|
|
|
|
|
attn_metadata,
|
2024-07-02 10:58:08 -07:00
|
|
|
|
intermediate_inputs,
|
2024-06-06 19:07:57 -07:00
|
|
|
|
**kwargs,
|
|
|
|
|
)
|
2024-05-16 10:59:52 -07:00
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
|
|
|
|
|
|
# Capture the graph.
|
|
|
|
|
self._graph = torch.cuda.CUDAGraph()
|
|
|
|
|
with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
|
2024-07-02 10:58:08 -07:00
|
|
|
|
output_hidden_or_intermediate_states = self.model(
|
2023-12-16 21:12:08 -08:00
|
|
|
|
input_ids,
|
|
|
|
|
positions,
|
|
|
|
|
kv_caches,
|
2024-03-24 21:39:33 -07:00
|
|
|
|
attn_metadata,
|
2024-07-02 10:58:08 -07:00
|
|
|
|
intermediate_inputs,
|
2024-03-25 14:16:30 -07:00
|
|
|
|
**kwargs,
|
2023-12-16 21:12:08 -08:00
|
|
|
|
)
|
2024-07-02 10:58:08 -07:00
|
|
|
|
if hidden_or_intermediate_states is not None:
|
|
|
|
|
if get_pp_group().is_last_rank:
|
|
|
|
|
hidden_or_intermediate_states.copy_(
|
|
|
|
|
output_hidden_or_intermediate_states)
|
|
|
|
|
else:
|
|
|
|
|
for key in hidden_or_intermediate_states.tensors:
|
|
|
|
|
hidden_or_intermediate_states[key].copy_(
|
|
|
|
|
output_hidden_or_intermediate_states[key])
|
2024-06-08 19:14:43 -07:00
|
|
|
|
else:
|
2024-07-02 10:58:08 -07:00
|
|
|
|
hidden_or_intermediate_states = (
|
|
|
|
|
output_hidden_or_intermediate_states)
|
|
|
|
|
|
|
|
|
|
del output_hidden_or_intermediate_states
|
2024-06-08 19:14:43 -07:00
|
|
|
|
# make sure `output_hidden_states` is deleted
|
|
|
|
|
# in the graph's memory pool
|
|
|
|
|
gc.collect()
|
2023-12-16 21:12:08 -08:00
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
|
|
|
|
|
|
# Save the input and output buffers.
|
2024-06-28 15:28:49 -07:00
|
|
|
|
if self.backend_name == "flashinfer":
|
|
|
|
|
self.input_buffers = {
|
|
|
|
|
"input_ids": input_ids,
|
|
|
|
|
"positions": positions,
|
|
|
|
|
"kv_caches": kv_caches,
|
|
|
|
|
"slot_mapping": attn_metadata.slot_mapping,
|
2024-07-03 02:11:29 +03:00
|
|
|
|
**kwargs,
|
2024-06-28 15:28:49 -07:00
|
|
|
|
}
|
|
|
|
|
else:
|
|
|
|
|
self.input_buffers = {
|
|
|
|
|
"input_ids": input_ids,
|
|
|
|
|
"positions": positions,
|
|
|
|
|
"kv_caches": kv_caches,
|
|
|
|
|
"slot_mapping": attn_metadata.slot_mapping,
|
|
|
|
|
"seq_lens_tensor":
|
|
|
|
|
attn_metadata.decode_metadata.seq_lens_tensor,
|
|
|
|
|
"block_tables": attn_metadata.decode_metadata.block_tables,
|
2024-07-03 02:11:29 +03:00
|
|
|
|
**kwargs,
|
2024-06-28 15:28:49 -07:00
|
|
|
|
}
|
2024-07-02 10:58:08 -07:00
|
|
|
|
if intermediate_inputs is not None:
|
|
|
|
|
self.input_buffers.update(intermediate_inputs.tensors)
|
|
|
|
|
if get_pp_group().is_last_rank:
|
|
|
|
|
self.output_buffers = {
|
|
|
|
|
"hidden_states": hidden_or_intermediate_states
|
|
|
|
|
}
|
|
|
|
|
else:
|
|
|
|
|
self.output_buffers = hidden_or_intermediate_states
|
|
|
|
|
return hidden_or_intermediate_states
|
2023-12-16 21:12:08 -08:00
|
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
|
self,
|
|
|
|
|
input_ids: torch.Tensor,
|
|
|
|
|
positions: torch.Tensor,
|
2024-03-24 21:39:33 -07:00
|
|
|
|
kv_caches: List[torch.Tensor],
|
|
|
|
|
attn_metadata: AttentionMetadata,
|
2024-07-02 10:58:08 -07:00
|
|
|
|
intermediate_tensors: Optional[IntermediateTensors],
|
2024-03-25 14:16:30 -07:00
|
|
|
|
**kwargs,
|
2023-12-16 21:12:08 -08:00
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
# KV caches are fixed tensors, so we don't need to copy them.
|
|
|
|
|
del kv_caches
|
|
|
|
|
|
|
|
|
|
# Copy the input tensors to the input buffers.
|
2023-12-19 16:52:46 -08:00
|
|
|
|
self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
|
|
|
|
|
self.input_buffers["positions"].copy_(positions, non_blocking=True)
|
2024-03-24 21:39:33 -07:00
|
|
|
|
self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
|
2023-12-19 16:52:46 -08:00
|
|
|
|
non_blocking=True)
|
2024-06-28 15:28:49 -07:00
|
|
|
|
if self.backend_name != "flashinfer":
|
|
|
|
|
self.input_buffers["seq_lens_tensor"].copy_(
|
|
|
|
|
attn_metadata.decode_metadata.seq_lens_tensor,
|
|
|
|
|
non_blocking=True)
|
|
|
|
|
self.input_buffers["block_tables"].copy_(
|
|
|
|
|
attn_metadata.decode_metadata.block_tables, non_blocking=True)
|
2024-07-03 02:11:29 +03:00
|
|
|
|
if "seqlen_agnostic_capture_inputs" in self.input_buffers:
|
|
|
|
|
self.model.copy_inputs_before_cuda_graphs(self.input_buffers,
|
|
|
|
|
**kwargs)
|
2024-07-02 10:58:08 -07:00
|
|
|
|
if intermediate_tensors is not None:
|
|
|
|
|
for key in intermediate_tensors.tensors:
|
|
|
|
|
self.input_buffers[key].copy_(intermediate_tensors[key],
|
|
|
|
|
non_blocking=True)
|
2023-12-16 21:12:08 -08:00
|
|
|
|
# Run the graph.
|
|
|
|
|
self.graph.replay()
|
2024-07-03 02:11:29 +03:00
|
|
|
|
if "seqlen_agnostic_capture_inputs" in self.input_buffers:
|
|
|
|
|
self.model.copy_outputs_after_cuda_graphs(self.input_buffers,
|
|
|
|
|
**kwargs)
|
2023-12-16 21:12:08 -08:00
|
|
|
|
# Return the output tensor.
|
2024-07-02 10:58:08 -07:00
|
|
|
|
if get_pp_group().is_last_rank:
|
|
|
|
|
return self.output_buffers["hidden_states"]
|
|
|
|
|
|
|
|
|
|
return self.output_buffers
|
2023-12-16 21:12:08 -08:00
|
|
|
|
|
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
|
|
|
return self.forward(*args, **kwargs)
|
|
|
|
|
|
2023-11-29 22:16:37 -08:00
|
|
|
|
|
2023-12-16 21:12:08 -08:00
|
|
|
|
def _get_graph_batch_size(batch_size: int) -> int:
|
2024-03-21 06:46:05 +09:00
|
|
|
|
"""Returns the padded batch size given actual batch size.
|
|
|
|
|
|
|
|
|
|
Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
|
|
|
|
|
2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
|
|
|
|
|
"""
|
2023-12-16 21:12:08 -08:00
|
|
|
|
if batch_size <= 2:
|
|
|
|
|
return batch_size
|
|
|
|
|
elif batch_size <= 4:
|
|
|
|
|
return 4
|
|
|
|
|
else:
|
2024-03-21 06:46:05 +09:00
|
|
|
|
return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
|
|
|
|
|
_BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
|
2024-03-25 14:16:30 -07:00
|
|
|
|
|
|
|
|
|
|
2024-05-11 11:30:37 -07:00
|
|
|
|
def _is_block_tables_empty(block_tables: Union[None, Dict]):
|
|
|
|
|
"""
|
|
|
|
|
Check if block_tables is None or a dictionary with all None values.
|
|
|
|
|
"""
|
|
|
|
|
if block_tables is None:
|
|
|
|
|
return True
|
|
|
|
|
if isinstance(block_tables, dict) and all(
|
|
|
|
|
value is None for value in block_tables.values()):
|
|
|
|
|
return True
|
|
|
|
|
return False
|