[Model][MiniMaxText01] Support MiniMaxText01 model inference (#13454)
Signed-off-by: qscqesze <475517977@qq.com> Co-authored-by: qingjun <qingjun@minimaxi.com> Co-authored-by: qscqesze <475517977@qq.com>
This commit is contained in:
parent
93491aefc7
commit
9ef98d527e
@ -503,6 +503,11 @@ See [this page](#generative-models) for more information on how to use generativ
|
||||
* `xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc.
|
||||
* ✅︎
|
||||
* ✅︎
|
||||
- * `MiniMaxText01ForCausalLM`
|
||||
* MiniMax-Text
|
||||
* `MiniMaxAI/MiniMax-Text-01`, etc.
|
||||
*
|
||||
* ✅︎
|
||||
- * `Zamba2ForCausalLM`
|
||||
* Zamba2
|
||||
* `Zyphra/Zamba2-7B-instruct`, `Zyphra/Zamba2-2.7B-instruct`, `Zyphra/Zamba2-1.2B-instruct`, etc.
|
||||
|
286
tests/kernels/test_lightning_attn.py
Normal file
286
tests/kernels/test_lightning_attn.py
Normal file
@ -0,0 +1,286 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.lightning_attn import (
|
||||
linear_decode_forward_triton)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
NUM_HEADS = [4, 8]
|
||||
HEAD_SIZES = [64]
|
||||
BATCH_SIZES = [1, 2]
|
||||
SEQ_LENGTHS = [16]
|
||||
DTYPES = [torch.float32]
|
||||
|
||||
|
||||
def reference_lightning_attention(q, k, v, ed, block_size, kv_history):
|
||||
"""Reference implementation of lightning attention core algorithm
|
||||
|
||||
The difference from the main implementation is that this processes
|
||||
each step sequentially, instead of using parallelized triton kernels
|
||||
"""
|
||||
B, H, S, D = q.shape
|
||||
E = v.shape[-1]
|
||||
dtype = q.dtype
|
||||
output = torch.zeros((B, H, S, E), dtype=dtype, device=q.device)
|
||||
|
||||
# Use clone() to ensure an independent copy
|
||||
if kv_history is None:
|
||||
kv_cache = torch.zeros((B, H, D, E), dtype=dtype, device=q.device)
|
||||
else:
|
||||
kv_cache = kv_history.clone()
|
||||
|
||||
# More efficient implementation
|
||||
# Convert decay factors to matrix form
|
||||
if ed.dim() == 1:
|
||||
decay = torch.exp(-ed).view(1, -1, 1, 1)
|
||||
else:
|
||||
decay = torch.exp(-ed)
|
||||
|
||||
for b in range(B):
|
||||
for step in range(S):
|
||||
# Process all heads at once for this position
|
||||
q_bs = q[b, :, step] # [H, D]
|
||||
k_bs = k[b, :, step] # [H, D]
|
||||
v_bs = v[b, :, step] # [H, E]
|
||||
|
||||
# Calculate KV outer products for all heads
|
||||
for h in range(H):
|
||||
# Calculate KV outer product
|
||||
kv_outer = torch.outer(k_bs[h], v_bs[h])
|
||||
|
||||
# Update KV cache with decay
|
||||
# Note: Using the same order as in the Triton kernel
|
||||
kv_cache[b, h] = decay[0, h, 0, 0] * kv_cache[b, h] + kv_outer
|
||||
|
||||
# Calculate attention output
|
||||
output[b, h, step] = torch.matmul(q_bs[h], kv_cache[b, h])
|
||||
|
||||
# Match the shape returned by the actual implementation
|
||||
# The actual implementation returns a tensor of shape [B, H, 2, D, E]
|
||||
# where dimension 2 contains both KV and KV history
|
||||
kv_reshaped = kv_cache.unsqueeze(2) # [B, H, 1, D, E]
|
||||
final_kv_cache = torch.cat([kv_reshaped, kv_reshaped],
|
||||
dim=2) # [B, H, 2, D, E]
|
||||
|
||||
return output, final_kv_cache
|
||||
|
||||
|
||||
def reference_linear_decode(q, k, v, kv_caches, slope_rate, slot_idx):
|
||||
"""Reference implementation: linear attention decode function"""
|
||||
B, H, _, D = q.shape
|
||||
output = torch.zeros(B, H * D, dtype=q.dtype, device=q.device)
|
||||
|
||||
# Calculate decay factors once (more efficient)
|
||||
decay = torch.exp(-slope_rate).view(-1, 1, 1) # [H, 1, 1]
|
||||
|
||||
# Process each batch
|
||||
for b in range(B):
|
||||
slot_id = slot_idx[b].item()
|
||||
|
||||
# Skip padding positions
|
||||
if slot_id == -1:
|
||||
continue
|
||||
|
||||
# Process all heads at once for this batch
|
||||
q_b = q[b, :, 0] # [H, D]
|
||||
k_b = k[b, :, 0] # [H, D]
|
||||
v_b = v[b, :, 0] # [H, D]
|
||||
|
||||
# Process each attention head
|
||||
for h in range(H):
|
||||
# Get current query, key and value
|
||||
q_bh = q_b[h]
|
||||
k_bh = k_b[h]
|
||||
v_bh = v_b[h]
|
||||
|
||||
# Get cache
|
||||
kv_cache_old = kv_caches[b, h]
|
||||
|
||||
# Calculate new key-value outer product
|
||||
kv_outer = torch.outer(k_bh, v_bh)
|
||||
|
||||
# Apply decay and update cache
|
||||
kv_new = kv_outer + decay[h, 0, 0] * kv_cache_old
|
||||
|
||||
# Calculate output
|
||||
out_h = torch.matmul(q_bh, kv_new)
|
||||
|
||||
# Update output and cache
|
||||
output[b, h * D:(h + 1) * D] = out_h
|
||||
kv_caches[b, h] = kv_new
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@torch.inference_mode()
|
||||
def test_linear_decode_forward_triton(
|
||||
batch_size: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
torch.set_default_device("cuda")
|
||||
torch.manual_seed(42)
|
||||
torch.cuda.manual_seed_all(42)
|
||||
current_platform.seed_everything(42)
|
||||
base = 0.01
|
||||
q = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
|
||||
k = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
|
||||
v = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
|
||||
|
||||
kv_caches = base * torch.randn(batch_size,
|
||||
num_heads,
|
||||
head_size,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
|
||||
kv_caches_copy = kv_caches.clone()
|
||||
|
||||
slope_rate = torch.zeros(num_heads, device="cuda")
|
||||
for h in range(num_heads):
|
||||
slope_rate[h] = 0.1 * (h + 1)
|
||||
|
||||
slot_idx = torch.arange(batch_size, device="cuda")
|
||||
|
||||
triton_output = linear_decode_forward_triton(q, k, v, kv_caches,
|
||||
slope_rate, slot_idx)
|
||||
|
||||
reference_output = reference_linear_decode(q, k, v, kv_caches_copy,
|
||||
slope_rate, slot_idx)
|
||||
torch.testing.assert_close(triton_output,
|
||||
reference_output,
|
||||
rtol=1e-1,
|
||||
atol=1e-1)
|
||||
torch.testing.assert_close(kv_caches, kv_caches_copy, rtol=1e-1, atol=1e-1)
|
||||
|
||||
assert triton_output.shape == (batch_size, num_heads * head_size)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@torch.inference_mode()
|
||||
def test_linear_decode_forward_triton_with_padding(
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
torch.set_default_device("cuda")
|
||||
torch.manual_seed(42)
|
||||
torch.cuda.manual_seed_all(42)
|
||||
current_platform.seed_everything(42)
|
||||
|
||||
batch_size = 4
|
||||
base = 0.01
|
||||
q = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
|
||||
k = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
|
||||
v = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
|
||||
|
||||
kv_caches = base * torch.randn(batch_size,
|
||||
num_heads,
|
||||
head_size,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
|
||||
kv_caches_copy = kv_caches.clone()
|
||||
|
||||
slope_rate = torch.zeros(num_heads, device="cuda")
|
||||
for h in range(num_heads):
|
||||
slope_rate[h] = 0.1 * (h + 1)
|
||||
|
||||
slot_idx = torch.tensor([0, 1, -1, 2], device="cuda")
|
||||
|
||||
triton_output = linear_decode_forward_triton(q, k, v, kv_caches,
|
||||
slope_rate, slot_idx)
|
||||
|
||||
reference_output = reference_linear_decode(q, k, v, kv_caches_copy,
|
||||
slope_rate, slot_idx)
|
||||
|
||||
padding_mask = (slot_idx
|
||||
!= -1).unsqueeze(1).expand(-1, num_heads * head_size)
|
||||
|
||||
triton_masked = triton_output[padding_mask]
|
||||
reference_masked = reference_output[padding_mask]
|
||||
|
||||
atol, rtol = 1.5e-1, 1.5e-1
|
||||
|
||||
valid_indices = slot_idx != -1
|
||||
|
||||
for i in range(batch_size):
|
||||
if valid_indices[i] > 0:
|
||||
torch.testing.assert_close(kv_caches[i],
|
||||
kv_caches_copy[i],
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
|
||||
torch.testing.assert_close(triton_masked,
|
||||
reference_masked,
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
|
||||
assert triton_output.shape == (batch_size, num_heads * head_size)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("seq_len", SEQ_LENGTHS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@torch.inference_mode()
|
||||
def test_lightning_attention_reference(
|
||||
batch_size: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
seq_len: int,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
torch.set_default_device("cuda")
|
||||
torch.manual_seed(42)
|
||||
torch.cuda.manual_seed_all(42)
|
||||
current_platform.seed_everything(42)
|
||||
|
||||
base = 0.01
|
||||
q = base * torch.randn(
|
||||
batch_size, num_heads, seq_len, head_size, dtype=dtype)
|
||||
k = base * torch.randn(
|
||||
batch_size, num_heads, seq_len, head_size, dtype=dtype)
|
||||
v = base * torch.randn(
|
||||
batch_size, num_heads, seq_len, head_size, dtype=dtype)
|
||||
|
||||
ed = torch.zeros(num_heads, device="cuda")
|
||||
for h in range(num_heads):
|
||||
ed[h] = 0.1 * (h + 1)
|
||||
|
||||
kv_history = base * torch.randn(batch_size,
|
||||
num_heads,
|
||||
head_size,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
|
||||
kv_history_clone = kv_history.clone()
|
||||
|
||||
ref_output, ref_kv_cache = reference_lightning_attention(
|
||||
q, k, v, ed, 256, kv_history)
|
||||
|
||||
from vllm.model_executor.layers.lightning_attn import lightning_attention
|
||||
actual_output, actual_kv_cache = lightning_attention(
|
||||
q, k, v, ed, 256, kv_history_clone)
|
||||
|
||||
atol, rtol = 1.5e-1, 1.5e-1
|
||||
torch.testing.assert_close(ref_output, actual_output, rtol=rtol, atol=atol)
|
||||
torch.testing.assert_close(ref_kv_cache,
|
||||
actual_kv_cache,
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
|
||||
assert ref_output.shape == (batch_size, num_heads, seq_len, head_size)
|
||||
assert ref_kv_cache.shape == actual_kv_cache.shape
|
@ -176,6 +176,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
trust_remote_code=True),
|
||||
"MiniCPM3ForCausalLM": _HfExamplesInfo("openbmb/MiniCPM3-4B",
|
||||
trust_remote_code=True),
|
||||
"MiniMaxText01ForCausalLM": _HfExamplesInfo("MiniMaxAI/MiniMax-Text-01",
|
||||
trust_remote_code=True),
|
||||
"MistralForCausalLM": _HfExamplesInfo("mistralai/Mistral-7B-Instruct-v0.1"),
|
||||
"MixtralForCausalLM": _HfExamplesInfo("mistralai/Mixtral-8x7B-Instruct-v0.1"), # noqa: E501
|
||||
"QuantMixtralForCausalLM": _HfExamplesInfo("mistral-community/Mixtral-8x22B-v0.1-AWQ"), # noqa: E501
|
||||
|
@ -971,26 +971,34 @@ class ModelConfig:
|
||||
return sum(not bc.attention.no_op
|
||||
for bc in block_configs[start:end])
|
||||
else:
|
||||
# Hybrid model
|
||||
# Hybrid model Jamba
|
||||
layers_block_type_value = getattr(self.hf_config,
|
||||
"layers_block_type", None)
|
||||
if layers_block_type_value is None:
|
||||
raise ValueError("The model is an hybrid without a "
|
||||
"layers_block_type in the hf_config, "
|
||||
"cannot determine the num of "
|
||||
f"{block_type.value} layers")
|
||||
if layers_block_type_value is not None:
|
||||
if hasattr(self.hf_text_config,
|
||||
"model_type") and (self.hf_text_config.model_type
|
||||
== "zamba2"):
|
||||
if attn_block_type:
|
||||
return sum(t == "hybrid"
|
||||
for t in layers_block_type_value[start:end])
|
||||
else:
|
||||
return self.get_num_layers(parallel_config)
|
||||
return sum(t == block_type.value
|
||||
for t in layers_block_type_value[start:end])
|
||||
|
||||
if hasattr(self.hf_text_config,
|
||||
"model_type") and (self.hf_text_config.model_type
|
||||
== "zamba2"):
|
||||
if attn_block_type:
|
||||
return sum(t == "hybrid"
|
||||
for t in layers_block_type_value[start:end])
|
||||
else:
|
||||
return self.get_num_layers(parallel_config)
|
||||
# Hybrid model Minimax
|
||||
attn_type_list = getattr(self.hf_config, "attn_type_list", None)
|
||||
if attn_type_list:
|
||||
return sum(t == 1 for t in attn_type_list[start:end])
|
||||
|
||||
return sum(t == block_type.value
|
||||
for t in layers_block_type_value[start:end])
|
||||
if layers_block_type_value is None and attn_type_list is None:
|
||||
raise ValueError(
|
||||
"The model is an hybrid without a"
|
||||
"layers_block_type or an attn_type_list in the hf_config,"
|
||||
"cannot determine the num of "
|
||||
f"{block_type.value} layers")
|
||||
|
||||
return sum(t == 1 for t in attn_type_list[start:end])
|
||||
|
||||
def get_multimodal_config(self) -> "MultiModalConfig":
|
||||
"""
|
||||
|
@ -303,8 +303,11 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
ctx.seq_group_metadata_list = seq_group_metadata_list
|
||||
ctx.scheduler_outputs = scheduler_outputs
|
||||
|
||||
finished_requests_ids = self.scheduler[
|
||||
virtual_engine].get_and_reset_finished_requests_ids()
|
||||
if not scheduler_outputs.is_empty():
|
||||
# this will cause mamba_cache/minimax_cache failed
|
||||
# to release finished_requests_ids of the last steps
|
||||
finished_requests_ids = self.scheduler[
|
||||
virtual_engine].get_and_reset_finished_requests_ids()
|
||||
|
||||
# Maybe switch from async mode to sync mode
|
||||
if not allow_async_output_proc and len(ctx.output_queue) > 0:
|
||||
|
651
vllm/model_executor/layers/lightning_attn.py
Normal file
651
vllm/model_executor/layers/lightning_attn.py
Normal file
@ -0,0 +1,651 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_diag_kernel(Q, K, V, Out, S, b: tl.constexpr, h: tl.constexpr, n,
|
||||
d: tl.constexpr, e: tl.constexpr, BLOCK: tl.constexpr,
|
||||
NUM_BLOCK, CBLOCK: tl.constexpr):
|
||||
# This kernel computes the diagonal blocks of the attention matrix
|
||||
# Each diagonal block represents attention
|
||||
# where queries attend to keys in the same block
|
||||
off = tl.program_id(0)
|
||||
off_bh = off // NUM_BLOCK # batch-head index
|
||||
off_block = off % NUM_BLOCK # block index within the sequence
|
||||
off_cblock = tl.program_id(1) # sub-block index within a block
|
||||
|
||||
off_h = off_bh % h # head index
|
||||
|
||||
# Calculate base offsets for the current batch and head
|
||||
qk_offset = off_bh * n * d
|
||||
v_offset = off_bh * n * e
|
||||
o_offset = off_bh * n * e
|
||||
|
||||
# Calculate offsets for the current block
|
||||
block_offset = off_block * BLOCK
|
||||
qk_block_offset = block_offset * d
|
||||
v_block_offset = block_offset * e
|
||||
o_block_offset = block_offset * e
|
||||
|
||||
# Calculate offsets for the current sub-block
|
||||
cblock_offset = off_cblock * CBLOCK
|
||||
q_cblock_offset = cblock_offset * d
|
||||
o_cblock_offset = cblock_offset * e
|
||||
|
||||
# Calculate pointers to the query, key, value, and output tensors
|
||||
Q_block_ptr = (Q + qk_offset + qk_block_offset + q_cblock_offset +
|
||||
tl.arange(0, CBLOCK)[:, None] * d +
|
||||
tl.arange(0, d)[None, :])
|
||||
K_trans_block_ptr = (K + qk_offset + qk_block_offset +
|
||||
tl.arange(0, CBLOCK)[None, :] * d +
|
||||
tl.arange(0, d)[:, None])
|
||||
V_block_ptr = (V + v_offset + v_block_offset +
|
||||
tl.arange(0, CBLOCK)[:, None] * e +
|
||||
tl.arange(0, e)[None, :])
|
||||
O_block_ptr = (Out + o_offset + o_block_offset + o_cblock_offset +
|
||||
tl.arange(0, CBLOCK)[:, None] * e +
|
||||
tl.arange(0, e)[None, :])
|
||||
|
||||
# Load the decay rate for the current head
|
||||
S_block_ptr = S + off_h
|
||||
s = tl.load(S_block_ptr)
|
||||
|
||||
i = off_cblock
|
||||
q_index = tl.arange(0, CBLOCK) + i * CBLOCK
|
||||
|
||||
# Load query values
|
||||
q = tl.load(Q_block_ptr,
|
||||
mask=block_offset + q_index[:, None] < n,
|
||||
other=0.0).to(tl.float32)
|
||||
|
||||
# Initialize output accumulator
|
||||
qkv = tl.zeros([CBLOCK, e], dtype=tl.float32)
|
||||
|
||||
# Process all sub-blocks up to and
|
||||
# including the current one (causal attention)
|
||||
for j in range(i + 1):
|
||||
kv_index = tl.arange(0, CBLOCK) + j * CBLOCK
|
||||
diff = q_index[:, None] - kv_index[None, :]
|
||||
s_index = s * diff
|
||||
# Apply causal mask: only attend to positions before the current one
|
||||
s_index = tl.where(diff >= 0, -s_index, float("-inf"))
|
||||
decay = tl.exp(s_index)
|
||||
|
||||
# Load key and value
|
||||
k_trans = tl.load(
|
||||
K_trans_block_ptr,
|
||||
mask=block_offset + kv_index[None, :] < n,
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
v = tl.load(
|
||||
V_block_ptr,
|
||||
mask=block_offset + kv_index[:, None] < n,
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
|
||||
# Compute attention scores and apply decay
|
||||
qk = tl.dot(q, k_trans) * decay
|
||||
|
||||
# Compute weighted values and accumulate
|
||||
qkv += tl.dot(qk, v)
|
||||
|
||||
# Move to the next sub-block
|
||||
K_trans_block_ptr += CBLOCK * d
|
||||
V_block_ptr += CBLOCK * e
|
||||
|
||||
# Store the result
|
||||
tl.store(
|
||||
O_block_ptr,
|
||||
qkv.to(O_block_ptr.dtype.element_ty),
|
||||
mask=block_offset + q_index[:, None] < n,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kv_parallel(
|
||||
K,
|
||||
V,
|
||||
K_decay,
|
||||
KV,
|
||||
b: tl.constexpr,
|
||||
h: tl.constexpr,
|
||||
n,
|
||||
d: tl.constexpr,
|
||||
e: tl.constexpr,
|
||||
BLOCK: tl.constexpr,
|
||||
NUM_BLOCK,
|
||||
D_FBLOCK: tl.constexpr,
|
||||
E_FBLOCK: tl.constexpr,
|
||||
NUM_FBLOCK: tl.constexpr,
|
||||
CBLOCK: tl.constexpr,
|
||||
NUM_CBLOCK: tl.constexpr,
|
||||
):
|
||||
# This kernel computes the key-value outer
|
||||
# products for each block in parallel
|
||||
off_bh = tl.program_id(0) # batch-head index
|
||||
off_block = tl.program_id(1) # block index
|
||||
|
||||
off_h = off_bh % h # head index
|
||||
|
||||
block_offset = off_block * BLOCK
|
||||
|
||||
# Calculate offsets for the current block
|
||||
k_block_offset = block_offset * d
|
||||
v_block_offset = block_offset * e
|
||||
kv_block_offset = off_block * d * e
|
||||
|
||||
# Calculate base offsets for the current batch and head
|
||||
k_offset = off_bh * n * d
|
||||
v_offset = off_bh * n * e
|
||||
kv_offset = off_bh * NUM_BLOCK * d * e
|
||||
|
||||
# Calculate pointers to the key, value, and key-value tensors
|
||||
K_trans_block_ptr = (K + k_offset + k_block_offset +
|
||||
tl.arange(0, CBLOCK)[None, :] * d +
|
||||
tl.arange(0, D_FBLOCK)[:, None])
|
||||
V_block_ptr = (V + v_offset + v_block_offset +
|
||||
tl.arange(0, CBLOCK)[:, None] * e +
|
||||
tl.arange(0, E_FBLOCK)[None, :])
|
||||
KV_block_ptr = (KV + kv_offset + kv_block_offset +
|
||||
tl.arange(0, D_FBLOCK)[:, None] * e +
|
||||
tl.arange(0, E_FBLOCK)[None, :])
|
||||
|
||||
# Load the decay factors for the current head and block
|
||||
k_decay_ptr = (K_decay + off_h * BLOCK + tl.arange(0, CBLOCK)[None, :])
|
||||
|
||||
kv_index = tl.arange(0, CBLOCK)
|
||||
|
||||
# Initialize the key-value outer product accumulator
|
||||
kv = tl.zeros([D_FBLOCK, E_FBLOCK], dtype=tl.float32)
|
||||
|
||||
# Handle the last block which might be smaller than BLOCK
|
||||
if off_block == NUM_BLOCK - 1:
|
||||
split_n = n - (NUM_BLOCK - 1) * BLOCK
|
||||
else:
|
||||
split_n = BLOCK
|
||||
left_shift = tl.cdiv(split_n, CBLOCK) * CBLOCK - split_n
|
||||
num_blocks = min(tl.cdiv(split_n, CBLOCK), NUM_CBLOCK)
|
||||
k_decay_ptr += (NUM_CBLOCK - num_blocks) * CBLOCK
|
||||
|
||||
# Process all sub-blocks in the current block
|
||||
for j in range(num_blocks):
|
||||
left_bound = (1 - j) * left_shift
|
||||
# Load key and value, handling boundary conditions
|
||||
k_trans = tl.load(K_trans_block_ptr - left_shift * d,
|
||||
mask=kv_index[None, :] >= left_bound,
|
||||
other=0.0)
|
||||
v = tl.load(V_block_ptr - left_shift * e,
|
||||
mask=kv_index[:, None] >= left_bound,
|
||||
other=0.0)
|
||||
|
||||
# Load decay factor and compute weighted key-value outer product
|
||||
k_decay = tl.load(k_decay_ptr)
|
||||
kv += tl.dot(k_trans * k_decay, v)
|
||||
|
||||
# Move to the next sub-block
|
||||
K_trans_block_ptr += CBLOCK * d
|
||||
V_block_ptr += CBLOCK * e
|
||||
k_decay_ptr += CBLOCK
|
||||
|
||||
# Store the result
|
||||
tl.store(KV_block_ptr, kv.to(KV_block_ptr.dtype.element_ty))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kv_reduce(S, KV, KV_HISTORY, b: tl.constexpr, h: tl.constexpr, n,
|
||||
d: tl.constexpr, e: tl.constexpr, BLOCK: tl.constexpr,
|
||||
NUM_BLOCK, D_FBLOCK: tl.constexpr, E_FBLOCK: tl.constexpr):
|
||||
# This kernel reduces the key-value outer products
|
||||
# across blocks and updates the KV history
|
||||
off_bh = tl.program_id(0) # batch-head index
|
||||
off_h = off_bh % h # head index
|
||||
|
||||
kv_offset = off_bh * NUM_BLOCK * d * e
|
||||
|
||||
# Calculate pointer to the key-value tensor
|
||||
KV_block_ptr = (KV + kv_offset + tl.arange(0, D_FBLOCK)[:, None] * e +
|
||||
tl.arange(0, E_FBLOCK)[None, :])
|
||||
|
||||
# Load the decay rate for the current head
|
||||
s_ptrs = S + off_h
|
||||
s = tl.load(s_ptrs)
|
||||
|
||||
# Calculate pointer to the key-value history tensor
|
||||
kv_history_offset = off_bh * d * e
|
||||
KV_HISTORY_block_ptr = (KV_HISTORY + kv_history_offset +
|
||||
tl.arange(0, D_FBLOCK)[:, None] * e +
|
||||
tl.arange(0, E_FBLOCK)[None, :])
|
||||
|
||||
# Load the previous key-value history
|
||||
kv_pre = tl.load(KV_HISTORY_block_ptr).to(tl.float32)
|
||||
|
||||
# Process all blocks in reverse order to compute the prefix sum
|
||||
for i in range(NUM_BLOCK):
|
||||
block_size = min(n - i * BLOCK, BLOCK)
|
||||
# Compute decay factor for the current block
|
||||
block_decay = tl.exp(-s.to(tl.float32) * block_size)
|
||||
|
||||
# Load the current key-value outer product
|
||||
kv_cur = tl.load(KV_block_ptr).to(tl.float32)
|
||||
# Store the previous key-value history to the current block
|
||||
tl.store(KV_block_ptr, kv_pre.to(KV_block_ptr.dtype.element_ty))
|
||||
|
||||
# Update the key-value history with the current block
|
||||
kv_pre = block_decay * kv_pre + kv_cur
|
||||
KV_block_ptr += d * e
|
||||
|
||||
# Store the updated key-value history
|
||||
tl.store(KV_HISTORY_block_ptr, kv_pre)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_none_diag_kernel(
|
||||
Q,
|
||||
Out,
|
||||
S,
|
||||
KV,
|
||||
b: tl.constexpr,
|
||||
h: tl.constexpr,
|
||||
n,
|
||||
d: tl.constexpr,
|
||||
e: tl.constexpr,
|
||||
BLOCK: tl.constexpr,
|
||||
NUM_BLOCK,
|
||||
E_FBLOCK: tl.constexpr,
|
||||
CBLOCK: tl.constexpr,
|
||||
NUM_CBLOCK: tl.constexpr,
|
||||
):
|
||||
# This kernel computes the non-diagonal blocks of the attention matrix
|
||||
# Each non-diagonal block represents attention
|
||||
# where queries attend to keys in different blocks
|
||||
off_bh = tl.program_id(0) # batch-head index
|
||||
off_h = off_bh % h # head index
|
||||
|
||||
off_nc = tl.program_id(1)
|
||||
off_n = off_nc // NUM_CBLOCK # block index
|
||||
off_c = off_nc % NUM_CBLOCK # sub-block index
|
||||
off_e = tl.program_id(2) # output feature block index
|
||||
|
||||
n_offset = off_n * BLOCK
|
||||
c_offset = off_c * CBLOCK
|
||||
e_offset = off_e * E_FBLOCK
|
||||
block_offset = n_offset + c_offset
|
||||
|
||||
# Calculate offsets for the current batch, head, and block
|
||||
q_offset = off_bh * n * d + (n_offset + c_offset) * d
|
||||
o_offset = off_bh * n * e + (n_offset + c_offset) * e + e_offset
|
||||
kv_offset = off_bh * NUM_BLOCK * d * e + off_n * d * e + e_offset
|
||||
|
||||
# Calculate pointers to the query, output, and key-value tensors
|
||||
Q_block_ptr = (Q + q_offset + tl.arange(0, CBLOCK)[:, None] * d +
|
||||
tl.arange(0, d)[None, :])
|
||||
O_block_ptr = (Out + o_offset + tl.arange(0, CBLOCK)[:, None] * e +
|
||||
tl.arange(0, E_FBLOCK)[None, :])
|
||||
KV_block_ptr = (KV + kv_offset + tl.arange(0, d)[:, None] * e +
|
||||
tl.arange(0, E_FBLOCK)[None, :])
|
||||
|
||||
# Load the decay rate for the current head
|
||||
S_block_ptr = S + off_h
|
||||
s = tl.load(S_block_ptr)
|
||||
|
||||
c_array = tl.arange(0, CBLOCK)
|
||||
|
||||
# Load the key-value outer product for the current block
|
||||
kv = tl.load(KV_block_ptr).to(tl.float32)
|
||||
q_index = block_offset + tl.arange(0, CBLOCK)
|
||||
|
||||
# Load query values
|
||||
q = tl.load(Q_block_ptr, mask=q_index[:, None] < n,
|
||||
other=0.).to(tl.float32)
|
||||
|
||||
# Compute decay factors for the current sub-block
|
||||
q_decay = tl.exp(-s.to(tl.float32) * (off_c * CBLOCK + c_array[:, None]))
|
||||
|
||||
# Compute non-diagonal attention output
|
||||
qkv_none_diag = tl.dot(q, kv) * q_decay
|
||||
|
||||
# Load diagonal attention output (computed by _fwd_diag_kernel)
|
||||
qkv_diag = tl.load(O_block_ptr, mask=q_index[:, None] < n,
|
||||
other=0.).to(tl.float32)
|
||||
|
||||
# Combine diagonal and non-diagonal attention outputs
|
||||
qkv = qkv_diag + qkv_none_diag
|
||||
|
||||
# Store the result
|
||||
tl.store(O_block_ptr,
|
||||
qkv.to(O_block_ptr.dtype.element_ty),
|
||||
mask=q_index[:, None] < n)
|
||||
|
||||
|
||||
class _attention(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, q, k, v, s, kv_history):
|
||||
# Forward pass of the lightning attention algorithm
|
||||
q = q.contiguous()
|
||||
k = k.contiguous()
|
||||
v = v.contiguous()
|
||||
s = s.contiguous()
|
||||
|
||||
# Check CUDA compute capability
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if capability[0] < 8:
|
||||
raise RuntimeError("Flash attention currently only supported",
|
||||
"for compute capability >= 80")
|
||||
|
||||
# Get input dimensions
|
||||
b, h, n, d = q.shape
|
||||
e = v.shape[-1]
|
||||
|
||||
# Initialize output tensor
|
||||
o = torch.empty((b, h, n, e), dtype=q.dtype, device=q.device)
|
||||
|
||||
# Set block sizes
|
||||
BLOCK = 256
|
||||
NUM_BLOCK = triton.cdiv(n, BLOCK)
|
||||
|
||||
CBLOCK = 32
|
||||
NUM_CBLOCK = BLOCK // CBLOCK
|
||||
assert BLOCK % CBLOCK == 0, "BLOCK must be a multiple of CBLOCK"
|
||||
|
||||
# Compute decay factors for keys
|
||||
array = torch.arange(0, BLOCK, device=q.device) + 1
|
||||
k_decay = torch.exp(-s * (BLOCK - array.reshape(1, -1)))
|
||||
|
||||
# Step 1: Compute diagonal blocks of attention
|
||||
grid = (b * h * NUM_BLOCK, NUM_CBLOCK)
|
||||
_fwd_diag_kernel[grid](q,
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
s,
|
||||
b,
|
||||
h,
|
||||
n,
|
||||
d,
|
||||
e,
|
||||
BLOCK=BLOCK,
|
||||
NUM_BLOCK=NUM_BLOCK,
|
||||
CBLOCK=CBLOCK)
|
||||
|
||||
# Set feature block sizes
|
||||
NUM_FBLOCK = 1
|
||||
D_FBLOCK = d // NUM_FBLOCK
|
||||
assert d % NUM_FBLOCK == 0
|
||||
E_FBLOCK = e // NUM_FBLOCK
|
||||
assert e % NUM_FBLOCK == 0
|
||||
|
||||
CBLOCK = 64
|
||||
NUM_CBLOCK = BLOCK // CBLOCK
|
||||
assert BLOCK % CBLOCK == 0, "BLOCK must be a multiple of CBLOCK"
|
||||
|
||||
# Step 2: Compute key-value outer products for each block in parallel
|
||||
kv = torch.empty((b, h, NUM_BLOCK, d, e),
|
||||
dtype=torch.float32,
|
||||
device=q.device)
|
||||
grid = (b * h, NUM_BLOCK)
|
||||
_fwd_kv_parallel[grid](
|
||||
k,
|
||||
v,
|
||||
k_decay,
|
||||
kv,
|
||||
b,
|
||||
h,
|
||||
n,
|
||||
d,
|
||||
e,
|
||||
BLOCK=BLOCK,
|
||||
NUM_BLOCK=NUM_BLOCK,
|
||||
D_FBLOCK=D_FBLOCK,
|
||||
E_FBLOCK=E_FBLOCK,
|
||||
NUM_FBLOCK=NUM_FBLOCK,
|
||||
CBLOCK=CBLOCK,
|
||||
NUM_CBLOCK=NUM_CBLOCK,
|
||||
)
|
||||
|
||||
# Step 3: Reduce key-value outer products
|
||||
# across blocks and update KV history
|
||||
grid = (b * h, NUM_FBLOCK)
|
||||
_fwd_kv_reduce[grid](s,
|
||||
kv,
|
||||
kv_history,
|
||||
b,
|
||||
h,
|
||||
n,
|
||||
d,
|
||||
e,
|
||||
BLOCK=BLOCK,
|
||||
NUM_BLOCK=NUM_BLOCK,
|
||||
D_FBLOCK=D_FBLOCK,
|
||||
E_FBLOCK=E_FBLOCK)
|
||||
|
||||
# Step 4: Compute non-diagonal blocks of attention
|
||||
grid = (b * h, NUM_BLOCK * NUM_CBLOCK)
|
||||
_fwd_none_diag_kernel[grid](
|
||||
q,
|
||||
o,
|
||||
s,
|
||||
kv,
|
||||
b,
|
||||
h,
|
||||
n,
|
||||
d,
|
||||
e,
|
||||
BLOCK=BLOCK,
|
||||
NUM_BLOCK=NUM_BLOCK,
|
||||
E_FBLOCK=E_FBLOCK,
|
||||
CBLOCK=CBLOCK,
|
||||
NUM_CBLOCK=NUM_CBLOCK,
|
||||
)
|
||||
|
||||
# Save tensors for backward pass
|
||||
ctx.save_for_backward(q, k, v, s, kv)
|
||||
ctx.BLOCK = BLOCK
|
||||
|
||||
return o, torch.cat([kv, kv_history.unsqueeze(2)], dim=2)
|
||||
|
||||
|
||||
# Apply the lightning attention function
|
||||
lightning_attention_ = _attention.apply
|
||||
|
||||
|
||||
def lightning_attention(q, k, v, ed, block_size=256, kv_history=None):
|
||||
"""
|
||||
Apply lightning attention algorithm
|
||||
to compute attention efficiently.
|
||||
|
||||
Args:
|
||||
q: Query tensor of shape [batch, heads, seq_len, dim]
|
||||
k: Key tensor of shape [batch, heads, seq_len, dim]
|
||||
v: Value tensor of shape [batch, heads, seq_len, dim_v]
|
||||
ed: Decay rate tensor of shape [heads]
|
||||
block_size: Size of blocks for block-sparse attention
|
||||
kv_history: Optional key-value history from previous computations
|
||||
|
||||
Returns:
|
||||
output: Attention output
|
||||
kv: Updated key-value history
|
||||
"""
|
||||
d = q.shape[-1]
|
||||
e = v.shape[-1]
|
||||
|
||||
if ed.dim() == 1:
|
||||
ed = ed.view(1, -1, 1, 1)
|
||||
|
||||
# Split the computation into chunks for better parallelism
|
||||
m = 128 if d >= 128 else 64
|
||||
assert d % m == 0, f"Dimension d ({d}) must be divisible by m ({m})"
|
||||
arr = [m * i for i in range(d // m + 1)]
|
||||
if arr[-1] != d:
|
||||
arr.append(d)
|
||||
n = len(arr)
|
||||
output = 0
|
||||
|
||||
# Initialize or clone key-value history
|
||||
if kv_history is None:
|
||||
kv_history = torch.zeros((q.shape[0], q.shape[1], d, e),
|
||||
dtype=torch.float32,
|
||||
device=q.device)
|
||||
else:
|
||||
kv_history = kv_history.clone().contiguous()
|
||||
|
||||
# Process each chunk and accumulate results
|
||||
for i in range(n - 1):
|
||||
s = arr[i]
|
||||
e = arr[i + 1]
|
||||
q1 = q[..., s:e]
|
||||
k1 = k[..., s:e]
|
||||
o, kv = lightning_attention_(q1, k1, v, ed, kv_history)
|
||||
output = output + o
|
||||
return output, kv
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _linear_attn_decode_kernel(
|
||||
q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
kv_cache_ptr,
|
||||
slope_rate,
|
||||
slot_idx,
|
||||
output_ptr,
|
||||
D: tl.constexpr,
|
||||
qkv_b_stride,
|
||||
qkv_h_stride,
|
||||
cache_b_stride,
|
||||
cache_h_stride,
|
||||
cache_d0_stride,
|
||||
cache_d1_stride,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Kernel for linear attention decoding with KV cache.
|
||||
|
||||
This kernel computes attention for a single token using the KV cache.
|
||||
"""
|
||||
pid_b = tl.program_id(0) # batch index
|
||||
pid_h = tl.program_id(1) # head index
|
||||
pid_d = tl.program_id(2) # dimension block index
|
||||
|
||||
# Load slot index for the current batch
|
||||
slot_id = tl.load(slot_idx + pid_b)
|
||||
|
||||
# Skip if slot_id is -1 (padding)
|
||||
if slot_id == -1:
|
||||
return
|
||||
|
||||
batch_id = pid_b
|
||||
head_id = pid_h
|
||||
|
||||
# Load decay rate for the current head
|
||||
ratio = tl.load(slope_rate + pid_h)
|
||||
|
||||
# Calculate offsets for dimensions
|
||||
qk_d_offsets = tl.arange(0, D)
|
||||
v_d_offsets = tl.arange(0, BLOCK_SIZE) + pid_d * BLOCK_SIZE
|
||||
cache_d_offsets = qk_d_offsets[:, None] * cache_d0_stride + v_d_offsets[
|
||||
None, :] * cache_d1_stride
|
||||
|
||||
# Calculate offsets for the current batch and head
|
||||
q_offset = batch_id * qkv_b_stride + head_id * qkv_h_stride
|
||||
k_offset = batch_id * qkv_b_stride + head_id * qkv_h_stride
|
||||
v_offset = batch_id * qkv_b_stride + head_id * qkv_h_stride
|
||||
|
||||
cache_offset = slot_id * cache_b_stride + head_id * cache_h_stride
|
||||
|
||||
# Create masks for loading tensors
|
||||
qk_mask = qk_d_offsets < D
|
||||
v_mask = v_d_offsets < D
|
||||
|
||||
# Load query, key, and value tensors
|
||||
q = tl.load(q_ptr + q_offset + qk_d_offsets, mask=qk_mask, other=0.0)
|
||||
k = tl.load(k_ptr + k_offset + qk_d_offsets, mask=qk_mask, other=0.0)
|
||||
v = tl.load(v_ptr + v_offset + v_d_offsets, mask=v_mask, other=0.0)
|
||||
|
||||
# Compute key-value outer product
|
||||
kv_outer = k[:, None] * v[None, :]
|
||||
kv_mask = qk_mask[:, None] & v_mask[None, :]
|
||||
|
||||
# Apply decay to previous KV cache
|
||||
ratio = tl.exp(-ratio)
|
||||
kv_ptr = kv_cache_ptr + cache_offset + cache_d_offsets
|
||||
kv_cache_old = tl.load(kv_ptr, mask=kv_mask, other=0.0)
|
||||
kv_outer = kv_outer + ratio * kv_cache_old
|
||||
|
||||
# Compute attention output
|
||||
output = q[:, None].to(tl.float32) * kv_outer
|
||||
output = tl.sum(output, axis=0)
|
||||
|
||||
# Update KV cache and store output
|
||||
tl.store(kv_ptr, kv_outer, mask=kv_mask)
|
||||
tl.store(output_ptr + q_offset + v_d_offsets, output, mask=v_mask)
|
||||
|
||||
|
||||
def linear_decode_forward_triton(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
kv_caches: torch.Tensor,
|
||||
slope_rate: torch.Tensor,
|
||||
slot_idx: torch.Tensor,
|
||||
BLOCK_SIZE: int = 32,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Perform linear attention decoding using Triton kernels.
|
||||
|
||||
Args:
|
||||
q: Query tensor of shape [B, H, 1, D]
|
||||
k: Key tensor of shape [B, H, 1, D]
|
||||
v: Value tensor of shape [B, H, 1, D]
|
||||
kv_caches: Key-value cache tensor
|
||||
slope_rate: Decay rate tensor
|
||||
slot_idx: Slot indices for batches
|
||||
BLOCK_SIZE: Size of blocks for processing
|
||||
|
||||
Returns:
|
||||
output: Attention output tensor
|
||||
"""
|
||||
B, H, _, D = q.shape
|
||||
assert k.shape == (B, H, 1, D)
|
||||
assert v.shape == (B, H, 1, D)
|
||||
|
||||
# Initialize output tensor
|
||||
output = torch.empty_like(q)
|
||||
|
||||
# Set grid dimensions for the kernel
|
||||
grid = (B, H, D // BLOCK_SIZE)
|
||||
|
||||
# Calculate strides for tensors
|
||||
qkv_b_stride = q.stride(0)
|
||||
qkv_h_stride = q.stride(1)
|
||||
|
||||
cache_b_stride = kv_caches.stride(0)
|
||||
cache_h_stride = kv_caches.stride(1)
|
||||
cache_d0_stride = kv_caches.stride(2)
|
||||
cache_d1_stride = kv_caches.stride(3)
|
||||
|
||||
# Launch the kernel
|
||||
_linear_attn_decode_kernel[grid](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
kv_caches,
|
||||
slope_rate,
|
||||
slot_idx,
|
||||
output,
|
||||
D,
|
||||
qkv_b_stride,
|
||||
qkv_h_stride,
|
||||
cache_b_stride,
|
||||
cache_h_stride,
|
||||
cache_d0_stride,
|
||||
cache_d1_stride,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
|
||||
# Reshape output and return
|
||||
output = rearrange(output, "b h n d -> b n (h d)")
|
||||
return output.squeeze(1).contiguous()
|
136
vllm/model_executor/models/constant_size_cache.py
Normal file
136
vllm/model_executor/models/constant_size_cache.py
Normal file
@ -0,0 +1,136 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
|
||||
|
||||
class ConstantSizeCache(ABC):
|
||||
"""
|
||||
Abstract base class for managing constant size caches
|
||||
like Mamba and Minimax.
|
||||
"""
|
||||
|
||||
def __init__(self, max_batch_size: int):
|
||||
# Maps between the request id and a dict that maps between the seq_id
|
||||
# and its index inside the cache
|
||||
self.cache_indices_mapping: Dict[str, Dict[int, int]] = {}
|
||||
self.free_cache_indices = list(range(max_batch_size))
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def cache(self) -> Any:
|
||||
"""Return the underlying cache tensor(s)"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _copy_cache(self, from_index: int, to_index: int):
|
||||
"""Copy cache data from one index to another"""
|
||||
pass
|
||||
|
||||
def current_run_tensors(self, **kwargs) -> Tuple:
|
||||
"""
|
||||
Return the tensors for the current run's conv and ssm state.
|
||||
"""
|
||||
if "seqlen_agnostic_capture_inputs" not in kwargs:
|
||||
# We get here only on Prefill/Eager mode runs
|
||||
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
|
||||
finished_requests_ids = kwargs["finished_requests_ids"]
|
||||
|
||||
self._release_finished_requests(finished_requests_ids)
|
||||
state_indices = self._prepare_current_run_cache(
|
||||
request_ids_to_seq_ids, finished_requests_ids)
|
||||
|
||||
state_indices_tensor = torch.as_tensor(state_indices,
|
||||
dtype=torch.int32,
|
||||
device="cuda")
|
||||
cache_tensors = self.cache
|
||||
else:
|
||||
# CUDA graph capturing runs
|
||||
cache_tensors, state_indices_tensor = kwargs[
|
||||
"seqlen_agnostic_capture_inputs"]
|
||||
|
||||
return (cache_tensors, state_indices_tensor)
|
||||
|
||||
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
|
||||
"""
|
||||
Copy the relevant state_indices into the CUDA graph input buffer
|
||||
"""
|
||||
assert all(
|
||||
key in kwargs
|
||||
for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
|
||||
finished_requests_ids = kwargs["finished_requests_ids"]
|
||||
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
|
||||
assert "seqlen_agnostic_capture_inputs" in input_buffers
|
||||
_, input_state_indices_buffer = input_buffers[
|
||||
"seqlen_agnostic_capture_inputs"]
|
||||
|
||||
self._release_finished_requests(finished_requests_ids)
|
||||
state_indices = self._prepare_current_run_cache(
|
||||
request_ids_to_seq_ids, finished_requests_ids)
|
||||
cuda_graph_pad_len = input_state_indices_buffer.shape[0] - len(
|
||||
state_indices)
|
||||
state_indices.extend([PAD_SLOT_ID] * cuda_graph_pad_len)
|
||||
|
||||
input_state_indices_buffer.copy_(
|
||||
torch.as_tensor(state_indices, dtype=torch.int32, device="cuda"))
|
||||
|
||||
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
||||
"""
|
||||
Provide the CUDA graph capture runs with a buffer in adjusted size.
|
||||
The buffer is used to maintain the Cache during the CUDA graph replay
|
||||
runs.
|
||||
"""
|
||||
state_indices_tensor = torch.as_tensor([PAD_SLOT_ID] * batch_size,
|
||||
dtype=torch.int32,
|
||||
device="cuda")
|
||||
return (self.cache, state_indices_tensor)
|
||||
|
||||
def _assign_seq_id_to_cache_index(self, cur_rid: str, seq_id: int,
|
||||
finished_requests_ids) -> int:
|
||||
"""
|
||||
Assign (req_id,seq_id) pair to a `destination_index` index, if
|
||||
already occupied, move the occupying index to a free index.
|
||||
"""
|
||||
if cur_rid in finished_requests_ids:
|
||||
# set as pad, do not allocate destination index
|
||||
return PAD_SLOT_ID
|
||||
elif cur_rid not in self.cache_indices_mapping:
|
||||
destination_index = self.free_cache_indices.pop()
|
||||
self.cache_indices_mapping[cur_rid] = {seq_id: destination_index}
|
||||
return destination_index
|
||||
elif seq_id not in (seq_ids2indices :=
|
||||
self.cache_indices_mapping[cur_rid]):
|
||||
# parallel sampling , where n > 1, assume prefill have
|
||||
# already happened, so we copy the
|
||||
# existing cache into the siblings seq_ids caches
|
||||
index_exists = next(iter(seq_ids2indices.values()))
|
||||
# case of decoding n>1, copy prefill cache to decoding indices
|
||||
destination_index = self.free_cache_indices.pop()
|
||||
self._copy_cache(from_index=index_exists,
|
||||
to_index=destination_index)
|
||||
self.cache_indices_mapping[cur_rid][seq_id] = destination_index
|
||||
return destination_index
|
||||
else:
|
||||
return self.cache_indices_mapping[cur_rid][seq_id]
|
||||
|
||||
def _prepare_current_run_cache(
|
||||
self, request_ids_to_seq_ids: Dict[str, list[int]],
|
||||
finished_requests_ids: List[str]) -> List[int]:
|
||||
return [
|
||||
self._assign_seq_id_to_cache_index(req_id, seq_id,
|
||||
finished_requests_ids)
|
||||
for req_id, seq_ids in request_ids_to_seq_ids.items()
|
||||
for seq_id in seq_ids
|
||||
]
|
||||
|
||||
def _release_finished_requests(self,
|
||||
finished_seq_groups_req_ids: List[str]):
|
||||
for req_id in finished_seq_groups_req_ids:
|
||||
if req_id in self.cache_indices_mapping:
|
||||
for seq_id in self.cache_indices_mapping[req_id]:
|
||||
self.free_cache_indices.append(
|
||||
self.cache_indices_mapping[req_id][seq_id])
|
||||
self.cache_indices_mapping.pop(req_id)
|
@ -1,12 +1,13 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.models.constant_size_cache import ConstantSizeCache
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -21,7 +22,7 @@ class MambaCacheParams:
|
||||
self.state_indices_tensor)
|
||||
|
||||
|
||||
class MambaCacheManager:
|
||||
class MambaCacheManager(ConstantSizeCache):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, dtype: torch.dtype,
|
||||
num_mamba_layers: int, conv_state_shape: Tuple[int, int],
|
||||
@ -32,6 +33,9 @@ class MambaCacheManager:
|
||||
if not vllm_config.model_config.enforce_eager:
|
||||
max_batch_size = vllm_config.pad_for_cudagraph(max_batch_size)
|
||||
|
||||
# Initialize parent class
|
||||
super().__init__(max_batch_size)
|
||||
|
||||
conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
|
||||
conv_state_shape,
|
||||
dtype=dtype,
|
||||
@ -41,126 +45,32 @@ class MambaCacheManager:
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
|
||||
self.mamba_cache = (conv_state, temporal_state)
|
||||
self._mamba_cache = (conv_state, temporal_state)
|
||||
|
||||
# Maps between the request id and a dict that maps between the seq_id
|
||||
# and its index inside the self.mamba_cache
|
||||
self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {}
|
||||
self.free_cache_indices = list(range(max_batch_size))
|
||||
@property
|
||||
def cache(self):
|
||||
return self._mamba_cache
|
||||
|
||||
def _copy_cache(self, from_index: int, to_index: int):
|
||||
for cache_t in self.cache:
|
||||
cache_t[:, to_index].copy_(cache_t[:, from_index],
|
||||
non_blocking=True)
|
||||
|
||||
def current_run_tensors(self, **kwargs) -> MambaCacheParams:
|
||||
"""
|
||||
Return the tensors for the current run's conv and ssm state.
|
||||
"""
|
||||
if "seqlen_agnostic_capture_inputs" not in kwargs:
|
||||
# We get here only on Prefill/Eager mode runs
|
||||
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
|
||||
finished_requests_ids = kwargs["finished_requests_ids"]
|
||||
|
||||
self._release_finished_requests(finished_requests_ids)
|
||||
state_indices = self._prepare_current_run_mamba_cache(
|
||||
request_ids_to_seq_ids, finished_requests_ids)
|
||||
|
||||
state_indices_tensor = torch.as_tensor(state_indices,
|
||||
dtype=torch.int32,
|
||||
device="cuda")
|
||||
mamba_cache_tensors = self.mamba_cache
|
||||
|
||||
else:
|
||||
# CUDA graph capturing runs
|
||||
(mamba_cache_tensors,
|
||||
state_indices_tensor) = kwargs["seqlen_agnostic_capture_inputs"]
|
||||
|
||||
return MambaCacheParams(mamba_cache_tensors[0], mamba_cache_tensors[1],
|
||||
cache_tensors, state_indices_tensor = super().current_run_tensors(
|
||||
**kwargs)
|
||||
return MambaCacheParams(cache_tensors[0], cache_tensors[1],
|
||||
state_indices_tensor)
|
||||
|
||||
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
|
||||
"""
|
||||
Copy the relevant state_indices into the CUDA graph input buffer
|
||||
"""
|
||||
assert all(
|
||||
key in kwargs
|
||||
for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
|
||||
finished_requests_ids = kwargs["finished_requests_ids"]
|
||||
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
|
||||
assert "seqlen_agnostic_capture_inputs" in input_buffers
|
||||
_, input_state_indices_buffer = input_buffers[
|
||||
"seqlen_agnostic_capture_inputs"]
|
||||
|
||||
self._release_finished_requests(finished_requests_ids)
|
||||
state_indices = self._prepare_current_run_mamba_cache(
|
||||
request_ids_to_seq_ids, finished_requests_ids)
|
||||
cuda_graph_pad_len = input_state_indices_buffer.shape[0] - len(
|
||||
state_indices)
|
||||
state_indices.extend([PAD_SLOT_ID] * cuda_graph_pad_len)
|
||||
|
||||
input_state_indices_buffer.copy_(
|
||||
torch.as_tensor(state_indices, dtype=torch.int32, device="cuda"))
|
||||
|
||||
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
||||
"""
|
||||
Provide the CUDA graph capture runs with a buffer in adjusted size.
|
||||
The buffer is used to maintain the Mamba Cache during the CUDA graph
|
||||
replay runs.
|
||||
"""
|
||||
state_indices_tensor = torch.as_tensor([PAD_SLOT_ID] * batch_size,
|
||||
dtype=torch.int32,
|
||||
device="cuda")
|
||||
return (self.mamba_cache, state_indices_tensor)
|
||||
|
||||
def _copy_mamba_cache(self, from_index: int, to_index: int):
|
||||
assert len(self.mamba_cache) > 0
|
||||
for cache_t in self.mamba_cache:
|
||||
cache_t[:, to_index].copy_(cache_t[:, from_index],
|
||||
non_blocking=True)
|
||||
|
||||
def _assign_seq_id_to_cache_index(self, cur_rid: str, seq_id: int,
|
||||
finished_requests_ids) -> int:
|
||||
"""
|
||||
Assign (req_id,seq_id) pair to a `destination_index` index, if
|
||||
already occupied, move the occupying index to a free index.
|
||||
"""
|
||||
if cur_rid in finished_requests_ids:
|
||||
# set as pad, do not allocate destination index
|
||||
return PAD_SLOT_ID
|
||||
elif cur_rid not in self.mamba_cache_indices_mapping:
|
||||
destination_index = self.free_cache_indices.pop()
|
||||
self.mamba_cache_indices_mapping[cur_rid] = {
|
||||
seq_id: destination_index
|
||||
}
|
||||
return destination_index
|
||||
elif seq_id not in (seq_ids2indices :=
|
||||
self.mamba_cache_indices_mapping[cur_rid]):
|
||||
# parallel sampling , where n > 1, assume prefill have
|
||||
# already happened, so we copy the
|
||||
# existing cache into the siblings seq_ids caches
|
||||
index_exists = next(iter(seq_ids2indices.values()))
|
||||
# case of decoding n>1, copy prefill cache to decoding indices
|
||||
destination_index = self.free_cache_indices.pop()
|
||||
self._copy_mamba_cache(from_index=index_exists,
|
||||
to_index=destination_index)
|
||||
self.mamba_cache_indices_mapping[cur_rid][
|
||||
seq_id] = destination_index
|
||||
return destination_index
|
||||
else:
|
||||
# already exists
|
||||
return self.mamba_cache_indices_mapping[cur_rid][seq_id]
|
||||
|
||||
def _prepare_current_run_mamba_cache(
|
||||
self, request_ids_to_seq_ids: Dict[str, list[int]],
|
||||
finished_requests_ids: List[str]) -> List[int]:
|
||||
return [
|
||||
self._assign_seq_id_to_cache_index(req_id, seq_id,
|
||||
finished_requests_ids)
|
||||
for req_id, seq_ids in request_ids_to_seq_ids.items()
|
||||
for seq_id in seq_ids
|
||||
]
|
||||
|
||||
def _release_finished_requests(self,
|
||||
finished_seq_groups_req_ids: List[str]):
|
||||
for req_id in finished_seq_groups_req_ids:
|
||||
if req_id in self.mamba_cache_indices_mapping:
|
||||
for seq_id in self.mamba_cache_indices_mapping[req_id]:
|
||||
self.free_cache_indices.append(
|
||||
self.mamba_cache_indices_mapping[req_id][seq_id])
|
||||
self.mamba_cache_indices_mapping.pop(req_id)
|
||||
return self._mamba_cache, torch.as_tensor([PAD_SLOT_ID] * batch_size,
|
||||
dtype=torch.int32,
|
||||
device="cuda")
|
||||
|
35
vllm/model_executor/models/minimax_cache.py
Normal file
35
vllm/model_executor/models/minimax_cache.py
Normal file
@ -0,0 +1,35 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.models.constant_size_cache import ConstantSizeCache
|
||||
|
||||
|
||||
@dataclass
|
||||
class MinimaxCacheParams:
|
||||
minimax_cache: torch.Tensor = torch.Tensor()
|
||||
state_indices_tensor: torch.Tensor = torch.Tensor()
|
||||
|
||||
def at_layer_idx(self, layer_idx):
|
||||
return MinimaxCacheParams(self.minimax_cache[layer_idx, ...],
|
||||
self.state_indices_tensor)
|
||||
|
||||
|
||||
class MinimaxCacheManager(ConstantSizeCache):
|
||||
|
||||
def __init__(self, dtype, cache_shape):
|
||||
super().__init__(cache_shape[1]) # max_batch_size is cache_shape[1]
|
||||
self._minimax_cache = torch.empty(size=cache_shape,
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
|
||||
@property
|
||||
def cache(self):
|
||||
return self._minimax_cache
|
||||
|
||||
def _copy_cache(self, from_index: int, to_index: int):
|
||||
assert len(self.cache) > 0
|
||||
for cache_t in self.cache:
|
||||
cache_t[:, to_index].copy_(cache_t[:, from_index],
|
||||
non_blocking=True)
|
1273
vllm/model_executor/models/minimax_text_01.py
Normal file
1273
vllm/model_executor/models/minimax_text_01.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -35,6 +35,7 @@ _TEXT_GENERATION_MODELS = {
|
||||
"AquilaModel": ("llama", "LlamaForCausalLM"),
|
||||
"AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2
|
||||
"ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
|
||||
"MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
|
||||
# baichuan-7b, upper case 'C' in the class name
|
||||
"BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
|
||||
# baichuan-13b, lower case 'c' in the class name
|
||||
|
Loading…
x
Reference in New Issue
Block a user