[Kernel] Use flashinfer for decoding (#4353)
Co-authored-by: LiuXiaoxuanPKU <llilyliupku@gmail.com>
This commit is contained in:
parent
f8e7adda21
commit
43c413ec57
@ -24,6 +24,14 @@ void reshape_and_cache(
|
||||
const std::string& kv_cache_dtype,
|
||||
const float kv_scale);
|
||||
|
||||
void reshape_and_cache_flash(
|
||||
torch::Tensor& key,
|
||||
torch::Tensor& value,
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
torch::Tensor& slot_mapping,
|
||||
const std::string& kv_cache_dtype);
|
||||
|
||||
// Just for unittest
|
||||
void convert_fp8(
|
||||
torch::Tensor& src_cache,
|
||||
|
@ -215,6 +215,41 @@ __global__ void reshape_and_cache_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
__global__ void reshape_and_cache_flash_kernel(
|
||||
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
||||
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
|
||||
scalar_t* __restrict__ k_cache, // [num_blocks, block_size, num_heads, head_size]
|
||||
scalar_t* __restrict__ v_cache, // [num_blocks, block_size, num_heads, head_size]
|
||||
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||
const int block_stride,
|
||||
const int key_stride,
|
||||
const int value_stride,
|
||||
const int num_heads,
|
||||
const int head_size,
|
||||
const int block_size) {
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
const int64_t slot_idx = slot_mapping[token_idx];
|
||||
// NOTE: slot_idx can be -1 if the token is padded
|
||||
if (slot_idx < 0) {
|
||||
return;
|
||||
}
|
||||
const int64_t block_idx = slot_idx / block_size;
|
||||
const int64_t block_offset = slot_idx % block_size;
|
||||
const int n = num_heads * head_size;
|
||||
for (int i = threadIdx.x; i < n; i += blockDim.x) {
|
||||
const int64_t src_key_idx = token_idx * key_stride + i;
|
||||
const int64_t src_value_idx = token_idx * value_stride + i;
|
||||
const int head_idx = i / head_size;
|
||||
const int head_offset = i % head_size;
|
||||
const int64_t tgt_value_idx = block_idx * block_stride
|
||||
+ block_offset * num_heads * head_size
|
||||
+ head_idx * head_size
|
||||
+ head_offset;
|
||||
k_cache[tgt_value_idx] = key[src_key_idx];
|
||||
v_cache[tgt_value_idx] = value[src_value_idx];
|
||||
}
|
||||
}
|
||||
} // namespace vllm
|
||||
|
||||
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_KV_CACHE) \
|
||||
@ -275,6 +310,51 @@ void reshape_and_cache(
|
||||
}
|
||||
}
|
||||
|
||||
void reshape_and_cache_flash(
|
||||
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& value, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& k_cache, // [num_blocks, block_size, num_heads, head_size]
|
||||
torch::Tensor& v_cache, // [num_blocks, block_size, num_heads, head_size]
|
||||
torch::Tensor& slot_mapping, // [num_tokens]
|
||||
const std::string& kv_cache_dtype)
|
||||
{
|
||||
// FIXME: only support auto datatype, does not support fp8
|
||||
if (kv_cache_dtype != "auto") {
|
||||
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
|
||||
}
|
||||
int num_tokens = key.size(0);
|
||||
int num_heads = key.size(1);
|
||||
int head_size = key.size(2);
|
||||
int block_size = k_cache.size(1);
|
||||
|
||||
int key_stride = key.stride(0);
|
||||
int value_stride = value.stride(0);
|
||||
int block_stride = k_cache.stride(0);
|
||||
TORCH_CHECK(k_cache.stride(0) == v_cache.stride(0));
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(num_heads * head_size, 512));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
key.scalar_type(),
|
||||
"reshape_and_cache_flash",
|
||||
[&] {
|
||||
vllm::reshape_and_cache_flash_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
key.data_ptr<scalar_t>(),
|
||||
value.data_ptr<scalar_t>(),
|
||||
k_cache.data_ptr<scalar_t>(),
|
||||
v_cache.data_ptr<scalar_t>(),
|
||||
slot_mapping.data_ptr<int64_t>(),
|
||||
block_stride,
|
||||
key_stride,
|
||||
value_stride,
|
||||
num_heads,
|
||||
head_size,
|
||||
block_size);
|
||||
});
|
||||
}
|
||||
|
||||
namespace vllm {
|
||||
|
||||
template<typename Tout, typename Tin>
|
||||
|
@ -96,6 +96,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
"reshape_and_cache",
|
||||
&reshape_and_cache,
|
||||
"Reshape the key and value tensors and cache them");
|
||||
cache_ops.def(
|
||||
"reshape_and_cache_flash",
|
||||
&reshape_and_cache_flash,
|
||||
"Reshape the key and value tensors and cache them");
|
||||
cache_ops.def(
|
||||
"convert_fp8",
|
||||
&convert_fp8,
|
||||
|
@ -2,12 +2,15 @@
|
||||
|
||||
Run `pytest tests/basic_correctness/test_basic_correctness.py`.
|
||||
"""
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
MODELS = [
|
||||
"facebook/opt-125m",
|
||||
"meta-llama/Llama-2-7b-hf",
|
||||
]
|
||||
VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@ -23,11 +26,18 @@ def test_models(
|
||||
max_tokens: int,
|
||||
enforce_eager: bool,
|
||||
) -> None:
|
||||
backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND)
|
||||
if backend_by_env_var == "FLASHINFER" and enforce_eager is False:
|
||||
pytest.skip("Skipping non-eager test for FlashInferBackend.")
|
||||
|
||||
hf_model = hf_runner(model, dtype=dtype)
|
||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
del hf_model
|
||||
|
||||
vllm_model = vllm_runner(model, dtype=dtype, enforce_eager=enforce_eager)
|
||||
vllm_model = vllm_runner(model,
|
||||
dtype=dtype,
|
||||
enforce_eager=enforce_eager,
|
||||
gpu_memory_utilization=0.7)
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
del vllm_model
|
||||
|
||||
|
@ -18,6 +18,7 @@ import torch
|
||||
MODELS = [
|
||||
os.environ["TEST_DIST_MODEL"],
|
||||
]
|
||||
VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND"
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
@ -33,16 +34,19 @@ def test_models(
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
enforce_eager = False
|
||||
backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND)
|
||||
if backend_by_env_var == "FLASHINFER":
|
||||
enforce_eager = True
|
||||
|
||||
hf_model = hf_runner(model, dtype=dtype)
|
||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
del hf_model
|
||||
|
||||
vllm_model = vllm_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
tensor_parallel_size=2,
|
||||
)
|
||||
vllm_model = vllm_runner(model,
|
||||
dtype=dtype,
|
||||
tensor_parallel_size=2,
|
||||
enforce_eager=enforce_eager)
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
del vllm_model
|
||||
|
||||
|
@ -1,8 +1,14 @@
|
||||
import pytest
|
||||
|
||||
from vllm.utils import create_kv_caches_with_random
|
||||
from vllm.utils import (create_kv_caches_with_random,
|
||||
create_kv_caches_with_random_flash)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def kv_cache_factory():
|
||||
return create_kv_caches_with_random
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def kv_cache_factory_flashinfer():
|
||||
return create_kv_caches_with_random_flash
|
||||
|
@ -5,6 +5,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm._C import cache_ops
|
||||
from vllm.utils import is_hip
|
||||
|
||||
COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
|
||||
@ -191,6 +192,82 @@ def test_reshape_and_cache(
|
||||
assert torch.allclose(value_cache, cloned_value_cache)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
|
||||
@torch.inference_mode()
|
||||
def test_reshape_and_cache_flash(
|
||||
kv_cache_factory_flashinfer,
|
||||
num_tokens: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
block_size: int,
|
||||
num_blocks: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: str,
|
||||
kv_cache_dtype: str,
|
||||
) -> None:
|
||||
if kv_cache_dtype == "fp8":
|
||||
pytest.skip()
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
# Create a random slot mapping.
|
||||
num_slots = block_size * num_blocks
|
||||
slot_mapping = random.sample(range(num_slots), num_tokens)
|
||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device='cuda')
|
||||
|
||||
qkv = torch.randn(num_tokens,
|
||||
3,
|
||||
num_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
_, key, value = qkv.unbind(dim=1)
|
||||
|
||||
# Create the KV caches.
|
||||
key_caches, value_caches = kv_cache_factory_flashinfer(
|
||||
num_blocks,
|
||||
block_size,
|
||||
1,
|
||||
num_heads,
|
||||
head_size,
|
||||
kv_cache_dtype,
|
||||
dtype,
|
||||
)
|
||||
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||
|
||||
# Clone the KV caches.
|
||||
cloned_key_cache = key_cache.clone()
|
||||
cloned_value_cache = value_cache.clone()
|
||||
|
||||
# Call the reshape_and_cache kernel.
|
||||
cache_ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
|
||||
slot_mapping, kv_cache_dtype)
|
||||
|
||||
# Run the reference implementation.
|
||||
block_indicies = torch.div(slot_mapping, block_size, rounding_mode='floor')
|
||||
block_indicies = block_indicies.cpu().tolist()
|
||||
block_offsets = slot_mapping % block_size
|
||||
block_offsets = block_offsets.cpu().tolist()
|
||||
for i in range(num_tokens):
|
||||
block_idx = block_indicies[i]
|
||||
block_offset = block_offsets[i]
|
||||
cloned_key_cache[block_idx, block_offset, :, :] = key[i]
|
||||
cloned_value_cache[block_idx, block_offset, :, :] = value[i]
|
||||
|
||||
assert torch.allclose(key_cache, cloned_key_cache)
|
||||
assert torch.allclose(value_cache, cloned_value_cache)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("direction", COPYING_DIRECTION)
|
||||
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
|
@ -222,6 +222,18 @@ def reshape_and_cache(
|
||||
slot_mapping, kv_cache_dtype, kv_scale)
|
||||
|
||||
|
||||
def reshape_and_cache_flash(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
) -> None:
|
||||
vllm_cache_ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
|
||||
slot_mapping, kv_cache_dtype)
|
||||
|
||||
|
||||
def copy_blocks(key_caches: torch.Tensor, value_caches: torch.Tensor,
|
||||
block_mapping: torch.Tensor) -> None:
|
||||
vllm_cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
|
||||
|
@ -1,6 +1,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import Any, Dict, Generic, List, Optional, Tuple, Type, TypeVar
|
||||
from typing import (Any, Dict, Generic, List, Optional, Set, Tuple, Type,
|
||||
TypeVar)
|
||||
|
||||
import torch
|
||||
|
||||
@ -15,7 +16,7 @@ class AttentionBackend(ABC):
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def make_metadata(*args, **kwargs) -> "AttentionMetadata":
|
||||
def make_metadata(*args, **kwargs) -> "AttentionMetadataPerStage":
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@ -50,13 +51,17 @@ class AttentionBackend(ABC):
|
||||
class AttentionMetadataPerStage:
|
||||
"""Attention metadata for a specific stage. I.e., prefill or decode."""
|
||||
|
||||
def asdict_zerocopy(self) -> Dict[str, Any]:
|
||||
def asdict_zerocopy(self,
|
||||
skip_fields: Optional[Set[str]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Similar to dataclasses.asdict, but avoids deepcopying."""
|
||||
if skip_fields is None:
|
||||
skip_fields = set()
|
||||
# Note that if we add dataclasses as fields, they will need
|
||||
# similar handling.
|
||||
return {
|
||||
field.name: getattr(self, field.name)
|
||||
for field in fields(self)
|
||||
for field in fields(self) if field.name not in skip_fields
|
||||
}
|
||||
|
||||
|
||||
|
220
vllm/attention/backends/flashinfer.py
Normal file
220
vllm/attention/backends/flashinfer.py
Normal file
@ -0,0 +1,220 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Type
|
||||
|
||||
try:
|
||||
import flashinfer
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
|
||||
except ImportError:
|
||||
flashinfer = None
|
||||
flash_attn_varlen_func = None
|
||||
BatchDecodeWithPagedKVCacheWrapper = None
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata,
|
||||
AttentionMetadataPerStage)
|
||||
|
||||
|
||||
class FlashInferBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["FlashInferImpl"]:
|
||||
return FlashInferImpl
|
||||
|
||||
@staticmethod
|
||||
def make_metadata(*args, **kwargs) -> "FlashInferMetadata":
|
||||
return FlashInferMetadata(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return (num_blocks, 2, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: Dict[int, int],
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: Dict[int, List[int]],
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def get_supported_head_sizes() -> List[int]:
|
||||
return [64, 128, 256]
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashInferMetadata(AttentionMetadataPerStage):
|
||||
|
||||
is_prompt: bool
|
||||
|
||||
use_cuda_graph: bool = False
|
||||
|
||||
decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
|
||||
|
||||
# Metadata for the prefill stage since we still
|
||||
# use flash attention for prefill.
|
||||
seq_start_loc: Optional[torch.Tensor] = None
|
||||
max_seq_len: Optional[int] = None
|
||||
block_tables: Optional[torch.Tensor] = None
|
||||
|
||||
# Metadata for the decode stage
|
||||
# Workspace buffer required by the kernel, the buffer should not
|
||||
# be allocated/deacollated by the FalshInfermetadata object.
|
||||
workspace_buffer: Optional[torch.Tensor] = None
|
||||
# An example for paged_kv_indices, paged_kv_indptr:
|
||||
# request 1, page indices [0, 5, 8]
|
||||
# request 2, page indices [1, 6, 7]
|
||||
# request 3, page indices [3, 4]
|
||||
# paged_kv_indices is a concatenation of page indices of all requests:
|
||||
# [0, 5, 8, 1, 6, 7, 3, 4]
|
||||
# paged_kv_indptr is used to index into paged_kv_indices:
|
||||
# [0, 3, 6, 8]
|
||||
# The indptr of the paged kv cache, shape: [batch_size + 1]
|
||||
paged_kv_indptr: Optional[torch.Tensor] = None
|
||||
# The page indices of the paged kv cache
|
||||
paged_kv_indices: Optional[torch.Tensor] = None
|
||||
# The number of entries in the last page of each request in
|
||||
# the paged kv cache, shape: [batch_size]
|
||||
paged_kv_last_page_len: Optional[torch.Tensor] = None
|
||||
# The number of query/output heads
|
||||
num_qo_heads: Optional[int] = None
|
||||
# The number of key/value heads
|
||||
num_kv_heads: Optional[int] = None
|
||||
# The dimension of the attention heads
|
||||
head_dim: Optional[int] = None
|
||||
# Block size of vllm
|
||||
page_size: Optional[int] = None
|
||||
# The data type of the paged kv cache
|
||||
data_type: torch.dtype = None
|
||||
|
||||
def __post_init__(self):
|
||||
# Refer to
|
||||
# https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
|
||||
supported_head_sizes = FlashInferBackend.get_supported_head_sizes()
|
||||
if self.head_dim is not None and self.head_dim \
|
||||
not in supported_head_sizes:
|
||||
raise ValueError(
|
||||
f"Only {supported_head_sizes} are supported for head_dim,",
|
||||
f"received {self.head_dim}.")
|
||||
|
||||
# When using flashinfer, we are also creating the FlashInferMetadata,
|
||||
# which will also call post_init by default, here we want to skip the
|
||||
# post_init if it's the prefill phase.
|
||||
if not self.is_prompt:
|
||||
self.decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
||||
self.workspace_buffer, "NHD")
|
||||
self.decode_wrapper.begin_forward(
|
||||
self.paged_kv_indptr,
|
||||
self.paged_kv_indices,
|
||||
self.paged_kv_last_page_len,
|
||||
self.num_qo_heads,
|
||||
self.num_kv_heads,
|
||||
self.head_dim,
|
||||
self.page_size,
|
||||
# Disable flashinfer's pos encoding and use vllm's rope.
|
||||
pos_encoding_mode="NONE",
|
||||
data_type=self.data_type)
|
||||
|
||||
def asdict_zerocopy(self,
|
||||
skip_fields: Optional[Set[str]] = None
|
||||
) -> Dict[str, Any]:
|
||||
if skip_fields is None:
|
||||
skip_fields = set()
|
||||
# We need to skip the decode_wrapper field since it cannot be
|
||||
# broadcasted with nccl when TP is enabled.
|
||||
skip_fields.add('decode_wrapper')
|
||||
return super().asdict_zerocopy(skip_fields)
|
||||
|
||||
|
||||
class FlashInferImpl(AttentionImpl):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: Optional[int] = None,
|
||||
alibi_slopes: Optional[List[float]] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
) -> None:
|
||||
if sliding_window is not None:
|
||||
raise ValueError("Sliding window is not supported in FlashInfer.")
|
||||
self.sliding_window = (-1, -1)
|
||||
self.alibi_slopes = alibi_slopes
|
||||
self.scale = scale
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||
|
||||
def forward(self, query: torch.Tensor, key: torch.Tensor,
|
||||
value: torch.Tensor, kv_cache: Optional[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata[FlashInferMetadata],
|
||||
kv_scale: float):
|
||||
num_tokens, hidden_size = query.shape
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
|
||||
if attn_metadata.num_prefill_tokens > 0:
|
||||
assert attn_metadata.num_decode_tokens == 0, (
|
||||
"Chunked prefill is not supported with flashinfer yet.")
|
||||
if attn_metadata.num_decode_tokens > 0:
|
||||
assert attn_metadata.num_prefill_tokens == 0, (
|
||||
"Chunked prefill is not supported with flashinfer yet.")
|
||||
|
||||
if kv_cache is not None:
|
||||
# Use the same reshape and cache kernel as flash attention.
|
||||
ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
kv_cache[:, 0],
|
||||
kv_cache[:, 1],
|
||||
attn_metadata.slot_mapping.flatten(),
|
||||
attn_metadata.kv_cache_dtype,
|
||||
)
|
||||
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
assert prefill_meta.block_tables is not None
|
||||
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
|
||||
output = flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
cu_seqlens_q=prefill_meta.seq_start_loc,
|
||||
cu_seqlens_k=prefill_meta.seq_start_loc,
|
||||
max_seqlen_q=prefill_meta.max_seq_len,
|
||||
max_seqlen_k=prefill_meta.max_seq_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
window_size=self.sliding_window,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Prefix caching is not supported with flashinfer yet.")
|
||||
else:
|
||||
assert attn_metadata.decode_metadata is not None
|
||||
assert attn_metadata.decode_metadata.decode_wrapper is not None
|
||||
query = query.contiguous(
|
||||
) # Flashinfer requires query to be contiguous
|
||||
output = attn_metadata.decode_metadata.decode_wrapper.forward(
|
||||
query,
|
||||
kv_cache,
|
||||
sm_scale=self.scale,
|
||||
)
|
||||
return output.view(num_tokens, hidden_size)
|
@ -17,6 +17,7 @@ class _Backend(enum.Enum):
|
||||
XFORMERS = enum.auto()
|
||||
ROCM_FLASH = enum.auto()
|
||||
TORCH_SDPA = enum.auto()
|
||||
FLASHINFER = enum.auto()
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
@ -41,6 +42,11 @@ def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]:
|
||||
logger.info("Using Torch SDPA backend.")
|
||||
from vllm.attention.backends.torch_sdpa import TorchSDPABackend
|
||||
return TorchSDPABackend
|
||||
elif backend == _Backend.FLASHINFER:
|
||||
logger.info("Using Flashinfer backend.")
|
||||
logger.warning("Eager mode is enforced for the Flashinfer backend. ")
|
||||
from vllm.attention.backends.flashinfer import FlashInferBackend
|
||||
return FlashInferBackend
|
||||
else:
|
||||
raise ValueError("Invalid attention backend.")
|
||||
|
||||
|
@ -298,6 +298,11 @@ class ModelConfig:
|
||||
return max(1,
|
||||
total_num_kv_heads // parallel_config.tensor_parallel_size)
|
||||
|
||||
def get_num_attention_heads(self,
|
||||
parallel_config: "ParallelConfig") -> int:
|
||||
return self.hf_text_config.num_attention_heads // \
|
||||
parallel_config.tensor_parallel_size
|
||||
|
||||
def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
|
||||
total_num_hidden_layers = self.hf_text_config.num_hidden_layers
|
||||
return total_num_hidden_layers // parallel_config.pipeline_parallel_size
|
||||
|
@ -579,8 +579,10 @@ class SequenceGroupMetadata:
|
||||
query tokens for prefill, we don't need sampling.
|
||||
token_chunk_size: The number of tokens to be processed (per sequence).
|
||||
None if chunking is not required.
|
||||
state: Internal state tied to this sequence group.
|
||||
lora_request: LoRA request.
|
||||
computed_block_nums: The block numbers that are already computed,
|
||||
used in prefix caching.
|
||||
state: Internal state tied to this sequence group.
|
||||
multi_modal_data: Multi modal data.
|
||||
"""
|
||||
|
||||
|
@ -355,21 +355,9 @@ def _generate_random_fp8(
|
||||
del tensor_tmp
|
||||
|
||||
|
||||
def create_kv_caches_with_random(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_layers: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype: Optional[Union[str, torch.dtype]],
|
||||
model_dtype: Optional[Union[str, torch.dtype]] = None,
|
||||
seed: int = 0,
|
||||
device: Optional[str] = "cuda",
|
||||
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||
torch.random.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
def get_kv_cache_torch_dtype(
|
||||
cache_dtype: Optional[Union[str, torch.dtype]],
|
||||
model_dtype: Optional[Union[str, torch.dtype]] = None) -> torch.dtype:
|
||||
if isinstance(cache_dtype, str):
|
||||
if cache_dtype == "auto":
|
||||
if isinstance(model_dtype, str):
|
||||
@ -388,6 +376,55 @@ def create_kv_caches_with_random(
|
||||
torch_dtype = cache_dtype
|
||||
else:
|
||||
raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
|
||||
return torch_dtype
|
||||
|
||||
|
||||
def create_kv_caches_with_random_flash(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_layers: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype: Optional[Union[str, torch.dtype]],
|
||||
model_dtype: Optional[Union[str, torch.dtype]] = None,
|
||||
seed: int = 0,
|
||||
device: Optional[str] = "cuda",
|
||||
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||
assert cache_dtype != "fp8"
|
||||
torch.random.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
|
||||
key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size)
|
||||
scale = head_size**-0.5
|
||||
key_caches, value_caches = [], []
|
||||
for _ in range(num_layers):
|
||||
key_value_cache = torch.empty(size=key_value_cache_shape,
|
||||
dtype=torch_dtype,
|
||||
device=device)
|
||||
key_value_cache.uniform_(-scale, scale)
|
||||
key_caches.append(key_value_cache[:, 0])
|
||||
value_caches.append(key_value_cache[:, 1])
|
||||
return key_caches, value_caches
|
||||
|
||||
|
||||
def create_kv_caches_with_random(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_layers: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype: Optional[Union[str, torch.dtype]],
|
||||
model_dtype: Optional[Union[str, torch.dtype]] = None,
|
||||
seed: int = 0,
|
||||
device: Optional[str] = "cuda",
|
||||
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||
torch.random.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
|
||||
|
||||
scale = head_size**-0.5
|
||||
x = 16 // torch.tensor([], dtype=torch_dtype).element_size()
|
||||
|
@ -9,6 +9,7 @@ import torch.nn as nn
|
||||
|
||||
from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage,
|
||||
get_attn_backend)
|
||||
from vllm.attention.backends.flashinfer import FlashInferBackend
|
||||
from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
|
||||
from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce
|
||||
@ -23,8 +24,8 @@ from vllm.model_executor.model_loader import get_model
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.utils import (CudaMemoryProfiler, is_hip, is_pin_memory_available,
|
||||
make_tensor_with_pad)
|
||||
from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip,
|
||||
is_pin_memory_available, make_tensor_with_pad)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -155,6 +156,9 @@ class ModelRunner:
|
||||
# (max batch size to capture, max context len to capture / block size).
|
||||
self.graph_block_tables: torch.Tensor # Set after initial profiling.
|
||||
|
||||
# Set if the backend is flashinfer.
|
||||
self.flashinfer_workspace_buffer: torch.Tensor
|
||||
|
||||
def load_model(self) -> None:
|
||||
with CudaMemoryProfiler() as m:
|
||||
self.model = get_model(
|
||||
@ -315,6 +319,7 @@ class ModelRunner:
|
||||
|
||||
# Compute the slot mapping.
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
|
||||
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
|
||||
# where start_idx is max(0, seq_len - sliding_window).
|
||||
# For example, if the prompt len is 10, sliding window is 8, and
|
||||
@ -390,18 +395,26 @@ class ModelRunner:
|
||||
dtype=seq_start_loc.dtype,
|
||||
out=seq_start_loc[1:])
|
||||
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
is_prompt=True,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_query_len=max_query_len,
|
||||
max_seq_len=max_seq_len,
|
||||
subquery_start_loc=subquery_start_loc,
|
||||
seq_start_loc=seq_start_loc,
|
||||
context_lens_tensor=context_lens_tensor,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=False,
|
||||
)
|
||||
if self.attn_backend is FlashInferBackend:
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
is_prompt=True,
|
||||
use_cuda_graph=False,
|
||||
seq_start_loc=seq_start_loc,
|
||||
max_seq_len=max_seq_len,
|
||||
block_tables=block_tables)
|
||||
else:
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
is_prompt=True,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_query_len=max_query_len,
|
||||
max_seq_len=max_seq_len,
|
||||
subquery_start_loc=subquery_start_loc,
|
||||
seq_start_loc=seq_start_loc,
|
||||
context_lens_tensor=context_lens_tensor,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=False,
|
||||
)
|
||||
|
||||
return PreparePromptMetadata(
|
||||
input_tokens=input_tokens,
|
||||
@ -429,6 +442,24 @@ class ModelRunner:
|
||||
lora_prompt_mapping: List[int] = []
|
||||
lora_requests: Set[LoRARequest] = set()
|
||||
|
||||
# The following fields are only for flashinfer
|
||||
# Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
|
||||
# for the precise definition of the following fields.
|
||||
# An example:
|
||||
# request 1, page indices [0, 5, 8]
|
||||
# request 2, page indices [1, 6, 7]
|
||||
# request 3, page indices [3, 4]
|
||||
# paged_kv_indices is a concatenation of page indices of all requests:
|
||||
# [0, 5, 8, 1, 6, 7, 3, 4]
|
||||
# paged_kv_indptr is used to index into paged_kv_indices:
|
||||
# [0, 3, 6, 8]
|
||||
paged_kv_indices: List[int] = []
|
||||
# 0 at the beginning of paged_kv_indptr indicates the start of the
|
||||
# first request’s page indices in the paged_kv_indices list.
|
||||
paged_kv_indptr: List[int] = [0]
|
||||
# paged_kv_last_page_len is the length of the last page of each request
|
||||
paged_kv_last_page_len: List[int] = []
|
||||
|
||||
if len(seq_group_metadata_list) == 0:
|
||||
return PrepareDecodeMetadata.empty()
|
||||
|
||||
@ -469,6 +500,13 @@ class ModelRunner:
|
||||
block_table = block_table[-sliding_window_blocks:]
|
||||
block_tables.append(block_table)
|
||||
|
||||
paged_kv_indices.extend(block_table)
|
||||
paged_kv_indptr.append(paged_kv_indptr[-1] + len(block_table))
|
||||
last_page_len = seq_data.get_len() % self.block_size
|
||||
if last_page_len == 0:
|
||||
last_page_len = self.block_size
|
||||
paged_kv_last_page_len.append(last_page_len)
|
||||
|
||||
# vLLM uses cuda graph only for decoding requests.
|
||||
# See `capture_model` API for more details.
|
||||
# For decoding requests, batch_size == input_tokens.
|
||||
@ -518,18 +556,51 @@ class ModelRunner:
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
is_prompt=False,
|
||||
seq_lens=None,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_query_len=None,
|
||||
max_seq_len=max_seq_len,
|
||||
subquery_start_loc=None,
|
||||
seq_start_loc=None,
|
||||
context_lens_tensor=None,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=use_captured_graph,
|
||||
)
|
||||
if self.attn_backend is FlashInferBackend:
|
||||
if not hasattr(self, "flashinfer_workspace_buffer"):
|
||||
# Allocate 16MB workspace buffer
|
||||
# Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html
|
||||
self.flashinfer_workspace_buffer = torch.empty(
|
||||
16 * 1024 * 1024, dtype=torch.uint8, device=self.device)
|
||||
paged_kv_indptr = torch.tensor(paged_kv_indptr,
|
||||
dtype=torch.int,
|
||||
device=self.device)
|
||||
paged_kv_indices = torch.tensor(paged_kv_indices,
|
||||
dtype=torch.int,
|
||||
device=self.device)
|
||||
paged_kv_last_page_len = torch.tensor(paged_kv_last_page_len,
|
||||
dtype=torch.int,
|
||||
device=self.device)
|
||||
kv_cache_dtype = get_kv_cache_torch_dtype(self.kv_cache_dtype,
|
||||
self.model_config.dtype)
|
||||
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
is_prompt=False,
|
||||
use_cuda_graph=False,
|
||||
workspace_buffer=self.flashinfer_workspace_buffer,
|
||||
paged_kv_indptr=paged_kv_indptr,
|
||||
paged_kv_indices=paged_kv_indices,
|
||||
paged_kv_last_page_len=paged_kv_last_page_len,
|
||||
num_qo_heads=self.model_config.get_num_attention_heads(
|
||||
self.parallel_config),
|
||||
num_kv_heads=self.model_config.get_num_kv_heads(
|
||||
self.parallel_config),
|
||||
head_dim=self.model_config.get_head_size(),
|
||||
page_size=self.block_size,
|
||||
data_type=kv_cache_dtype)
|
||||
else:
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
is_prompt=False,
|
||||
seq_lens=None,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_query_len=None,
|
||||
max_seq_len=max_seq_len,
|
||||
subquery_start_loc=None,
|
||||
seq_start_loc=None,
|
||||
context_lens_tensor=None,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=use_captured_graph,
|
||||
)
|
||||
return PrepareDecodeMetadata(
|
||||
input_tokens=input_tokens,
|
||||
input_positions=input_positions,
|
||||
|
Loading…
x
Reference in New Issue
Block a user