From e59ca942f576c78fd457f16f2029bda716c81959 Mon Sep 17 00:00:00 2001 From: bnellnm <49004751+bnellnm@users.noreply.github.com> Date: Tue, 1 Apr 2025 12:07:43 -0400 Subject: [PATCH] Add option to use DeepGemm contiguous grouped gemm kernel for fused MoE operations. (#13932) Signed-off-by: Bill Nell --- benchmarks/kernels/benchmark_moe.py | 97 ++-- tests/kernels/test_block_fp8.py | 279 ++++++++++- vllm/_custom_ops.py | 2 +- vllm/envs.py | 5 + .../layers/fused_moe/fused_moe.py | 468 ++++++++++++++++-- .../model_executor/layers/quantization/fp8.py | 36 ++ 6 files changed, 773 insertions(+), 114 deletions(-) diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index 491f8c39..f1803b39 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -30,19 +30,18 @@ class BenchmarkConfig(TypedDict): num_stages: int -def benchmark_config( - config: BenchmarkConfig, - num_tokens: int, - num_experts: int, - shard_intermediate_size: int, - hidden_size: int, - topk: int, - dtype: torch.dtype, - use_fp8_w8a8: bool, - use_int8_w8a16: bool, - num_iters: int = 100, - block_quant_shape: List[int] = None, -) -> float: +def benchmark_config(config: BenchmarkConfig, + num_tokens: int, + num_experts: int, + shard_intermediate_size: int, + hidden_size: int, + topk: int, + dtype: torch.dtype, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + num_iters: int = 100, + block_quant_shape: List[int] = None, + use_deep_gemm: bool = False) -> float: init_dtype = torch.float16 if use_fp8_w8a8 else dtype x = torch.randn(num_tokens, hidden_size, dtype=dtype) if use_int8_w8a16: @@ -115,22 +114,41 @@ def benchmark_config( def run(): from vllm.model_executor.layers.fused_moe import override_config with override_config(config): - fused_moe( - x, - w1, - w2, - input_gating, - topk, - renormalize=True, - inplace=True, - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a16=use_int8_w8a16, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a1_scale, - a2_scale=a2_scale, - block_shape=block_quant_shape, - ) + if use_deep_gemm: + topk_weights, topk_ids = fused_topk(x, input_gating, topk, + False) + return fused_experts( + x, + w1, + w2, + topk_weights, + topk_ids, + inplace=True, + use_fp8_w8a8=use_fp8_w8a8, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + block_shape=block_quant_shape, + allow_deep_gemm=True, + ) + else: + fused_moe( + x, + w1, + w2, + input_gating, + topk, + renormalize=True, + inplace=True, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + block_shape=block_quant_shape, + ) # JIT compilation & warmup run() @@ -366,6 +384,7 @@ class BenchmarkWorker: use_fp8_w8a8: bool, use_int8_w8a16: bool, block_quant_shape: List[int] = None, + use_deep_gemm: bool = False, ) -> tuple[dict[str, int], float]: current_platform.seed_everything(self.seed) dtype_str = get_config_dtype_str(dtype, @@ -396,7 +415,8 @@ class BenchmarkWorker: use_fp8_w8a8, use_int8_w8a16, num_iters=100, - block_quant_shape=block_quant_shape) + block_quant_shape=block_quant_shape, + use_deep_gemm=use_deep_gemm) return config, kernel_time def tune( @@ -411,6 +431,7 @@ class BenchmarkWorker: use_int8_w8a16: bool, search_space: list[dict[str, int]], block_quant_shape: list[int], + use_deep_gemm: bool, ) -> dict[str, int]: best_config = None best_time = float("inf") @@ -436,7 +457,8 @@ class BenchmarkWorker: use_fp8_w8a8, use_int8_w8a16, num_iters=20, - block_quant_shape=block_quant_shape) + block_quant_shape=block_quant_shape, + use_deep_gemm=use_deep_gemm) except triton.runtime.autotuner.OutOfResources: # Some configurations may be invalid and fail to compile. continue @@ -550,6 +572,8 @@ def main(args: argparse.Namespace): else: batch_sizes = [args.batch_size] + use_deep_gemm = bool(args.use_deep_gemm) + ray.init() num_gpus = int(ray.available_resources()["GPU"]) workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)] @@ -572,10 +596,10 @@ def main(args: argparse.Namespace): start = time.time() configs = _distribute( - "tune", - [(batch_size, E, shard_intermediate_size, hidden_size, topk, dtype, - use_fp8_w8a8, use_int8_w8a16, search_space, block_quant_shape) - for batch_size in batch_sizes]) + "tune", [(batch_size, E, shard_intermediate_size, hidden_size, + topk, dtype, use_fp8_w8a8, use_int8_w8a16, search_space, + block_quant_shape, use_deep_gemm) + for batch_size in batch_sizes]) best_configs = { M: sort_config(config) for M, config in zip(batch_sizes, configs) @@ -589,7 +613,7 @@ def main(args: argparse.Namespace): outputs = _distribute( "benchmark", [(batch_size, E, shard_intermediate_size, hidden_size, topk, dtype, - use_fp8_w8a8, use_int8_w8a16, block_quant_shape) + use_fp8_w8a8, use_int8_w8a16, block_quant_shape, use_deep_gemm) for batch_size in batch_sizes]) for batch_size, (config, kernel_time) in zip(batch_sizes, outputs): @@ -611,6 +635,7 @@ if __name__ == "__main__": type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto") + parser.add_argument("--use-deep-gemm", action="store_true") parser.add_argument("--seed", type=int, default=0) parser.add_argument("--batch-size", type=int, required=False) parser.add_argument("--tune", action="store_true") diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 6206cbd5..fda981f4 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -6,12 +6,22 @@ import itertools import pytest import torch +from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.fused_moe.fused_moe import ( + deep_gemm_moe_fp8, fused_topk, moe_align_block_size) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8, w8a8_block_fp8_matmul) from vllm.platforms import current_platform +dg_available = False +try: + import deep_gemm + dg_available = True +except ImportError: + pass + if current_platform.get_device_capability() < (9, 0): pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True) @@ -21,17 +31,18 @@ DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32] NUM_TOKENS = [7, 83, 2048] D = [512, 4096, 5120, 13824] GROUP_SIZE = [64, 128, 256, 512] -M = [1, 7, 83, 512, 2048] -N = [128, 512, 1024, 4096, 7748, 13824] -K = [256, 4096, 5120, 3884, 13824] +M = [1, 7, 8, 83, 84, 512, 2048, 4096] +N = [128, 512, 1024, 4096, 7168, 7748, 13824] +K = [256, 4096, 5120, 3884, 13824, 16384] # Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8 # and its hidden size is 7168. -M_moe = [1, 7, 83, 512, 2048] -N_moe = [4608] # [128, 4608, 13824] -K_moe = [7168] # [256, 7168, 13824] +M_moe = [1, 2, 7, 83, 128, 512, 2048] +M_moe_dg = [128, 192, 512, 1335, 2048] +N_moe = [128, 256, 1024, 4608] # [13824] +K_moe = [256, 512, 7168] # [13824] BLOCK_SIZE = [[128, 128]] -E = [8, 24] # [8, 24, 128, 256] -TOP_KS = [2] # [1, 2, 6] +E = [2, 8, 16, 24] # [128, 256] +TOP_KS = [1, 2, 6] OUT_DTYPES = [torch.bfloat16] # [torch.float32, torch.half, torch.bfloat16] SEEDS = [0] @@ -217,11 +228,16 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): SEEDS)) @torch.inference_mode() def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): + if topk > E: + pytest.skip(f"Skipping test; topk={topk} > E={E}") + torch.manual_seed(seed) factor_for_scale = 1e-2 fp8_info = torch.finfo(torch.float8_e4m3fn) fp8_max, fp8_min = fp8_info.max, fp8_info.min + vllm_config = VllmConfig() + a = torch.randn((M, K), dtype=dtype) / 10 w1_bf16 = (torch.rand( @@ -246,25 +262,240 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): score = torch.randn((M, E), dtype=dtype) - out = fused_moe( - a, - w1, - w2, - score, - topk, - renormalize=False, - use_fp8_w8a8=True, - w1_scale=w1_s, - w2_scale=w2_s, - block_shape=block_size, - ) - ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, - block_size) + # Set the context to avoid lots of warning spam. + with set_current_vllm_config(vllm_config): + out = fused_moe( + a, + w1, + w2, + score, + topk, + renormalize=False, + use_fp8_w8a8=True, + w1_scale=w1_s, + w2_scale=w2_s, + block_shape=block_size, + ) + ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, + block_size) - print(f"{out.sum()=}") - print(f"{ref_out.sum()=}") + #print(f"{out.sum()=}") + #print(f"{ref_out.sum()=}") rel_diff = (torch.mean( torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / torch.mean(torch.abs(ref_out.to(torch.float32)))) assert rel_diff < 0.03 + + +def per_block_cast_to_fp8( + x: torch.Tensor, + block_size_n: int = 128) -> tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + m, n = x.shape + x_padded = torch.zeros( + (deep_gemm.ceil_div(m, 128) * 128, + deep_gemm.ceil_div(n, block_size_n) * block_size_n), + dtype=x.dtype, + device=x.device) + x_padded[:m, :n] = x + x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n) + x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) + x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) + x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous() + scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2)) + return x_scaled_sub, scales + + +@pytest.mark.parametrize( + "M,N,K,block_size,out_dtype,seed", + itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS)) +@torch.inference_mode() +def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): + # only aligned sizes + if M % 4 != 0 or K % 128 != 0 or N % 64 != 0: + pytest.skip(f"Skipping test; invalid size {M}, {N}, {K}") + + torch.manual_seed(seed) + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_max = fp8_info.max + + A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max + B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max + + _, block_k = block_size[0], block_size[1] + + A_fp8, As_fp8 = per_token_group_quant_fp8(A_fp32, block_k) + B_fp8, Bs_fp8 = per_block_cast_to_fp8(B_fp32) + + As = As_fp8.to(torch.float32) + Bs = Bs_fp8.to(torch.float32) + + ref_out = native_w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, + out_dtype) + + # Transpose earlier so that the testing will not trigger transposing kernels + As_fp8 = deep_gemm.get_col_major_tma_aligned_tensor(As_fp8) + + out = torch.zeros((M, N), device='cuda', dtype=out_dtype) + + assert As_fp8.shape == (M, (K + 127) // + 128), f"{As_fp8.shape} != {(M, (K + 127) // 128)}" + + deep_gemm.gemm_fp8_fp8_bf16_nt((A_fp8, As_fp8), (B_fp8, Bs_fp8), out) + + rel_diff = (torch.mean( + torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / + torch.mean(torch.abs(ref_out.to(torch.float32)))) + assert rel_diff < 0.001 + + +def fp8_perm(m, idx): + if torch.is_floating_point(m) and torch.finfo(m.dtype).bits == 8: + return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype) + else: + return m[idx, ...] + + +def test_moe_permute(a, a_s, topk_ids, num_groups, topk, block_m): + M, K = a.shape + + sorted_token_ids, m_indices, num_pad = moe_align_block_size( + topk_ids, block_m, num_groups, None, pad_sorted_ids=True) + + num_tokens = topk * M + + sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1) + m_indices = torch.repeat_interleave(m_indices, block_m, dim=0) + inv_perm = torch.argsort(sorted_token_ids)[:M * topk] + + a = fp8_perm(a, sorted_token_ids // topk) + if a_s is not None: + a_s = a_s[sorted_token_ids // topk] + + return a, a_s, m_indices, inv_perm + + +def test_moe_unpermute(out, inv_perm, topk, K, topk_weight): + M = topk_weight.shape[0] + out = out[inv_perm, ...] + tmp_out = out.view(-1, topk, K) + return (tmp_out * topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) + + +def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, + block_shape): + """Fused moe with block-wise quantization using DeepGemm grouped gemm.""" + num_groups = w1.shape[0] + M, K = a.shape + N = w2.shape[-1] + + topk_weight, topk_ids = fused_topk(a, score.float(), topk, False) + + block_m = deep_gemm.get_m_alignment_for_contiguous_layout() + + _, block_k = block_shape[0], block_shape[1] + + a_q, a_s = per_token_group_quant_fp8(a, block_m) + + a_q, a_s, m_indices, inv_perm = test_moe_permute(a_q, a_s, topk_ids, + num_groups, topk, block_m) + + inter_out = torch.zeros((a_q.shape[0], N * 2), + dtype=torch.bfloat16, + device=a.device) + + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous((a_q, a_s), (w1, w1_s), + inter_out, m_indices) + + act_out = SiluAndMul().forward_native(inter_out) + act_out_q, act_out_s = per_token_group_quant_fp8(act_out, block_k) + + out = torch.zeros(a_q.shape[0], K, dtype=torch.bfloat16, device=a.device) + + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( + (act_out_q, act_out_s), (w2, w2_s), out, m_indices) + + final_out = test_moe_unpermute(out, inv_perm, topk, K, topk_weight) + + return final_out + + +@pytest.mark.parametrize( + "M,N,K,E,topk,seed", + itertools.product(M_moe_dg, N_moe, K_moe, E, TOP_KS, SEEDS)) +@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.") +@torch.inference_mode() +def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed): + + block_m = deep_gemm.get_m_alignment_for_contiguous_layout() + block_size = [block_m, block_m] + dtype = torch.bfloat16 + + # only aligned sizes + if (N % block_m != 0 or K % block_m != 0 or topk > E): + pytest.skip( + f"Skipping test; bad size m={M}, n={N}, k={K}, topk={topk}, E={E}") + + if (N <= 512): + pytest.skip("Skipping N <= 512 until performance issues solved.") + + vllm_config = VllmConfig() + + torch.manual_seed(seed) + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_max, fp8_min = fp8_info.max, fp8_info.min + + a = torch.randn((M, K), dtype=dtype) / 10 + + w1_bf16 = ((torch.rand((E, 2 * N, K), dtype=torch.bfloat16) - 0.5) * 2 * + fp8_max).clamp(min=fp8_min, max=fp8_max) + + w2_bf16 = ((torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 * + fp8_max).clamp(min=fp8_min, max=fp8_max) + + score = torch.randn((M, E), dtype=dtype) + + block_n, block_k = block_size[0], block_size[1] + n_tiles_w1 = ((2 * N) + block_n - 1) // block_n + k_tiles_w1 = (K + block_k - 1) // block_k + n_tiles_w2 = (K + block_n - 1) // block_n + k_tiles_w2 = (N + block_k - 1) // block_k + + w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn) + w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn) + + w1_s = torch.empty((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) + w2_s = torch.empty((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) + + w1_s = deep_gemm.get_col_major_tma_aligned_tensor(w1_s).contiguous() + w2_s = deep_gemm.get_col_major_tma_aligned_tensor(w2_s).contiguous() + + assert w1_s.shape == (E, (2 * N + 127) // 128, (K + 127) // 128) + assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2] + + for i in range(E): + w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i]) + w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i]) + + # Set the context to avoid lots of warning spam. + with set_current_vllm_config(vllm_config): + if M >= 128: + ref_out = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, + score, topk, block_size) + else: + ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, + topk, block_size) + + topk_weights, topk_ids = fused_topk(a, score.float(), topk, False) + + out = deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids) + + #print(f"{out.sum()=}") + #print(f"{ref_out.sum()=}") + + rel_diff = (torch.mean( + torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / + torch.mean(torch.abs(ref_out.to(torch.float32)))) + + assert rel_diff < 0.03 diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 2aa99ca2..039397f5 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1224,7 +1224,7 @@ def moe_wna16_gemm(input: torch.Tensor, output: torch.Tensor, def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor, token_expert_indicies: torch.Tensor, - gating_output: float) -> None: + gating_output: torch.Tensor) -> None: torch.ops._moe_C.topk_softmax(topk_weights, topk_ids, token_expert_indicies, gating_output) diff --git a/vllm/envs.py b/vllm/envs.py index b34c2df8..6067f5bd 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -105,6 +105,7 @@ if TYPE_CHECKING: VLLM_V0_USE_OUTLINES_CACHE: bool = False VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION: bool = False VLLM_TPU_BUCKET_PADDING_GAP: int = 0 + VLLM_USE_DEEP_GEMM: bool = False def get_default_cache_root(): @@ -687,6 +688,10 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_TPU_BUCKET_PADDING_GAP": lambda: int(os.environ["VLLM_TPU_BUCKET_PADDING_GAP"]) if "VLLM_TPU_BUCKET_PADDING_GAP" in os.environ else 0, + + # Allow use of DeepGemm kernels for fused moe ops. + "VLLM_USE_DEEP_GEMM": + lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))), } # end-env-vars-definition diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 70d0037d..977447e0 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1,8 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 """Fused MoE kernel.""" import functools +import importlib.util import json import os +from math import prod from typing import Any, Callable, Dict, List, Optional, Tuple import torch @@ -15,7 +17,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8) from vllm.platforms import current_platform -from vllm.utils import direct_register_custom_op +from vllm.utils import direct_register_custom_op, round_up from .rocm_aiter_fused_moe import (is_rocm_aiter_moe_enabled, rocm_aiter_fused_experts, @@ -23,6 +25,8 @@ from .rocm_aiter_fused_moe import (is_rocm_aiter_moe_enabled, logger = init_logger(__name__) +has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None + @triton.jit def write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, offs_token, @@ -581,7 +585,8 @@ def moe_align_block_size( topk_ids: torch.Tensor, block_size: int, num_experts: int, - expert_map: torch.Tensor = None + expert_map: Optional[torch.Tensor] = None, + pad_sorted_ids: bool = False ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Aligns the token distribution across experts to be compatible with block @@ -596,6 +601,8 @@ def moe_align_block_size( from the global space to the local index space of the current expert parallel shard. If the expert is not in the current expert parallel shard, the mapping is set to -1. + - pad_sorted_ids: A flag indicating whether the sorted_token_ids length + should be padded to a multiple of block_size, Returns: - sorted_token_ids: A tensor containing the sorted token indices according @@ -625,6 +632,8 @@ def moe_align_block_size( by block_size for proper block matrix operations. """ max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) + if pad_sorted_ids: + max_num_tokens_padded = round_up(max_num_tokens_padded, block_size) sorted_ids = torch.empty((max_num_tokens_padded, ), dtype=torch.int32, device=topk_ids.device) @@ -667,6 +676,59 @@ def moe_align_block_size( return sorted_ids, expert_ids, num_tokens_post_pad +def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor, + w2: torch.Tensor, + expert_map: Optional[torch.Tensor]) -> bool: + """ + Check if the given problem size is supported by the DeepGemm grouped + gemm kernel. All of M, N, K and the quantization block_shape must be + aligned by `dg.get_m_alignment_for_contiguous_layout()`. + """ + if not has_deep_gemm: + return False + + # Lazy import to avoid CUDA initialization problems. + import deep_gemm as dg + + # Expert maps not supported yet. + if expert_map is not None: + return False + + align = dg.get_m_alignment_for_contiguous_layout() + M = hidden_states.shape[0] + _, K, N = w2.shape + + # For now, disable DeepGemm for small N until better permute/unpermute + # ops are available. + if N <= 512: + return False + + if align > M or N % align != 0 or K % align != 0: + return False + + return (hidden_states.is_contiguous() and w1.is_contiguous() + and w2.is_contiguous()) + + +def _fp8_quantize( + A: torch.Tensor, + A_scale: Optional[torch.Tensor], + block_shape: Optional[List[int]], +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Perform fp8 quantization on the inputs. If a block_shape + is provided, the output will be blocked. + """ + if block_shape is None: + A, A_scale = ops.scaled_fp8_quant(A, A_scale) + else: + assert len(block_shape) == 2 + _, block_k = block_shape[0], block_shape[1] + A, A_scale = per_token_group_quant_fp8(A, block_k) + assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1] + return A, A_scale + + def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, @@ -691,15 +753,11 @@ def invoke_fused_moe_kernel(A: torch.Tensor, if use_fp8_w8a8: assert B_scale is not None - if block_shape is None: - A, A_scale = ops.scaled_fp8_quant(A, A_scale) - else: - assert len(block_shape) == 2 - block_n, block_k = block_shape[0], block_shape[1] - A, A_scale = per_token_group_quant_fp8(A, block_k) - assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1] - assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2] - assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1] + assert (block_shape is None or triton.cdiv(B.shape[-2], block_shape[0]) + == B_scale.shape[-2]) + assert (block_shape is None or triton.cdiv(B.shape[-1], block_shape[1]) + == B_scale.shape[-1]) + elif use_int8_w8a16 or use_int4_w4a16: assert B_scale is not None assert block_shape is None or block_shape[0] == 0 @@ -1066,7 +1124,7 @@ def fused_topk( gating_output: torch.Tensor, topk: int, renormalize: bool, -): +) -> Tuple[torch.Tensor, torch.Tensor]: assert hidden_states.shape[0] == gating_output.shape[0], ( "Number of tokens mismatch") @@ -1098,14 +1156,16 @@ def fused_topk( # This is used by the Deepseek-V2 and Deepseek-V3 model @torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) -def grouped_topk(hidden_states: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool, - num_expert_group: int = 0, - topk_group: int = 0, - scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None): +def grouped_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int = 0, + topk_group: int = 0, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None +) -> Tuple[torch.Tensor, torch.Tensor]: assert hidden_states.shape[0] == gating_output.shape[0], ( "Number of tokens mismatch") @@ -1154,10 +1214,11 @@ def grouped_topk(hidden_states: torch.Tensor, return topk_weights.to(torch.float32), topk_ids.to(torch.int32) -def get_config_dtype_str(dtype: torch.dtype, - use_int4_w4a16: Optional[bool] = False, - use_int8_w8a16: Optional[bool] = False, - use_fp8_w8a8: Optional[bool] = False): +def get_config_dtype_str( + dtype: torch.dtype, + use_int4_w4a16: Optional[bool] = False, + use_int8_w8a16: Optional[bool] = False, + use_fp8_w8a8: Optional[bool] = False) -> Optional[str]: if use_fp8_w8a8: return "fp8_w8a8" elif use_int8_w8a16: @@ -1318,26 +1379,123 @@ def fused_experts(hidden_states: torch.Tensor, w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[List[int]] = None) -> torch.Tensor: - return dispatch_fused_experts_func(inplace)( - hidden_states=hidden_states, - w1=w1, - w2=w2, - topk_weights=topk_weights, - topk_ids=topk_ids, - activation=activation, - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - global_num_experts=global_num_experts, - expert_map=expert_map, - w1_scale=w1_scale, - w2_scale=w2_scale, - w1_zp=w1_zp, - w2_zp=w2_zp, - a1_scale=a1_scale, - a2_scale=a2_scale, - block_shape=block_shape) + block_shape: Optional[List[int]] = None, + allow_deep_gemm: bool = False) -> torch.Tensor: + if (allow_deep_gemm and use_fp8_w8a8 + and _valid_deep_gemm(hidden_states, w1, w2, expert_map)): + return deep_gemm_moe_fp8( + hidden_states=hidden_states, + w1=w1, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + ) + else: + return dispatch_fused_experts_func(inplace)( + hidden_states=hidden_states, + w1=w1, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=activation, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_zp=w1_zp, + w2_zp=w2_zp, + a1_scale=a1_scale, + a2_scale=a2_scale, + block_shape=block_shape) + + +def _fp8_perm(m: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: + """ + A permutation routine that works on fp8 types. + """ + if torch.is_floating_point(m) and torch.finfo(m.dtype).bits == 8: + return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype) + else: + return m[idx, ...] + + +def _moe_permute( + curr_hidden_states: torch.Tensor, + a1q_scale: Optional[torch.Tensor], + curr_topk_ids: torch.Tensor, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + top_k_num: int, + block_m: int, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor, + torch.Tensor]: + """ + Determine the sorted_token_ids, expert_ids for the given problem size. + Permute the hidden states and scales according to `sorted_token_ids`. + """ + tokens_in_chunk, _ = curr_hidden_states.shape + + sorted_token_ids, expert_ids, num_tokens_post_padded = ( + moe_align_block_size(curr_topk_ids, + block_m, + global_num_experts, + expert_map, + pad_sorted_ids=True)) + + inv_perm: Optional[torch.Tensor] = None + + num_tokens = top_k_num * tokens_in_chunk + sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1) + expert_ids = torch.repeat_interleave(expert_ids, block_m, dim=0) + inv_perm = torch.argsort(sorted_token_ids)[:num_tokens] + + # Permute according to sorted token ids. + curr_hidden_states = _fp8_perm(curr_hidden_states, + sorted_token_ids // top_k_num) + + if a1q_scale is not None: + a1q_scale = a1q_scale[sorted_token_ids // top_k_num] + + return (curr_hidden_states, a1q_scale, sorted_token_ids, expert_ids, + inv_perm) + + +def _moe_unpermute_and_reduce( + out: torch.Tensor, + curr_hidden: torch.Tensor, + inv_perm: Optional[torch.Tensor], + topk: int, + K: int, + topk_weight: torch.Tensor, +) -> None: + """ + Unpermute the final result and apply topk_weights, then perform the final + reduction on the hidden states. + """ + M = topk_weight.shape[0] + curr_hidden = curr_hidden[inv_perm, ...] + curr_hidden = curr_hidden.view(-1, topk, K) + curr_hidden.mul_(topk_weight.view(M, -1, 1)) + ops.moe_sum(curr_hidden, out) + + +def _resize_cache(x: torch.Tensor, v: Tuple[int, ...]) -> torch.Tensor: + """ + Shrink the given tensor and apply the given view to it. This is + used to resize the intermediate fused_moe caches. + """ + assert prod(v) <= x.numel() + return x.flatten()[:prod(v)].view(*v) def fused_experts_impl(hidden_states: torch.Tensor, @@ -1376,6 +1534,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, num_tokens, _ = hidden_states.shape E, N, _ = w1.shape + K = w2.shape[1] if global_num_experts == -1: global_num_experts = E top_k_num = topk_ids.shape[1] @@ -1401,13 +1560,11 @@ def fused_experts_impl(hidden_states: torch.Tensor, # We can reuse the memory between these because by the time we need # cache3, we're done with cache1 - cache13 = torch.empty(M * top_k_num * max(N, w2.shape[1]), + cache13 = torch.empty(M * top_k_num * max(N, K), device=hidden_states.device, dtype=hidden_states.dtype) - intermediate_cache1 = cache13[:M * top_k_num * N].view( - (M, topk_ids.shape[1], N)) - intermediate_cache3 = cache13[:M * top_k_num * w2.shape[1]].view( - (M, topk_ids.shape[1], w2.shape[1])) + intermediate_cache1 = cache13[:M * top_k_num * N].view(M, top_k_num, N) + intermediate_cache3 = cache13[:M * top_k_num * K].view(M, top_k_num, K) # This needs separate memory since it's used concurrently with cache1 intermediate_cache2 = torch.empty((M * top_k_num, N // 2), @@ -1452,14 +1609,23 @@ def fused_experts_impl(hidden_states: torch.Tensor, curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] + a1q_scale: Optional[torch.Tensor] = None + + if use_fp8_w8a8: + qcurr_hidden_states, a1q_scale = _fp8_quantize( + curr_hidden_states, a1_scale, block_shape) + else: + qcurr_hidden_states = curr_hidden_states + a1q_scale = a1_scale + sorted_token_ids, expert_ids, num_tokens_post_padded = ( moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], global_num_experts, expert_map)) - invoke_fused_moe_kernel(curr_hidden_states, + invoke_fused_moe_kernel(qcurr_hidden_states, w1, intermediate_cache1, - a1_scale, + a1q_scale, w1_scale, w1_zp, curr_topk_weights, @@ -1485,10 +1651,19 @@ def fused_experts_impl(hidden_states: torch.Tensor, else: raise ValueError(f"Unsupported FusedMoe activation: {activation}") - invoke_fused_moe_kernel(intermediate_cache2, + a2q_scale: Optional[torch.Tensor] = None + + if use_fp8_w8a8: + qintermediate_cache2, a2q_scale = _fp8_quantize( + intermediate_cache2, a2_scale, block_shape) + else: + qintermediate_cache2 = intermediate_cache2 + a2q_scale = a2_scale + + invoke_fused_moe_kernel(qintermediate_cache2, w2, intermediate_cache3, - a2_scale, + a2q_scale, w2_scale, w2_zp, curr_topk_weights, @@ -1617,6 +1792,193 @@ def fused_moe( block_shape=block_shape) +def deep_gemm_moe_fp8( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + activation: str = "silu", + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + This function computes a a8w8-quantized Mixture of Experts (MoE) layer + using two sets of quantized weights, w1_q and w2_q, and top-k gating + mechanism. The matrix multiplications are implemented with DeepGemm + grouped gemm. + + Parameters: + - hidden_states (torch.Tensor): The input tensor to the MoE layer. + Shape: [M, K] + - w1 (torch.Tensor): The first set of fp8 quantized expert weights. + Shape: [num_experts, K, 2N] (the weights are passed transposed) + - w2 (torch.Tensor): The second set of fp8 quantized expert weights. + Shape: [num_experts, N, K] (the weights are passed transposed) + - w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q. + Shape: [num_experts] or [num_experts, 2N] + - w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q. + Shape: [num_experts] or [num_experts, K] + - topk_weights (torch.Tensor): The weights of each token->expert mapping. + - topk_ids (torch.Tensor): The token->expert mapping for topk_weights. + - inplace (bool): If True, perform the operation in-place. + Defaults to False. + - activation (str): The activation function to apply after the first + MoE layer. + - global_num_experts (int): The total number of experts in the global + expert space. + - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices + from the global expert space to the local expert space of the expert + parallel shard. + - a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a. + Shape: scalar or [M] + - a2_scale (Optional[torch.Tensor]): The optional fp32 scale to + quantize the intermediate result between the gemms. + Shape: scalar or [M] + + Returns: + - torch.Tensor: The bfloat16 output tensor after applying the MoE layer. + """ + # Lazy import to avoid CUDA initialization problems. + import deep_gemm as dg + + assert expert_map is None, "Expert maps not supported yet" + + assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + + assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w1.stride(-1) == 1, "Stride of last dimension must be 1" + assert w2.stride(-1) == 1, "Stride of last dimension must be 1" + assert hidden_states.dtype in [ + torch.float32, torch.float16, torch.bfloat16 + ] + assert w1.dtype == torch.float8_e4m3fn + assert w2.dtype == torch.float8_e4m3fn + assert w1.shape[0] == w2.shape[0], "Expert number mismatch" + assert w1.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch" + assert w1.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch" + assert a1_scale is None or a1_scale.dim( + ) == 0 or a1_scale.shape[0] == 1 or a1_scale.shape[ + 0] == hidden_states.shape[0], "Input scale shape mismatch" + assert a2_scale is None or a1_scale is None or a2_scale.shape == a1_scale.shape, "Intermediate scale shape mismatch" # noqa: E501 + + num_tokens, _ = hidden_states.shape + E, N, _ = w1.shape + K = w2.shape[1] + if global_num_experts == -1: + global_num_experts = E + top_k_num = topk_ids.shape[1] + # We execute the fused_moe kernel in chunks to circumvent this issue: + # https://github.com/vllm-project/vllm/issues/5938 + CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE + + assert _valid_deep_gemm(hidden_states, w1, w2, expert_map) + + if inplace: + out_hidden_states = hidden_states + else: + out_hidden_states = torch.empty_like(hidden_states) + + block_m = dg.get_m_alignment_for_contiguous_layout() + block_shape = [block_m, block_m] + + assert w1_scale is not None + assert w2_scale is not None + + # We attempt to transpose and align offline in Fp8MoEMethod, in which + # case these calls will be nops. Otherwise, they'll be performed every + # time the layer is executed. + w1_scale = dg.get_col_major_tma_aligned_tensor(w1_scale).contiguous() + w2_scale = dg.get_col_major_tma_aligned_tensor(w2_scale).contiguous() + + M_sum = topk_ids.numel() + global_num_experts * (block_m - 1) + M_sum = round_up(M_sum, block_m) + + num_chunks = (num_tokens // CHUNK_SIZE) + 1 + + # We can reuse the memory between cache1 and cache3 because by the time + # we need cache3, we're done with cache1 + cache13 = torch.empty(M_sum * max(N, K), + device=hidden_states.device, + dtype=hidden_states.dtype) + + intermediate_cache1 = cache13[:M_sum * N].view(M_sum, N) + intermediate_cache2 = torch.empty((M_sum, N // 2), + device=hidden_states.device, + dtype=hidden_states.dtype) + intermediate_cache3 = cache13[:M_sum * K].view(M_sum, K) + + for chunk in range(num_chunks): + begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, + min((chunk + 1) * CHUNK_SIZE, + num_tokens)) + curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] + tokens_in_chunk, _ = curr_hidden_states.shape + + if tokens_in_chunk == 0: + break + + curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] + curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] + + a1q_scale: Optional[torch.Tensor] = None + + qcurr_hidden_states, a1q_scale = _fp8_quantize(curr_hidden_states, + a1_scale, block_shape) + + (qcurr_hidden_states, a1q_scale, sorted_token_ids, expert_ids, + inv_perm) = _moe_permute(qcurr_hidden_states, a1q_scale, + curr_topk_ids, global_num_experts, + expert_map, top_k_num, block_m) + + # Adjust the intermediate cache size and config for the last chunk. + # Note that in most cases we only have one chunk so the cache size + # and config are already set correctly and do not need to be adjusted. + if tokens_in_chunk < CHUNK_SIZE and chunk > 0: + curr_M = sorted_token_ids.numel() + intermediate_cache1 = _resize_cache(intermediate_cache1, + (curr_M, N)) + intermediate_cache2 = _resize_cache(intermediate_cache2, + (curr_M, N // 2)) + intermediate_cache3 = _resize_cache(intermediate_cache3, + (curr_M, K)) + + dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( + (qcurr_hidden_states, a1q_scale), (w1, w1_scale), + intermediate_cache1, expert_ids) + + if activation == "silu": + torch.ops._C.silu_and_mul(intermediate_cache2, + intermediate_cache1.view(-1, N)) + elif activation == "gelu": + torch.ops._C.gelu_and_mul(intermediate_cache2, + intermediate_cache1.view(-1, N)) + else: + raise ValueError(f"Unsupported FusedMoe activation: {activation}") + + a2q_scale: Optional[torch.Tensor] = None + + qintermediate_cache2, a2q_scale = _fp8_quantize( + intermediate_cache2, a2_scale, block_shape) + + dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( + (qintermediate_cache2, a2q_scale), (w2, w2_scale), + intermediate_cache3, expert_ids) + + _moe_unpermute_and_reduce( + out_hidden_states[begin_chunk_idx:end_chunk_idx], + intermediate_cache3.view(*intermediate_cache3.shape), inv_perm, + top_k_num, K, curr_topk_weights) + + return out_hidden_states + + #TODO make the grouped gemm kernel consistent with scaled gemm kernel def cutlass_moe_fp8( a: torch.Tensor, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 11bfdb41..e7c733db 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 +import importlib.util from typing import Any, Callable, Dict, List, Optional import torch @@ -37,6 +38,14 @@ ACTIVATION_SCHEMES = ["static", "dynamic"] logger = init_logger(__name__) +has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None + + +def _is_col_major(x: torch.Tensor) -> bool: + assert x.dim() == 3 + b, m, n = x.shape + return x.stride(0) == m * n and x.stride(1) == 1 and x.stride(2) == m + class Fp8Config(QuantizationConfig): """Config class for FP8.""" @@ -424,6 +433,19 @@ class Fp8MoEMethod(FusedMoEMethodBase): self.quant_config = quant_config self.block_quant = self.quant_config.weight_block_size is not None + # Check for DeepGemm support. + self.allow_deep_gemm = False + if envs.VLLM_USE_DEEP_GEMM: + if not has_deep_gemm: + logger.warning_once("Failed to import DeepGemm kernels.") + elif (current_platform.is_cuda() + and current_platform.has_device_capability(90)): + logger.info_once("Using DeepGemm kernels for Fp8MoEMethod.") + self.allow_deep_gemm = True + else: + logger.warning_once( + "DeepGemm not supported on the current platform.") + def create_weights(self, layer: Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): @@ -585,6 +607,19 @@ class Fp8MoEMethod(FusedMoEMethodBase): requires_grad=False) layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False) + + # DeepGemm scales need to be transposed and aligned. We try to do + # it ahead of time for performance reasons. + if self.allow_deep_gemm: + # Lazy import to avoid CUDA initialization problems. + import deep_gemm as dg + if _is_col_major(layer.w13_weight_scale_inv): + layer.w13_weight_scale_inv = \ + dg.get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv).contiguous() + if _is_col_major(layer.w2_weight_scale_inv): + layer.w2_weight_scale_inv = \ + dg.get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv).contiguous() + return # If checkpoint is fp16, quantize in place. @@ -773,6 +808,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): a1_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, block_shape=self.quant_config.weight_block_size, + allow_deep_gemm=self.allow_deep_gemm, )