[Model][MiniMaxText01] Support MiniMaxText01 model inference (#13454)

Signed-off-by: qscqesze <475517977@qq.com>
Co-authored-by: qingjun <qingjun@minimaxi.com>
Co-authored-by: qscqesze <475517977@qq.com>
This commit is contained in:
Gerald 2025-04-02 04:23:55 +08:00 committed by GitHub
parent 93491aefc7
commit 9ef98d527e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 2439 additions and 129 deletions

View File

@ -503,6 +503,11 @@ See [this page](#generative-models) for more information on how to use generativ
* `xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc. * `xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc.
* ✅︎ * ✅︎
* ✅︎ * ✅︎
- * `MiniMaxText01ForCausalLM`
* MiniMax-Text
* `MiniMaxAI/MiniMax-Text-01`, etc.
*
* ✅︎
- * `Zamba2ForCausalLM` - * `Zamba2ForCausalLM`
* Zamba2 * Zamba2
* `Zyphra/Zamba2-7B-instruct`, `Zyphra/Zamba2-2.7B-instruct`, `Zyphra/Zamba2-1.2B-instruct`, etc. * `Zyphra/Zamba2-7B-instruct`, `Zyphra/Zamba2-2.7B-instruct`, `Zyphra/Zamba2-1.2B-instruct`, etc.

View File

@ -0,0 +1,286 @@
# SPDX-License-Identifier: Apache-2.0
import pytest
import torch
from vllm.model_executor.layers.lightning_attn import (
linear_decode_forward_triton)
from vllm.platforms import current_platform
NUM_HEADS = [4, 8]
HEAD_SIZES = [64]
BATCH_SIZES = [1, 2]
SEQ_LENGTHS = [16]
DTYPES = [torch.float32]
def reference_lightning_attention(q, k, v, ed, block_size, kv_history):
"""Reference implementation of lightning attention core algorithm
The difference from the main implementation is that this processes
each step sequentially, instead of using parallelized triton kernels
"""
B, H, S, D = q.shape
E = v.shape[-1]
dtype = q.dtype
output = torch.zeros((B, H, S, E), dtype=dtype, device=q.device)
# Use clone() to ensure an independent copy
if kv_history is None:
kv_cache = torch.zeros((B, H, D, E), dtype=dtype, device=q.device)
else:
kv_cache = kv_history.clone()
# More efficient implementation
# Convert decay factors to matrix form
if ed.dim() == 1:
decay = torch.exp(-ed).view(1, -1, 1, 1)
else:
decay = torch.exp(-ed)
for b in range(B):
for step in range(S):
# Process all heads at once for this position
q_bs = q[b, :, step] # [H, D]
k_bs = k[b, :, step] # [H, D]
v_bs = v[b, :, step] # [H, E]
# Calculate KV outer products for all heads
for h in range(H):
# Calculate KV outer product
kv_outer = torch.outer(k_bs[h], v_bs[h])
# Update KV cache with decay
# Note: Using the same order as in the Triton kernel
kv_cache[b, h] = decay[0, h, 0, 0] * kv_cache[b, h] + kv_outer
# Calculate attention output
output[b, h, step] = torch.matmul(q_bs[h], kv_cache[b, h])
# Match the shape returned by the actual implementation
# The actual implementation returns a tensor of shape [B, H, 2, D, E]
# where dimension 2 contains both KV and KV history
kv_reshaped = kv_cache.unsqueeze(2) # [B, H, 1, D, E]
final_kv_cache = torch.cat([kv_reshaped, kv_reshaped],
dim=2) # [B, H, 2, D, E]
return output, final_kv_cache
def reference_linear_decode(q, k, v, kv_caches, slope_rate, slot_idx):
"""Reference implementation: linear attention decode function"""
B, H, _, D = q.shape
output = torch.zeros(B, H * D, dtype=q.dtype, device=q.device)
# Calculate decay factors once (more efficient)
decay = torch.exp(-slope_rate).view(-1, 1, 1) # [H, 1, 1]
# Process each batch
for b in range(B):
slot_id = slot_idx[b].item()
# Skip padding positions
if slot_id == -1:
continue
# Process all heads at once for this batch
q_b = q[b, :, 0] # [H, D]
k_b = k[b, :, 0] # [H, D]
v_b = v[b, :, 0] # [H, D]
# Process each attention head
for h in range(H):
# Get current query, key and value
q_bh = q_b[h]
k_bh = k_b[h]
v_bh = v_b[h]
# Get cache
kv_cache_old = kv_caches[b, h]
# Calculate new key-value outer product
kv_outer = torch.outer(k_bh, v_bh)
# Apply decay and update cache
kv_new = kv_outer + decay[h, 0, 0] * kv_cache_old
# Calculate output
out_h = torch.matmul(q_bh, kv_new)
# Update output and cache
output[b, h * D:(h + 1) * D] = out_h
kv_caches[b, h] = kv_new
return output
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@torch.inference_mode()
def test_linear_decode_forward_triton(
batch_size: int,
num_heads: int,
head_size: int,
dtype: torch.dtype,
):
torch.set_default_device("cuda")
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
current_platform.seed_everything(42)
base = 0.01
q = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
k = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
v = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
kv_caches = base * torch.randn(batch_size,
num_heads,
head_size,
head_size,
dtype=dtype,
device="cuda")
kv_caches_copy = kv_caches.clone()
slope_rate = torch.zeros(num_heads, device="cuda")
for h in range(num_heads):
slope_rate[h] = 0.1 * (h + 1)
slot_idx = torch.arange(batch_size, device="cuda")
triton_output = linear_decode_forward_triton(q, k, v, kv_caches,
slope_rate, slot_idx)
reference_output = reference_linear_decode(q, k, v, kv_caches_copy,
slope_rate, slot_idx)
torch.testing.assert_close(triton_output,
reference_output,
rtol=1e-1,
atol=1e-1)
torch.testing.assert_close(kv_caches, kv_caches_copy, rtol=1e-1, atol=1e-1)
assert triton_output.shape == (batch_size, num_heads * head_size)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@torch.inference_mode()
def test_linear_decode_forward_triton_with_padding(
num_heads: int,
head_size: int,
dtype: torch.dtype,
):
torch.set_default_device("cuda")
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
current_platform.seed_everything(42)
batch_size = 4
base = 0.01
q = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
k = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
v = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
kv_caches = base * torch.randn(batch_size,
num_heads,
head_size,
head_size,
dtype=dtype,
device="cuda")
kv_caches_copy = kv_caches.clone()
slope_rate = torch.zeros(num_heads, device="cuda")
for h in range(num_heads):
slope_rate[h] = 0.1 * (h + 1)
slot_idx = torch.tensor([0, 1, -1, 2], device="cuda")
triton_output = linear_decode_forward_triton(q, k, v, kv_caches,
slope_rate, slot_idx)
reference_output = reference_linear_decode(q, k, v, kv_caches_copy,
slope_rate, slot_idx)
padding_mask = (slot_idx
!= -1).unsqueeze(1).expand(-1, num_heads * head_size)
triton_masked = triton_output[padding_mask]
reference_masked = reference_output[padding_mask]
atol, rtol = 1.5e-1, 1.5e-1
valid_indices = slot_idx != -1
for i in range(batch_size):
if valid_indices[i] > 0:
torch.testing.assert_close(kv_caches[i],
kv_caches_copy[i],
rtol=rtol,
atol=atol)
torch.testing.assert_close(triton_masked,
reference_masked,
rtol=rtol,
atol=atol)
assert triton_output.shape == (batch_size, num_heads * head_size)
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("seq_len", SEQ_LENGTHS)
@pytest.mark.parametrize("dtype", DTYPES)
@torch.inference_mode()
def test_lightning_attention_reference(
batch_size: int,
num_heads: int,
head_size: int,
seq_len: int,
dtype: torch.dtype,
):
torch.set_default_device("cuda")
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
current_platform.seed_everything(42)
base = 0.01
q = base * torch.randn(
batch_size, num_heads, seq_len, head_size, dtype=dtype)
k = base * torch.randn(
batch_size, num_heads, seq_len, head_size, dtype=dtype)
v = base * torch.randn(
batch_size, num_heads, seq_len, head_size, dtype=dtype)
ed = torch.zeros(num_heads, device="cuda")
for h in range(num_heads):
ed[h] = 0.1 * (h + 1)
kv_history = base * torch.randn(batch_size,
num_heads,
head_size,
head_size,
dtype=dtype,
device="cuda")
kv_history_clone = kv_history.clone()
ref_output, ref_kv_cache = reference_lightning_attention(
q, k, v, ed, 256, kv_history)
from vllm.model_executor.layers.lightning_attn import lightning_attention
actual_output, actual_kv_cache = lightning_attention(
q, k, v, ed, 256, kv_history_clone)
atol, rtol = 1.5e-1, 1.5e-1
torch.testing.assert_close(ref_output, actual_output, rtol=rtol, atol=atol)
torch.testing.assert_close(ref_kv_cache,
actual_kv_cache,
rtol=rtol,
atol=atol)
assert ref_output.shape == (batch_size, num_heads, seq_len, head_size)
assert ref_kv_cache.shape == actual_kv_cache.shape

View File

@ -176,6 +176,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
trust_remote_code=True), trust_remote_code=True),
"MiniCPM3ForCausalLM": _HfExamplesInfo("openbmb/MiniCPM3-4B", "MiniCPM3ForCausalLM": _HfExamplesInfo("openbmb/MiniCPM3-4B",
trust_remote_code=True), trust_remote_code=True),
"MiniMaxText01ForCausalLM": _HfExamplesInfo("MiniMaxAI/MiniMax-Text-01",
trust_remote_code=True),
"MistralForCausalLM": _HfExamplesInfo("mistralai/Mistral-7B-Instruct-v0.1"), "MistralForCausalLM": _HfExamplesInfo("mistralai/Mistral-7B-Instruct-v0.1"),
"MixtralForCausalLM": _HfExamplesInfo("mistralai/Mixtral-8x7B-Instruct-v0.1"), # noqa: E501 "MixtralForCausalLM": _HfExamplesInfo("mistralai/Mixtral-8x7B-Instruct-v0.1"), # noqa: E501
"QuantMixtralForCausalLM": _HfExamplesInfo("mistral-community/Mixtral-8x22B-v0.1-AWQ"), # noqa: E501 "QuantMixtralForCausalLM": _HfExamplesInfo("mistral-community/Mixtral-8x22B-v0.1-AWQ"), # noqa: E501

View File

@ -971,26 +971,34 @@ class ModelConfig:
return sum(not bc.attention.no_op return sum(not bc.attention.no_op
for bc in block_configs[start:end]) for bc in block_configs[start:end])
else: else:
# Hybrid model # Hybrid model Jamba
layers_block_type_value = getattr(self.hf_config, layers_block_type_value = getattr(self.hf_config,
"layers_block_type", None) "layers_block_type", None)
if layers_block_type_value is None: if layers_block_type_value is not None:
raise ValueError("The model is an hybrid without a " if hasattr(self.hf_text_config,
"layers_block_type in the hf_config, " "model_type") and (self.hf_text_config.model_type
"cannot determine the num of " == "zamba2"):
f"{block_type.value} layers") if attn_block_type:
return sum(t == "hybrid"
for t in layers_block_type_value[start:end])
else:
return self.get_num_layers(parallel_config)
return sum(t == block_type.value
for t in layers_block_type_value[start:end])
if hasattr(self.hf_text_config, # Hybrid model Minimax
"model_type") and (self.hf_text_config.model_type attn_type_list = getattr(self.hf_config, "attn_type_list", None)
== "zamba2"): if attn_type_list:
if attn_block_type: return sum(t == 1 for t in attn_type_list[start:end])
return sum(t == "hybrid"
for t in layers_block_type_value[start:end])
else:
return self.get_num_layers(parallel_config)
return sum(t == block_type.value if layers_block_type_value is None and attn_type_list is None:
for t in layers_block_type_value[start:end]) raise ValueError(
"The model is an hybrid without a"
"layers_block_type or an attn_type_list in the hf_config,"
"cannot determine the num of "
f"{block_type.value} layers")
return sum(t == 1 for t in attn_type_list[start:end])
def get_multimodal_config(self) -> "MultiModalConfig": def get_multimodal_config(self) -> "MultiModalConfig":
""" """

View File

@ -303,8 +303,11 @@ class _AsyncLLMEngine(LLMEngine):
ctx.seq_group_metadata_list = seq_group_metadata_list ctx.seq_group_metadata_list = seq_group_metadata_list
ctx.scheduler_outputs = scheduler_outputs ctx.scheduler_outputs = scheduler_outputs
finished_requests_ids = self.scheduler[ if not scheduler_outputs.is_empty():
virtual_engine].get_and_reset_finished_requests_ids() # this will cause mamba_cache/minimax_cache failed
# to release finished_requests_ids of the last steps
finished_requests_ids = self.scheduler[
virtual_engine].get_and_reset_finished_requests_ids()
# Maybe switch from async mode to sync mode # Maybe switch from async mode to sync mode
if not allow_async_output_proc and len(ctx.output_queue) > 0: if not allow_async_output_proc and len(ctx.output_queue) > 0:

View File

@ -0,0 +1,651 @@
# SPDX-License-Identifier: Apache-2.0
import torch
import triton
import triton.language as tl
from einops import rearrange
@triton.jit
def _fwd_diag_kernel(Q, K, V, Out, S, b: tl.constexpr, h: tl.constexpr, n,
d: tl.constexpr, e: tl.constexpr, BLOCK: tl.constexpr,
NUM_BLOCK, CBLOCK: tl.constexpr):
# This kernel computes the diagonal blocks of the attention matrix
# Each diagonal block represents attention
# where queries attend to keys in the same block
off = tl.program_id(0)
off_bh = off // NUM_BLOCK # batch-head index
off_block = off % NUM_BLOCK # block index within the sequence
off_cblock = tl.program_id(1) # sub-block index within a block
off_h = off_bh % h # head index
# Calculate base offsets for the current batch and head
qk_offset = off_bh * n * d
v_offset = off_bh * n * e
o_offset = off_bh * n * e
# Calculate offsets for the current block
block_offset = off_block * BLOCK
qk_block_offset = block_offset * d
v_block_offset = block_offset * e
o_block_offset = block_offset * e
# Calculate offsets for the current sub-block
cblock_offset = off_cblock * CBLOCK
q_cblock_offset = cblock_offset * d
o_cblock_offset = cblock_offset * e
# Calculate pointers to the query, key, value, and output tensors
Q_block_ptr = (Q + qk_offset + qk_block_offset + q_cblock_offset +
tl.arange(0, CBLOCK)[:, None] * d +
tl.arange(0, d)[None, :])
K_trans_block_ptr = (K + qk_offset + qk_block_offset +
tl.arange(0, CBLOCK)[None, :] * d +
tl.arange(0, d)[:, None])
V_block_ptr = (V + v_offset + v_block_offset +
tl.arange(0, CBLOCK)[:, None] * e +
tl.arange(0, e)[None, :])
O_block_ptr = (Out + o_offset + o_block_offset + o_cblock_offset +
tl.arange(0, CBLOCK)[:, None] * e +
tl.arange(0, e)[None, :])
# Load the decay rate for the current head
S_block_ptr = S + off_h
s = tl.load(S_block_ptr)
i = off_cblock
q_index = tl.arange(0, CBLOCK) + i * CBLOCK
# Load query values
q = tl.load(Q_block_ptr,
mask=block_offset + q_index[:, None] < n,
other=0.0).to(tl.float32)
# Initialize output accumulator
qkv = tl.zeros([CBLOCK, e], dtype=tl.float32)
# Process all sub-blocks up to and
# including the current one (causal attention)
for j in range(i + 1):
kv_index = tl.arange(0, CBLOCK) + j * CBLOCK
diff = q_index[:, None] - kv_index[None, :]
s_index = s * diff
# Apply causal mask: only attend to positions before the current one
s_index = tl.where(diff >= 0, -s_index, float("-inf"))
decay = tl.exp(s_index)
# Load key and value
k_trans = tl.load(
K_trans_block_ptr,
mask=block_offset + kv_index[None, :] < n,
other=0.0,
).to(tl.float32)
v = tl.load(
V_block_ptr,
mask=block_offset + kv_index[:, None] < n,
other=0.0,
).to(tl.float32)
# Compute attention scores and apply decay
qk = tl.dot(q, k_trans) * decay
# Compute weighted values and accumulate
qkv += tl.dot(qk, v)
# Move to the next sub-block
K_trans_block_ptr += CBLOCK * d
V_block_ptr += CBLOCK * e
# Store the result
tl.store(
O_block_ptr,
qkv.to(O_block_ptr.dtype.element_ty),
mask=block_offset + q_index[:, None] < n,
)
@triton.jit
def _fwd_kv_parallel(
K,
V,
K_decay,
KV,
b: tl.constexpr,
h: tl.constexpr,
n,
d: tl.constexpr,
e: tl.constexpr,
BLOCK: tl.constexpr,
NUM_BLOCK,
D_FBLOCK: tl.constexpr,
E_FBLOCK: tl.constexpr,
NUM_FBLOCK: tl.constexpr,
CBLOCK: tl.constexpr,
NUM_CBLOCK: tl.constexpr,
):
# This kernel computes the key-value outer
# products for each block in parallel
off_bh = tl.program_id(0) # batch-head index
off_block = tl.program_id(1) # block index
off_h = off_bh % h # head index
block_offset = off_block * BLOCK
# Calculate offsets for the current block
k_block_offset = block_offset * d
v_block_offset = block_offset * e
kv_block_offset = off_block * d * e
# Calculate base offsets for the current batch and head
k_offset = off_bh * n * d
v_offset = off_bh * n * e
kv_offset = off_bh * NUM_BLOCK * d * e
# Calculate pointers to the key, value, and key-value tensors
K_trans_block_ptr = (K + k_offset + k_block_offset +
tl.arange(0, CBLOCK)[None, :] * d +
tl.arange(0, D_FBLOCK)[:, None])
V_block_ptr = (V + v_offset + v_block_offset +
tl.arange(0, CBLOCK)[:, None] * e +
tl.arange(0, E_FBLOCK)[None, :])
KV_block_ptr = (KV + kv_offset + kv_block_offset +
tl.arange(0, D_FBLOCK)[:, None] * e +
tl.arange(0, E_FBLOCK)[None, :])
# Load the decay factors for the current head and block
k_decay_ptr = (K_decay + off_h * BLOCK + tl.arange(0, CBLOCK)[None, :])
kv_index = tl.arange(0, CBLOCK)
# Initialize the key-value outer product accumulator
kv = tl.zeros([D_FBLOCK, E_FBLOCK], dtype=tl.float32)
# Handle the last block which might be smaller than BLOCK
if off_block == NUM_BLOCK - 1:
split_n = n - (NUM_BLOCK - 1) * BLOCK
else:
split_n = BLOCK
left_shift = tl.cdiv(split_n, CBLOCK) * CBLOCK - split_n
num_blocks = min(tl.cdiv(split_n, CBLOCK), NUM_CBLOCK)
k_decay_ptr += (NUM_CBLOCK - num_blocks) * CBLOCK
# Process all sub-blocks in the current block
for j in range(num_blocks):
left_bound = (1 - j) * left_shift
# Load key and value, handling boundary conditions
k_trans = tl.load(K_trans_block_ptr - left_shift * d,
mask=kv_index[None, :] >= left_bound,
other=0.0)
v = tl.load(V_block_ptr - left_shift * e,
mask=kv_index[:, None] >= left_bound,
other=0.0)
# Load decay factor and compute weighted key-value outer product
k_decay = tl.load(k_decay_ptr)
kv += tl.dot(k_trans * k_decay, v)
# Move to the next sub-block
K_trans_block_ptr += CBLOCK * d
V_block_ptr += CBLOCK * e
k_decay_ptr += CBLOCK
# Store the result
tl.store(KV_block_ptr, kv.to(KV_block_ptr.dtype.element_ty))
@triton.jit
def _fwd_kv_reduce(S, KV, KV_HISTORY, b: tl.constexpr, h: tl.constexpr, n,
d: tl.constexpr, e: tl.constexpr, BLOCK: tl.constexpr,
NUM_BLOCK, D_FBLOCK: tl.constexpr, E_FBLOCK: tl.constexpr):
# This kernel reduces the key-value outer products
# across blocks and updates the KV history
off_bh = tl.program_id(0) # batch-head index
off_h = off_bh % h # head index
kv_offset = off_bh * NUM_BLOCK * d * e
# Calculate pointer to the key-value tensor
KV_block_ptr = (KV + kv_offset + tl.arange(0, D_FBLOCK)[:, None] * e +
tl.arange(0, E_FBLOCK)[None, :])
# Load the decay rate for the current head
s_ptrs = S + off_h
s = tl.load(s_ptrs)
# Calculate pointer to the key-value history tensor
kv_history_offset = off_bh * d * e
KV_HISTORY_block_ptr = (KV_HISTORY + kv_history_offset +
tl.arange(0, D_FBLOCK)[:, None] * e +
tl.arange(0, E_FBLOCK)[None, :])
# Load the previous key-value history
kv_pre = tl.load(KV_HISTORY_block_ptr).to(tl.float32)
# Process all blocks in reverse order to compute the prefix sum
for i in range(NUM_BLOCK):
block_size = min(n - i * BLOCK, BLOCK)
# Compute decay factor for the current block
block_decay = tl.exp(-s.to(tl.float32) * block_size)
# Load the current key-value outer product
kv_cur = tl.load(KV_block_ptr).to(tl.float32)
# Store the previous key-value history to the current block
tl.store(KV_block_ptr, kv_pre.to(KV_block_ptr.dtype.element_ty))
# Update the key-value history with the current block
kv_pre = block_decay * kv_pre + kv_cur
KV_block_ptr += d * e
# Store the updated key-value history
tl.store(KV_HISTORY_block_ptr, kv_pre)
@triton.jit
def _fwd_none_diag_kernel(
Q,
Out,
S,
KV,
b: tl.constexpr,
h: tl.constexpr,
n,
d: tl.constexpr,
e: tl.constexpr,
BLOCK: tl.constexpr,
NUM_BLOCK,
E_FBLOCK: tl.constexpr,
CBLOCK: tl.constexpr,
NUM_CBLOCK: tl.constexpr,
):
# This kernel computes the non-diagonal blocks of the attention matrix
# Each non-diagonal block represents attention
# where queries attend to keys in different blocks
off_bh = tl.program_id(0) # batch-head index
off_h = off_bh % h # head index
off_nc = tl.program_id(1)
off_n = off_nc // NUM_CBLOCK # block index
off_c = off_nc % NUM_CBLOCK # sub-block index
off_e = tl.program_id(2) # output feature block index
n_offset = off_n * BLOCK
c_offset = off_c * CBLOCK
e_offset = off_e * E_FBLOCK
block_offset = n_offset + c_offset
# Calculate offsets for the current batch, head, and block
q_offset = off_bh * n * d + (n_offset + c_offset) * d
o_offset = off_bh * n * e + (n_offset + c_offset) * e + e_offset
kv_offset = off_bh * NUM_BLOCK * d * e + off_n * d * e + e_offset
# Calculate pointers to the query, output, and key-value tensors
Q_block_ptr = (Q + q_offset + tl.arange(0, CBLOCK)[:, None] * d +
tl.arange(0, d)[None, :])
O_block_ptr = (Out + o_offset + tl.arange(0, CBLOCK)[:, None] * e +
tl.arange(0, E_FBLOCK)[None, :])
KV_block_ptr = (KV + kv_offset + tl.arange(0, d)[:, None] * e +
tl.arange(0, E_FBLOCK)[None, :])
# Load the decay rate for the current head
S_block_ptr = S + off_h
s = tl.load(S_block_ptr)
c_array = tl.arange(0, CBLOCK)
# Load the key-value outer product for the current block
kv = tl.load(KV_block_ptr).to(tl.float32)
q_index = block_offset + tl.arange(0, CBLOCK)
# Load query values
q = tl.load(Q_block_ptr, mask=q_index[:, None] < n,
other=0.).to(tl.float32)
# Compute decay factors for the current sub-block
q_decay = tl.exp(-s.to(tl.float32) * (off_c * CBLOCK + c_array[:, None]))
# Compute non-diagonal attention output
qkv_none_diag = tl.dot(q, kv) * q_decay
# Load diagonal attention output (computed by _fwd_diag_kernel)
qkv_diag = tl.load(O_block_ptr, mask=q_index[:, None] < n,
other=0.).to(tl.float32)
# Combine diagonal and non-diagonal attention outputs
qkv = qkv_diag + qkv_none_diag
# Store the result
tl.store(O_block_ptr,
qkv.to(O_block_ptr.dtype.element_ty),
mask=q_index[:, None] < n)
class _attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, s, kv_history):
# Forward pass of the lightning attention algorithm
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
s = s.contiguous()
# Check CUDA compute capability
capability = torch.cuda.get_device_capability()
if capability[0] < 8:
raise RuntimeError("Flash attention currently only supported",
"for compute capability >= 80")
# Get input dimensions
b, h, n, d = q.shape
e = v.shape[-1]
# Initialize output tensor
o = torch.empty((b, h, n, e), dtype=q.dtype, device=q.device)
# Set block sizes
BLOCK = 256
NUM_BLOCK = triton.cdiv(n, BLOCK)
CBLOCK = 32
NUM_CBLOCK = BLOCK // CBLOCK
assert BLOCK % CBLOCK == 0, "BLOCK must be a multiple of CBLOCK"
# Compute decay factors for keys
array = torch.arange(0, BLOCK, device=q.device) + 1
k_decay = torch.exp(-s * (BLOCK - array.reshape(1, -1)))
# Step 1: Compute diagonal blocks of attention
grid = (b * h * NUM_BLOCK, NUM_CBLOCK)
_fwd_diag_kernel[grid](q,
k,
v,
o,
s,
b,
h,
n,
d,
e,
BLOCK=BLOCK,
NUM_BLOCK=NUM_BLOCK,
CBLOCK=CBLOCK)
# Set feature block sizes
NUM_FBLOCK = 1
D_FBLOCK = d // NUM_FBLOCK
assert d % NUM_FBLOCK == 0
E_FBLOCK = e // NUM_FBLOCK
assert e % NUM_FBLOCK == 0
CBLOCK = 64
NUM_CBLOCK = BLOCK // CBLOCK
assert BLOCK % CBLOCK == 0, "BLOCK must be a multiple of CBLOCK"
# Step 2: Compute key-value outer products for each block in parallel
kv = torch.empty((b, h, NUM_BLOCK, d, e),
dtype=torch.float32,
device=q.device)
grid = (b * h, NUM_BLOCK)
_fwd_kv_parallel[grid](
k,
v,
k_decay,
kv,
b,
h,
n,
d,
e,
BLOCK=BLOCK,
NUM_BLOCK=NUM_BLOCK,
D_FBLOCK=D_FBLOCK,
E_FBLOCK=E_FBLOCK,
NUM_FBLOCK=NUM_FBLOCK,
CBLOCK=CBLOCK,
NUM_CBLOCK=NUM_CBLOCK,
)
# Step 3: Reduce key-value outer products
# across blocks and update KV history
grid = (b * h, NUM_FBLOCK)
_fwd_kv_reduce[grid](s,
kv,
kv_history,
b,
h,
n,
d,
e,
BLOCK=BLOCK,
NUM_BLOCK=NUM_BLOCK,
D_FBLOCK=D_FBLOCK,
E_FBLOCK=E_FBLOCK)
# Step 4: Compute non-diagonal blocks of attention
grid = (b * h, NUM_BLOCK * NUM_CBLOCK)
_fwd_none_diag_kernel[grid](
q,
o,
s,
kv,
b,
h,
n,
d,
e,
BLOCK=BLOCK,
NUM_BLOCK=NUM_BLOCK,
E_FBLOCK=E_FBLOCK,
CBLOCK=CBLOCK,
NUM_CBLOCK=NUM_CBLOCK,
)
# Save tensors for backward pass
ctx.save_for_backward(q, k, v, s, kv)
ctx.BLOCK = BLOCK
return o, torch.cat([kv, kv_history.unsqueeze(2)], dim=2)
# Apply the lightning attention function
lightning_attention_ = _attention.apply
def lightning_attention(q, k, v, ed, block_size=256, kv_history=None):
"""
Apply lightning attention algorithm
to compute attention efficiently.
Args:
q: Query tensor of shape [batch, heads, seq_len, dim]
k: Key tensor of shape [batch, heads, seq_len, dim]
v: Value tensor of shape [batch, heads, seq_len, dim_v]
ed: Decay rate tensor of shape [heads]
block_size: Size of blocks for block-sparse attention
kv_history: Optional key-value history from previous computations
Returns:
output: Attention output
kv: Updated key-value history
"""
d = q.shape[-1]
e = v.shape[-1]
if ed.dim() == 1:
ed = ed.view(1, -1, 1, 1)
# Split the computation into chunks for better parallelism
m = 128 if d >= 128 else 64
assert d % m == 0, f"Dimension d ({d}) must be divisible by m ({m})"
arr = [m * i for i in range(d // m + 1)]
if arr[-1] != d:
arr.append(d)
n = len(arr)
output = 0
# Initialize or clone key-value history
if kv_history is None:
kv_history = torch.zeros((q.shape[0], q.shape[1], d, e),
dtype=torch.float32,
device=q.device)
else:
kv_history = kv_history.clone().contiguous()
# Process each chunk and accumulate results
for i in range(n - 1):
s = arr[i]
e = arr[i + 1]
q1 = q[..., s:e]
k1 = k[..., s:e]
o, kv = lightning_attention_(q1, k1, v, ed, kv_history)
output = output + o
return output, kv
@triton.jit
def _linear_attn_decode_kernel(
q_ptr,
k_ptr,
v_ptr,
kv_cache_ptr,
slope_rate,
slot_idx,
output_ptr,
D: tl.constexpr,
qkv_b_stride,
qkv_h_stride,
cache_b_stride,
cache_h_stride,
cache_d0_stride,
cache_d1_stride,
BLOCK_SIZE: tl.constexpr,
):
"""
Kernel for linear attention decoding with KV cache.
This kernel computes attention for a single token using the KV cache.
"""
pid_b = tl.program_id(0) # batch index
pid_h = tl.program_id(1) # head index
pid_d = tl.program_id(2) # dimension block index
# Load slot index for the current batch
slot_id = tl.load(slot_idx + pid_b)
# Skip if slot_id is -1 (padding)
if slot_id == -1:
return
batch_id = pid_b
head_id = pid_h
# Load decay rate for the current head
ratio = tl.load(slope_rate + pid_h)
# Calculate offsets for dimensions
qk_d_offsets = tl.arange(0, D)
v_d_offsets = tl.arange(0, BLOCK_SIZE) + pid_d * BLOCK_SIZE
cache_d_offsets = qk_d_offsets[:, None] * cache_d0_stride + v_d_offsets[
None, :] * cache_d1_stride
# Calculate offsets for the current batch and head
q_offset = batch_id * qkv_b_stride + head_id * qkv_h_stride
k_offset = batch_id * qkv_b_stride + head_id * qkv_h_stride
v_offset = batch_id * qkv_b_stride + head_id * qkv_h_stride
cache_offset = slot_id * cache_b_stride + head_id * cache_h_stride
# Create masks for loading tensors
qk_mask = qk_d_offsets < D
v_mask = v_d_offsets < D
# Load query, key, and value tensors
q = tl.load(q_ptr + q_offset + qk_d_offsets, mask=qk_mask, other=0.0)
k = tl.load(k_ptr + k_offset + qk_d_offsets, mask=qk_mask, other=0.0)
v = tl.load(v_ptr + v_offset + v_d_offsets, mask=v_mask, other=0.0)
# Compute key-value outer product
kv_outer = k[:, None] * v[None, :]
kv_mask = qk_mask[:, None] & v_mask[None, :]
# Apply decay to previous KV cache
ratio = tl.exp(-ratio)
kv_ptr = kv_cache_ptr + cache_offset + cache_d_offsets
kv_cache_old = tl.load(kv_ptr, mask=kv_mask, other=0.0)
kv_outer = kv_outer + ratio * kv_cache_old
# Compute attention output
output = q[:, None].to(tl.float32) * kv_outer
output = tl.sum(output, axis=0)
# Update KV cache and store output
tl.store(kv_ptr, kv_outer, mask=kv_mask)
tl.store(output_ptr + q_offset + v_d_offsets, output, mask=v_mask)
def linear_decode_forward_triton(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
kv_caches: torch.Tensor,
slope_rate: torch.Tensor,
slot_idx: torch.Tensor,
BLOCK_SIZE: int = 32,
) -> torch.Tensor:
"""
Perform linear attention decoding using Triton kernels.
Args:
q: Query tensor of shape [B, H, 1, D]
k: Key tensor of shape [B, H, 1, D]
v: Value tensor of shape [B, H, 1, D]
kv_caches: Key-value cache tensor
slope_rate: Decay rate tensor
slot_idx: Slot indices for batches
BLOCK_SIZE: Size of blocks for processing
Returns:
output: Attention output tensor
"""
B, H, _, D = q.shape
assert k.shape == (B, H, 1, D)
assert v.shape == (B, H, 1, D)
# Initialize output tensor
output = torch.empty_like(q)
# Set grid dimensions for the kernel
grid = (B, H, D // BLOCK_SIZE)
# Calculate strides for tensors
qkv_b_stride = q.stride(0)
qkv_h_stride = q.stride(1)
cache_b_stride = kv_caches.stride(0)
cache_h_stride = kv_caches.stride(1)
cache_d0_stride = kv_caches.stride(2)
cache_d1_stride = kv_caches.stride(3)
# Launch the kernel
_linear_attn_decode_kernel[grid](
q,
k,
v,
kv_caches,
slope_rate,
slot_idx,
output,
D,
qkv_b_stride,
qkv_h_stride,
cache_b_stride,
cache_h_stride,
cache_d0_stride,
cache_d1_stride,
BLOCK_SIZE=BLOCK_SIZE,
)
# Reshape output and return
output = rearrange(output, "b h n d -> b n (h d)")
return output.squeeze(1).contiguous()

View File

@ -0,0 +1,136 @@
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Tuple
import torch
from vllm.attention.backends.utils import PAD_SLOT_ID
class ConstantSizeCache(ABC):
"""
Abstract base class for managing constant size caches
like Mamba and Minimax.
"""
def __init__(self, max_batch_size: int):
# Maps between the request id and a dict that maps between the seq_id
# and its index inside the cache
self.cache_indices_mapping: Dict[str, Dict[int, int]] = {}
self.free_cache_indices = list(range(max_batch_size))
@property
@abstractmethod
def cache(self) -> Any:
"""Return the underlying cache tensor(s)"""
pass
@abstractmethod
def _copy_cache(self, from_index: int, to_index: int):
"""Copy cache data from one index to another"""
pass
def current_run_tensors(self, **kwargs) -> Tuple:
"""
Return the tensors for the current run's conv and ssm state.
"""
if "seqlen_agnostic_capture_inputs" not in kwargs:
# We get here only on Prefill/Eager mode runs
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
finished_requests_ids = kwargs["finished_requests_ids"]
self._release_finished_requests(finished_requests_ids)
state_indices = self._prepare_current_run_cache(
request_ids_to_seq_ids, finished_requests_ids)
state_indices_tensor = torch.as_tensor(state_indices,
dtype=torch.int32,
device="cuda")
cache_tensors = self.cache
else:
# CUDA graph capturing runs
cache_tensors, state_indices_tensor = kwargs[
"seqlen_agnostic_capture_inputs"]
return (cache_tensors, state_indices_tensor)
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
"""
Copy the relevant state_indices into the CUDA graph input buffer
"""
assert all(
key in kwargs
for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
finished_requests_ids = kwargs["finished_requests_ids"]
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
assert "seqlen_agnostic_capture_inputs" in input_buffers
_, input_state_indices_buffer = input_buffers[
"seqlen_agnostic_capture_inputs"]
self._release_finished_requests(finished_requests_ids)
state_indices = self._prepare_current_run_cache(
request_ids_to_seq_ids, finished_requests_ids)
cuda_graph_pad_len = input_state_indices_buffer.shape[0] - len(
state_indices)
state_indices.extend([PAD_SLOT_ID] * cuda_graph_pad_len)
input_state_indices_buffer.copy_(
torch.as_tensor(state_indices, dtype=torch.int32, device="cuda"))
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
"""
Provide the CUDA graph capture runs with a buffer in adjusted size.
The buffer is used to maintain the Cache during the CUDA graph replay
runs.
"""
state_indices_tensor = torch.as_tensor([PAD_SLOT_ID] * batch_size,
dtype=torch.int32,
device="cuda")
return (self.cache, state_indices_tensor)
def _assign_seq_id_to_cache_index(self, cur_rid: str, seq_id: int,
finished_requests_ids) -> int:
"""
Assign (req_id,seq_id) pair to a `destination_index` index, if
already occupied, move the occupying index to a free index.
"""
if cur_rid in finished_requests_ids:
# set as pad, do not allocate destination index
return PAD_SLOT_ID
elif cur_rid not in self.cache_indices_mapping:
destination_index = self.free_cache_indices.pop()
self.cache_indices_mapping[cur_rid] = {seq_id: destination_index}
return destination_index
elif seq_id not in (seq_ids2indices :=
self.cache_indices_mapping[cur_rid]):
# parallel sampling , where n > 1, assume prefill have
# already happened, so we copy the
# existing cache into the siblings seq_ids caches
index_exists = next(iter(seq_ids2indices.values()))
# case of decoding n>1, copy prefill cache to decoding indices
destination_index = self.free_cache_indices.pop()
self._copy_cache(from_index=index_exists,
to_index=destination_index)
self.cache_indices_mapping[cur_rid][seq_id] = destination_index
return destination_index
else:
return self.cache_indices_mapping[cur_rid][seq_id]
def _prepare_current_run_cache(
self, request_ids_to_seq_ids: Dict[str, list[int]],
finished_requests_ids: List[str]) -> List[int]:
return [
self._assign_seq_id_to_cache_index(req_id, seq_id,
finished_requests_ids)
for req_id, seq_ids in request_ids_to_seq_ids.items()
for seq_id in seq_ids
]
def _release_finished_requests(self,
finished_seq_groups_req_ids: List[str]):
for req_id in finished_seq_groups_req_ids:
if req_id in self.cache_indices_mapping:
for seq_id in self.cache_indices_mapping[req_id]:
self.free_cache_indices.append(
self.cache_indices_mapping[req_id][seq_id])
self.cache_indices_mapping.pop(req_id)

View File

@ -1,12 +1,13 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Tuple from typing import Tuple
import torch import torch
from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.models.constant_size_cache import ConstantSizeCache
@dataclass @dataclass
@ -21,7 +22,7 @@ class MambaCacheParams:
self.state_indices_tensor) self.state_indices_tensor)
class MambaCacheManager: class MambaCacheManager(ConstantSizeCache):
def __init__(self, vllm_config: VllmConfig, dtype: torch.dtype, def __init__(self, vllm_config: VllmConfig, dtype: torch.dtype,
num_mamba_layers: int, conv_state_shape: Tuple[int, int], num_mamba_layers: int, conv_state_shape: Tuple[int, int],
@ -32,6 +33,9 @@ class MambaCacheManager:
if not vllm_config.model_config.enforce_eager: if not vllm_config.model_config.enforce_eager:
max_batch_size = vllm_config.pad_for_cudagraph(max_batch_size) max_batch_size = vllm_config.pad_for_cudagraph(max_batch_size)
# Initialize parent class
super().__init__(max_batch_size)
conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) + conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
conv_state_shape, conv_state_shape,
dtype=dtype, dtype=dtype,
@ -41,126 +45,32 @@ class MambaCacheManager:
dtype=dtype, dtype=dtype,
device="cuda") device="cuda")
self.mamba_cache = (conv_state, temporal_state) self._mamba_cache = (conv_state, temporal_state)
# Maps between the request id and a dict that maps between the seq_id @property
# and its index inside the self.mamba_cache def cache(self):
self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {} return self._mamba_cache
self.free_cache_indices = list(range(max_batch_size))
def _copy_cache(self, from_index: int, to_index: int):
for cache_t in self.cache:
cache_t[:, to_index].copy_(cache_t[:, from_index],
non_blocking=True)
def current_run_tensors(self, **kwargs) -> MambaCacheParams: def current_run_tensors(self, **kwargs) -> MambaCacheParams:
""" """
Return the tensors for the current run's conv and ssm state. Return the tensors for the current run's conv and ssm state.
""" """
if "seqlen_agnostic_capture_inputs" not in kwargs: cache_tensors, state_indices_tensor = super().current_run_tensors(
# We get here only on Prefill/Eager mode runs **kwargs)
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] return MambaCacheParams(cache_tensors[0], cache_tensors[1],
finished_requests_ids = kwargs["finished_requests_ids"]
self._release_finished_requests(finished_requests_ids)
state_indices = self._prepare_current_run_mamba_cache(
request_ids_to_seq_ids, finished_requests_ids)
state_indices_tensor = torch.as_tensor(state_indices,
dtype=torch.int32,
device="cuda")
mamba_cache_tensors = self.mamba_cache
else:
# CUDA graph capturing runs
(mamba_cache_tensors,
state_indices_tensor) = kwargs["seqlen_agnostic_capture_inputs"]
return MambaCacheParams(mamba_cache_tensors[0], mamba_cache_tensors[1],
state_indices_tensor) state_indices_tensor)
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
"""
Copy the relevant state_indices into the CUDA graph input buffer
"""
assert all(
key in kwargs
for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
finished_requests_ids = kwargs["finished_requests_ids"]
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
assert "seqlen_agnostic_capture_inputs" in input_buffers
_, input_state_indices_buffer = input_buffers[
"seqlen_agnostic_capture_inputs"]
self._release_finished_requests(finished_requests_ids)
state_indices = self._prepare_current_run_mamba_cache(
request_ids_to_seq_ids, finished_requests_ids)
cuda_graph_pad_len = input_state_indices_buffer.shape[0] - len(
state_indices)
state_indices.extend([PAD_SLOT_ID] * cuda_graph_pad_len)
input_state_indices_buffer.copy_(
torch.as_tensor(state_indices, dtype=torch.int32, device="cuda"))
def get_seqlen_agnostic_capture_inputs(self, batch_size: int): def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
""" """
Provide the CUDA graph capture runs with a buffer in adjusted size. Provide the CUDA graph capture runs with a buffer in adjusted size.
The buffer is used to maintain the Mamba Cache during the CUDA graph The buffer is used to maintain the Mamba Cache during the CUDA graph
replay runs. replay runs.
""" """
state_indices_tensor = torch.as_tensor([PAD_SLOT_ID] * batch_size, return self._mamba_cache, torch.as_tensor([PAD_SLOT_ID] * batch_size,
dtype=torch.int32, dtype=torch.int32,
device="cuda") device="cuda")
return (self.mamba_cache, state_indices_tensor)
def _copy_mamba_cache(self, from_index: int, to_index: int):
assert len(self.mamba_cache) > 0
for cache_t in self.mamba_cache:
cache_t[:, to_index].copy_(cache_t[:, from_index],
non_blocking=True)
def _assign_seq_id_to_cache_index(self, cur_rid: str, seq_id: int,
finished_requests_ids) -> int:
"""
Assign (req_id,seq_id) pair to a `destination_index` index, if
already occupied, move the occupying index to a free index.
"""
if cur_rid in finished_requests_ids:
# set as pad, do not allocate destination index
return PAD_SLOT_ID
elif cur_rid not in self.mamba_cache_indices_mapping:
destination_index = self.free_cache_indices.pop()
self.mamba_cache_indices_mapping[cur_rid] = {
seq_id: destination_index
}
return destination_index
elif seq_id not in (seq_ids2indices :=
self.mamba_cache_indices_mapping[cur_rid]):
# parallel sampling , where n > 1, assume prefill have
# already happened, so we copy the
# existing cache into the siblings seq_ids caches
index_exists = next(iter(seq_ids2indices.values()))
# case of decoding n>1, copy prefill cache to decoding indices
destination_index = self.free_cache_indices.pop()
self._copy_mamba_cache(from_index=index_exists,
to_index=destination_index)
self.mamba_cache_indices_mapping[cur_rid][
seq_id] = destination_index
return destination_index
else:
# already exists
return self.mamba_cache_indices_mapping[cur_rid][seq_id]
def _prepare_current_run_mamba_cache(
self, request_ids_to_seq_ids: Dict[str, list[int]],
finished_requests_ids: List[str]) -> List[int]:
return [
self._assign_seq_id_to_cache_index(req_id, seq_id,
finished_requests_ids)
for req_id, seq_ids in request_ids_to_seq_ids.items()
for seq_id in seq_ids
]
def _release_finished_requests(self,
finished_seq_groups_req_ids: List[str]):
for req_id in finished_seq_groups_req_ids:
if req_id in self.mamba_cache_indices_mapping:
for seq_id in self.mamba_cache_indices_mapping[req_id]:
self.free_cache_indices.append(
self.mamba_cache_indices_mapping[req_id][seq_id])
self.mamba_cache_indices_mapping.pop(req_id)

View File

@ -0,0 +1,35 @@
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
import torch
from vllm.model_executor.models.constant_size_cache import ConstantSizeCache
@dataclass
class MinimaxCacheParams:
minimax_cache: torch.Tensor = torch.Tensor()
state_indices_tensor: torch.Tensor = torch.Tensor()
def at_layer_idx(self, layer_idx):
return MinimaxCacheParams(self.minimax_cache[layer_idx, ...],
self.state_indices_tensor)
class MinimaxCacheManager(ConstantSizeCache):
def __init__(self, dtype, cache_shape):
super().__init__(cache_shape[1]) # max_batch_size is cache_shape[1]
self._minimax_cache = torch.empty(size=cache_shape,
dtype=dtype,
device="cuda")
@property
def cache(self):
return self._minimax_cache
def _copy_cache(self, from_index: int, to_index: int):
assert len(self.cache) > 0
for cache_t in self.cache:
cache_t[:, to_index].copy_(cache_t[:, from_index],
non_blocking=True)

File diff suppressed because it is too large Load Diff

View File

@ -35,6 +35,7 @@ _TEXT_GENERATION_MODELS = {
"AquilaModel": ("llama", "LlamaForCausalLM"), "AquilaModel": ("llama", "LlamaForCausalLM"),
"AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2 "AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2
"ArcticForCausalLM": ("arctic", "ArcticForCausalLM"), "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
"MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
# baichuan-7b, upper case 'C' in the class name # baichuan-7b, upper case 'C' in the class name
"BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
# baichuan-13b, lower case 'c' in the class name # baichuan-13b, lower case 'c' in the class name