From 5d60def02cb5a43fa5864fcb123909b101df9ec5 Mon Sep 17 00:00:00 2001 From: wangding zeng <155410488+zwd003@users.noreply.github.com> Date: Tue, 30 Jan 2024 13:19:48 +0800 Subject: [PATCH] DeepseekMoE support with Fused MoE kernel (#2453) Co-authored-by: roy --- csrc/dispatch_utils.h | 11 + csrc/moe_align_block_size_kernels.cu | 108 ++++++ csrc/ops.h | 9 + csrc/pybind.cpp | 4 + setup.py | 1 + tests/kernels/test_fused_moe.py | 50 +++ vllm/model_executor/layers/fused_moe.py | 287 +++++++++++++++ vllm/model_executor/models/__init__.py | 1 + vllm/model_executor/models/deepseek.py | 453 ++++++++++++++++++++++++ 9 files changed, 924 insertions(+) create mode 100644 csrc/moe_align_block_size_kernels.cu create mode 100644 tests/kernels/test_fused_moe.py create mode 100644 vllm/model_executor/layers/fused_moe.py create mode 100644 vllm/model_executor/models/deepseek.py diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h index 85fdfc09..91abd9e8 100644 --- a/csrc/dispatch_utils.h +++ b/csrc/dispatch_utils.h @@ -24,3 +24,14 @@ #define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH( \ TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__)) + +#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) + +#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) diff --git a/csrc/moe_align_block_size_kernels.cu b/csrc/moe_align_block_size_kernels.cu new file mode 100644 index 00000000..81cc6dd6 --- /dev/null +++ b/csrc/moe_align_block_size_kernels.cu @@ -0,0 +1,108 @@ +#include +#include + +#include +#include + +#include "cuda_compat.h" +#include "dispatch_utils.h" + +const static size_t NUM_MAX_EXPERTS = 64; +#define CEILDIV(x,y) (((x) + (y) - 1) / (y)) + +namespace vllm { +template +__global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids, + int32_t *sorted_token_ids, + int32_t *expert_ids, + int32_t *total_tokens_post_pad, + int32_t num_experts, + int32_t block_size, + size_t numel) { + const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); + const size_t start_idx = threadIdx.x * tokens_per_thread; + __shared__ int32_t tokens_cnts[NUM_MAX_EXPERTS + 1][NUM_MAX_EXPERTS]; + __shared__ int32_t cumsum[NUM_MAX_EXPERTS + 1]; + for (int i = 0; i < num_experts; ++i) { + tokens_cnts[threadIdx.x + 1][i] = 0; + } + + /** + * In the first step we compute token_cnts[thread_index + 1][expert_index], + * which counts how many tokens in the token shard of thread_index are assigned + * to expert expert_index. + */ + for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { + ++tokens_cnts[threadIdx.x + 1][topk_ids[i]]; + } + + __syncthreads(); + + // For each expert we accumulate the token counts from the different threads. + tokens_cnts[0][threadIdx.x] = 0; + for (int i = 1; i <= blockDim.x; ++i) { + tokens_cnts[i][threadIdx.x] += tokens_cnts[i-1][threadIdx.x]; + } + + __syncthreads(); + + // We accumulate the token counts of all experts in thread 0. + if (threadIdx.x == 0) { + cumsum[0] = 0; + for (int i = 1; i <= num_experts; ++i) { + cumsum[i] = cumsum[i-1] + CEILDIV(tokens_cnts[blockDim.x][i - 1], block_size) * block_size; + } + *total_tokens_post_pad = cumsum[num_experts]; + } + + __syncthreads(); + + /** + * For each expert, each thread processes the tokens of the corresponding blocks + * and stores the corresponding expert_id for each block. + */ + for (int i = cumsum[threadIdx.x];i < cumsum[threadIdx.x + 1];i += block_size) { + expert_ids[i / block_size] = threadIdx.x; + } + + /** + * Each thread processes a token shard, calculating the index of each token after + * sorting by expert number. Given the example topk_ids = [0,1,2,1,2,3,0,3,4] and + * block_size = 4, then the output would be [0, 6, *, *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], + * where * represents a padding value(preset in python). + */ + for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { + int32_t expert_id = topk_ids[i]; + /** The cumsum[expert_id] stores the starting index of the tokens that the + * expert with expert_id needs to process, and tokens_cnts[threadIdx.x][expert_id] + * stores the indices of the tokens processed by the expert with expert_id within + * the current thread's token shard. + */ + int32_t rank_post_pad = tokens_cnts[threadIdx.x][expert_id] + cumsum[expert_id]; + sorted_token_ids[rank_post_pad] = i; + ++tokens_cnts[threadIdx.x][expert_id]; + } +} +} + +void moe_align_block_size( + torch::Tensor topk_ids, + int num_experts, + int block_size, + torch::Tensor sorted_token_ids, + torch::Tensor experts_ids, + torch::Tensor num_tokens_post_pad) { + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + assert(num_experts <= NUM_MAX_EXPERTS); + VLLM_DISPATCH_INTEGRAL_TYPES( + topk_ids.scalar_type(), "moe_alig_block_size_kernel", [&] { + vllm::moe_align_block_size_kernel<<<1, num_experts, 0, stream>>>( + topk_ids.data_ptr(), + sorted_token_ids.data_ptr(), + experts_ids.data_ptr(), + num_tokens_post_pad.data_ptr(), + num_experts, + block_size, + topk_ids.numel()); + }); +} diff --git a/csrc/ops.h b/csrc/ops.h index ce77dd47..6e52dd81 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -121,3 +121,12 @@ std::pair, std::vector> get_graph_buffer_ipc_meta( void register_graph_buffers(fptr_t _fa, const std::vector &handles, const std::vector> &offsets); #endif + +void moe_align_block_size( + torch::Tensor topk_ids, + int num_experts, + int block_size, + torch::Tensor sorted_token_ids, + torch::Tensor experts_ids, + torch::Tensor num_tokens_post_pad + ); diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index db2da8f0..a8a99883 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -56,6 +56,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ"); ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ"); ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM"); + ops.def( + "moe_align_block_size", + &moe_align_block_size, + "Aligning the number of tokens to be processed by each expert such that it is divisible by the block size."); // Cache ops pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); diff --git a/setup.py b/setup.py index 8fad433f..3e212785 100644 --- a/setup.py +++ b/setup.py @@ -309,6 +309,7 @@ vllm_extension_sources = [ "csrc/quantization/squeezellm/quant_cuda_kernel.cu", "csrc/quantization/gptq/q_gemm.cu", "csrc/cuda_utils_kernels.cu", + "csrc/moe_align_block_size_kernels.cu", "csrc/pybind.cpp", ] diff --git a/tests/kernels/test_fused_moe.py b/tests/kernels/test_fused_moe.py new file mode 100644 index 00000000..80a0349d --- /dev/null +++ b/tests/kernels/test_fused_moe.py @@ -0,0 +1,50 @@ +import pytest +import torch + +from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.activation import SiluAndMul + + +def torch_moe(a, w1, w2, topk_weight, topk_ids): + B, D = a.shape + a = a.view(B, -1, D).repeat(1, topk_ids.shape[1], 1).reshape(-1, D) + out = torch.zeros(B * topk_ids.shape[1], + w2.shape[1], + dtype=a.dtype, + device=a.device) + topk_ids = topk_ids.view(-1) + topk_weight = topk_weight.view(-1) + for i in range(w1.shape[0]): + mask = topk_ids == i + if mask.sum(): + out[mask] = SiluAndMul()( + a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) + return (out.view(B, -1, w2.shape[1]) * + topk_weight.view(B, -1, 1)).sum(dim=1) + + +@pytest.mark.parametrize("m", [512, 222, 33, 1]) +@pytest.mark.parametrize("n", [2048, 256, 1024]) +@pytest.mark.parametrize("k", [128, 511, 1024]) +@pytest.mark.parametrize("e", [8, 64]) +@pytest.mark.parametrize("topk", [2, 6]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_fused_moe( + m: int, + n: int, + k: int, + e: int, + topk: int, + dtype: torch.dtype, +): + a = torch.randn((m, k), device='cuda', dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10 + + score = torch.randn((m, e), device='cuda', dtype=dtype) + score = torch.softmax(score, dim=-1) + topk_weight, topk_ids = torch.topk(score, topk) + + triton_output = fused_moe(a, w1, w2, topk_weight, topk_ids, False) + torch_output = torch_moe(a, w1, w2, topk_weight, topk_ids) + assert torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0) diff --git a/vllm/model_executor/layers/fused_moe.py b/vllm/model_executor/layers/fused_moe.py new file mode 100644 index 00000000..998062d8 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe.py @@ -0,0 +1,287 @@ +"""Fused MoE kernel.""" +import torch +import triton +import triton.language as tl + +from vllm._C import ops + + +@triton.jit +def fused_moe_kernel( + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + # Matrix dimensions + N, + K, + EM, + num_valid_tokens, + # The stride variables represent how much to increase the ptr by when moving by 1 + # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` + # by to get the element one row down (A has M rows). + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + top_k: tl.constexpr, + compute_type: tl.constexpr, +): + """ + Implements the fused computation for a Mixture of Experts (MOE) using token and expert matrices. + + Key Parameters: + - A: The input tensor representing tokens with shape (*, K), where '*' can be any shape representing batches and K is the feature dimension of each token. + - B: The stacked MOE weight tensor with shape (E, N, K), where E is the number of experts, K is the input feature dimension, and N is the output feature dimension. + - C: The output cache tensor with shape (M, topk, N), where M is the total number of tokens post padding, topk is the number of times each token is repeated, + and N is the output feature dimension. + - sorted_token_ids: A tensor containing the sorted indices of tokens, repeated topk times and arranged by the expert index they are assigned to. + - expert_ids: A tensor containing the indices of the expert for each block. It determines which expert matrix from B should be used for each block in A. + This kernel performs the multiplication of a token by its corresponding expert matrix as determined by `expert_ids`. The sorting of `sorted_token_ids` + by expert index and padding ensures divisibility by BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix multiplication across different blocks processed by the same expert. + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) + if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: + return + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + token_mask = offs_token < num_valid_tokens + + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + + offs_k[None, :] * stride_ak) + + off_experts = tl.load(expert_ids_ptr + pid_m) + b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + + offs_bn[None, :] * stride_bn) + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the K dimension. + a = tl.load(a_ptrs, + mask=token_mask[:, None] & + (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0) + b = tl.load(b_ptrs, + mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, + other=0.0) + # We accumulate along the K dimension. + accumulator += tl.dot(a, b) + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + if MUL_ROUTED_WEIGHT: + moe_weight = tl.load(topk_weights_ptr + offs_token, + mask=token_mask, + other=0) + accumulator = accumulator * moe_weight[:, None] + + accumulator = accumulator.to(compute_type) + # ----------------------------------------------------------- + # Write back the block of the output + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[ + None, :] + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + +def moe_align_block_size( + topk_ids: torch.Tensor, block_size: int, + num_experts: int) -> (torch.Tensor, torch.Tensor, torch.Tensor): + """ + Aligns the token distribution across experts to be compatible with block size for matrix multiplication. + + Parameters: + - topk_ids: A tensor of shape [total_tokens, top_k] representing the top-k expert indices for each token. + - block_size: The block size used in block matrix multiplication. + - num_experts: The total number of experts. + + Returns: + - sorted_token_ids: A tensor containing the sorted token indices according to their allocated expert. + - expert_ids: A tensor indicating the assigned expert index for each block. + - num_tokens_post_padded: The total number of tokens after padding, ensuring divisibility by block_size. + + This function pads the number of tokens that each expert needs to process so that it is divisible by block_size. + Padding ensures that during block matrix multiplication, the dimensions align correctly. + + Example: + Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]], block_size = 4, and num_experts = 4: + - We initially have 12 tokens (after repeating 'top_k' times) and 4 experts, with each expert needing to process 3 tokens. + - As block_size is 4, we pad 1 token for each expert. + - First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3]. + - Then append padding tokens [12, 12, 12, 12] for each block. + - After sorting by expert index, we obtain token_ids [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12]. + Tokens 12 are non-existent (padding) and are ignored in the subsequent matrix multiplication. + - The padding ensures that the total number of tokens is now divisible by block_size for proper block matrix operations. + """ + sorted_ids = torch.empty( + (topk_ids.numel() + num_experts * (block_size - 1), ), + dtype=torch.int32, + device=topk_ids.device) + expert_ids = torch.empty((topk_ids.numel() + num_experts, ), + dtype=torch.int32, + device=topk_ids.device) + sorted_ids.fill_(topk_ids.numel()) + num_tokens_post_pad = torch.empty((1), + dtype=torch.int32, + device=topk_ids.device) + ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids, + expert_ids, num_tokens_post_pad) + return sorted_ids, expert_ids, num_tokens_post_pad + + +def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + mul_routed_weight: bool, top_k: int, config: dict): + + assert topk_weights.stride(1) == 1 + assert sorted_token_ids.stride(0) == 1 + + grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[ + 'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), ) + + fused_moe_kernel[grid]( + A, + B, + C, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + B.shape[1], + B.shape[2], + sorted_token_ids.shape[0], + topk_ids.numel(), + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(1), + C.stride(2), + MUL_ROUTED_WEIGHT=mul_routed_weight, + top_k=top_k, + compute_type=tl.bfloat16 if A.dtype == torch.bfloat16 else tl.float16, + **config, + ) + + +def fused_moe(hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace=False): + """ + This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism. + + Parameters: + - hidden_states (torch.Tensor): The input tensor to the MoE layer. + - w1 (torch.Tensor): The first set of expert weights. + - w2 (torch.Tensor): The second set of expert weights. + - topk_weights (torch.Tensor): The weights for the top-k selected experts. + - topk_ids (torch.Tensor): The indices of the top-k selected experts. + - inplace (bool): If True, perform the operation in-place. Defaults to False. + + Returns: + - torch.Tensor: The output tensor after applying the MoE layer. + """ + # Check constraints. + assert hidden_states.shape[1] == w1.shape[2], "Incompatible dimensions" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w1.is_contiguous(), "Expert weights1 must be contiguous" + assert w2.is_contiguous(), "Expert weights2 must be contiguous" + assert hidden_states.dtype in [torch.float16, torch.bfloat16] + M, _ = hidden_states.shape + E, N, _ = w1.shape + + config = { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + } + + if topk_ids.numel() <= w1.shape[0]: + config = { + 'BLOCK_SIZE_M': 16, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 64, + 'GROUP_SIZE_M': 1 + } + + intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N), + device=hidden_states.device, + dtype=hidden_states.dtype) + intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2), + device=hidden_states.device, + dtype=hidden_states.dtype) + intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]), + device=hidden_states.device, + dtype=hidden_states.dtype) + + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( + topk_ids, config['BLOCK_SIZE_M'], E) + + invoke_fused_moe_kernel(hidden_states, w1, intermediate_cache1, + topk_weights, topk_ids, sorted_token_ids, + expert_ids, num_tokens_post_padded, False, + topk_ids.shape[1], config) + + ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) + + invoke_fused_moe_kernel(intermediate_cache2, w2, intermediate_cache3, + topk_weights, topk_ids, sorted_token_ids, + expert_ids, num_tokens_post_padded, True, 1, + config) + + if inplace: + return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), + dim=1, + out=hidden_states) + return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), + dim=1) diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index b1d74b51..93631d26 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -18,6 +18,7 @@ _MODELS = { "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"), "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"), "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"), + "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"), "FalconForCausalLM": ("falcon", "FalconForCausalLM"), "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"), "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"), diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py new file mode 100644 index 00000000..fc727b8e --- /dev/null +++ b/vllm/model_executor/models/deepseek.py @@ -0,0 +1,453 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Deepseek model.""" +from typing import Any, Dict, List, Optional, Tuple + +import torch +from torch import nn +import torch.nn.functional as F +from transformers import PretrainedConfig + +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (LinearMethodBase, + MergedColumnParallelLinear, + ReplicatedLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding, ParallelLMHead) +from vllm.model_executor.parallel_utils.communication_op import ( + tensor_model_parallel_all_reduce) +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) +from vllm.sequence import SamplerOutput + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class DeepseekMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + linear_method: Optional[LinearMethodBase] = None, + reduce_results: bool = True, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + linear_method=linear_method) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + linear_method=linear_method, + reduce_results=reduce_results) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class DeepseekMoE(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.config = config + self.rank = get_tensor_model_parallel_rank() + self.tp_size = get_tensor_model_parallel_world_size() + self.n_routed_experts = config.n_routed_experts + self.top_k = config.num_experts_per_tok + if self.tp_size > self.n_routed_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {self.n_routed_experts}.") + + self.experts = nn.ModuleList([ + DeepseekMLP(hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + hidden_act=config.hidden_act, + linear_method=linear_method, + reduce_results=False) + for idx in range(self.n_routed_experts) + ]) + self.pack_params() + + self.gate = ReplicatedLinear(config.hidden_size, + self.n_routed_experts, + bias=False, + linear_method=None) + + if config.n_shared_experts is not None: + intermediate_size = config.moe_intermediate_size * config.n_shared_experts + self.shared_experts = DeepseekMLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=config.hidden_act, + linear_method=linear_method, + reduce_results=False, + ) + + def pack_params(self): + w1 = [] + w2 = [] + for expert in self.experts: + w1.append(expert.gate_up_proj.weight) + w2.append(expert.down_proj.weight) + self.w1 = torch._utils._flatten_dense_tensors(w1) + w1s = torch._utils._unflatten_dense_tensors(self.w1, w1) + for data, param in zip(w1s, w1): + param.data = data + self.w1 = self.w1.view(len(w1), *w1s[0].shape) + + self.w2 = torch._utils._flatten_dense_tensors(w2) + w2s = torch._utils._unflatten_dense_tensors(self.w2, w2) + for data, param in zip(w2s, w2): + param.data = data + + self.w2 = self.w2.view(len(w2), *w2s[0].shape) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + if self.config.n_shared_experts is not None: + shared_output = self.shared_experts(hidden_states) + # router_logits: (batch * sequence_length, n_experts) + router_logits, _ = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, + self.top_k, + dim=-1) + + if self.config.norm_topk_prob: + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + + final_hidden_states = fused_moe(hidden_states, + self.w1, + self.w2, + routing_weights, + selected_experts, + inplace=True) + + if self.config.n_shared_experts is not None: + final_hidden_states = final_hidden_states + shared_output + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + + return final_hidden_states.view(batch_size, sequence_length, + hidden_dim) + + +class DeepseekAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + linear_method=linear_method, + ) + + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + linear_method=linear_method, + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = PagedAttention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + k_cache, v_cache = kv_cache + attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class DeepseekDecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + layer_idx: int, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + self.self_attn = DeepseekAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + linear_method=linear_method, + ) + if (config.n_routed_experts is not None and \ + layer_idx >= config.first_k_dense_replace and layer_idx % config.moe_layer_freq == 0): + self.mlp = DeepseekMoE(config=config, linear_method=linear_method) + else: + self.mlp = DeepseekMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + linear_method=linear_method, + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + residual: Optional[torch.Tensor], + ) -> torch.Tensor: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + input_metadata=input_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class DeepseekModel(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.layers = nn.ModuleList([ + DeepseekDecoderLayer(config, + layer_idx, + linear_method=linear_method) + for layer_idx in range(config.num_hidden_layers) + ]) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + residual = None + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states, residual = layer(positions, hidden_states, + kv_caches[i], input_metadata, + residual) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class DeepseekForCausalLM(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.config = config + self.linear_method = linear_method + self.model = DeepseekModel(config, linear_method) + self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + self.sampler = Sampler(config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, kv_caches, + input_metadata) + return hidden_states + + def sample( + self, + hidden_states: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(self.lm_head.weight, hidden_states, + sampling_metadata) + return next_tokens + + def load_weights(self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, + cache_dir, + load_format, + revision, + fall_back_to_pt=False): + if "rotary_emb.inv_freq" in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip experts that are not assigned to this worker. + if (("mlp.experts." in name or "mlp.shared_experts." in name) + and name not in params_dict): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip experts that are not assigned to this worker. + if (("mlp.experts." in name or "mlp.shared_experts." in name) + and name not in params_dict): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight)