Add fused top-K softmax kernel for MoE (#2769)
This commit is contained in:
parent
2ccee3def6
commit
f0d4e14557
7
csrc/moe/moe_ops.cpp
Normal file
7
csrc/moe/moe_ops.cpp
Normal file
@ -0,0 +1,7 @@
|
||||
#include "moe_ops.h"
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("topk_softmax", &topk_softmax, "Apply topk softmax to the gating outputs.");
|
||||
}
|
9
csrc/moe/moe_ops.h
Normal file
9
csrc/moe/moe_ops.h
Normal file
@ -0,0 +1,9 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
void topk_softmax(
|
||||
torch::Tensor& topk_weights,
|
||||
torch::Tensor& topk_indices,
|
||||
torch::Tensor& token_expert_indices,
|
||||
torch::Tensor& gating_output);
|
499
csrc/moe/topk_softmax_kernels.cu
Normal file
499
csrc/moe/topk_softmax_kernels.cu
Normal file
@ -0,0 +1,499 @@
|
||||
/*
|
||||
* Adapted from https://github.com/NVIDIA/TensorRT-LLM/blob/v0.7.1/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.cu
|
||||
* Copyright (c) 2024, The vLLM team.
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include <cub/cub.cuh>
|
||||
#include <cub/util_type.cuh>
|
||||
|
||||
namespace vllm {
|
||||
namespace moe {
|
||||
|
||||
static constexpr int WARP_SIZE = 32;
|
||||
|
||||
/// Aligned array type
|
||||
template <
|
||||
typename T,
|
||||
/// Number of elements in the array
|
||||
int N,
|
||||
/// Alignment requirement in bytes
|
||||
int Alignment = sizeof(T) * N
|
||||
>
|
||||
class alignas(Alignment) AlignedArray {
|
||||
float data[N];
|
||||
};
|
||||
|
||||
// ====================== Softmax things ===============================
|
||||
// We have our own implementation of softmax here so we can support transposing the output
|
||||
// in the softmax kernel when we extend this module to support expert-choice routing.
|
||||
template <int TPB>
|
||||
__launch_bounds__(TPB) __global__
|
||||
void moeSoftmax(const float* input, const bool* finished, float* output, const int num_cols)
|
||||
{
|
||||
using BlockReduce = cub::BlockReduce<float, TPB>;
|
||||
__shared__ typename BlockReduce::TempStorage tmpStorage;
|
||||
|
||||
__shared__ float normalizing_factor;
|
||||
__shared__ float float_max;
|
||||
|
||||
const int thread_row_offset = blockIdx.x * num_cols;
|
||||
|
||||
cub::Sum sum;
|
||||
float threadData(-FLT_MAX);
|
||||
|
||||
// Don't touch finished rows.
|
||||
if ((finished != nullptr) && finished[blockIdx.x])
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
|
||||
{
|
||||
const int idx = thread_row_offset + ii;
|
||||
threadData = max(static_cast<float>(input[idx]), threadData);
|
||||
}
|
||||
|
||||
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
float_max = maxElem;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
threadData = 0;
|
||||
|
||||
for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
|
||||
{
|
||||
const int idx = thread_row_offset + ii;
|
||||
threadData += exp((static_cast<float>(input[idx]) - float_max));
|
||||
}
|
||||
|
||||
const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum);
|
||||
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
normalizing_factor = 1.f / Z;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
|
||||
{
|
||||
const int idx = thread_row_offset + ii;
|
||||
const float val = exp((static_cast<float>(input[idx]) - float_max)) * normalizing_factor;
|
||||
output[idx] = val;
|
||||
}
|
||||
}
|
||||
|
||||
template <int TPB>
|
||||
__launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax, const bool* finished, float* output,
|
||||
int* indices, int* source_rows, const int num_experts, const int k, const int start_expert, const int end_expert)
|
||||
{
|
||||
|
||||
using cub_kvp = cub::KeyValuePair<int, float>;
|
||||
using BlockReduce = cub::BlockReduce<cub_kvp, TPB>;
|
||||
__shared__ typename BlockReduce::TempStorage tmpStorage;
|
||||
|
||||
cub_kvp thread_kvp;
|
||||
cub::ArgMax arg_max;
|
||||
|
||||
const int num_rows = gridDim.x;
|
||||
const int block_row = blockIdx.x;
|
||||
|
||||
const bool row_is_active = finished ? !finished[block_row] : true;
|
||||
const int thread_read_offset = blockIdx.x * num_experts;
|
||||
for (int k_idx = 0; k_idx < k; ++k_idx)
|
||||
{
|
||||
thread_kvp.key = 0;
|
||||
thread_kvp.value = -1.f; // This is OK because inputs are probabilities
|
||||
|
||||
cub_kvp inp_kvp;
|
||||
for (int expert = threadIdx.x; expert < num_experts; expert += TPB)
|
||||
{
|
||||
const int idx = thread_read_offset + expert;
|
||||
inp_kvp.key = expert;
|
||||
inp_kvp.value = inputs_after_softmax[idx];
|
||||
|
||||
for (int prior_k = 0; prior_k < k_idx; ++prior_k)
|
||||
{
|
||||
const int prior_winning_expert = indices[k * block_row + prior_k];
|
||||
|
||||
if (prior_winning_expert == expert)
|
||||
{
|
||||
inp_kvp = thread_kvp;
|
||||
}
|
||||
}
|
||||
|
||||
thread_kvp = arg_max(inp_kvp, thread_kvp);
|
||||
}
|
||||
|
||||
const cub_kvp result_kvp = BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max);
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
// Ignore experts the node isn't responsible for with expert parallelism
|
||||
const int expert = result_kvp.key;
|
||||
const bool node_uses_expert = expert >= start_expert && expert < end_expert;
|
||||
const bool should_process_row = row_is_active && node_uses_expert;
|
||||
|
||||
const int idx = k * block_row + k_idx;
|
||||
output[idx] = result_kvp.value;
|
||||
indices[idx] = should_process_row ? (expert - start_expert) : num_experts;
|
||||
assert(indices[idx] >= 0);
|
||||
source_rows[idx] = k_idx * num_rows + block_row;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
// ====================== TopK softmax things ===============================
|
||||
|
||||
/*
|
||||
A Top-K gating softmax written to exploit when the number of experts in the MoE layers
|
||||
are a small power of 2. This allows us to cleanly share the rows among the threads in
|
||||
a single warp and eliminate communication between warps (so no need to use shared mem).
|
||||
|
||||
It fuses the softmax, max and argmax into a single kernel.
|
||||
|
||||
Limitations:
|
||||
1) This implementation is intended for when the number of experts is a small power of 2.
|
||||
2) This implementation assumes k is small, but will work for any k.
|
||||
*/
|
||||
|
||||
template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG>
|
||||
__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
|
||||
void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, int* indices,
|
||||
int* source_rows, const int k, const int start_expert, const int end_expert)
|
||||
{
|
||||
// We begin by enforcing compile time assertions and setting up compile time constants.
|
||||
static_assert(VPT == (VPT & -VPT), "VPT must be power of 2");
|
||||
static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS), "NUM_EXPERTS must be power of 2");
|
||||
static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG), "BYTES_PER_LDG must be power of 2");
|
||||
static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16");
|
||||
|
||||
// Number of bytes each thread pulls in per load
|
||||
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
|
||||
static constexpr int ELTS_PER_ROW = NUM_EXPERTS;
|
||||
static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT;
|
||||
static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG;
|
||||
|
||||
// Restrictions based on previous section.
|
||||
static_assert(VPT % ELTS_PER_LDG == 0, "The elements per thread must be a multiple of the elements per ldg");
|
||||
static_assert(WARP_SIZE % THREADS_PER_ROW == 0, "The threads per row must cleanly divide the threads per warp");
|
||||
static_assert(THREADS_PER_ROW == (THREADS_PER_ROW & -THREADS_PER_ROW), "THREADS_PER_ROW must be power of 2");
|
||||
static_assert(THREADS_PER_ROW <= WARP_SIZE, "THREADS_PER_ROW can be at most warp size");
|
||||
|
||||
// We have NUM_EXPERTS elements per row. We specialize for small #experts
|
||||
static constexpr int ELTS_PER_WARP = WARP_SIZE * VPT;
|
||||
static constexpr int ROWS_PER_WARP = ELTS_PER_WARP / ELTS_PER_ROW;
|
||||
static constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP;
|
||||
|
||||
// Restrictions for previous section.
|
||||
static_assert(ELTS_PER_WARP % ELTS_PER_ROW == 0, "The elts per row must cleanly divide the total elt per warp");
|
||||
|
||||
// ===================== From this point, we finally start computing run-time variables. ========================
|
||||
|
||||
// Compute CTA and warp rows. We pack multiple rows into a single warp, and a block contains WARPS_PER_CTA warps.
|
||||
// This, each block processes a chunk of rows. We start by computing the start row for each block.
|
||||
const int cta_base_row = blockIdx.x * ROWS_PER_CTA;
|
||||
|
||||
// Now, using the base row per thread block, we compute the base row per warp.
|
||||
const int warp_base_row = cta_base_row + threadIdx.y * ROWS_PER_WARP;
|
||||
|
||||
// The threads in a warp are split into sub-groups that will work on a row.
|
||||
// We compute row offset for each thread sub-group
|
||||
const int thread_row_in_warp = threadIdx.x / THREADS_PER_ROW;
|
||||
const int thread_row = warp_base_row + thread_row_in_warp;
|
||||
|
||||
// Threads with indices out of bounds should early exit here.
|
||||
if (thread_row >= num_rows)
|
||||
{
|
||||
return;
|
||||
}
|
||||
const bool row_is_active = finished ? !finished[thread_row] : true;
|
||||
|
||||
// We finally start setting up the read pointers for each thread. First, each thread jumps to the start of the
|
||||
// row it will read.
|
||||
const float* thread_row_ptr = input + thread_row * ELTS_PER_ROW;
|
||||
|
||||
// Now, we compute the group each thread belong to in order to determine the first column to start loads.
|
||||
const int thread_group_idx = threadIdx.x % THREADS_PER_ROW;
|
||||
const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG;
|
||||
const float* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread;
|
||||
|
||||
// Determine the pointer type to use to read in the data depending on the BYTES_PER_LDG template param. In theory,
|
||||
// this can support all powers of 2 up to 16.
|
||||
// NOTE(woosuk): The original implementation uses CUTLASS aligned array here.
|
||||
// We defined our own aligned array and use it here to avoid the dependency on CUTLASS.
|
||||
using AccessType = AlignedArray<float, ELTS_PER_LDG>;
|
||||
|
||||
// Finally, we pull in the data from global mem
|
||||
float row_chunk[VPT];
|
||||
AccessType* row_chunk_vec_ptr = reinterpret_cast<AccessType*>(&row_chunk);
|
||||
const AccessType* vec_thread_read_ptr = reinterpret_cast<const AccessType*>(thread_read_ptr);
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < LDG_PER_THREAD; ++ii)
|
||||
{
|
||||
row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW];
|
||||
}
|
||||
|
||||
// First, we perform a max reduce within the thread. We can do the max in fp16 safely (I think) and just
|
||||
// convert to float afterwards for the exp + sum reduction.
|
||||
float thread_max = row_chunk[0];
|
||||
#pragma unroll
|
||||
for (int ii = 1; ii < VPT; ++ii)
|
||||
{
|
||||
thread_max = max(thread_max, row_chunk[ii]);
|
||||
}
|
||||
|
||||
// Now, we find the max within the thread group and distribute among the threads. We use a butterfly reduce.
|
||||
#pragma unroll
|
||||
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
|
||||
{
|
||||
thread_max = max(thread_max, __shfl_xor_sync(0xFFFFFFFF, thread_max, mask, THREADS_PER_ROW));
|
||||
}
|
||||
|
||||
// From this point, thread max in all the threads have the max within the row.
|
||||
// Now, we subtract the max from each element in the thread and take the exp. We also compute the thread local sum.
|
||||
float row_sum = 0;
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < VPT; ++ii)
|
||||
{
|
||||
row_chunk[ii] = expf(row_chunk[ii] - thread_max);
|
||||
row_sum += row_chunk[ii];
|
||||
}
|
||||
|
||||
// Now, we perform the sum reduce within each thread group. Similar to the max reduce, we use a bufferfly pattern.
|
||||
#pragma unroll
|
||||
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
|
||||
{
|
||||
row_sum += __shfl_xor_sync(0xFFFFFFFF, row_sum, mask, THREADS_PER_ROW);
|
||||
}
|
||||
|
||||
// From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables
|
||||
// respectively. Finally, we can scale the rows for the softmax. Technically, for top-k gating we don't need to
|
||||
// compute the entire softmax row. We can likely look at the maxes and only compute for the top-k values in the row.
|
||||
// However, this kernel will likely not be a bottle neck and it seems better to closer match torch and find the
|
||||
// argmax after computing the softmax.
|
||||
const float reciprocal_row_sum = 1.f / row_sum;
|
||||
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < VPT; ++ii)
|
||||
{
|
||||
row_chunk[ii] = row_chunk[ii] * reciprocal_row_sum;
|
||||
}
|
||||
|
||||
// Now, softmax_res contains the softmax of the row chunk. Now, I want to find the topk elements in each row, along
|
||||
// with the max index.
|
||||
int start_col = first_elt_read_by_thread;
|
||||
static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW;
|
||||
|
||||
for (int k_idx = 0; k_idx < k; ++k_idx)
|
||||
{
|
||||
// First, each thread does the local argmax
|
||||
float max_val = row_chunk[0];
|
||||
int expert = start_col;
|
||||
#pragma unroll
|
||||
for (int ldg = 0, col = start_col; ldg < LDG_PER_THREAD; ++ldg, col += COLS_PER_GROUP_LDG)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ELTS_PER_LDG; ++ii)
|
||||
{
|
||||
float val = row_chunk[ldg * ELTS_PER_LDG + ii];
|
||||
|
||||
// No check on the experts here since columns with the smallest index are processed first and only
|
||||
// updated if > (not >=)
|
||||
if (val > max_val)
|
||||
{
|
||||
max_val = val;
|
||||
expert = col + ii;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Now, we perform the argmax reduce. We use the butterfly pattern so threads reach consensus about the max.
|
||||
// This will be useful for K > 1 so that the threads can agree on "who" had the max value. That thread can
|
||||
// then blank out their max with -inf and the warp can run more iterations...
|
||||
#pragma unroll
|
||||
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
|
||||
{
|
||||
float other_max = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, THREADS_PER_ROW);
|
||||
int other_expert = __shfl_xor_sync(0xFFFFFFFF, expert, mask, THREADS_PER_ROW);
|
||||
|
||||
// We want lower indices to "win" in every thread so we break ties this way
|
||||
if (other_max > max_val || (other_max == max_val && other_expert < expert))
|
||||
{
|
||||
max_val = other_max;
|
||||
expert = other_expert;
|
||||
}
|
||||
}
|
||||
|
||||
// Write the max for this k iteration to global memory.
|
||||
if (thread_group_idx == 0)
|
||||
{
|
||||
// Add a guard to ignore experts not included by this node
|
||||
const bool node_uses_expert = expert >= start_expert && expert < end_expert;
|
||||
const bool should_process_row = row_is_active && node_uses_expert;
|
||||
|
||||
// The lead thread from each sub-group will write out the final results to global memory. (This will be a
|
||||
// single) thread per row of the input/output matrices.
|
||||
const int idx = k * thread_row + k_idx;
|
||||
output[idx] = max_val;
|
||||
indices[idx] = should_process_row ? (expert - start_expert) : NUM_EXPERTS;
|
||||
source_rows[idx] = k_idx * num_rows + thread_row;
|
||||
}
|
||||
|
||||
// Finally, we clear the value in the thread with the current max if there is another iteration to run.
|
||||
if (k_idx + 1 < k)
|
||||
{
|
||||
const int ldg_group_for_expert = expert / COLS_PER_GROUP_LDG;
|
||||
const int thread_to_clear_in_group = (expert / ELTS_PER_LDG) % THREADS_PER_ROW;
|
||||
|
||||
// Only the thread in the group which produced the max will reset the "winning" value to -inf.
|
||||
if (thread_group_idx == thread_to_clear_in_group)
|
||||
{
|
||||
const int offset_for_expert = expert % ELTS_PER_LDG;
|
||||
// Safe to set to any negative value since row_chunk values must be between 0 and 1.
|
||||
row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] = -10000.f;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
namespace detail
|
||||
{
|
||||
// Constructs some constants needed to partition the work across threads at compile time.
|
||||
template <int EXPERTS, int BYTES_PER_LDG>
|
||||
struct TopkConstants
|
||||
{
|
||||
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
|
||||
static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, "");
|
||||
static constexpr int VECs_PER_THREAD = std::max(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE));
|
||||
static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG;
|
||||
static constexpr int THREADS_PER_ROW = EXPERTS / VPT;
|
||||
static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW;
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
template <int EXPERTS, int WARPS_PER_TB>
|
||||
void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, int* indices,
|
||||
int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream)
|
||||
{
|
||||
static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
|
||||
|
||||
static constexpr int BYTES_PER_LDG = std::min(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS);
|
||||
using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG>;
|
||||
static constexpr int VPT = Constants::VPT;
|
||||
static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
|
||||
const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP;
|
||||
const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB;
|
||||
|
||||
dim3 block_dim(WARP_SIZE, WARPS_PER_TB);
|
||||
topkGatingSoftmax<VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG><<<num_blocks, block_dim, 0, stream>>>(
|
||||
input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert);
|
||||
}
|
||||
|
||||
#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB) \
|
||||
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB>( \
|
||||
gating_output, nullptr, topk_weights, topk_indicies, \
|
||||
token_expert_indices, num_tokens, topk, 0, num_experts, \
|
||||
stream);
|
||||
|
||||
void topkGatingSoftmaxKernelLauncher(
|
||||
const float* gating_output,
|
||||
float* topk_weights,
|
||||
int* topk_indicies,
|
||||
int* token_expert_indices,
|
||||
float* softmax_workspace,
|
||||
const int num_tokens,
|
||||
const int num_experts,
|
||||
const int topk,
|
||||
cudaStream_t stream) {
|
||||
static constexpr int WARPS_PER_TB = 4;
|
||||
switch (num_experts) {
|
||||
case 1:
|
||||
LAUNCH_SOFTMAX(1, WARPS_PER_TB);
|
||||
break;
|
||||
case 2:
|
||||
LAUNCH_SOFTMAX(2, WARPS_PER_TB);
|
||||
break;
|
||||
case 4:
|
||||
LAUNCH_SOFTMAX(4, WARPS_PER_TB);
|
||||
break;
|
||||
case 8:
|
||||
LAUNCH_SOFTMAX(8, WARPS_PER_TB);
|
||||
break;
|
||||
case 16:
|
||||
LAUNCH_SOFTMAX(16, WARPS_PER_TB);
|
||||
break;
|
||||
case 32:
|
||||
LAUNCH_SOFTMAX(32, WARPS_PER_TB);
|
||||
break;
|
||||
case 64:
|
||||
LAUNCH_SOFTMAX(64, WARPS_PER_TB);
|
||||
break;
|
||||
case 128:
|
||||
LAUNCH_SOFTMAX(128, WARPS_PER_TB);
|
||||
break;
|
||||
case 256:
|
||||
LAUNCH_SOFTMAX(256, WARPS_PER_TB);
|
||||
break;
|
||||
default: {
|
||||
TORCH_CHECK(softmax_workspace != nullptr,
|
||||
"softmax_workspace must be provided for num_experts that are not a power of 2.");
|
||||
static constexpr int TPB = 256;
|
||||
moeSoftmax<TPB><<<num_tokens, TPB, 0, stream>>>(
|
||||
gating_output, nullptr, softmax_workspace, num_experts);
|
||||
moeTopK<TPB><<<num_tokens, TPB, 0, stream>>>(
|
||||
softmax_workspace, nullptr, topk_weights, topk_indicies, token_expert_indices,
|
||||
num_experts, topk, 0, num_experts);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace moe
|
||||
} // namespace vllm
|
||||
|
||||
void topk_softmax(
|
||||
torch::Tensor& topk_weights, // [num_tokens, topk]
|
||||
torch::Tensor& topk_indices, // [num_tokens, topk]
|
||||
torch::Tensor& token_expert_indices, // [num_tokens, topk]
|
||||
torch::Tensor& gating_output) // [num_tokens, num_experts]
|
||||
{
|
||||
const int num_experts = gating_output.size(-1);
|
||||
const int num_tokens = gating_output.numel() / num_experts;
|
||||
const int topk = topk_weights.size(-1);
|
||||
|
||||
const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
|
||||
const bool needs_workspace = !is_pow_2 || num_experts > 256;
|
||||
const int64_t workspace_size = needs_workspace ? num_tokens * num_experts : 0;
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
torch::Tensor softmax_workspace = torch::empty({workspace_size}, gating_output.options());
|
||||
vllm::moe::topkGatingSoftmaxKernelLauncher(
|
||||
gating_output.data_ptr<float>(),
|
||||
topk_weights.data_ptr<float>(),
|
||||
topk_indices.data_ptr<int>(),
|
||||
token_expert_indices.data_ptr<int>(),
|
||||
softmax_workspace.data_ptr<float>(),
|
||||
num_tokens,
|
||||
num_experts,
|
||||
topk,
|
||||
stream);
|
||||
}
|
@ -48,8 +48,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
&rotary_embedding,
|
||||
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
|
||||
|
||||
#ifndef USE_ROCM
|
||||
// Quantization ops
|
||||
#ifndef USE_ROCM
|
||||
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
|
||||
ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
|
||||
#endif
|
||||
|
11
setup.py
11
setup.py
@ -339,6 +339,17 @@ if _is_cuda():
|
||||
vllm_extension_sources.append("csrc/quantization/awq/gemm_kernels.cu")
|
||||
vllm_extension_sources.append("csrc/custom_all_reduce.cu")
|
||||
|
||||
# Add MoE kernels.
|
||||
ext_modules.append(
|
||||
CUDAExtension(
|
||||
name="vllm._moe_C",
|
||||
sources=glob("csrc/moe/*.cu") + glob("csrc/moe/*.cpp"),
|
||||
extra_compile_args={
|
||||
"cxx": CXX_FLAGS,
|
||||
"nvcc": NVCC_FLAGS,
|
||||
},
|
||||
))
|
||||
|
||||
if not _is_neuron():
|
||||
vllm_extension = CUDAExtension(
|
||||
name="vllm._C",
|
||||
|
@ -2,10 +2,8 @@
|
||||
|
||||
Run `pytest tests/kernels/test_moe.py`.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from transformers import MixtralConfig
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||
|
||||
@ -14,22 +12,21 @@ from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.models.mixtral import MixtralMoE
|
||||
|
||||
|
||||
def torch_moe(a, w1, w2, topk_weight, topk_ids):
|
||||
def torch_moe(a, w1, w2, score, topk):
|
||||
B, D = a.shape
|
||||
a = a.view(B, -1, D).repeat(1, topk_ids.shape[1], 1).reshape(-1, D)
|
||||
out = torch.zeros(B * topk_ids.shape[1],
|
||||
w2.shape[1],
|
||||
dtype=a.dtype,
|
||||
device=a.device)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||
topk_weight, topk_ids = torch.topk(score, topk)
|
||||
topk_weight = topk_weight.view(-1)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
for i in range(w1.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
out[mask] = SiluAndMul()(
|
||||
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
|
||||
return (out.view(B, -1, w2.shape[1]) *
|
||||
topk_weight.view(B, -1, 1)).sum(dim=1)
|
||||
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [512, 222, 33, 1])
|
||||
@ -51,11 +48,8 @@ def test_fused_moe(
|
||||
w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10
|
||||
|
||||
score = torch.randn((m, e), device='cuda', dtype=dtype)
|
||||
score = torch.softmax(score, dim=-1)
|
||||
topk_weight, topk_ids = torch.topk(score, topk)
|
||||
|
||||
triton_output = fused_moe(a, w1, w2, topk_weight, topk_ids, False)
|
||||
torch_output = torch_moe(a, w1, w2, topk_weight, topk_ids)
|
||||
triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False)
|
||||
torch_output = torch_moe(a, w1, w2, score, topk)
|
||||
assert torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0)
|
||||
|
||||
|
||||
@ -75,7 +69,7 @@ def test_mixtral_moe(dtype: torch.dtype):
|
||||
intermediate_size=config.intermediate_size,
|
||||
params_dtype=dtype,
|
||||
tp_size=1,
|
||||
)
|
||||
).cuda()
|
||||
|
||||
# Load the weights
|
||||
vllm_moe.gate.linear_weights["weight"][:] = hf_moe.gate.weight.data
|
||||
|
@ -4,6 +4,7 @@ import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm._C import ops
|
||||
from vllm.utils import is_hip
|
||||
|
||||
|
||||
@triton.jit
|
||||
@ -177,7 +178,6 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
|
||||
expert_ids: torch.Tensor,
|
||||
num_tokens_post_padded: torch.Tensor,
|
||||
mul_routed_weight: bool, top_k: int, config: dict):
|
||||
|
||||
assert topk_weights.stride(1) == 1
|
||||
assert sorted_token_ids.stride(0) == 1
|
||||
|
||||
@ -210,12 +210,15 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
|
||||
)
|
||||
|
||||
|
||||
def fused_moe(hidden_states: torch.Tensor,
|
||||
def fused_moe(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
inplace=False):
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
inplace: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism.
|
||||
|
||||
@ -223,15 +226,19 @@ def fused_moe(hidden_states: torch.Tensor,
|
||||
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
|
||||
- w1 (torch.Tensor): The first set of expert weights.
|
||||
- w2 (torch.Tensor): The second set of expert weights.
|
||||
- topk_weights (torch.Tensor): The weights for the top-k selected experts.
|
||||
- topk_ids (torch.Tensor): The indices of the top-k selected experts.
|
||||
- gating_output (torch.Tensor): The output of the gating operation (before softmax).
|
||||
- topk (int): The number of top-k experts to select.
|
||||
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
|
||||
- inplace (bool): If True, perform the operation in-place. Defaults to False.
|
||||
|
||||
Returns:
|
||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||
"""
|
||||
# Check constraints.
|
||||
assert hidden_states.shape[1] == w1.shape[2], "Incompatible dimensions"
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], (
|
||||
"Number of tokens mismatch")
|
||||
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
|
||||
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
|
||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
||||
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
||||
@ -241,6 +248,37 @@ def fused_moe(hidden_states: torch.Tensor,
|
||||
M, _ = hidden_states.shape
|
||||
E, N, _ = w1.shape
|
||||
|
||||
if is_hip():
|
||||
# The MoE kernels are not yet supported on ROCm.
|
||||
routing_weights = torch.softmax(gating_output,
|
||||
dim=-1,
|
||||
dtype=torch.float32)
|
||||
topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1)
|
||||
else:
|
||||
import vllm._moe_C as moe_kernels
|
||||
|
||||
topk_weights = torch.empty(M,
|
||||
topk,
|
||||
dtype=torch.float32,
|
||||
device=hidden_states.device)
|
||||
topk_ids = torch.empty(M,
|
||||
topk,
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device)
|
||||
token_expert_indicies = torch.empty(M,
|
||||
topk,
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device)
|
||||
moe_kernels.topk_softmax(
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
token_expert_indicies,
|
||||
gating_output.float(), # TODO(woosuk): Optimize this.
|
||||
)
|
||||
del token_expert_indicies # Not used. Will be used in the future.
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
config = {
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 64,
|
||||
|
@ -25,7 +25,6 @@ from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
@ -155,20 +154,12 @@ class DeepseekMoE(nn.Module):
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
# router_logits: (batch * sequence_length, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
routing_weights, selected_experts = torch.topk(routing_weights,
|
||||
self.top_k,
|
||||
dim=-1)
|
||||
|
||||
if self.config.norm_topk_prob:
|
||||
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
final_hidden_states = fused_moe(hidden_states,
|
||||
self.w1,
|
||||
self.w2,
|
||||
routing_weights,
|
||||
selected_experts,
|
||||
router_logits,
|
||||
self.top_k,
|
||||
renormalize=self.config.norm_topk_prob,
|
||||
inplace=True)
|
||||
|
||||
if self.config.n_shared_experts is not None:
|
||||
|
@ -24,8 +24,6 @@
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from torch import nn
|
||||
from transformers import MixtralConfig
|
||||
|
||||
@ -128,18 +126,12 @@ class MixtralMoE(nn.Module):
|
||||
hidden_states = hidden_states.view(-1, self.hidden_size)
|
||||
# router_logits: (batch * sequence_length, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
routing_weights, selected_experts = torch.topk(routing_weights,
|
||||
self.top_k,
|
||||
dim=-1)
|
||||
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
final_hidden_states = fused_moe(hidden_states,
|
||||
self.ws,
|
||||
self.w2s,
|
||||
routing_weights,
|
||||
selected_experts,
|
||||
router_logits,
|
||||
self.top_k,
|
||||
renormalize=True,
|
||||
inplace=True)
|
||||
|
||||
if self.tp_size > 1:
|
||||
|
Loading…
x
Reference in New Issue
Block a user