diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index 1b742717..af0f7304 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -503,6 +503,11 @@ See [this page](#generative-models) for more information on how to use generativ * `xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc. * ✅︎ * ✅︎ +- * `MiniMaxText01ForCausalLM` + * MiniMax-Text + * `MiniMaxAI/MiniMax-Text-01`, etc. + * + * ✅︎ - * `Zamba2ForCausalLM` * Zamba2 * `Zyphra/Zamba2-7B-instruct`, `Zyphra/Zamba2-2.7B-instruct`, `Zyphra/Zamba2-1.2B-instruct`, etc. diff --git a/tests/kernels/test_lightning_attn.py b/tests/kernels/test_lightning_attn.py new file mode 100644 index 00000000..fbad5298 --- /dev/null +++ b/tests/kernels/test_lightning_attn.py @@ -0,0 +1,286 @@ +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch + +from vllm.model_executor.layers.lightning_attn import ( + linear_decode_forward_triton) +from vllm.platforms import current_platform + +NUM_HEADS = [4, 8] +HEAD_SIZES = [64] +BATCH_SIZES = [1, 2] +SEQ_LENGTHS = [16] +DTYPES = [torch.float32] + + +def reference_lightning_attention(q, k, v, ed, block_size, kv_history): + """Reference implementation of lightning attention core algorithm + + The difference from the main implementation is that this processes + each step sequentially, instead of using parallelized triton kernels + """ + B, H, S, D = q.shape + E = v.shape[-1] + dtype = q.dtype + output = torch.zeros((B, H, S, E), dtype=dtype, device=q.device) + + # Use clone() to ensure an independent copy + if kv_history is None: + kv_cache = torch.zeros((B, H, D, E), dtype=dtype, device=q.device) + else: + kv_cache = kv_history.clone() + + # More efficient implementation + # Convert decay factors to matrix form + if ed.dim() == 1: + decay = torch.exp(-ed).view(1, -1, 1, 1) + else: + decay = torch.exp(-ed) + + for b in range(B): + for step in range(S): + # Process all heads at once for this position + q_bs = q[b, :, step] # [H, D] + k_bs = k[b, :, step] # [H, D] + v_bs = v[b, :, step] # [H, E] + + # Calculate KV outer products for all heads + for h in range(H): + # Calculate KV outer product + kv_outer = torch.outer(k_bs[h], v_bs[h]) + + # Update KV cache with decay + # Note: Using the same order as in the Triton kernel + kv_cache[b, h] = decay[0, h, 0, 0] * kv_cache[b, h] + kv_outer + + # Calculate attention output + output[b, h, step] = torch.matmul(q_bs[h], kv_cache[b, h]) + + # Match the shape returned by the actual implementation + # The actual implementation returns a tensor of shape [B, H, 2, D, E] + # where dimension 2 contains both KV and KV history + kv_reshaped = kv_cache.unsqueeze(2) # [B, H, 1, D, E] + final_kv_cache = torch.cat([kv_reshaped, kv_reshaped], + dim=2) # [B, H, 2, D, E] + + return output, final_kv_cache + + +def reference_linear_decode(q, k, v, kv_caches, slope_rate, slot_idx): + """Reference implementation: linear attention decode function""" + B, H, _, D = q.shape + output = torch.zeros(B, H * D, dtype=q.dtype, device=q.device) + + # Calculate decay factors once (more efficient) + decay = torch.exp(-slope_rate).view(-1, 1, 1) # [H, 1, 1] + + # Process each batch + for b in range(B): + slot_id = slot_idx[b].item() + + # Skip padding positions + if slot_id == -1: + continue + + # Process all heads at once for this batch + q_b = q[b, :, 0] # [H, D] + k_b = k[b, :, 0] # [H, D] + v_b = v[b, :, 0] # [H, D] + + # Process each attention head + for h in range(H): + # Get current query, key and value + q_bh = q_b[h] + k_bh = k_b[h] + v_bh = v_b[h] + + # Get cache + kv_cache_old = kv_caches[b, h] + + # Calculate new key-value outer product + kv_outer = torch.outer(k_bh, v_bh) + + # Apply decay and update cache + kv_new = kv_outer + decay[h, 0, 0] * kv_cache_old + + # Calculate output + out_h = torch.matmul(q_bh, kv_new) + + # Update output and cache + output[b, h * D:(h + 1) * D] = out_h + kv_caches[b, h] = kv_new + + return output + + +@pytest.mark.parametrize("batch_size", BATCH_SIZES) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@torch.inference_mode() +def test_linear_decode_forward_triton( + batch_size: int, + num_heads: int, + head_size: int, + dtype: torch.dtype, +): + torch.set_default_device("cuda") + torch.manual_seed(42) + torch.cuda.manual_seed_all(42) + current_platform.seed_everything(42) + base = 0.01 + q = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype) + k = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype) + v = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype) + + kv_caches = base * torch.randn(batch_size, + num_heads, + head_size, + head_size, + dtype=dtype, + device="cuda") + + kv_caches_copy = kv_caches.clone() + + slope_rate = torch.zeros(num_heads, device="cuda") + for h in range(num_heads): + slope_rate[h] = 0.1 * (h + 1) + + slot_idx = torch.arange(batch_size, device="cuda") + + triton_output = linear_decode_forward_triton(q, k, v, kv_caches, + slope_rate, slot_idx) + + reference_output = reference_linear_decode(q, k, v, kv_caches_copy, + slope_rate, slot_idx) + torch.testing.assert_close(triton_output, + reference_output, + rtol=1e-1, + atol=1e-1) + torch.testing.assert_close(kv_caches, kv_caches_copy, rtol=1e-1, atol=1e-1) + + assert triton_output.shape == (batch_size, num_heads * head_size) + + +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@torch.inference_mode() +def test_linear_decode_forward_triton_with_padding( + num_heads: int, + head_size: int, + dtype: torch.dtype, +): + torch.set_default_device("cuda") + torch.manual_seed(42) + torch.cuda.manual_seed_all(42) + current_platform.seed_everything(42) + + batch_size = 4 + base = 0.01 + q = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype) + k = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype) + v = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype) + + kv_caches = base * torch.randn(batch_size, + num_heads, + head_size, + head_size, + dtype=dtype, + device="cuda") + + kv_caches_copy = kv_caches.clone() + + slope_rate = torch.zeros(num_heads, device="cuda") + for h in range(num_heads): + slope_rate[h] = 0.1 * (h + 1) + + slot_idx = torch.tensor([0, 1, -1, 2], device="cuda") + + triton_output = linear_decode_forward_triton(q, k, v, kv_caches, + slope_rate, slot_idx) + + reference_output = reference_linear_decode(q, k, v, kv_caches_copy, + slope_rate, slot_idx) + + padding_mask = (slot_idx + != -1).unsqueeze(1).expand(-1, num_heads * head_size) + + triton_masked = triton_output[padding_mask] + reference_masked = reference_output[padding_mask] + + atol, rtol = 1.5e-1, 1.5e-1 + + valid_indices = slot_idx != -1 + + for i in range(batch_size): + if valid_indices[i] > 0: + torch.testing.assert_close(kv_caches[i], + kv_caches_copy[i], + rtol=rtol, + atol=atol) + + torch.testing.assert_close(triton_masked, + reference_masked, + rtol=rtol, + atol=atol) + + assert triton_output.shape == (batch_size, num_heads * head_size) + + +@pytest.mark.parametrize("batch_size", BATCH_SIZES) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("seq_len", SEQ_LENGTHS) +@pytest.mark.parametrize("dtype", DTYPES) +@torch.inference_mode() +def test_lightning_attention_reference( + batch_size: int, + num_heads: int, + head_size: int, + seq_len: int, + dtype: torch.dtype, +): + torch.set_default_device("cuda") + torch.manual_seed(42) + torch.cuda.manual_seed_all(42) + current_platform.seed_everything(42) + + base = 0.01 + q = base * torch.randn( + batch_size, num_heads, seq_len, head_size, dtype=dtype) + k = base * torch.randn( + batch_size, num_heads, seq_len, head_size, dtype=dtype) + v = base * torch.randn( + batch_size, num_heads, seq_len, head_size, dtype=dtype) + + ed = torch.zeros(num_heads, device="cuda") + for h in range(num_heads): + ed[h] = 0.1 * (h + 1) + + kv_history = base * torch.randn(batch_size, + num_heads, + head_size, + head_size, + dtype=dtype, + device="cuda") + + kv_history_clone = kv_history.clone() + + ref_output, ref_kv_cache = reference_lightning_attention( + q, k, v, ed, 256, kv_history) + + from vllm.model_executor.layers.lightning_attn import lightning_attention + actual_output, actual_kv_cache = lightning_attention( + q, k, v, ed, 256, kv_history_clone) + + atol, rtol = 1.5e-1, 1.5e-1 + torch.testing.assert_close(ref_output, actual_output, rtol=rtol, atol=atol) + torch.testing.assert_close(ref_kv_cache, + actual_kv_cache, + rtol=rtol, + atol=atol) + + assert ref_output.shape == (batch_size, num_heads, seq_len, head_size) + assert ref_kv_cache.shape == actual_kv_cache.shape diff --git a/tests/models/registry.py b/tests/models/registry.py index 137f1418..39e104a1 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -176,6 +176,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { trust_remote_code=True), "MiniCPM3ForCausalLM": _HfExamplesInfo("openbmb/MiniCPM3-4B", trust_remote_code=True), + "MiniMaxText01ForCausalLM": _HfExamplesInfo("MiniMaxAI/MiniMax-Text-01", + trust_remote_code=True), "MistralForCausalLM": _HfExamplesInfo("mistralai/Mistral-7B-Instruct-v0.1"), "MixtralForCausalLM": _HfExamplesInfo("mistralai/Mixtral-8x7B-Instruct-v0.1"), # noqa: E501 "QuantMixtralForCausalLM": _HfExamplesInfo("mistral-community/Mixtral-8x22B-v0.1-AWQ"), # noqa: E501 diff --git a/vllm/config.py b/vllm/config.py index 6ec5d1bc..ba20e3fd 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -971,26 +971,34 @@ class ModelConfig: return sum(not bc.attention.no_op for bc in block_configs[start:end]) else: - # Hybrid model + # Hybrid model Jamba layers_block_type_value = getattr(self.hf_config, "layers_block_type", None) - if layers_block_type_value is None: - raise ValueError("The model is an hybrid without a " - "layers_block_type in the hf_config, " - "cannot determine the num of " - f"{block_type.value} layers") + if layers_block_type_value is not None: + if hasattr(self.hf_text_config, + "model_type") and (self.hf_text_config.model_type + == "zamba2"): + if attn_block_type: + return sum(t == "hybrid" + for t in layers_block_type_value[start:end]) + else: + return self.get_num_layers(parallel_config) + return sum(t == block_type.value + for t in layers_block_type_value[start:end]) - if hasattr(self.hf_text_config, - "model_type") and (self.hf_text_config.model_type - == "zamba2"): - if attn_block_type: - return sum(t == "hybrid" - for t in layers_block_type_value[start:end]) - else: - return self.get_num_layers(parallel_config) + # Hybrid model Minimax + attn_type_list = getattr(self.hf_config, "attn_type_list", None) + if attn_type_list: + return sum(t == 1 for t in attn_type_list[start:end]) - return sum(t == block_type.value - for t in layers_block_type_value[start:end]) + if layers_block_type_value is None and attn_type_list is None: + raise ValueError( + "The model is an hybrid without a" + "layers_block_type or an attn_type_list in the hf_config," + "cannot determine the num of " + f"{block_type.value} layers") + + return sum(t == 1 for t in attn_type_list[start:end]) def get_multimodal_config(self) -> "MultiModalConfig": """ diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 079e2a08..3e337731 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -303,8 +303,11 @@ class _AsyncLLMEngine(LLMEngine): ctx.seq_group_metadata_list = seq_group_metadata_list ctx.scheduler_outputs = scheduler_outputs - finished_requests_ids = self.scheduler[ - virtual_engine].get_and_reset_finished_requests_ids() + if not scheduler_outputs.is_empty(): + # this will cause mamba_cache/minimax_cache failed + # to release finished_requests_ids of the last steps + finished_requests_ids = self.scheduler[ + virtual_engine].get_and_reset_finished_requests_ids() # Maybe switch from async mode to sync mode if not allow_async_output_proc and len(ctx.output_queue) > 0: diff --git a/vllm/model_executor/layers/lightning_attn.py b/vllm/model_executor/layers/lightning_attn.py new file mode 100644 index 00000000..de360778 --- /dev/null +++ b/vllm/model_executor/layers/lightning_attn.py @@ -0,0 +1,651 @@ +# SPDX-License-Identifier: Apache-2.0 +import torch +import triton +import triton.language as tl +from einops import rearrange + + +@triton.jit +def _fwd_diag_kernel(Q, K, V, Out, S, b: tl.constexpr, h: tl.constexpr, n, + d: tl.constexpr, e: tl.constexpr, BLOCK: tl.constexpr, + NUM_BLOCK, CBLOCK: tl.constexpr): + # This kernel computes the diagonal blocks of the attention matrix + # Each diagonal block represents attention + # where queries attend to keys in the same block + off = tl.program_id(0) + off_bh = off // NUM_BLOCK # batch-head index + off_block = off % NUM_BLOCK # block index within the sequence + off_cblock = tl.program_id(1) # sub-block index within a block + + off_h = off_bh % h # head index + + # Calculate base offsets for the current batch and head + qk_offset = off_bh * n * d + v_offset = off_bh * n * e + o_offset = off_bh * n * e + + # Calculate offsets for the current block + block_offset = off_block * BLOCK + qk_block_offset = block_offset * d + v_block_offset = block_offset * e + o_block_offset = block_offset * e + + # Calculate offsets for the current sub-block + cblock_offset = off_cblock * CBLOCK + q_cblock_offset = cblock_offset * d + o_cblock_offset = cblock_offset * e + + # Calculate pointers to the query, key, value, and output tensors + Q_block_ptr = (Q + qk_offset + qk_block_offset + q_cblock_offset + + tl.arange(0, CBLOCK)[:, None] * d + + tl.arange(0, d)[None, :]) + K_trans_block_ptr = (K + qk_offset + qk_block_offset + + tl.arange(0, CBLOCK)[None, :] * d + + tl.arange(0, d)[:, None]) + V_block_ptr = (V + v_offset + v_block_offset + + tl.arange(0, CBLOCK)[:, None] * e + + tl.arange(0, e)[None, :]) + O_block_ptr = (Out + o_offset + o_block_offset + o_cblock_offset + + tl.arange(0, CBLOCK)[:, None] * e + + tl.arange(0, e)[None, :]) + + # Load the decay rate for the current head + S_block_ptr = S + off_h + s = tl.load(S_block_ptr) + + i = off_cblock + q_index = tl.arange(0, CBLOCK) + i * CBLOCK + + # Load query values + q = tl.load(Q_block_ptr, + mask=block_offset + q_index[:, None] < n, + other=0.0).to(tl.float32) + + # Initialize output accumulator + qkv = tl.zeros([CBLOCK, e], dtype=tl.float32) + + # Process all sub-blocks up to and + # including the current one (causal attention) + for j in range(i + 1): + kv_index = tl.arange(0, CBLOCK) + j * CBLOCK + diff = q_index[:, None] - kv_index[None, :] + s_index = s * diff + # Apply causal mask: only attend to positions before the current one + s_index = tl.where(diff >= 0, -s_index, float("-inf")) + decay = tl.exp(s_index) + + # Load key and value + k_trans = tl.load( + K_trans_block_ptr, + mask=block_offset + kv_index[None, :] < n, + other=0.0, + ).to(tl.float32) + v = tl.load( + V_block_ptr, + mask=block_offset + kv_index[:, None] < n, + other=0.0, + ).to(tl.float32) + + # Compute attention scores and apply decay + qk = tl.dot(q, k_trans) * decay + + # Compute weighted values and accumulate + qkv += tl.dot(qk, v) + + # Move to the next sub-block + K_trans_block_ptr += CBLOCK * d + V_block_ptr += CBLOCK * e + + # Store the result + tl.store( + O_block_ptr, + qkv.to(O_block_ptr.dtype.element_ty), + mask=block_offset + q_index[:, None] < n, + ) + + +@triton.jit +def _fwd_kv_parallel( + K, + V, + K_decay, + KV, + b: tl.constexpr, + h: tl.constexpr, + n, + d: tl.constexpr, + e: tl.constexpr, + BLOCK: tl.constexpr, + NUM_BLOCK, + D_FBLOCK: tl.constexpr, + E_FBLOCK: tl.constexpr, + NUM_FBLOCK: tl.constexpr, + CBLOCK: tl.constexpr, + NUM_CBLOCK: tl.constexpr, +): + # This kernel computes the key-value outer + # products for each block in parallel + off_bh = tl.program_id(0) # batch-head index + off_block = tl.program_id(1) # block index + + off_h = off_bh % h # head index + + block_offset = off_block * BLOCK + + # Calculate offsets for the current block + k_block_offset = block_offset * d + v_block_offset = block_offset * e + kv_block_offset = off_block * d * e + + # Calculate base offsets for the current batch and head + k_offset = off_bh * n * d + v_offset = off_bh * n * e + kv_offset = off_bh * NUM_BLOCK * d * e + + # Calculate pointers to the key, value, and key-value tensors + K_trans_block_ptr = (K + k_offset + k_block_offset + + tl.arange(0, CBLOCK)[None, :] * d + + tl.arange(0, D_FBLOCK)[:, None]) + V_block_ptr = (V + v_offset + v_block_offset + + tl.arange(0, CBLOCK)[:, None] * e + + tl.arange(0, E_FBLOCK)[None, :]) + KV_block_ptr = (KV + kv_offset + kv_block_offset + + tl.arange(0, D_FBLOCK)[:, None] * e + + tl.arange(0, E_FBLOCK)[None, :]) + + # Load the decay factors for the current head and block + k_decay_ptr = (K_decay + off_h * BLOCK + tl.arange(0, CBLOCK)[None, :]) + + kv_index = tl.arange(0, CBLOCK) + + # Initialize the key-value outer product accumulator + kv = tl.zeros([D_FBLOCK, E_FBLOCK], dtype=tl.float32) + + # Handle the last block which might be smaller than BLOCK + if off_block == NUM_BLOCK - 1: + split_n = n - (NUM_BLOCK - 1) * BLOCK + else: + split_n = BLOCK + left_shift = tl.cdiv(split_n, CBLOCK) * CBLOCK - split_n + num_blocks = min(tl.cdiv(split_n, CBLOCK), NUM_CBLOCK) + k_decay_ptr += (NUM_CBLOCK - num_blocks) * CBLOCK + + # Process all sub-blocks in the current block + for j in range(num_blocks): + left_bound = (1 - j) * left_shift + # Load key and value, handling boundary conditions + k_trans = tl.load(K_trans_block_ptr - left_shift * d, + mask=kv_index[None, :] >= left_bound, + other=0.0) + v = tl.load(V_block_ptr - left_shift * e, + mask=kv_index[:, None] >= left_bound, + other=0.0) + + # Load decay factor and compute weighted key-value outer product + k_decay = tl.load(k_decay_ptr) + kv += tl.dot(k_trans * k_decay, v) + + # Move to the next sub-block + K_trans_block_ptr += CBLOCK * d + V_block_ptr += CBLOCK * e + k_decay_ptr += CBLOCK + + # Store the result + tl.store(KV_block_ptr, kv.to(KV_block_ptr.dtype.element_ty)) + + +@triton.jit +def _fwd_kv_reduce(S, KV, KV_HISTORY, b: tl.constexpr, h: tl.constexpr, n, + d: tl.constexpr, e: tl.constexpr, BLOCK: tl.constexpr, + NUM_BLOCK, D_FBLOCK: tl.constexpr, E_FBLOCK: tl.constexpr): + # This kernel reduces the key-value outer products + # across blocks and updates the KV history + off_bh = tl.program_id(0) # batch-head index + off_h = off_bh % h # head index + + kv_offset = off_bh * NUM_BLOCK * d * e + + # Calculate pointer to the key-value tensor + KV_block_ptr = (KV + kv_offset + tl.arange(0, D_FBLOCK)[:, None] * e + + tl.arange(0, E_FBLOCK)[None, :]) + + # Load the decay rate for the current head + s_ptrs = S + off_h + s = tl.load(s_ptrs) + + # Calculate pointer to the key-value history tensor + kv_history_offset = off_bh * d * e + KV_HISTORY_block_ptr = (KV_HISTORY + kv_history_offset + + tl.arange(0, D_FBLOCK)[:, None] * e + + tl.arange(0, E_FBLOCK)[None, :]) + + # Load the previous key-value history + kv_pre = tl.load(KV_HISTORY_block_ptr).to(tl.float32) + + # Process all blocks in reverse order to compute the prefix sum + for i in range(NUM_BLOCK): + block_size = min(n - i * BLOCK, BLOCK) + # Compute decay factor for the current block + block_decay = tl.exp(-s.to(tl.float32) * block_size) + + # Load the current key-value outer product + kv_cur = tl.load(KV_block_ptr).to(tl.float32) + # Store the previous key-value history to the current block + tl.store(KV_block_ptr, kv_pre.to(KV_block_ptr.dtype.element_ty)) + + # Update the key-value history with the current block + kv_pre = block_decay * kv_pre + kv_cur + KV_block_ptr += d * e + + # Store the updated key-value history + tl.store(KV_HISTORY_block_ptr, kv_pre) + + +@triton.jit +def _fwd_none_diag_kernel( + Q, + Out, + S, + KV, + b: tl.constexpr, + h: tl.constexpr, + n, + d: tl.constexpr, + e: tl.constexpr, + BLOCK: tl.constexpr, + NUM_BLOCK, + E_FBLOCK: tl.constexpr, + CBLOCK: tl.constexpr, + NUM_CBLOCK: tl.constexpr, +): + # This kernel computes the non-diagonal blocks of the attention matrix + # Each non-diagonal block represents attention + # where queries attend to keys in different blocks + off_bh = tl.program_id(0) # batch-head index + off_h = off_bh % h # head index + + off_nc = tl.program_id(1) + off_n = off_nc // NUM_CBLOCK # block index + off_c = off_nc % NUM_CBLOCK # sub-block index + off_e = tl.program_id(2) # output feature block index + + n_offset = off_n * BLOCK + c_offset = off_c * CBLOCK + e_offset = off_e * E_FBLOCK + block_offset = n_offset + c_offset + + # Calculate offsets for the current batch, head, and block + q_offset = off_bh * n * d + (n_offset + c_offset) * d + o_offset = off_bh * n * e + (n_offset + c_offset) * e + e_offset + kv_offset = off_bh * NUM_BLOCK * d * e + off_n * d * e + e_offset + + # Calculate pointers to the query, output, and key-value tensors + Q_block_ptr = (Q + q_offset + tl.arange(0, CBLOCK)[:, None] * d + + tl.arange(0, d)[None, :]) + O_block_ptr = (Out + o_offset + tl.arange(0, CBLOCK)[:, None] * e + + tl.arange(0, E_FBLOCK)[None, :]) + KV_block_ptr = (KV + kv_offset + tl.arange(0, d)[:, None] * e + + tl.arange(0, E_FBLOCK)[None, :]) + + # Load the decay rate for the current head + S_block_ptr = S + off_h + s = tl.load(S_block_ptr) + + c_array = tl.arange(0, CBLOCK) + + # Load the key-value outer product for the current block + kv = tl.load(KV_block_ptr).to(tl.float32) + q_index = block_offset + tl.arange(0, CBLOCK) + + # Load query values + q = tl.load(Q_block_ptr, mask=q_index[:, None] < n, + other=0.).to(tl.float32) + + # Compute decay factors for the current sub-block + q_decay = tl.exp(-s.to(tl.float32) * (off_c * CBLOCK + c_array[:, None])) + + # Compute non-diagonal attention output + qkv_none_diag = tl.dot(q, kv) * q_decay + + # Load diagonal attention output (computed by _fwd_diag_kernel) + qkv_diag = tl.load(O_block_ptr, mask=q_index[:, None] < n, + other=0.).to(tl.float32) + + # Combine diagonal and non-diagonal attention outputs + qkv = qkv_diag + qkv_none_diag + + # Store the result + tl.store(O_block_ptr, + qkv.to(O_block_ptr.dtype.element_ty), + mask=q_index[:, None] < n) + + +class _attention(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, k, v, s, kv_history): + # Forward pass of the lightning attention algorithm + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + s = s.contiguous() + + # Check CUDA compute capability + capability = torch.cuda.get_device_capability() + if capability[0] < 8: + raise RuntimeError("Flash attention currently only supported", + "for compute capability >= 80") + + # Get input dimensions + b, h, n, d = q.shape + e = v.shape[-1] + + # Initialize output tensor + o = torch.empty((b, h, n, e), dtype=q.dtype, device=q.device) + + # Set block sizes + BLOCK = 256 + NUM_BLOCK = triton.cdiv(n, BLOCK) + + CBLOCK = 32 + NUM_CBLOCK = BLOCK // CBLOCK + assert BLOCK % CBLOCK == 0, "BLOCK must be a multiple of CBLOCK" + + # Compute decay factors for keys + array = torch.arange(0, BLOCK, device=q.device) + 1 + k_decay = torch.exp(-s * (BLOCK - array.reshape(1, -1))) + + # Step 1: Compute diagonal blocks of attention + grid = (b * h * NUM_BLOCK, NUM_CBLOCK) + _fwd_diag_kernel[grid](q, + k, + v, + o, + s, + b, + h, + n, + d, + e, + BLOCK=BLOCK, + NUM_BLOCK=NUM_BLOCK, + CBLOCK=CBLOCK) + + # Set feature block sizes + NUM_FBLOCK = 1 + D_FBLOCK = d // NUM_FBLOCK + assert d % NUM_FBLOCK == 0 + E_FBLOCK = e // NUM_FBLOCK + assert e % NUM_FBLOCK == 0 + + CBLOCK = 64 + NUM_CBLOCK = BLOCK // CBLOCK + assert BLOCK % CBLOCK == 0, "BLOCK must be a multiple of CBLOCK" + + # Step 2: Compute key-value outer products for each block in parallel + kv = torch.empty((b, h, NUM_BLOCK, d, e), + dtype=torch.float32, + device=q.device) + grid = (b * h, NUM_BLOCK) + _fwd_kv_parallel[grid]( + k, + v, + k_decay, + kv, + b, + h, + n, + d, + e, + BLOCK=BLOCK, + NUM_BLOCK=NUM_BLOCK, + D_FBLOCK=D_FBLOCK, + E_FBLOCK=E_FBLOCK, + NUM_FBLOCK=NUM_FBLOCK, + CBLOCK=CBLOCK, + NUM_CBLOCK=NUM_CBLOCK, + ) + + # Step 3: Reduce key-value outer products + # across blocks and update KV history + grid = (b * h, NUM_FBLOCK) + _fwd_kv_reduce[grid](s, + kv, + kv_history, + b, + h, + n, + d, + e, + BLOCK=BLOCK, + NUM_BLOCK=NUM_BLOCK, + D_FBLOCK=D_FBLOCK, + E_FBLOCK=E_FBLOCK) + + # Step 4: Compute non-diagonal blocks of attention + grid = (b * h, NUM_BLOCK * NUM_CBLOCK) + _fwd_none_diag_kernel[grid]( + q, + o, + s, + kv, + b, + h, + n, + d, + e, + BLOCK=BLOCK, + NUM_BLOCK=NUM_BLOCK, + E_FBLOCK=E_FBLOCK, + CBLOCK=CBLOCK, + NUM_CBLOCK=NUM_CBLOCK, + ) + + # Save tensors for backward pass + ctx.save_for_backward(q, k, v, s, kv) + ctx.BLOCK = BLOCK + + return o, torch.cat([kv, kv_history.unsqueeze(2)], dim=2) + + +# Apply the lightning attention function +lightning_attention_ = _attention.apply + + +def lightning_attention(q, k, v, ed, block_size=256, kv_history=None): + """ + Apply lightning attention algorithm + to compute attention efficiently. + + Args: + q: Query tensor of shape [batch, heads, seq_len, dim] + k: Key tensor of shape [batch, heads, seq_len, dim] + v: Value tensor of shape [batch, heads, seq_len, dim_v] + ed: Decay rate tensor of shape [heads] + block_size: Size of blocks for block-sparse attention + kv_history: Optional key-value history from previous computations + + Returns: + output: Attention output + kv: Updated key-value history + """ + d = q.shape[-1] + e = v.shape[-1] + + if ed.dim() == 1: + ed = ed.view(1, -1, 1, 1) + + # Split the computation into chunks for better parallelism + m = 128 if d >= 128 else 64 + assert d % m == 0, f"Dimension d ({d}) must be divisible by m ({m})" + arr = [m * i for i in range(d // m + 1)] + if arr[-1] != d: + arr.append(d) + n = len(arr) + output = 0 + + # Initialize or clone key-value history + if kv_history is None: + kv_history = torch.zeros((q.shape[0], q.shape[1], d, e), + dtype=torch.float32, + device=q.device) + else: + kv_history = kv_history.clone().contiguous() + + # Process each chunk and accumulate results + for i in range(n - 1): + s = arr[i] + e = arr[i + 1] + q1 = q[..., s:e] + k1 = k[..., s:e] + o, kv = lightning_attention_(q1, k1, v, ed, kv_history) + output = output + o + return output, kv + + +@triton.jit +def _linear_attn_decode_kernel( + q_ptr, + k_ptr, + v_ptr, + kv_cache_ptr, + slope_rate, + slot_idx, + output_ptr, + D: tl.constexpr, + qkv_b_stride, + qkv_h_stride, + cache_b_stride, + cache_h_stride, + cache_d0_stride, + cache_d1_stride, + BLOCK_SIZE: tl.constexpr, +): + """ + Kernel for linear attention decoding with KV cache. + + This kernel computes attention for a single token using the KV cache. + """ + pid_b = tl.program_id(0) # batch index + pid_h = tl.program_id(1) # head index + pid_d = tl.program_id(2) # dimension block index + + # Load slot index for the current batch + slot_id = tl.load(slot_idx + pid_b) + + # Skip if slot_id is -1 (padding) + if slot_id == -1: + return + + batch_id = pid_b + head_id = pid_h + + # Load decay rate for the current head + ratio = tl.load(slope_rate + pid_h) + + # Calculate offsets for dimensions + qk_d_offsets = tl.arange(0, D) + v_d_offsets = tl.arange(0, BLOCK_SIZE) + pid_d * BLOCK_SIZE + cache_d_offsets = qk_d_offsets[:, None] * cache_d0_stride + v_d_offsets[ + None, :] * cache_d1_stride + + # Calculate offsets for the current batch and head + q_offset = batch_id * qkv_b_stride + head_id * qkv_h_stride + k_offset = batch_id * qkv_b_stride + head_id * qkv_h_stride + v_offset = batch_id * qkv_b_stride + head_id * qkv_h_stride + + cache_offset = slot_id * cache_b_stride + head_id * cache_h_stride + + # Create masks for loading tensors + qk_mask = qk_d_offsets < D + v_mask = v_d_offsets < D + + # Load query, key, and value tensors + q = tl.load(q_ptr + q_offset + qk_d_offsets, mask=qk_mask, other=0.0) + k = tl.load(k_ptr + k_offset + qk_d_offsets, mask=qk_mask, other=0.0) + v = tl.load(v_ptr + v_offset + v_d_offsets, mask=v_mask, other=0.0) + + # Compute key-value outer product + kv_outer = k[:, None] * v[None, :] + kv_mask = qk_mask[:, None] & v_mask[None, :] + + # Apply decay to previous KV cache + ratio = tl.exp(-ratio) + kv_ptr = kv_cache_ptr + cache_offset + cache_d_offsets + kv_cache_old = tl.load(kv_ptr, mask=kv_mask, other=0.0) + kv_outer = kv_outer + ratio * kv_cache_old + + # Compute attention output + output = q[:, None].to(tl.float32) * kv_outer + output = tl.sum(output, axis=0) + + # Update KV cache and store output + tl.store(kv_ptr, kv_outer, mask=kv_mask) + tl.store(output_ptr + q_offset + v_d_offsets, output, mask=v_mask) + + +def linear_decode_forward_triton( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + kv_caches: torch.Tensor, + slope_rate: torch.Tensor, + slot_idx: torch.Tensor, + BLOCK_SIZE: int = 32, +) -> torch.Tensor: + """ + Perform linear attention decoding using Triton kernels. + + Args: + q: Query tensor of shape [B, H, 1, D] + k: Key tensor of shape [B, H, 1, D] + v: Value tensor of shape [B, H, 1, D] + kv_caches: Key-value cache tensor + slope_rate: Decay rate tensor + slot_idx: Slot indices for batches + BLOCK_SIZE: Size of blocks for processing + + Returns: + output: Attention output tensor + """ + B, H, _, D = q.shape + assert k.shape == (B, H, 1, D) + assert v.shape == (B, H, 1, D) + + # Initialize output tensor + output = torch.empty_like(q) + + # Set grid dimensions for the kernel + grid = (B, H, D // BLOCK_SIZE) + + # Calculate strides for tensors + qkv_b_stride = q.stride(0) + qkv_h_stride = q.stride(1) + + cache_b_stride = kv_caches.stride(0) + cache_h_stride = kv_caches.stride(1) + cache_d0_stride = kv_caches.stride(2) + cache_d1_stride = kv_caches.stride(3) + + # Launch the kernel + _linear_attn_decode_kernel[grid]( + q, + k, + v, + kv_caches, + slope_rate, + slot_idx, + output, + D, + qkv_b_stride, + qkv_h_stride, + cache_b_stride, + cache_h_stride, + cache_d0_stride, + cache_d1_stride, + BLOCK_SIZE=BLOCK_SIZE, + ) + + # Reshape output and return + output = rearrange(output, "b h n d -> b n (h d)") + return output.squeeze(1).contiguous() diff --git a/vllm/model_executor/models/constant_size_cache.py b/vllm/model_executor/models/constant_size_cache.py new file mode 100644 index 00000000..d073a7de --- /dev/null +++ b/vllm/model_executor/models/constant_size_cache.py @@ -0,0 +1,136 @@ +# SPDX-License-Identifier: Apache-2.0 +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Tuple + +import torch + +from vllm.attention.backends.utils import PAD_SLOT_ID + + +class ConstantSizeCache(ABC): + """ + Abstract base class for managing constant size caches + like Mamba and Minimax. + """ + + def __init__(self, max_batch_size: int): + # Maps between the request id and a dict that maps between the seq_id + # and its index inside the cache + self.cache_indices_mapping: Dict[str, Dict[int, int]] = {} + self.free_cache_indices = list(range(max_batch_size)) + + @property + @abstractmethod + def cache(self) -> Any: + """Return the underlying cache tensor(s)""" + pass + + @abstractmethod + def _copy_cache(self, from_index: int, to_index: int): + """Copy cache data from one index to another""" + pass + + def current_run_tensors(self, **kwargs) -> Tuple: + """ + Return the tensors for the current run's conv and ssm state. + """ + if "seqlen_agnostic_capture_inputs" not in kwargs: + # We get here only on Prefill/Eager mode runs + request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] + finished_requests_ids = kwargs["finished_requests_ids"] + + self._release_finished_requests(finished_requests_ids) + state_indices = self._prepare_current_run_cache( + request_ids_to_seq_ids, finished_requests_ids) + + state_indices_tensor = torch.as_tensor(state_indices, + dtype=torch.int32, + device="cuda") + cache_tensors = self.cache + else: + # CUDA graph capturing runs + cache_tensors, state_indices_tensor = kwargs[ + "seqlen_agnostic_capture_inputs"] + + return (cache_tensors, state_indices_tensor) + + def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): + """ + Copy the relevant state_indices into the CUDA graph input buffer + """ + assert all( + key in kwargs + for key in ["request_ids_to_seq_ids", "finished_requests_ids"]) + finished_requests_ids = kwargs["finished_requests_ids"] + request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] + assert "seqlen_agnostic_capture_inputs" in input_buffers + _, input_state_indices_buffer = input_buffers[ + "seqlen_agnostic_capture_inputs"] + + self._release_finished_requests(finished_requests_ids) + state_indices = self._prepare_current_run_cache( + request_ids_to_seq_ids, finished_requests_ids) + cuda_graph_pad_len = input_state_indices_buffer.shape[0] - len( + state_indices) + state_indices.extend([PAD_SLOT_ID] * cuda_graph_pad_len) + + input_state_indices_buffer.copy_( + torch.as_tensor(state_indices, dtype=torch.int32, device="cuda")) + + def get_seqlen_agnostic_capture_inputs(self, batch_size: int): + """ + Provide the CUDA graph capture runs with a buffer in adjusted size. + The buffer is used to maintain the Cache during the CUDA graph replay + runs. + """ + state_indices_tensor = torch.as_tensor([PAD_SLOT_ID] * batch_size, + dtype=torch.int32, + device="cuda") + return (self.cache, state_indices_tensor) + + def _assign_seq_id_to_cache_index(self, cur_rid: str, seq_id: int, + finished_requests_ids) -> int: + """ + Assign (req_id,seq_id) pair to a `destination_index` index, if + already occupied, move the occupying index to a free index. + """ + if cur_rid in finished_requests_ids: + # set as pad, do not allocate destination index + return PAD_SLOT_ID + elif cur_rid not in self.cache_indices_mapping: + destination_index = self.free_cache_indices.pop() + self.cache_indices_mapping[cur_rid] = {seq_id: destination_index} + return destination_index + elif seq_id not in (seq_ids2indices := + self.cache_indices_mapping[cur_rid]): + # parallel sampling , where n > 1, assume prefill have + # already happened, so we copy the + # existing cache into the siblings seq_ids caches + index_exists = next(iter(seq_ids2indices.values())) + # case of decoding n>1, copy prefill cache to decoding indices + destination_index = self.free_cache_indices.pop() + self._copy_cache(from_index=index_exists, + to_index=destination_index) + self.cache_indices_mapping[cur_rid][seq_id] = destination_index + return destination_index + else: + return self.cache_indices_mapping[cur_rid][seq_id] + + def _prepare_current_run_cache( + self, request_ids_to_seq_ids: Dict[str, list[int]], + finished_requests_ids: List[str]) -> List[int]: + return [ + self._assign_seq_id_to_cache_index(req_id, seq_id, + finished_requests_ids) + for req_id, seq_ids in request_ids_to_seq_ids.items() + for seq_id in seq_ids + ] + + def _release_finished_requests(self, + finished_seq_groups_req_ids: List[str]): + for req_id in finished_seq_groups_req_ids: + if req_id in self.cache_indices_mapping: + for seq_id in self.cache_indices_mapping[req_id]: + self.free_cache_indices.append( + self.cache_indices_mapping[req_id][seq_id]) + self.cache_indices_mapping.pop(req_id) diff --git a/vllm/model_executor/models/mamba_cache.py b/vllm/model_executor/models/mamba_cache.py index d5298330..25839727 100644 --- a/vllm/model_executor/models/mamba_cache.py +++ b/vllm/model_executor/models/mamba_cache.py @@ -1,12 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass -from typing import Dict, List, Tuple +from typing import Tuple import torch from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.config import VllmConfig +from vllm.model_executor.models.constant_size_cache import ConstantSizeCache @dataclass @@ -21,7 +22,7 @@ class MambaCacheParams: self.state_indices_tensor) -class MambaCacheManager: +class MambaCacheManager(ConstantSizeCache): def __init__(self, vllm_config: VllmConfig, dtype: torch.dtype, num_mamba_layers: int, conv_state_shape: Tuple[int, int], @@ -32,6 +33,9 @@ class MambaCacheManager: if not vllm_config.model_config.enforce_eager: max_batch_size = vllm_config.pad_for_cudagraph(max_batch_size) + # Initialize parent class + super().__init__(max_batch_size) + conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) + conv_state_shape, dtype=dtype, @@ -41,126 +45,32 @@ class MambaCacheManager: dtype=dtype, device="cuda") - self.mamba_cache = (conv_state, temporal_state) + self._mamba_cache = (conv_state, temporal_state) - # Maps between the request id and a dict that maps between the seq_id - # and its index inside the self.mamba_cache - self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {} - self.free_cache_indices = list(range(max_batch_size)) + @property + def cache(self): + return self._mamba_cache + + def _copy_cache(self, from_index: int, to_index: int): + for cache_t in self.cache: + cache_t[:, to_index].copy_(cache_t[:, from_index], + non_blocking=True) def current_run_tensors(self, **kwargs) -> MambaCacheParams: """ Return the tensors for the current run's conv and ssm state. """ - if "seqlen_agnostic_capture_inputs" not in kwargs: - # We get here only on Prefill/Eager mode runs - request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] - finished_requests_ids = kwargs["finished_requests_ids"] - - self._release_finished_requests(finished_requests_ids) - state_indices = self._prepare_current_run_mamba_cache( - request_ids_to_seq_ids, finished_requests_ids) - - state_indices_tensor = torch.as_tensor(state_indices, - dtype=torch.int32, - device="cuda") - mamba_cache_tensors = self.mamba_cache - - else: - # CUDA graph capturing runs - (mamba_cache_tensors, - state_indices_tensor) = kwargs["seqlen_agnostic_capture_inputs"] - - return MambaCacheParams(mamba_cache_tensors[0], mamba_cache_tensors[1], + cache_tensors, state_indices_tensor = super().current_run_tensors( + **kwargs) + return MambaCacheParams(cache_tensors[0], cache_tensors[1], state_indices_tensor) - def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): - """ - Copy the relevant state_indices into the CUDA graph input buffer - """ - assert all( - key in kwargs - for key in ["request_ids_to_seq_ids", "finished_requests_ids"]) - finished_requests_ids = kwargs["finished_requests_ids"] - request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] - assert "seqlen_agnostic_capture_inputs" in input_buffers - _, input_state_indices_buffer = input_buffers[ - "seqlen_agnostic_capture_inputs"] - - self._release_finished_requests(finished_requests_ids) - state_indices = self._prepare_current_run_mamba_cache( - request_ids_to_seq_ids, finished_requests_ids) - cuda_graph_pad_len = input_state_indices_buffer.shape[0] - len( - state_indices) - state_indices.extend([PAD_SLOT_ID] * cuda_graph_pad_len) - - input_state_indices_buffer.copy_( - torch.as_tensor(state_indices, dtype=torch.int32, device="cuda")) - def get_seqlen_agnostic_capture_inputs(self, batch_size: int): """ Provide the CUDA graph capture runs with a buffer in adjusted size. The buffer is used to maintain the Mamba Cache during the CUDA graph replay runs. """ - state_indices_tensor = torch.as_tensor([PAD_SLOT_ID] * batch_size, - dtype=torch.int32, - device="cuda") - return (self.mamba_cache, state_indices_tensor) - - def _copy_mamba_cache(self, from_index: int, to_index: int): - assert len(self.mamba_cache) > 0 - for cache_t in self.mamba_cache: - cache_t[:, to_index].copy_(cache_t[:, from_index], - non_blocking=True) - - def _assign_seq_id_to_cache_index(self, cur_rid: str, seq_id: int, - finished_requests_ids) -> int: - """ - Assign (req_id,seq_id) pair to a `destination_index` index, if - already occupied, move the occupying index to a free index. - """ - if cur_rid in finished_requests_ids: - # set as pad, do not allocate destination index - return PAD_SLOT_ID - elif cur_rid not in self.mamba_cache_indices_mapping: - destination_index = self.free_cache_indices.pop() - self.mamba_cache_indices_mapping[cur_rid] = { - seq_id: destination_index - } - return destination_index - elif seq_id not in (seq_ids2indices := - self.mamba_cache_indices_mapping[cur_rid]): - # parallel sampling , where n > 1, assume prefill have - # already happened, so we copy the - # existing cache into the siblings seq_ids caches - index_exists = next(iter(seq_ids2indices.values())) - # case of decoding n>1, copy prefill cache to decoding indices - destination_index = self.free_cache_indices.pop() - self._copy_mamba_cache(from_index=index_exists, - to_index=destination_index) - self.mamba_cache_indices_mapping[cur_rid][ - seq_id] = destination_index - return destination_index - else: - # already exists - return self.mamba_cache_indices_mapping[cur_rid][seq_id] - - def _prepare_current_run_mamba_cache( - self, request_ids_to_seq_ids: Dict[str, list[int]], - finished_requests_ids: List[str]) -> List[int]: - return [ - self._assign_seq_id_to_cache_index(req_id, seq_id, - finished_requests_ids) - for req_id, seq_ids in request_ids_to_seq_ids.items() - for seq_id in seq_ids - ] - - def _release_finished_requests(self, - finished_seq_groups_req_ids: List[str]): - for req_id in finished_seq_groups_req_ids: - if req_id in self.mamba_cache_indices_mapping: - for seq_id in self.mamba_cache_indices_mapping[req_id]: - self.free_cache_indices.append( - self.mamba_cache_indices_mapping[req_id][seq_id]) - self.mamba_cache_indices_mapping.pop(req_id) + return self._mamba_cache, torch.as_tensor([PAD_SLOT_ID] * batch_size, + dtype=torch.int32, + device="cuda") diff --git a/vllm/model_executor/models/minimax_cache.py b/vllm/model_executor/models/minimax_cache.py new file mode 100644 index 00000000..c95cbb41 --- /dev/null +++ b/vllm/model_executor/models/minimax_cache.py @@ -0,0 +1,35 @@ +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass + +import torch + +from vllm.model_executor.models.constant_size_cache import ConstantSizeCache + + +@dataclass +class MinimaxCacheParams: + minimax_cache: torch.Tensor = torch.Tensor() + state_indices_tensor: torch.Tensor = torch.Tensor() + + def at_layer_idx(self, layer_idx): + return MinimaxCacheParams(self.minimax_cache[layer_idx, ...], + self.state_indices_tensor) + + +class MinimaxCacheManager(ConstantSizeCache): + + def __init__(self, dtype, cache_shape): + super().__init__(cache_shape[1]) # max_batch_size is cache_shape[1] + self._minimax_cache = torch.empty(size=cache_shape, + dtype=dtype, + device="cuda") + + @property + def cache(self): + return self._minimax_cache + + def _copy_cache(self, from_index: int, to_index: int): + assert len(self.cache) > 0 + for cache_t in self.cache: + cache_t[:, to_index].copy_(cache_t[:, from_index], + non_blocking=True) diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py new file mode 100644 index 00000000..7562aa67 --- /dev/null +++ b/vllm/model_executor/models/minimax_text_01.py @@ -0,0 +1,1273 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Inference-only MiniMaxText01 model.""" +import copy +import math +import re +from typing import Dict, Iterable, List, Optional, Tuple, Union + +import torch +import torch.distributed +import torch.nn.functional as F +from einops import rearrange +from torch import nn +from transformers.configuration_utils import PretrainedConfig + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed.communication_op import tensor_model_parallel_all_reduce +from vllm.distributed.parallel_state import ( + get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.forward_context import get_forward_context +from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.lightning_attn import ( + lightning_attention, linear_decode_forward_triton) +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.utils import maybe_prefix +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import HasInnerState, IsHybrid, SupportsV0Only +from .minimax_cache import MinimaxCacheManager, MinimaxCacheParams +from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers + + +def replace_weight_name(name: str, + key: str = None, + to: str = None, + count: int = None, + prefix: str = None) -> str: + name = name.replace(key, to) if count is None else \ + name.replace(key, to, count) + return name + + +def weight_loader_with_alias(alias: str): + + def wrapper(func: callable): + + def inner_func(param: torch.Tensor, + loaded_weight: torch.Tensor, + *args, + prefix: str = None, + **kwargs): + value = func(param, loaded_weight, *args, **kwargs) + return value + + return inner_func + + return wrapper + + +class MiniMaxText01RMSNormTP(CustomOp): + name = "MiniMaxText01RMSNormTP" + + def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: + super().__init__() + self.tp_world = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.weight = nn.Parameter(torch.ones(int(hidden_size / + self.tp_world))) + + self.weight.weight_loader = self.weight_loader + self.variance_epsilon = eps + return + + @staticmethod + def weight_loader( + param: nn.Parameter, + loaded_weight: torch.Tensor, + ) -> None: + tp_world = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + + shard_size = loaded_weight.shape[0] // tp_world + shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) + param.data.copy_(loaded_weight[shard]) + return + + def _forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + orig_dtype = x.dtype + x = x.to(torch.float32) + variance = x.pow(2).mean(dim=-1, keepdim=True, dtype=torch.float32) + if self.tp_world > 1: + variance = tensor_model_parallel_all_reduce( + variance) / self.tp_world + x = x * torch.rsqrt(variance + self.variance_epsilon) + x = x.to(orig_dtype) * self.weight + return x + + def forward( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + assert residual is None, "RMSNorm does not support residual connection." + return self._forward(x) + + +class MiniMaxText01RotaryEmbedding(CustomOp): + name = "MiniMaxText01RotaryEmbedding" + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position: int, + base: int, + is_neox_style: bool, + cache_dtype: torch.dtype, + ) -> None: + super().__init__() + self.head_size = head_size + self.rotary_dim = rotary_dim + self.max_position_embeddings = max_position + self.base = base + self.is_neox_style = is_neox_style + self.cache_dtype = cache_dtype + cache = self._compute_cos_sin_cache().to(cache_dtype) + self.register_buffer("cos_sin_cache", cache, persistent=False) + + def _compute_inv_freq( + self, + base: Union[int, float], + ) -> torch.Tensor: + """Compute the inverse frequency.""" + inv_freq = 1.0 / (base**(torch.arange( + 0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim)) + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + """Compute the cos and sin cache.""" + inv_freq = self._compute_inv_freq(self.base) + t = torch.arange(self.max_position_embeddings, dtype=torch.float) + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + return cache + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + from vllm import _custom_ops as ops + self.cos_sin_cache = self.cos_sin_cache.to(positions.device) + query_cast = query.to(self.cache_dtype) + key_cast = key.to(self.cache_dtype) + ops.rotary_embedding(positions, query_cast, key_cast, self.head_size, + self.cos_sin_cache, self.is_neox_style) + query = query_cast.to(query.dtype) + key = key_cast.to(key.dtype) + return query, key + + +class MiniMaxText01MLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + quant_config: Optional[QuantizationConfig] = None, + layer_idx: int = None, + prefix: str = "mlp", + ) -> None: + super().__init__() + self.layer_idx = layer_idx + + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) + self.act_fn = SiluAndMul() + return + + def forward(self, x: torch.Tensor) -> torch.Tensor: + + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class MiniMaxText01MoE(nn.Module): + + def __init__( + self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + layer_idx: int = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "moe", + ) -> None: + super().__init__() + + self.layer_idx = layer_idx + self.tp_size = get_tensor_model_parallel_world_size() + self.num_total_experts = num_experts + self.top_k = top_k + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size // self.tp_size + self.quant_config = quant_config + + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype + + self.gate = ReplicatedLinear( + self.hidden_size, + self.num_total_experts, + bias=False, + params_dtype=torch.float32, + quant_config=None, + prefix=f"{prefix}.gate", + ) + self.gate.weight.weight_loader = MiniMaxText01MoE.gate_weight_loader + + self.experts = FusedMoE( + num_experts=self.num_total_experts, + top_k=self.top_k, + hidden_size=self.hidden_size, + intermediate_size=self.intermediate_size * self.tp_size, + params_dtype=self.params_dtype, + reduce_results=True, + renormalize=True, + quant_config=self.quant_config, + tp_size=self.tp_size, + prefix=f"{prefix}.experts", + ) + return + + @staticmethod + def gate_weight_loader(param: nn.Parameter, + loaded_weight: torch.Tensor) -> None: + assert param.size() == loaded_weight.size() + param.data.copy_(loaded_weight.to(torch.float32)) + return + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_size = hidden_states.shape + hidden_states = hidden_states.view(-1, self.hidden_size) + router_logits_fp32, _ = self.gate(hidden_states.to(torch.float32)) + final_hidden_states = self.experts( + hidden_states, router_logits_fp32.to(hidden_states.dtype)) + final_hidden = final_hidden_states.view(num_tokens, hidden_size) + return final_hidden + + +class MiniMaxText01LinearKernel: + + @staticmethod + def jit_linear_forward_prefix(q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + kv_caches: torch.Tensor, + slope_rate: torch.Tensor, + block_size: int, + layer_idx: int = None, + **kwargs) -> torch.Tensor: + + slope_rate = slope_rate.to(torch.float32) + should_pad_dim = q.dim() == 3 + if should_pad_dim: + q = q.unsqueeze(0) + k = k.unsqueeze(0) + v = v.unsqueeze(0) + b, h, n, d = q.shape + e = d + kv_history = kv_caches.reshape(1, h, d, e).contiguous() + output, kv_history = lightning_attention(q, + k, + v, + slope_rate, + block_size=block_size, + kv_history=kv_history) + kv_caches.copy_(kv_history[:, :, -1, :, :].reshape(h, d, e)) + assert output.shape[0] == 1, "batch size must be 1" + return rearrange(output.squeeze(0), "h n d -> n (h d)") + + +class MiniMaxText01LinearAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + hidden_inner_size: int, + num_heads: int, + head_dim: int, + max_position: int, + block_size: int, + num_hidden_layer: int, + quant_config: Optional[QuantizationConfig] = None, + layer_idx: int = 0, + linear_layer_idx: int = 0, + prefix: str = "linear_attn", + ) -> None: + super().__init__() + + self.layer_idx = layer_idx + self.BLOCK = block_size + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = head_dim + self.total_num_heads = num_heads + self.hidden_inner_size = hidden_inner_size + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + + assert self.total_num_heads % self.tp_size == 0 + self.tp_heads = self.total_num_heads // self.tp_size + self.qkv_size = self.num_heads * self.head_dim + self.tp_hidden = self.head_dim * self.tp_heads + + self.qkv_proj = ColumnParallelLinear( + hidden_size, + self.hidden_inner_size * 3, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.output_gate = ColumnParallelLinear( + hidden_size, + self.hidden_inner_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.output_gate", + ) + self.out_proj = RowParallelLinear( + self.hidden_inner_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + ) + self.norm = MiniMaxText01RMSNormTP( + self.hidden_inner_size, + eps=1e-5, + ) + + slope_rate = MiniMaxText01LinearAttention._build_slope_tensor( + self.num_heads) + if num_hidden_layer <= 1: + self.slope_rate = slope_rate * (1 + 1e-5) + else: + self.slope_rate = slope_rate * (1 - layer_idx / + (num_hidden_layer - 1) + 1e-5) + self.tp_slope = self.slope_rate[self.tp_rank * + self.tp_heads:(self.tp_rank + 1) * + self.tp_heads].contiguous() + + @staticmethod + def weight_direct_load(param: torch.Tensor, + loaded_weight: torch.Tensor) -> None: + assert param.size() == loaded_weight.size() + param.data.copy_(loaded_weight) + return + + @staticmethod + def _build_slope_tensor(n_attention_heads: int): + + def get_slopes(n): + + def get_slopes_power_of_2(n): + start = 2**(-(2**-(math.log2(n) - 3))) + ratio = start + return [start * ratio**i for i in range(n)] + + if math.log2(n).is_integer(): + return get_slopes_power_of_2(n) + else: + closest_power_of_2 = 2**math.floor(math.log2(n)) + return (get_slopes_power_of_2(closest_power_of_2) + get_slopes( + 2 * closest_power_of_2)[0::2][:n - closest_power_of_2]) + + slopes = torch.tensor(get_slopes(n_attention_heads), + dtype=torch.float32).reshape( + n_attention_heads, 1, 1) + return slopes + + def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor, + attn_metadata): + hidden = [] + for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)): + _start = attn_metadata.query_start_loc[_prefill_idx] + _end = attn_metadata.query_start_loc[_prefill_idx + 1] + slot_id = state_indices_tensor[_prefill_idx] + qs = q[_start:_end].transpose(0, 1).contiguous() + ks = k[_start:_end].transpose(0, 1).contiguous() + vs = v[_start:_end].transpose(0, 1).contiguous() + slot_id = state_indices_tensor[_prefill_idx] + slice_layer_cache = kv_cache[slot_id, ...] + + out_slice = MiniMaxText01LinearKernel.jit_linear_forward_prefix( + qs, + ks, + vs, + slice_layer_cache, + self.tp_slope, + self.BLOCK, + layer_idx=self.layer_idx) + hidden.append(out_slice.contiguous()) + if attn_metadata.num_decode_tokens > 0: + hidden.append( + self._decode_infer(q, k, v, kv_cache, state_indices_tensor, + attn_metadata)) + hidden = torch.concat(hidden, dim=0).contiguous() + return hidden + + def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor, + attn_metadata): + q = q[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous() + k = k[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous() + v = v[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous() + slot_id = state_indices_tensor[getattr(attn_metadata, "num_prefills", 0 + ):] + hidden = linear_decode_forward_triton(q, k, v, kv_cache, self.tp_slope, + slot_id, 32) + return hidden + + def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor, + kv_caches: MinimaxCacheParams, **kwargs) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + qkv32 = qkv.to(torch.float32) + qkvact = torch.nn.functional.silu(qkv32) + qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1)) + q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1) + forward_context = get_forward_context() + attn_metadata = forward_context.attn_metadata + kv_cache = kv_caches.minimax_cache + state_indices_tensor = kv_caches.state_indices_tensor + + decode_only = getattr(attn_metadata, "num_prefills", 0) == 0 + if not decode_only: + hidden = self._prefill_and_mix_infer(q, k, v, kv_cache, + state_indices_tensor, + attn_metadata) + else: + hidden = self._decode_infer(q, k, v, kv_cache, + state_indices_tensor, attn_metadata) + + hidden = self.norm._forward(hidden) + gate, _ = self.output_gate(hidden_states) + hidden = F.sigmoid(gate) * hidden + hidden = hidden.to(hidden_states.dtype) + hidden, _ = self.out_proj(hidden) + return hidden + + +class MiniMaxText01Attention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + head_dim: int, + num_kv_heads: int, + rotary_dim: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + sliding_window: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, + layer_idx: int = None, + cache_config: Optional[CacheConfig] = None, + prefix: str = "mha", + ) -> None: + super().__init__() + self.layer_idx = layer_idx + + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + assert self.total_num_kv_heads % tp_size == 0 + else: + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim + + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.sliding_window = sliding_window + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) + return + + def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor, + **kwargs) -> torch.Tensor: + forward_context = get_forward_context() + attn_metadata = forward_context.attn_metadata + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = attn_metadata.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class MiniMaxText01DecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + expert_num: int = 1, + layer_id: int = None, + linear_layer_id: Optional[int] = None, + prefix: str = "decoder", + ) -> None: + self._ilayer = layer_id + self._irank = get_tensor_model_parallel_rank() + super().__init__() + + self.hidden_size = config.hidden_size + self.expert_num = expert_num + + rope_theta = getattr(config, "rope_theta", 10000) + + head_dim = getattr(config, "head_dim", + config.hidden_size // config.num_attention_heads) + if hasattr(config, "max_model_len") and isinstance( + config.max_model_len, int): + max_position_embeddings = min(config.max_position_embeddings, + config.max_model_len) + if config.attention_type == 0: + use_headxdim = True + hidden_inner = (head_dim * config.num_attention_heads + if use_headxdim else config.hidden_size) + self.self_attn = MiniMaxText01LinearAttention( + hidden_size=self.hidden_size, + hidden_inner_size=hidden_inner, + num_heads=config.num_attention_heads, + head_dim=head_dim, + max_position=max_position_embeddings, + block_size=config.block if hasattr(config, "block") else 256, + num_hidden_layer=config.num_hidden_layers, + quant_config=quant_config, + layer_idx=self._ilayer, + linear_layer_idx=linear_layer_id, + prefix=prefix) + elif config.attention_type == 1: + self.self_attn = MiniMaxText01Attention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + head_dim=head_dim, + rotary_dim=config.rotary_dim + if hasattr(config, "rotary_dim") else head_dim, + num_kv_heads=config.num_key_value_heads, + max_position=max_position_embeddings, + rope_theta=rope_theta, + sliding_window=config.sliding_window, + quant_config=quant_config, + layer_idx=self._ilayer, + cache_config=cache_config, + prefix=prefix) + else: + raise ValueError( + f"Unsupported attention type: {self.config.attention_type}") + + if expert_num == 1: + self.mlp = MiniMaxText01MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + quant_config=quant_config, + layer_idx=self._ilayer, + prefix=prefix) + else: + self.block_sparse_moe = MiniMaxText01MoE( + num_experts=expert_num, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + layer_idx=self._ilayer, + quant_config=quant_config, + prefix=prefix) + + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + if config.attention_type == 0: + self.layernorm_attention_alpha = getattr( + config, 'layernorm_linear_attention_alpha', 1) + self.layernorm_attention_beta = getattr( + config, 'layernorm_linear_attention_beta', 1) + else: + self.layernorm_attention_alpha = getattr( + config, 'layernorm_full_attention_alpha', 1) + self.layernorm_attention_beta = getattr( + config, 'layernorm_full_attention_beta', 1) + self.layernorm_mlp_alpha = getattr(config, 'layernorm_mlp_alpha', 1) + self.layernorm_mlp_beta = getattr(config, 'layernorm_mlp_beta', 1) + self.postnorm = getattr(config, 'postnorm', False) + self.shared_moe = False + + shared_intermediate = getattr(config, 'shared_intermediate_size', 0) + if shared_intermediate > 0: + self.shared_moe = True + self.shared_mlp = MiniMaxText01MLP( + hidden_size=self.hidden_size, + intermediate_size=shared_intermediate, + quant_config=quant_config, + layer_idx=self._ilayer, + prefix=prefix) + self.coefficient = ReplicatedLinear( + self.hidden_size, + 1, + bias=False, + quant_config=quant_config, + params_dtype=torch.float32, + ) + self.coefficient.weight.weight_loader = ( + self.shared_moe_coefficient_loader) + self.shared_moe_mode = getattr(config, 'shared_moe_mode', + 'softmax') + return + + def forward(self, + hidden_states: torch.Tensor, + positions: torch.Tensor, + kv_caches: Union[List[Dict], Optional[torch.Tensor]], + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + is_warmup: bool = False, + **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: + + forward_context = get_forward_context() + attn_metadata = forward_context.attn_metadata + layernorm_input = hidden_states + layernorm_output = self.input_layernorm(layernorm_input) + residual = layernorm_output if self.postnorm else layernorm_input + self_attention_output = self.self_attn( + hidden_states=layernorm_output, + positions=positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + ) + + residual = residual * self.layernorm_attention_alpha + self_attention_output = (self_attention_output * + self.layernorm_attention_beta) + + layernorm_input = residual + self_attention_output + layernorm_output = self.post_attention_layernorm(layernorm_input) + residual = layernorm_output if self.postnorm else layernorm_input + + if self.expert_num == 1: + hidden_states = self.mlp(layernorm_output) + else: + moe_hidden_states = self.block_sparse_moe( + copy.deepcopy(layernorm_output)) + if self.shared_moe: + before_moe_dtype = layernorm_output.dtype + moe_hidden_fp32 = moe_hidden_states.to(torch.float32) + output_mlp = self.shared_mlp(layernorm_output).to( + torch.float32) + + coef, _ = self.coefficient(layernorm_output.to(torch.float32)) + + if self.shared_moe_mode == 'softmax': + coef = torch.nn.functional.softmax(coef, dim=-1) + hidden_states = moe_hidden_fp32 * ( + 1 - coef) + output_mlp * coef + elif self.shared_moe_mode == 'sigmoid': + coef = torch.nn.functional.sigmoid(coef) + hidden_states = moe_hidden_fp32 * ( + 1 - coef) + output_mlp * coef + + hidden_states = hidden_states.to(before_moe_dtype) + else: + hidden_states = moe_hidden_states + + residual = residual * self.layernorm_mlp_alpha + hidden_states = hidden_states * self.layernorm_mlp_beta + + hidden_states = residual + hidden_states + + return hidden_states, None + + @staticmethod + def shared_moe_coefficient_loader(param: torch.Tensor, + loaded_weight: torch.Tensor) -> None: + assert param.size() == loaded_weight.size() + + param.data.copy_(loaded_weight.to(torch.float32)) + return + + +class MiniMaxText01Model(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + cache_config: Optional[CacheConfig] = None, + scheduler_config=None, + prefix: str = "", + ) -> None: + super().__init__() + + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.decoder_attention_types = getattr( + config, "attn_type_list", False) or getattr( + config, "decoder_attention_types", False) + if not self.decoder_attention_types: + self.decoder_attention_types = [1] * config.num_hidden_layers + self.num_layers = config.num_hidden_layers + + self._layer_barrier = False + if get_pp_group().is_first_rank: + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=self.vocab_size, + ) + else: + self.embed_tokens = PPMissingLayer() + + def layer_fn(prefix): + layer_idx = int(prefix.split('.')[-1]) + layer_config = config + layer_config.attention_type = self.decoder_attention_types[ + layer_idx] + layer_config.layer_idx = layer_idx + + decoder_kwargs = { + "quant_config": quant_config, + "layer_id": layer_idx, + "cache_config": cache_config + } + + if layer_config.attention_type == 0: + decoder_kwargs["linear_layer_id"] = sum( + 1 for i in range(layer_idx) + if self.decoder_attention_types[i] == 0) + else: + decoder_kwargs["linear_layer_id"] = None + + if hasattr(config, "num_local_experts") and isinstance( + config.num_local_experts, list): + decoder_kwargs["expert_num"] = config.num_local_experts[ + layer_idx] + elif hasattr(config, "num_local_experts") and isinstance( + config.num_local_experts, int): + decoder_kwargs["expert_num"] = config.num_local_experts + else: + decoder_kwargs["expert_num"] = 1 + + return MiniMaxText01DecoderLayer(layer_config, + **decoder_kwargs, + prefix=prefix) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, layer_fn, prefix=f"{prefix}.layers") + + linear_layer_nums = sum(1 for i in range(config.num_hidden_layers) + if self.decoder_attention_types[i] == 0) + max_slots_number = scheduler_config.max_num_seqs + self.cache_shape = (linear_layer_nums, max_slots_number, + config.num_attention_heads // + get_tensor_model_parallel_world_size(), + config.head_dim, config.head_dim) + _dummy = torch.zeros(1) + self._dtype = _dummy.dtype + del _dummy + + self.minimax_cache = MinimaxCacheManager(dtype=self._dtype, + cache_shape=self.cache_shape) + + rope_theta = getattr(config, "rope_theta", 10000) + head_dim = getattr(config, "head_dim", + config.hidden_size // config.num_attention_heads) + if hasattr(config, "max_model_len") and isinstance( + config.max_model_len, int): + max_position_embeddings = min(config.max_position_embeddings, + config.max_model_len) + self.rotary_emb = MiniMaxText01RotaryEmbedding( + head_dim, + rotary_dim=config.rotary_dim + if hasattr(config, "rotary_dim") else head_dim, + max_position=max_position_embeddings, + base=int(rope_theta), + is_neox_style=True, + cache_dtype=torch.float32, + ) + + norm_kwargs = {} + if hasattr(config, "rms_norm_eps"): + norm_kwargs["eps"] = config.rms_norm_eps + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, **norm_kwargs) + else: + self.norm = PPMissingLayer() + self.embed_scale = 1.0 + return + + def _clear_prefill_cache(self, attn_metadata, + minimax_cache_tensors: torch.Tensor, **kwargs): + seq_to_slot_maps = {} + seq_id_map = sum(list(kwargs["request_ids_to_seq_ids"].values()), []) + for _, seq_to_slot_map in ( + self.minimax_cache.cache_indices_mapping.items()): + seq_to_slot_maps.update(seq_to_slot_map) + + slots_to_clear = [] + for _prefill_id in range(getattr(attn_metadata, "num_prefills", 0)): + seq_id = seq_id_map[_prefill_id] + if attn_metadata.context_lens_tensor[ + _prefill_id] == 0 and seq_id in seq_to_slot_maps: + slots_to_clear.append(seq_to_slot_maps[seq_id]) + + if slots_to_clear: + slots_tensor = torch.tensor(slots_to_clear, + device=minimax_cache_tensors.device, + dtype=torch.long) + minimax_cache_tensors[:, slots_tensor, ...] = 0 + + def forward(self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + intermediate_tensors=None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs) -> torch.Tensor: + forward_context = get_forward_context() + attn_metadata = forward_context.attn_metadata + if attn_metadata is None: + return None + if "request_ids_to_seq_ids" not in kwargs: + kwargs["request_ids_to_seq_ids"] = {} + if "finished_requests_ids" not in kwargs: + kwargs["finished_requests_ids"] = [] + ( + minimax_cache_tensors, + state_indices_tensor, + ) = self.minimax_cache.current_run_tensors(**kwargs) + if getattr(attn_metadata, "num_prefills", 0) > 0: + self._clear_prefill_cache(attn_metadata, minimax_cache_tensors, + **kwargs) + + minimax_cache_params = MinimaxCacheParams(minimax_cache_tensors, + state_indices_tensor) + if get_pp_group().is_first_rank: + if inputs_embeds is None: + hidden_states = self.embed_scale * self.embed_tokens(input_ids) + else: + hidden_states = inputs_embeds + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + kv_cache_index = 0 + minimax_cache_index = 0 + attn_metadata.rotary_emb = self.rotary_emb + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + _caches = None + if isinstance(layer.self_attn, MiniMaxText01Attention): + _caches = kv_caches[kv_cache_index] + kv_cache_index += 1 + if isinstance(layer.self_attn, MiniMaxText01LinearAttention): + current_state_layer = minimax_cache_index + _caches = minimax_cache_params.at_layer_idx( + current_state_layer) + minimax_cache_index += 1 + hidden_states, residual = layer( + hidden_states=hidden_states, + positions=positions, + kv_caches=_caches, + attn_metadata=attn_metadata, + residual=residual, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + if residual is not None: + hidden_states, _ = self.norm(hidden_states, residual) + else: + hidden_states = self.norm(hidden_states) + + return hidden_states + + +class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid, + SupportsV0Only): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + self.config = config + self.lora_config = lora_config + + if not hasattr(config, "sliding_window"): + config.sliding_window = None + + self.CONCAT_FFN = True + + self.unpadded_vocab_size = self.config.vocab_size + if hasattr(vllm_config.model_config, "max_model_len"): + self.config.max_model_len = vllm_config.model_config.max_model_len + self.model = MiniMaxText01Model( + self.config, + quant_config, + cache_config=vllm_config.cache_config, + scheduler_config=vllm_config.scheduler_config, + prefix=maybe_prefix(prefix, "model")) + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + self.config.hidden_size, + org_num_embeddings=self.config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE, + ) + + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + self.config.vocab_size) + + self.sampler = Sampler() + else: + self.lm_head = PPMissingLayer() + + flash_layer_count = sum(1 for attn_type in self.config.attn_type_list + if attn_type == 1) + self.kv_cache = [torch.tensor([]) for _ in range(flash_layer_count)] + return + + def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): + return self.model.minimax_cache.copy_inputs_before_cuda_graphs( + input_buffers, **kwargs) + + def get_seqlen_agnostic_capture_inputs(self, batch_size: int): + return self.model.minimax_cache.get_seqlen_agnostic_capture_inputs( + batch_size) + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, self.kv_cache, + intermediate_tensors, inputs_embeds, + **kwargs) + + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ): + + next_tokens = self.sampler(logits, sampling_metadata) + + return next_tokens + + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + "hidden_states": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + "residual": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + }) + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> None: + params_dict = dict(self.named_parameters()) + + def which_layer(name: str) -> int: + if "layers" in name: + after_layer = name.split("layers")[-1] + return int(after_layer.split(".")[1]) + return None + + def is_linear_attn_layer(layer_idx: int) -> bool: + if layer_idx is None or not hasattr(self.config, "attn_type_list"): + return False + return self.config.attn_type_list[layer_idx] == 0 + + def is_moe_weight(name: str) -> bool: + return "block_sparse_moe" in name and not name.endswith(".bias") + + def get_expert_id(param_name): + pattern = r'model\.layers\.\d+\.block_sparse_moe\.experts\.(\d+)\.' + match = re.search(pattern, param_name) + if match: + return match.group(1) + return None + + def load_sparse_moe_weight(name: str, loaded_weight: torch.Tensor, + self) -> None: + if isinstance(self.config.num_local_experts, list): + expert_params_mapping = [ + ("w13_weight" + if weight_name in ["w1", "w3"] else "w2_weight", + f"experts.{expert_id}.{weight_name}.weight", expert_id) + for expert_id in range(max(self.config.num_local_experts)) + for weight_name in ["w1", "w2", "w3"] + ] + else: + expert_params_mapping = [ + ("w13_scale" if weight_name in ["w1", "w3"] else + "w2_scale", f"{expert_id}.{weight_name}.weight_scale", + expert_id, weight_name) + for expert_id in range(self.config.num_local_experts) + for weight_name in ["w1", "w2", "w3"] + ] + [("w13_weight" if weight_name in ["w1", "w3"] else + "w2_weight", f"{expert_id}.{weight_name}.weight", + expert_id, weight_name) + for expert_id in range(self.config.num_local_experts) + for weight_name in ["w1", "w2", "w3"]] + for (param_name, weight_name, expert_id, + shard_id) in expert_params_mapping: + name_expert_id = get_expert_id(name) + if name_expert_id is not None and int(name_expert_id) != int( + expert_id): + continue + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name, self): + return + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader = weight_loader_with_alias(name)(weight_loader) + weight_loader(param, + loaded_weight, + weight_name, + expert_id=expert_id, + shard_id=shard_id) + break + else: + if is_pp_missing_parameter(name, self): + return + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader = weight_loader_with_alias(name)(weight_loader) + weight_loader(param, loaded_weight) + return + + def is_shared_mlp_weight(name: str) -> bool: + return "shared_mlp" in name and not name.endswith(".bias") + + def load_shared_mlp_weight(name: str, loaded_weight: torch.Tensor, + self) -> None: + if not self.CONCAT_FFN: + if "gate_proj" in name: + name = name.replace("gate_proj", "w1", 1) + elif "up_proj" in name: + name = name.replace("up_proj", "w3", 1) + elif "down_proj" in name: + name = name.replace("down_proj", "w2", 1) + else: + if "gate_proj" in name: + name = name.replace("gate_proj", "gate_up_proj", 1) + loaded_shard_id = 0 + elif "up_proj" in name: + name = name.replace("up_proj", "gate_up_proj", 1) + loaded_shard_id = 1 + if is_pp_missing_parameter(name, self): + return + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader = weight_loader_with_alias(name)(weight_loader) + if not self.CONCAT_FFN: + weight_loader(param, loaded_weight) + else: + if "gate_up_proj" in name: + weight_loader(param, loaded_weight, loaded_shard_id) + elif "down_proj" in name: + weight_loader(param, loaded_weight) + else: + raise AssertionError( + "MLP weight not in [gate_up_proj, down_proj]") + return + + def is_mha_weight(name: str) -> bool: + return "self_attn" in name and not name.endswith(".bias") + + def load_linear_attn_weight(name: str, loaded_weight: torch.Tensor, + self) -> None: + if is_pp_missing_parameter(name, self): + return + param = params_dict[name] + + weight_loader = getattr( + param, "weight_loader", + MiniMaxText01LinearAttention.weight_direct_load) + weight_loader = weight_loader_with_alias(name)(weight_loader) + weight_loader(param, loaded_weight) + return + + def load_flash_attn_weight(name: str, loaded_weight: torch.Tensor, + self) -> None: + + flash_mha_params_mapping = [ + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + for (param_name, weight_name, + shard_id) in flash_mha_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name, self): + return + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader = weight_loader_with_alias(name)(weight_loader) + weight_loader(param, loaded_weight, shard_id) + break + else: + if is_pp_missing_parameter(name, self): + return + param = params_dict[name] + + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader = weight_loader_with_alias(name)(weight_loader) + weight_loader(param, loaded_weight) + return + + def is_layer_norm_weight(name: str) -> bool: + return "norm" in name and not name.endswith( + ".bias") and name in params_dict + + def load_layer_norm_weight(name: str, loaded_weight: torch.Tensor, + self) -> None: + if is_pp_missing_parameter(name, self): + return + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader = weight_loader_with_alias(name)(weight_loader) + weight_loader(param, loaded_weight) + return + + def load_basic_weight(name: str, loaded_weight: torch.Tensor, + self) -> None: + if is_pp_missing_parameter(name, self): + return + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader = weight_loader_with_alias(name)(weight_loader) + weight_loader(param, loaded_weight) + return + + for name, loaded_weight in weights: + weight_at_layer = which_layer(name) + if weight_at_layer and weight_at_layer >= len( + self.config.attn_type_list): + continue + + if is_layer_norm_weight(name): + load_layer_norm_weight(name, loaded_weight, self) + continue + if is_mha_weight(name): + if is_linear_attn_layer(weight_at_layer): + load_linear_attn_weight(name, loaded_weight, self) + else: + load_flash_attn_weight(name, loaded_weight, self) + continue + if is_moe_weight(name): + load_sparse_moe_weight(name, loaded_weight, self) + continue + if is_shared_mlp_weight(name): + load_shared_mlp_weight(name, loaded_weight, self) + continue + + if "rotary_emb.inv_freq" in name: + continue + + load_basic_weight(name, loaded_weight, self) + return diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 2f1827c1..6ead6509 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -35,6 +35,7 @@ _TEXT_GENERATION_MODELS = { "AquilaModel": ("llama", "LlamaForCausalLM"), "AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2 "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"), + "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"), # baichuan-7b, upper case 'C' in the class name "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-13b, lower case 'c' in the class name