[Kernel] port sgl moe_align_block_size kernels (#12574)
sgl_moe_align_block_size is based on:ded9fcd09a
moe_align_block_size is based on:ba5112ff69
Signed-off-by: Yang Chen <yangche@fb.com>
This commit is contained in:
parent
326fcc8b9f
commit
95460fc513
@ -197,6 +197,72 @@ __global__ void moe_align_block_size_global_mem_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
// taken from
|
||||
// https://github.com/sgl-project/sglang/commit/ded9fcd09a43d5e7d5bb31a2bc3e9fc21bf65d2a
|
||||
template <typename scalar_t>
|
||||
__global__ void sgl_moe_align_block_size_kernel(
|
||||
scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids,
|
||||
int32_t* expert_ids, int32_t* total_tokens_post_pad, int32_t num_experts,
|
||||
int32_t block_size, size_t numel, int32_t* cumsum) {
|
||||
__shared__ int32_t shared_counts[32][8];
|
||||
__shared__ int32_t local_offsets[256];
|
||||
|
||||
const int warp_id = threadIdx.x / WARP_SIZE;
|
||||
const int lane_id = threadIdx.x % WARP_SIZE;
|
||||
const int experts_per_warp = 8;
|
||||
const int my_expert_start = warp_id * experts_per_warp;
|
||||
|
||||
for (int i = 0; i < experts_per_warp; ++i) {
|
||||
if (my_expert_start + i < num_experts) {
|
||||
shared_counts[warp_id][i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
|
||||
const size_t start_idx = threadIdx.x * tokens_per_thread;
|
||||
|
||||
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
|
||||
int expert_id = topk_ids[i];
|
||||
int warp_idx = expert_id / experts_per_warp;
|
||||
int expert_offset = expert_id % experts_per_warp;
|
||||
atomicAdd(&shared_counts[warp_idx][expert_offset], 1);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
cumsum[0] = 0;
|
||||
for (int i = 1; i <= num_experts; ++i) {
|
||||
int expert_count = 0;
|
||||
int warp_idx = (i - 1) / experts_per_warp;
|
||||
int expert_offset = (i - 1) % experts_per_warp;
|
||||
expert_count = shared_counts[warp_idx][expert_offset];
|
||||
|
||||
cumsum[i] =
|
||||
cumsum[i - 1] + CEILDIV(expert_count, block_size) * block_size;
|
||||
}
|
||||
*total_tokens_post_pad = cumsum[num_experts];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x < num_experts) {
|
||||
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
|
||||
i += block_size) {
|
||||
expert_ids[i / block_size] = threadIdx.x;
|
||||
}
|
||||
local_offsets[threadIdx.x] = cumsum[threadIdx.x];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
|
||||
int32_t expert_id = topk_ids[i];
|
||||
int32_t rank_post_pad = atomicAdd(&local_offsets[expert_id], 1);
|
||||
sorted_token_ids[rank_post_pad] = i;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, int TOPK>
|
||||
__global__ void moe_sum_kernel(
|
||||
scalar_t* __restrict__ out, // [..., d]
|
||||
@ -305,6 +371,32 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||
}
|
||||
}
|
||||
|
||||
void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||
int64_t block_size,
|
||||
torch::Tensor sorted_token_ids,
|
||||
torch::Tensor experts_ids,
|
||||
torch::Tensor num_tokens_post_pad) {
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_INTEGRAL_TYPES(
|
||||
topk_ids.scalar_type(), "sgl_moe_align_block_size_kernel", [&] {
|
||||
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
|
||||
// tensors
|
||||
auto options_int =
|
||||
torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device());
|
||||
// torch::Tensor token_cnts_buffer =
|
||||
// torch::empty({(num_experts + 1) * num_experts}, options_int);
|
||||
torch::Tensor cumsum_buffer =
|
||||
torch::empty({num_experts + 1}, options_int);
|
||||
|
||||
auto kernel = vllm::moe::sgl_moe_align_block_size_kernel<scalar_t>;
|
||||
kernel<<<1, 1024, 0, stream>>>(
|
||||
topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
|
||||
experts_ids.data_ptr<int32_t>(),
|
||||
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
|
||||
topk_ids.numel(), cumsum_buffer.data_ptr<int32_t>());
|
||||
});
|
||||
}
|
||||
|
||||
void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size]
|
||||
torch::Tensor& output) // [num_tokens, hidden_size]
|
||||
{
|
||||
|
@ -12,3 +12,9 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||
int64_t block_size, torch::Tensor sorted_token_ids,
|
||||
torch::Tensor experts_ids,
|
||||
torch::Tensor num_tokens_post_pad);
|
||||
|
||||
void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||
int64_t block_size,
|
||||
torch::Tensor sorted_token_ids,
|
||||
torch::Tensor experts_ids,
|
||||
torch::Tensor num_tokens_post_pad);
|
||||
|
@ -22,6 +22,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
||||
" Tensor! num_tokens_post_pad) -> ()");
|
||||
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
|
||||
|
||||
// temporarily adapted from
|
||||
// https://github.com/sgl-project/sglang/commit/ded9fcd09a43d5e7d5bb31a2bc3e9fc21bf65d2a
|
||||
m.def(
|
||||
"sgl_moe_align_block_size(Tensor topk_ids, int num_experts,"
|
||||
" int block_size, Tensor! sorted_token_ids,"
|
||||
" Tensor! experts_ids,"
|
||||
" Tensor! num_tokens_post_pad) -> ()");
|
||||
m.impl("sgl_moe_align_block_size", torch::kCUDA, &sgl_moe_align_block_size);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
m.def(
|
||||
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
|
||||
|
@ -952,6 +952,15 @@ def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
|
||||
num_tokens_post_pad)
|
||||
|
||||
|
||||
def sgl_moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
|
||||
block_size: int, sorted_token_ids: torch.Tensor,
|
||||
experts_ids: torch.Tensor,
|
||||
num_tokens_post_pad: torch.Tensor) -> None:
|
||||
torch.ops._moe_C.sgl_moe_align_block_size(topk_ids, num_experts,
|
||||
block_size, sorted_token_ids,
|
||||
experts_ids, num_tokens_post_pad)
|
||||
|
||||
|
||||
def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||
token_expert_indicies: torch.Tensor,
|
||||
gating_output: float) -> None:
|
||||
|
@ -82,6 +82,7 @@ if TYPE_CHECKING:
|
||||
VLLM_MLA_DISABLE: bool = False
|
||||
VLLM_MLA_PERFORM_MATRIX_ABSORPTION: bool = True
|
||||
VLLM_MLA_DISABLE_REQUANTIZATION: bool = False
|
||||
VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: bool = False
|
||||
|
||||
|
||||
def get_default_cache_root():
|
||||
@ -531,7 +532,13 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
||||
# matrices to match the activation type. This can lead to higher memory and
|
||||
# compute usage but better preserves the accuracy of the original model.
|
||||
"VLLM_MLA_DISABLE_REQUANTIZATION":
|
||||
lambda: bool(int(os.getenv("VLLM_MLA_DISABLE_REQUANTIZATION", "0")))
|
||||
lambda: bool(int(os.getenv("VLLM_MLA_DISABLE_REQUANTIZATION", "0"))),
|
||||
|
||||
# If set, vLLM will use the Triton implementation of moe_align_block_size,
|
||||
# i.e. moe_align_block_size_triton in fused_moe.py.
|
||||
"VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON":
|
||||
lambda: bool(int(os.getenv("VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
|
||||
),
|
||||
}
|
||||
|
||||
# end-env-vars-definition
|
||||
|
@ -405,6 +405,144 @@ def fused_moe_kernel(
|
||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||
|
||||
|
||||
def ceil_div(a, b):
|
||||
return (a + b - 1) // b
|
||||
|
||||
|
||||
@triton.jit
|
||||
def moe_align_block_size_stage1(
|
||||
topk_ids_ptr,
|
||||
tokens_cnts_ptr,
|
||||
num_experts: tl.constexpr,
|
||||
numel: tl.constexpr,
|
||||
tokens_per_thread: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
|
||||
start_idx = pid * tokens_per_thread
|
||||
|
||||
off_c = (pid + 1) * num_experts
|
||||
|
||||
for i in range(tokens_per_thread):
|
||||
if start_idx + i < numel:
|
||||
idx = tl.load(topk_ids_ptr + start_idx + i)
|
||||
token_cnt = tl.load(tokens_cnts_ptr + off_c + idx)
|
||||
tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def moe_align_block_size_stage2(
|
||||
tokens_cnts_ptr,
|
||||
num_experts: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
|
||||
last_cnt = 0
|
||||
for i in range(1, num_experts + 1):
|
||||
token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid)
|
||||
last_cnt = last_cnt + token_cnt
|
||||
tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def moe_align_block_size_stage3(
|
||||
total_tokens_post_pad_ptr,
|
||||
tokens_cnts_ptr,
|
||||
cumsum_ptr,
|
||||
num_experts: tl.constexpr,
|
||||
block_size: tl.constexpr,
|
||||
):
|
||||
last_cumsum = 0
|
||||
off_cnt = num_experts * num_experts
|
||||
for i in range(1, num_experts + 1):
|
||||
token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1)
|
||||
last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size
|
||||
tl.store(cumsum_ptr + i, last_cumsum)
|
||||
tl.store(total_tokens_post_pad_ptr, last_cumsum)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def moe_align_block_size_stage4(
|
||||
topk_ids_ptr,
|
||||
sorted_token_ids_ptr,
|
||||
expert_ids_ptr,
|
||||
tokens_cnts_ptr,
|
||||
cumsum_ptr,
|
||||
num_experts: tl.constexpr,
|
||||
block_size: tl.constexpr,
|
||||
numel: tl.constexpr,
|
||||
tokens_per_thread: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
start_idx = tl.load(cumsum_ptr + pid)
|
||||
end_idx = tl.load(cumsum_ptr + pid + 1)
|
||||
|
||||
for i in range(start_idx, end_idx, block_size):
|
||||
tl.store(expert_ids_ptr + i // block_size, pid)
|
||||
|
||||
start_idx = pid * tokens_per_thread
|
||||
off_t = pid * num_experts
|
||||
|
||||
for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread,
|
||||
numel)):
|
||||
expert_id = tl.load(topk_ids_ptr + i)
|
||||
token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id)
|
||||
rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id)
|
||||
tl.store(sorted_token_ids_ptr + rank_post_pad, i)
|
||||
tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1)
|
||||
|
||||
|
||||
# Triton implementation based on:
|
||||
# https://github.com/sgl-project/sglang/commit/ba5112ff691d791a9e38c6c71f59324a5fcb49d0
|
||||
def moe_align_block_size_triton(
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
block_size: int,
|
||||
sorted_token_ids: torch.Tensor,
|
||||
expert_ids: torch.Tensor,
|
||||
num_tokens_post_pad: torch.Tensor,
|
||||
) -> None:
|
||||
numel = topk_ids.numel()
|
||||
grid = (num_experts, )
|
||||
tokens_cnts = torch.zeros((num_experts + 1, num_experts),
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device)
|
||||
cumsum = torch.zeros((num_experts + 1, ),
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device)
|
||||
tokens_per_thread = ceil_div(numel, num_experts)
|
||||
|
||||
moe_align_block_size_stage1[grid](
|
||||
topk_ids,
|
||||
tokens_cnts,
|
||||
num_experts,
|
||||
numel,
|
||||
tokens_per_thread,
|
||||
)
|
||||
moe_align_block_size_stage2[grid](
|
||||
tokens_cnts,
|
||||
num_experts,
|
||||
)
|
||||
moe_align_block_size_stage3[(1, )](
|
||||
num_tokens_post_pad,
|
||||
tokens_cnts,
|
||||
cumsum,
|
||||
num_experts,
|
||||
block_size,
|
||||
)
|
||||
moe_align_block_size_stage4[grid](
|
||||
topk_ids,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
tokens_cnts,
|
||||
cumsum,
|
||||
num_experts,
|
||||
block_size,
|
||||
numel,
|
||||
tokens_per_thread,
|
||||
)
|
||||
|
||||
|
||||
def moe_align_block_size(
|
||||
topk_ids: torch.Tensor, block_size: int,
|
||||
num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
@ -457,6 +595,26 @@ def moe_align_block_size(
|
||||
num_tokens_post_pad = torch.empty((1),
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device)
|
||||
if num_experts >= 224:
|
||||
if envs.VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON:
|
||||
moe_align_block_size_triton(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_pad,
|
||||
)
|
||||
else:
|
||||
ops.sgl_moe_align_block_size(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_pad,
|
||||
)
|
||||
else:
|
||||
ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
|
||||
expert_ids, num_tokens_post_pad)
|
||||
return sorted_ids, expert_ids, num_tokens_post_pad
|
||||
|
Loading…
x
Reference in New Issue
Block a user