[Kernel] Use flashinfer for decoding (#4353)

Co-authored-by: LiuXiaoxuanPKU <llilyliupku@gmail.com>
This commit is contained in:
Lily Liu 2024-05-03 15:51:27 -07:00 committed by GitHub
parent f8e7adda21
commit 43c413ec57
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 600 additions and 53 deletions

View File

@ -24,6 +24,14 @@ void reshape_and_cache(
const std::string& kv_cache_dtype,
const float kv_scale);
void reshape_and_cache_flash(
torch::Tensor& key,
torch::Tensor& value,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype);
// Just for unittest
void convert_fp8(
torch::Tensor& src_cache,

View File

@ -215,6 +215,41 @@ __global__ void reshape_and_cache_kernel(
}
}
template<typename scalar_t>
__global__ void reshape_and_cache_flash_kernel(
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
scalar_t* __restrict__ k_cache, // [num_blocks, block_size, num_heads, head_size]
scalar_t* __restrict__ v_cache, // [num_blocks, block_size, num_heads, head_size]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int block_stride,
const int key_stride,
const int value_stride,
const int num_heads,
const int head_size,
const int block_size) {
const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx];
// NOTE: slot_idx can be -1 if the token is padded
if (slot_idx < 0) {
return;
}
const int64_t block_idx = slot_idx / block_size;
const int64_t block_offset = slot_idx % block_size;
const int n = num_heads * head_size;
for (int i = threadIdx.x; i < n; i += blockDim.x) {
const int64_t src_key_idx = token_idx * key_stride + i;
const int64_t src_value_idx = token_idx * value_stride + i;
const int head_idx = i / head_size;
const int head_offset = i % head_size;
const int64_t tgt_value_idx = block_idx * block_stride
+ block_offset * num_heads * head_size
+ head_idx * head_size
+ head_offset;
k_cache[tgt_value_idx] = key[src_key_idx];
v_cache[tgt_value_idx] = value[src_value_idx];
}
}
} // namespace vllm
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_KV_CACHE) \
@ -275,6 +310,51 @@ void reshape_and_cache(
}
}
void reshape_and_cache_flash(
torch::Tensor& key, // [num_tokens, num_heads, head_size]
torch::Tensor& value, // [num_tokens, num_heads, head_size]
torch::Tensor& k_cache, // [num_blocks, block_size, num_heads, head_size]
torch::Tensor& v_cache, // [num_blocks, block_size, num_heads, head_size]
torch::Tensor& slot_mapping, // [num_tokens]
const std::string& kv_cache_dtype)
{
// FIXME: only support auto datatype, does not support fp8
if (kv_cache_dtype != "auto") {
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
}
int num_tokens = key.size(0);
int num_heads = key.size(1);
int head_size = key.size(2);
int block_size = k_cache.size(1);
int key_stride = key.stride(0);
int value_stride = value.stride(0);
int block_stride = k_cache.stride(0);
TORCH_CHECK(k_cache.stride(0) == v_cache.stride(0));
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
key.scalar_type(),
"reshape_and_cache_flash",
[&] {
vllm::reshape_and_cache_flash_kernel<scalar_t><<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
k_cache.data_ptr<scalar_t>(),
v_cache.data_ptr<scalar_t>(),
slot_mapping.data_ptr<int64_t>(),
block_stride,
key_stride,
value_stride,
num_heads,
head_size,
block_size);
});
}
namespace vllm {
template<typename Tout, typename Tin>

View File

@ -96,6 +96,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"reshape_and_cache",
&reshape_and_cache,
"Reshape the key and value tensors and cache them");
cache_ops.def(
"reshape_and_cache_flash",
&reshape_and_cache_flash,
"Reshape the key and value tensors and cache them");
cache_ops.def(
"convert_fp8",
&convert_fp8,

View File

@ -2,12 +2,15 @@
Run `pytest tests/basic_correctness/test_basic_correctness.py`.
"""
import os
import pytest
MODELS = [
"facebook/opt-125m",
"meta-llama/Llama-2-7b-hf",
]
VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND"
@pytest.mark.parametrize("model", MODELS)
@ -23,11 +26,18 @@ def test_models(
max_tokens: int,
enforce_eager: bool,
) -> None:
backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND)
if backend_by_env_var == "FLASHINFER" and enforce_eager is False:
pytest.skip("Skipping non-eager test for FlashInferBackend.")
hf_model = hf_runner(model, dtype=dtype)
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
del hf_model
vllm_model = vllm_runner(model, dtype=dtype, enforce_eager=enforce_eager)
vllm_model = vllm_runner(model,
dtype=dtype,
enforce_eager=enforce_eager,
gpu_memory_utilization=0.7)
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
del vllm_model

View File

@ -18,6 +18,7 @@ import torch
MODELS = [
os.environ["TEST_DIST_MODEL"],
]
VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND"
@pytest.mark.skipif(torch.cuda.device_count() < 2,
@ -33,16 +34,19 @@ def test_models(
dtype: str,
max_tokens: int,
) -> None:
enforce_eager = False
backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND)
if backend_by_env_var == "FLASHINFER":
enforce_eager = True
hf_model = hf_runner(model, dtype=dtype)
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
del hf_model
vllm_model = vllm_runner(
model,
dtype=dtype,
tensor_parallel_size=2,
)
vllm_model = vllm_runner(model,
dtype=dtype,
tensor_parallel_size=2,
enforce_eager=enforce_eager)
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
del vllm_model

View File

@ -1,8 +1,14 @@
import pytest
from vllm.utils import create_kv_caches_with_random
from vllm.utils import (create_kv_caches_with_random,
create_kv_caches_with_random_flash)
@pytest.fixture()
def kv_cache_factory():
return create_kv_caches_with_random
@pytest.fixture()
def kv_cache_factory_flashinfer():
return create_kv_caches_with_random_flash

View File

@ -5,6 +5,7 @@ import pytest
import torch
from vllm import _custom_ops as ops
from vllm._C import cache_ops
from vllm.utils import is_hip
COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
@ -191,6 +192,82 @@ def test_reshape_and_cache(
assert torch.allclose(value_cache, cloned_value_cache)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@torch.inference_mode()
def test_reshape_and_cache_flash(
kv_cache_factory_flashinfer,
num_tokens: int,
num_heads: int,
head_size: int,
block_size: int,
num_blocks: int,
dtype: torch.dtype,
seed: int,
device: str,
kv_cache_dtype: str,
) -> None:
if kv_cache_dtype == "fp8":
pytest.skip()
random.seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Create a random slot mapping.
num_slots = block_size * num_blocks
slot_mapping = random.sample(range(num_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device='cuda')
qkv = torch.randn(num_tokens,
3,
num_heads,
head_size,
dtype=dtype,
device=device)
_, key, value = qkv.unbind(dim=1)
# Create the KV caches.
key_caches, value_caches = kv_cache_factory_flashinfer(
num_blocks,
block_size,
1,
num_heads,
head_size,
kv_cache_dtype,
dtype,
)
key_cache, value_cache = key_caches[0], value_caches[0]
# Clone the KV caches.
cloned_key_cache = key_cache.clone()
cloned_value_cache = value_cache.clone()
# Call the reshape_and_cache kernel.
cache_ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
slot_mapping, kv_cache_dtype)
# Run the reference implementation.
block_indicies = torch.div(slot_mapping, block_size, rounding_mode='floor')
block_indicies = block_indicies.cpu().tolist()
block_offsets = slot_mapping % block_size
block_offsets = block_offsets.cpu().tolist()
for i in range(num_tokens):
block_idx = block_indicies[i]
block_offset = block_offsets[i]
cloned_key_cache[block_idx, block_offset, :, :] = key[i]
cloned_value_cache[block_idx, block_offset, :, :] = value[i]
assert torch.allclose(key_cache, cloned_key_cache)
assert torch.allclose(value_cache, cloned_value_cache)
@pytest.mark.parametrize("direction", COPYING_DIRECTION)
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)

View File

@ -222,6 +222,18 @@ def reshape_and_cache(
slot_mapping, kv_cache_dtype, kv_scale)
def reshape_and_cache_flash(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
) -> None:
vllm_cache_ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
slot_mapping, kv_cache_dtype)
def copy_blocks(key_caches: torch.Tensor, value_caches: torch.Tensor,
block_mapping: torch.Tensor) -> None:
vllm_cache_ops.copy_blocks(key_caches, value_caches, block_mapping)

View File

@ -1,6 +1,7 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, fields
from typing import Any, Dict, Generic, List, Optional, Tuple, Type, TypeVar
from typing import (Any, Dict, Generic, List, Optional, Set, Tuple, Type,
TypeVar)
import torch
@ -15,7 +16,7 @@ class AttentionBackend(ABC):
@staticmethod
@abstractmethod
def make_metadata(*args, **kwargs) -> "AttentionMetadata":
def make_metadata(*args, **kwargs) -> "AttentionMetadataPerStage":
raise NotImplementedError
@staticmethod
@ -50,13 +51,17 @@ class AttentionBackend(ABC):
class AttentionMetadataPerStage:
"""Attention metadata for a specific stage. I.e., prefill or decode."""
def asdict_zerocopy(self) -> Dict[str, Any]:
def asdict_zerocopy(self,
skip_fields: Optional[Set[str]] = None
) -> Dict[str, Any]:
"""Similar to dataclasses.asdict, but avoids deepcopying."""
if skip_fields is None:
skip_fields = set()
# Note that if we add dataclasses as fields, they will need
# similar handling.
return {
field.name: getattr(self, field.name)
for field in fields(self)
for field in fields(self) if field.name not in skip_fields
}

View File

@ -0,0 +1,220 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Set, Tuple, Type
try:
import flashinfer
from flash_attn import flash_attn_varlen_func
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
except ImportError:
flashinfer = None
flash_attn_varlen_func = None
BatchDecodeWithPagedKVCacheWrapper = None
import torch
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata,
AttentionMetadataPerStage)
class FlashInferBackend(AttentionBackend):
@staticmethod
def get_impl_cls() -> Type["FlashInferImpl"]:
return FlashInferImpl
@staticmethod
def make_metadata(*args, **kwargs) -> "FlashInferMetadata":
return FlashInferMetadata(*args, **kwargs)
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return (num_blocks, 2, block_size, num_kv_heads, head_size)
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int],
) -> None:
raise NotImplementedError
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
) -> None:
raise NotImplementedError
@staticmethod
def get_supported_head_sizes() -> List[int]:
return [64, 128, 256]
@dataclass
class FlashInferMetadata(AttentionMetadataPerStage):
is_prompt: bool
use_cuda_graph: bool = False
decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
# Metadata for the prefill stage since we still
# use flash attention for prefill.
seq_start_loc: Optional[torch.Tensor] = None
max_seq_len: Optional[int] = None
block_tables: Optional[torch.Tensor] = None
# Metadata for the decode stage
# Workspace buffer required by the kernel, the buffer should not
# be allocated/deacollated by the FalshInfermetadata object.
workspace_buffer: Optional[torch.Tensor] = None
# An example for paged_kv_indices, paged_kv_indptr:
# request 1, page indices [0, 5, 8]
# request 2, page indices [1, 6, 7]
# request 3, page indices [3, 4]
# paged_kv_indices is a concatenation of page indices of all requests:
# [0, 5, 8, 1, 6, 7, 3, 4]
# paged_kv_indptr is used to index into paged_kv_indices:
# [0, 3, 6, 8]
# The indptr of the paged kv cache, shape: [batch_size + 1]
paged_kv_indptr: Optional[torch.Tensor] = None
# The page indices of the paged kv cache
paged_kv_indices: Optional[torch.Tensor] = None
# The number of entries in the last page of each request in
# the paged kv cache, shape: [batch_size]
paged_kv_last_page_len: Optional[torch.Tensor] = None
# The number of query/output heads
num_qo_heads: Optional[int] = None
# The number of key/value heads
num_kv_heads: Optional[int] = None
# The dimension of the attention heads
head_dim: Optional[int] = None
# Block size of vllm
page_size: Optional[int] = None
# The data type of the paged kv cache
data_type: torch.dtype = None
def __post_init__(self):
# Refer to
# https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
supported_head_sizes = FlashInferBackend.get_supported_head_sizes()
if self.head_dim is not None and self.head_dim \
not in supported_head_sizes:
raise ValueError(
f"Only {supported_head_sizes} are supported for head_dim,",
f"received {self.head_dim}.")
# When using flashinfer, we are also creating the FlashInferMetadata,
# which will also call post_init by default, here we want to skip the
# post_init if it's the prefill phase.
if not self.is_prompt:
self.decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
self.workspace_buffer, "NHD")
self.decode_wrapper.begin_forward(
self.paged_kv_indptr,
self.paged_kv_indices,
self.paged_kv_last_page_len,
self.num_qo_heads,
self.num_kv_heads,
self.head_dim,
self.page_size,
# Disable flashinfer's pos encoding and use vllm's rope.
pos_encoding_mode="NONE",
data_type=self.data_type)
def asdict_zerocopy(self,
skip_fields: Optional[Set[str]] = None
) -> Dict[str, Any]:
if skip_fields is None:
skip_fields = set()
# We need to skip the decode_wrapper field since it cannot be
# broadcasted with nccl when TP is enabled.
skip_fields.add('decode_wrapper')
return super().asdict_zerocopy(skip_fields)
class FlashInferImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
) -> None:
if sliding_window is not None:
raise ValueError("Sliding window is not supported in FlashInfer.")
self.sliding_window = (-1, -1)
self.alibi_slopes = alibi_slopes
self.scale = scale
self.num_heads = num_heads
self.head_size = head_size
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
def forward(self, query: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, kv_cache: Optional[torch.Tensor],
attn_metadata: AttentionMetadata[FlashInferMetadata],
kv_scale: float):
num_tokens, hidden_size = query.shape
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
if attn_metadata.num_prefill_tokens > 0:
assert attn_metadata.num_decode_tokens == 0, (
"Chunked prefill is not supported with flashinfer yet.")
if attn_metadata.num_decode_tokens > 0:
assert attn_metadata.num_prefill_tokens == 0, (
"Chunked prefill is not supported with flashinfer yet.")
if kv_cache is not None:
# Use the same reshape and cache kernel as flash attention.
ops.reshape_and_cache_flash(
key,
value,
kv_cache[:, 0],
kv_cache[:, 1],
attn_metadata.slot_mapping.flatten(),
attn_metadata.kv_cache_dtype,
)
if prefill_meta := attn_metadata.prefill_metadata:
assert prefill_meta.block_tables is not None
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
output = flash_attn_varlen_func(
q=query,
k=key,
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_q=prefill_meta.max_seq_len,
max_seqlen_k=prefill_meta.max_seq_len,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
)
else:
raise NotImplementedError(
"Prefix caching is not supported with flashinfer yet.")
else:
assert attn_metadata.decode_metadata is not None
assert attn_metadata.decode_metadata.decode_wrapper is not None
query = query.contiguous(
) # Flashinfer requires query to be contiguous
output = attn_metadata.decode_metadata.decode_wrapper.forward(
query,
kv_cache,
sm_scale=self.scale,
)
return output.view(num_tokens, hidden_size)

View File

@ -17,6 +17,7 @@ class _Backend(enum.Enum):
XFORMERS = enum.auto()
ROCM_FLASH = enum.auto()
TORCH_SDPA = enum.auto()
FLASHINFER = enum.auto()
@lru_cache(maxsize=None)
@ -41,6 +42,11 @@ def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]:
logger.info("Using Torch SDPA backend.")
from vllm.attention.backends.torch_sdpa import TorchSDPABackend
return TorchSDPABackend
elif backend == _Backend.FLASHINFER:
logger.info("Using Flashinfer backend.")
logger.warning("Eager mode is enforced for the Flashinfer backend. ")
from vllm.attention.backends.flashinfer import FlashInferBackend
return FlashInferBackend
else:
raise ValueError("Invalid attention backend.")

View File

@ -298,6 +298,11 @@ class ModelConfig:
return max(1,
total_num_kv_heads // parallel_config.tensor_parallel_size)
def get_num_attention_heads(self,
parallel_config: "ParallelConfig") -> int:
return self.hf_text_config.num_attention_heads // \
parallel_config.tensor_parallel_size
def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
total_num_hidden_layers = self.hf_text_config.num_hidden_layers
return total_num_hidden_layers // parallel_config.pipeline_parallel_size

View File

@ -579,8 +579,10 @@ class SequenceGroupMetadata:
query tokens for prefill, we don't need sampling.
token_chunk_size: The number of tokens to be processed (per sequence).
None if chunking is not required.
state: Internal state tied to this sequence group.
lora_request: LoRA request.
computed_block_nums: The block numbers that are already computed,
used in prefix caching.
state: Internal state tied to this sequence group.
multi_modal_data: Multi modal data.
"""

View File

@ -355,21 +355,9 @@ def _generate_random_fp8(
del tensor_tmp
def create_kv_caches_with_random(
num_blocks: int,
block_size: int,
num_layers: int,
num_heads: int,
head_size: int,
cache_dtype: Optional[Union[str, torch.dtype]],
model_dtype: Optional[Union[str, torch.dtype]] = None,
seed: int = 0,
device: Optional[str] = "cuda",
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
def get_kv_cache_torch_dtype(
cache_dtype: Optional[Union[str, torch.dtype]],
model_dtype: Optional[Union[str, torch.dtype]] = None) -> torch.dtype:
if isinstance(cache_dtype, str):
if cache_dtype == "auto":
if isinstance(model_dtype, str):
@ -388,6 +376,55 @@ def create_kv_caches_with_random(
torch_dtype = cache_dtype
else:
raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
return torch_dtype
def create_kv_caches_with_random_flash(
num_blocks: int,
block_size: int,
num_layers: int,
num_heads: int,
head_size: int,
cache_dtype: Optional[Union[str, torch.dtype]],
model_dtype: Optional[Union[str, torch.dtype]] = None,
seed: int = 0,
device: Optional[str] = "cuda",
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
assert cache_dtype != "fp8"
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size)
scale = head_size**-0.5
key_caches, value_caches = [], []
for _ in range(num_layers):
key_value_cache = torch.empty(size=key_value_cache_shape,
dtype=torch_dtype,
device=device)
key_value_cache.uniform_(-scale, scale)
key_caches.append(key_value_cache[:, 0])
value_caches.append(key_value_cache[:, 1])
return key_caches, value_caches
def create_kv_caches_with_random(
num_blocks: int,
block_size: int,
num_layers: int,
num_heads: int,
head_size: int,
cache_dtype: Optional[Union[str, torch.dtype]],
model_dtype: Optional[Union[str, torch.dtype]] = None,
seed: int = 0,
device: Optional[str] = "cuda",
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
scale = head_size**-0.5
x = 16 // torch.tensor([], dtype=torch_dtype).element_size()

View File

@ -9,6 +9,7 @@ import torch.nn as nn
from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage,
get_attn_backend)
from vllm.attention.backends.flashinfer import FlashInferBackend
from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce
@ -23,8 +24,8 @@ from vllm.model_executor.model_loader import get_model
from vllm.sampling_params import SamplingParams
from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData,
SequenceGroupMetadata)
from vllm.utils import (CudaMemoryProfiler, is_hip, is_pin_memory_available,
make_tensor_with_pad)
from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip,
is_pin_memory_available, make_tensor_with_pad)
logger = init_logger(__name__)
@ -155,6 +156,9 @@ class ModelRunner:
# (max batch size to capture, max context len to capture / block size).
self.graph_block_tables: torch.Tensor # Set after initial profiling.
# Set if the backend is flashinfer.
self.flashinfer_workspace_buffer: torch.Tensor
def load_model(self) -> None:
with CudaMemoryProfiler() as m:
self.model = get_model(
@ -315,6 +319,7 @@ class ModelRunner:
# Compute the slot mapping.
block_table = seq_group_metadata.block_tables[seq_id]
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
# where start_idx is max(0, seq_len - sliding_window).
# For example, if the prompt len is 10, sliding window is 8, and
@ -390,18 +395,26 @@ class ModelRunner:
dtype=seq_start_loc.dtype,
out=seq_start_loc[1:])
attn_metadata = self.attn_backend.make_metadata(
is_prompt=True,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len,
max_seq_len=max_seq_len,
subquery_start_loc=subquery_start_loc,
seq_start_loc=seq_start_loc,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=False,
)
if self.attn_backend is FlashInferBackend:
attn_metadata = self.attn_backend.make_metadata(
is_prompt=True,
use_cuda_graph=False,
seq_start_loc=seq_start_loc,
max_seq_len=max_seq_len,
block_tables=block_tables)
else:
attn_metadata = self.attn_backend.make_metadata(
is_prompt=True,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len,
max_seq_len=max_seq_len,
subquery_start_loc=subquery_start_loc,
seq_start_loc=seq_start_loc,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=False,
)
return PreparePromptMetadata(
input_tokens=input_tokens,
@ -429,6 +442,24 @@ class ModelRunner:
lora_prompt_mapping: List[int] = []
lora_requests: Set[LoRARequest] = set()
# The following fields are only for flashinfer
# Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
# for the precise definition of the following fields.
# An example:
# request 1, page indices [0, 5, 8]
# request 2, page indices [1, 6, 7]
# request 3, page indices [3, 4]
# paged_kv_indices is a concatenation of page indices of all requests:
# [0, 5, 8, 1, 6, 7, 3, 4]
# paged_kv_indptr is used to index into paged_kv_indices:
# [0, 3, 6, 8]
paged_kv_indices: List[int] = []
# 0 at the beginning of paged_kv_indptr indicates the start of the
# first requests page indices in the paged_kv_indices list.
paged_kv_indptr: List[int] = [0]
# paged_kv_last_page_len is the length of the last page of each request
paged_kv_last_page_len: List[int] = []
if len(seq_group_metadata_list) == 0:
return PrepareDecodeMetadata.empty()
@ -469,6 +500,13 @@ class ModelRunner:
block_table = block_table[-sliding_window_blocks:]
block_tables.append(block_table)
paged_kv_indices.extend(block_table)
paged_kv_indptr.append(paged_kv_indptr[-1] + len(block_table))
last_page_len = seq_data.get_len() % self.block_size
if last_page_len == 0:
last_page_len = self.block_size
paged_kv_last_page_len.append(last_page_len)
# vLLM uses cuda graph only for decoding requests.
# See `capture_model` API for more details.
# For decoding requests, batch_size == input_tokens.
@ -518,18 +556,51 @@ class ModelRunner:
device=self.device,
)
attn_metadata = self.attn_backend.make_metadata(
is_prompt=False,
seq_lens=None,
seq_lens_tensor=seq_lens_tensor,
max_query_len=None,
max_seq_len=max_seq_len,
subquery_start_loc=None,
seq_start_loc=None,
context_lens_tensor=None,
block_tables=block_tables,
use_cuda_graph=use_captured_graph,
)
if self.attn_backend is FlashInferBackend:
if not hasattr(self, "flashinfer_workspace_buffer"):
# Allocate 16MB workspace buffer
# Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html
self.flashinfer_workspace_buffer = torch.empty(
16 * 1024 * 1024, dtype=torch.uint8, device=self.device)
paged_kv_indptr = torch.tensor(paged_kv_indptr,
dtype=torch.int,
device=self.device)
paged_kv_indices = torch.tensor(paged_kv_indices,
dtype=torch.int,
device=self.device)
paged_kv_last_page_len = torch.tensor(paged_kv_last_page_len,
dtype=torch.int,
device=self.device)
kv_cache_dtype = get_kv_cache_torch_dtype(self.kv_cache_dtype,
self.model_config.dtype)
attn_metadata = self.attn_backend.make_metadata(
is_prompt=False,
use_cuda_graph=False,
workspace_buffer=self.flashinfer_workspace_buffer,
paged_kv_indptr=paged_kv_indptr,
paged_kv_indices=paged_kv_indices,
paged_kv_last_page_len=paged_kv_last_page_len,
num_qo_heads=self.model_config.get_num_attention_heads(
self.parallel_config),
num_kv_heads=self.model_config.get_num_kv_heads(
self.parallel_config),
head_dim=self.model_config.get_head_size(),
page_size=self.block_size,
data_type=kv_cache_dtype)
else:
attn_metadata = self.attn_backend.make_metadata(
is_prompt=False,
seq_lens=None,
seq_lens_tensor=seq_lens_tensor,
max_query_len=None,
max_seq_len=max_seq_len,
subquery_start_loc=None,
seq_start_loc=None,
context_lens_tensor=None,
block_tables=block_tables,
use_cuda_graph=use_captured_graph,
)
return PrepareDecodeMetadata(
input_tokens=input_tokens,
input_positions=input_positions,