Add option to use DeepGemm contiguous grouped gemm kernel for fused MoE operations. (#13932)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
parent
a57a3044aa
commit
e59ca942f5
@ -30,19 +30,18 @@ class BenchmarkConfig(TypedDict):
|
||||
num_stages: int
|
||||
|
||||
|
||||
def benchmark_config(
|
||||
config: BenchmarkConfig,
|
||||
num_tokens: int,
|
||||
num_experts: int,
|
||||
shard_intermediate_size: int,
|
||||
hidden_size: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
num_iters: int = 100,
|
||||
block_quant_shape: List[int] = None,
|
||||
) -> float:
|
||||
def benchmark_config(config: BenchmarkConfig,
|
||||
num_tokens: int,
|
||||
num_experts: int,
|
||||
shard_intermediate_size: int,
|
||||
hidden_size: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
num_iters: int = 100,
|
||||
block_quant_shape: List[int] = None,
|
||||
use_deep_gemm: bool = False) -> float:
|
||||
init_dtype = torch.float16 if use_fp8_w8a8 else dtype
|
||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||
if use_int8_w8a16:
|
||||
@ -115,22 +114,41 @@ def benchmark_config(
|
||||
def run():
|
||||
from vllm.model_executor.layers.fused_moe import override_config
|
||||
with override_config(config):
|
||||
fused_moe(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
input_gating,
|
||||
topk,
|
||||
renormalize=True,
|
||||
inplace=True,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_quant_shape,
|
||||
)
|
||||
if use_deep_gemm:
|
||||
topk_weights, topk_ids = fused_topk(x, input_gating, topk,
|
||||
False)
|
||||
return fused_experts(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=True,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_quant_shape,
|
||||
allow_deep_gemm=True,
|
||||
)
|
||||
else:
|
||||
fused_moe(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
input_gating,
|
||||
topk,
|
||||
renormalize=True,
|
||||
inplace=True,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_quant_shape,
|
||||
)
|
||||
|
||||
# JIT compilation & warmup
|
||||
run()
|
||||
@ -366,6 +384,7 @@ class BenchmarkWorker:
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
block_quant_shape: List[int] = None,
|
||||
use_deep_gemm: bool = False,
|
||||
) -> tuple[dict[str, int], float]:
|
||||
current_platform.seed_everything(self.seed)
|
||||
dtype_str = get_config_dtype_str(dtype,
|
||||
@ -396,7 +415,8 @@ class BenchmarkWorker:
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
num_iters=100,
|
||||
block_quant_shape=block_quant_shape)
|
||||
block_quant_shape=block_quant_shape,
|
||||
use_deep_gemm=use_deep_gemm)
|
||||
return config, kernel_time
|
||||
|
||||
def tune(
|
||||
@ -411,6 +431,7 @@ class BenchmarkWorker:
|
||||
use_int8_w8a16: bool,
|
||||
search_space: list[dict[str, int]],
|
||||
block_quant_shape: list[int],
|
||||
use_deep_gemm: bool,
|
||||
) -> dict[str, int]:
|
||||
best_config = None
|
||||
best_time = float("inf")
|
||||
@ -436,7 +457,8 @@ class BenchmarkWorker:
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
num_iters=20,
|
||||
block_quant_shape=block_quant_shape)
|
||||
block_quant_shape=block_quant_shape,
|
||||
use_deep_gemm=use_deep_gemm)
|
||||
except triton.runtime.autotuner.OutOfResources:
|
||||
# Some configurations may be invalid and fail to compile.
|
||||
continue
|
||||
@ -550,6 +572,8 @@ def main(args: argparse.Namespace):
|
||||
else:
|
||||
batch_sizes = [args.batch_size]
|
||||
|
||||
use_deep_gemm = bool(args.use_deep_gemm)
|
||||
|
||||
ray.init()
|
||||
num_gpus = int(ray.available_resources()["GPU"])
|
||||
workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)]
|
||||
@ -572,10 +596,10 @@ def main(args: argparse.Namespace):
|
||||
|
||||
start = time.time()
|
||||
configs = _distribute(
|
||||
"tune",
|
||||
[(batch_size, E, shard_intermediate_size, hidden_size, topk, dtype,
|
||||
use_fp8_w8a8, use_int8_w8a16, search_space, block_quant_shape)
|
||||
for batch_size in batch_sizes])
|
||||
"tune", [(batch_size, E, shard_intermediate_size, hidden_size,
|
||||
topk, dtype, use_fp8_w8a8, use_int8_w8a16, search_space,
|
||||
block_quant_shape, use_deep_gemm)
|
||||
for batch_size in batch_sizes])
|
||||
best_configs = {
|
||||
M: sort_config(config)
|
||||
for M, config in zip(batch_sizes, configs)
|
||||
@ -589,7 +613,7 @@ def main(args: argparse.Namespace):
|
||||
outputs = _distribute(
|
||||
"benchmark",
|
||||
[(batch_size, E, shard_intermediate_size, hidden_size, topk, dtype,
|
||||
use_fp8_w8a8, use_int8_w8a16, block_quant_shape)
|
||||
use_fp8_w8a8, use_int8_w8a16, block_quant_shape, use_deep_gemm)
|
||||
for batch_size in batch_sizes])
|
||||
|
||||
for batch_size, (config, kernel_time) in zip(batch_sizes, outputs):
|
||||
@ -611,6 +635,7 @@ if __name__ == "__main__":
|
||||
type=str,
|
||||
choices=["auto", "fp8_w8a8", "int8_w8a16"],
|
||||
default="auto")
|
||||
parser.add_argument("--use-deep-gemm", action="store_true")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--batch-size", type=int, required=False)
|
||||
parser.add_argument("--tune", action="store_true")
|
||||
|
@ -6,12 +6,22 @@ import itertools
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
deep_gemm_moe_fp8, fused_topk, moe_align_block_size)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8, w8a8_block_fp8_matmul)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
dg_available = False
|
||||
try:
|
||||
import deep_gemm
|
||||
dg_available = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if current_platform.get_device_capability() < (9, 0):
|
||||
pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
|
||||
allow_module_level=True)
|
||||
@ -21,17 +31,18 @@ DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32]
|
||||
NUM_TOKENS = [7, 83, 2048]
|
||||
D = [512, 4096, 5120, 13824]
|
||||
GROUP_SIZE = [64, 128, 256, 512]
|
||||
M = [1, 7, 83, 512, 2048]
|
||||
N = [128, 512, 1024, 4096, 7748, 13824]
|
||||
K = [256, 4096, 5120, 3884, 13824]
|
||||
M = [1, 7, 8, 83, 84, 512, 2048, 4096]
|
||||
N = [128, 512, 1024, 4096, 7168, 7748, 13824]
|
||||
K = [256, 4096, 5120, 3884, 13824, 16384]
|
||||
# Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8
|
||||
# and its hidden size is 7168.
|
||||
M_moe = [1, 7, 83, 512, 2048]
|
||||
N_moe = [4608] # [128, 4608, 13824]
|
||||
K_moe = [7168] # [256, 7168, 13824]
|
||||
M_moe = [1, 2, 7, 83, 128, 512, 2048]
|
||||
M_moe_dg = [128, 192, 512, 1335, 2048]
|
||||
N_moe = [128, 256, 1024, 4608] # [13824]
|
||||
K_moe = [256, 512, 7168] # [13824]
|
||||
BLOCK_SIZE = [[128, 128]]
|
||||
E = [8, 24] # [8, 24, 128, 256]
|
||||
TOP_KS = [2] # [1, 2, 6]
|
||||
E = [2, 8, 16, 24] # [128, 256]
|
||||
TOP_KS = [1, 2, 6]
|
||||
OUT_DTYPES = [torch.bfloat16] # [torch.float32, torch.half, torch.bfloat16]
|
||||
SEEDS = [0]
|
||||
|
||||
@ -217,11 +228,16 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
|
||||
SEEDS))
|
||||
@torch.inference_mode()
|
||||
def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
|
||||
if topk > E:
|
||||
pytest.skip(f"Skipping test; topk={topk} > E={E}")
|
||||
|
||||
torch.manual_seed(seed)
|
||||
factor_for_scale = 1e-2
|
||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
|
||||
a = torch.randn((M, K), dtype=dtype) / 10
|
||||
|
||||
w1_bf16 = (torch.rand(
|
||||
@ -246,25 +262,240 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
|
||||
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
|
||||
out = fused_moe(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
score,
|
||||
topk,
|
||||
renormalize=False,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
block_shape=block_size,
|
||||
)
|
||||
ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk,
|
||||
block_size)
|
||||
# Set the context to avoid lots of warning spam.
|
||||
with set_current_vllm_config(vllm_config):
|
||||
out = fused_moe(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
score,
|
||||
topk,
|
||||
renormalize=False,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
block_shape=block_size,
|
||||
)
|
||||
ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk,
|
||||
block_size)
|
||||
|
||||
print(f"{out.sum()=}")
|
||||
print(f"{ref_out.sum()=}")
|
||||
#print(f"{out.sum()=}")
|
||||
#print(f"{ref_out.sum()=}")
|
||||
|
||||
rel_diff = (torch.mean(
|
||||
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
|
||||
torch.mean(torch.abs(ref_out.to(torch.float32))))
|
||||
assert rel_diff < 0.03
|
||||
|
||||
|
||||
def per_block_cast_to_fp8(
|
||||
x: torch.Tensor,
|
||||
block_size_n: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 2
|
||||
m, n = x.shape
|
||||
x_padded = torch.zeros(
|
||||
(deep_gemm.ceil_div(m, 128) * 128,
|
||||
deep_gemm.ceil_div(n, block_size_n) * block_size_n),
|
||||
dtype=x.dtype,
|
||||
device=x.device)
|
||||
x_padded[:m, :n] = x
|
||||
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n)
|
||||
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
||||
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
|
||||
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
|
||||
scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
|
||||
return x_scaled_sub, scales
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"M,N,K,block_size,out_dtype,seed",
|
||||
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS))
|
||||
@torch.inference_mode()
|
||||
def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
|
||||
# only aligned sizes
|
||||
if M % 4 != 0 or K % 128 != 0 or N % 64 != 0:
|
||||
pytest.skip(f"Skipping test; invalid size {M}, {N}, {K}")
|
||||
|
||||
torch.manual_seed(seed)
|
||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max = fp8_info.max
|
||||
|
||||
A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
|
||||
B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
|
||||
|
||||
_, block_k = block_size[0], block_size[1]
|
||||
|
||||
A_fp8, As_fp8 = per_token_group_quant_fp8(A_fp32, block_k)
|
||||
B_fp8, Bs_fp8 = per_block_cast_to_fp8(B_fp32)
|
||||
|
||||
As = As_fp8.to(torch.float32)
|
||||
Bs = Bs_fp8.to(torch.float32)
|
||||
|
||||
ref_out = native_w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size,
|
||||
out_dtype)
|
||||
|
||||
# Transpose earlier so that the testing will not trigger transposing kernels
|
||||
As_fp8 = deep_gemm.get_col_major_tma_aligned_tensor(As_fp8)
|
||||
|
||||
out = torch.zeros((M, N), device='cuda', dtype=out_dtype)
|
||||
|
||||
assert As_fp8.shape == (M, (K + 127) //
|
||||
128), f"{As_fp8.shape} != {(M, (K + 127) // 128)}"
|
||||
|
||||
deep_gemm.gemm_fp8_fp8_bf16_nt((A_fp8, As_fp8), (B_fp8, Bs_fp8), out)
|
||||
|
||||
rel_diff = (torch.mean(
|
||||
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
|
||||
torch.mean(torch.abs(ref_out.to(torch.float32))))
|
||||
assert rel_diff < 0.001
|
||||
|
||||
|
||||
def fp8_perm(m, idx):
|
||||
if torch.is_floating_point(m) and torch.finfo(m.dtype).bits == 8:
|
||||
return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype)
|
||||
else:
|
||||
return m[idx, ...]
|
||||
|
||||
|
||||
def test_moe_permute(a, a_s, topk_ids, num_groups, topk, block_m):
|
||||
M, K = a.shape
|
||||
|
||||
sorted_token_ids, m_indices, num_pad = moe_align_block_size(
|
||||
topk_ids, block_m, num_groups, None, pad_sorted_ids=True)
|
||||
|
||||
num_tokens = topk * M
|
||||
|
||||
sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1)
|
||||
m_indices = torch.repeat_interleave(m_indices, block_m, dim=0)
|
||||
inv_perm = torch.argsort(sorted_token_ids)[:M * topk]
|
||||
|
||||
a = fp8_perm(a, sorted_token_ids // topk)
|
||||
if a_s is not None:
|
||||
a_s = a_s[sorted_token_ids // topk]
|
||||
|
||||
return a, a_s, m_indices, inv_perm
|
||||
|
||||
|
||||
def test_moe_unpermute(out, inv_perm, topk, K, topk_weight):
|
||||
M = topk_weight.shape[0]
|
||||
out = out[inv_perm, ...]
|
||||
tmp_out = out.view(-1, topk, K)
|
||||
return (tmp_out * topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1)
|
||||
|
||||
|
||||
def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk,
|
||||
block_shape):
|
||||
"""Fused moe with block-wise quantization using DeepGemm grouped gemm."""
|
||||
num_groups = w1.shape[0]
|
||||
M, K = a.shape
|
||||
N = w2.shape[-1]
|
||||
|
||||
topk_weight, topk_ids = fused_topk(a, score.float(), topk, False)
|
||||
|
||||
block_m = deep_gemm.get_m_alignment_for_contiguous_layout()
|
||||
|
||||
_, block_k = block_shape[0], block_shape[1]
|
||||
|
||||
a_q, a_s = per_token_group_quant_fp8(a, block_m)
|
||||
|
||||
a_q, a_s, m_indices, inv_perm = test_moe_permute(a_q, a_s, topk_ids,
|
||||
num_groups, topk, block_m)
|
||||
|
||||
inter_out = torch.zeros((a_q.shape[0], N * 2),
|
||||
dtype=torch.bfloat16,
|
||||
device=a.device)
|
||||
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous((a_q, a_s), (w1, w1_s),
|
||||
inter_out, m_indices)
|
||||
|
||||
act_out = SiluAndMul().forward_native(inter_out)
|
||||
act_out_q, act_out_s = per_token_group_quant_fp8(act_out, block_k)
|
||||
|
||||
out = torch.zeros(a_q.shape[0], K, dtype=torch.bfloat16, device=a.device)
|
||||
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||
(act_out_q, act_out_s), (w2, w2_s), out, m_indices)
|
||||
|
||||
final_out = test_moe_unpermute(out, inv_perm, topk, K, topk_weight)
|
||||
|
||||
return final_out
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"M,N,K,E,topk,seed",
|
||||
itertools.product(M_moe_dg, N_moe, K_moe, E, TOP_KS, SEEDS))
|
||||
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
|
||||
@torch.inference_mode()
|
||||
def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed):
|
||||
|
||||
block_m = deep_gemm.get_m_alignment_for_contiguous_layout()
|
||||
block_size = [block_m, block_m]
|
||||
dtype = torch.bfloat16
|
||||
|
||||
# only aligned sizes
|
||||
if (N % block_m != 0 or K % block_m != 0 or topk > E):
|
||||
pytest.skip(
|
||||
f"Skipping test; bad size m={M}, n={N}, k={K}, topk={topk}, E={E}")
|
||||
|
||||
if (N <= 512):
|
||||
pytest.skip("Skipping N <= 512 until performance issues solved.")
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
|
||||
torch.manual_seed(seed)
|
||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||
|
||||
a = torch.randn((M, K), dtype=dtype) / 10
|
||||
|
||||
w1_bf16 = ((torch.rand((E, 2 * N, K), dtype=torch.bfloat16) - 0.5) * 2 *
|
||||
fp8_max).clamp(min=fp8_min, max=fp8_max)
|
||||
|
||||
w2_bf16 = ((torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 *
|
||||
fp8_max).clamp(min=fp8_min, max=fp8_max)
|
||||
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
n_tiles_w1 = ((2 * N) + block_n - 1) // block_n
|
||||
k_tiles_w1 = (K + block_k - 1) // block_k
|
||||
n_tiles_w2 = (K + block_n - 1) // block_n
|
||||
k_tiles_w2 = (N + block_k - 1) // block_k
|
||||
|
||||
w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn)
|
||||
w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn)
|
||||
|
||||
w1_s = torch.empty((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32)
|
||||
w2_s = torch.empty((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32)
|
||||
|
||||
w1_s = deep_gemm.get_col_major_tma_aligned_tensor(w1_s).contiguous()
|
||||
w2_s = deep_gemm.get_col_major_tma_aligned_tensor(w2_s).contiguous()
|
||||
|
||||
assert w1_s.shape == (E, (2 * N + 127) // 128, (K + 127) // 128)
|
||||
assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2]
|
||||
|
||||
for i in range(E):
|
||||
w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i])
|
||||
w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i])
|
||||
|
||||
# Set the context to avoid lots of warning spam.
|
||||
with set_current_vllm_config(vllm_config):
|
||||
if M >= 128:
|
||||
ref_out = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s,
|
||||
score, topk, block_size)
|
||||
else:
|
||||
ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score,
|
||||
topk, block_size)
|
||||
|
||||
topk_weights, topk_ids = fused_topk(a, score.float(), topk, False)
|
||||
|
||||
out = deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids)
|
||||
|
||||
#print(f"{out.sum()=}")
|
||||
#print(f"{ref_out.sum()=}")
|
||||
|
||||
rel_diff = (torch.mean(
|
||||
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
|
||||
torch.mean(torch.abs(ref_out.to(torch.float32))))
|
||||
|
||||
assert rel_diff < 0.03
|
||||
|
@ -1224,7 +1224,7 @@ def moe_wna16_gemm(input: torch.Tensor, output: torch.Tensor,
|
||||
|
||||
def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||
token_expert_indicies: torch.Tensor,
|
||||
gating_output: float) -> None:
|
||||
gating_output: torch.Tensor) -> None:
|
||||
torch.ops._moe_C.topk_softmax(topk_weights, topk_ids,
|
||||
token_expert_indicies, gating_output)
|
||||
|
||||
|
@ -105,6 +105,7 @@ if TYPE_CHECKING:
|
||||
VLLM_V0_USE_OUTLINES_CACHE: bool = False
|
||||
VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION: bool = False
|
||||
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
|
||||
VLLM_USE_DEEP_GEMM: bool = False
|
||||
|
||||
|
||||
def get_default_cache_root():
|
||||
@ -687,6 +688,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_TPU_BUCKET_PADDING_GAP":
|
||||
lambda: int(os.environ["VLLM_TPU_BUCKET_PADDING_GAP"])
|
||||
if "VLLM_TPU_BUCKET_PADDING_GAP" in os.environ else 0,
|
||||
|
||||
# Allow use of DeepGemm kernels for fused moe ops.
|
||||
"VLLM_USE_DEEP_GEMM":
|
||||
lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))),
|
||||
}
|
||||
|
||||
# end-env-vars-definition
|
||||
|
@ -1,8 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Fused MoE kernel."""
|
||||
import functools
|
||||
import importlib.util
|
||||
import json
|
||||
import os
|
||||
from math import prod
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
@ -15,7 +17,7 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.utils import direct_register_custom_op, round_up
|
||||
|
||||
from .rocm_aiter_fused_moe import (is_rocm_aiter_moe_enabled,
|
||||
rocm_aiter_fused_experts,
|
||||
@ -23,6 +25,8 @@ from .rocm_aiter_fused_moe import (is_rocm_aiter_moe_enabled,
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
|
||||
|
||||
|
||||
@triton.jit
|
||||
def write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, offs_token,
|
||||
@ -581,7 +585,8 @@ def moe_align_block_size(
|
||||
topk_ids: torch.Tensor,
|
||||
block_size: int,
|
||||
num_experts: int,
|
||||
expert_map: torch.Tensor = None
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
pad_sorted_ids: bool = False
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Aligns the token distribution across experts to be compatible with block
|
||||
@ -596,6 +601,8 @@ def moe_align_block_size(
|
||||
from the global space to the local index space of the current
|
||||
expert parallel shard. If the expert is not in the current expert
|
||||
parallel shard, the mapping is set to -1.
|
||||
- pad_sorted_ids: A flag indicating whether the sorted_token_ids length
|
||||
should be padded to a multiple of block_size,
|
||||
|
||||
Returns:
|
||||
- sorted_token_ids: A tensor containing the sorted token indices according
|
||||
@ -625,6 +632,8 @@ def moe_align_block_size(
|
||||
by block_size for proper block matrix operations.
|
||||
"""
|
||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
||||
if pad_sorted_ids:
|
||||
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
|
||||
sorted_ids = torch.empty((max_num_tokens_padded, ),
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device)
|
||||
@ -667,6 +676,59 @@ def moe_align_block_size(
|
||||
return sorted_ids, expert_ids, num_tokens_post_pad
|
||||
|
||||
|
||||
def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
expert_map: Optional[torch.Tensor]) -> bool:
|
||||
"""
|
||||
Check if the given problem size is supported by the DeepGemm grouped
|
||||
gemm kernel. All of M, N, K and the quantization block_shape must be
|
||||
aligned by `dg.get_m_alignment_for_contiguous_layout()`.
|
||||
"""
|
||||
if not has_deep_gemm:
|
||||
return False
|
||||
|
||||
# Lazy import to avoid CUDA initialization problems.
|
||||
import deep_gemm as dg
|
||||
|
||||
# Expert maps not supported yet.
|
||||
if expert_map is not None:
|
||||
return False
|
||||
|
||||
align = dg.get_m_alignment_for_contiguous_layout()
|
||||
M = hidden_states.shape[0]
|
||||
_, K, N = w2.shape
|
||||
|
||||
# For now, disable DeepGemm for small N until better permute/unpermute
|
||||
# ops are available.
|
||||
if N <= 512:
|
||||
return False
|
||||
|
||||
if align > M or N % align != 0 or K % align != 0:
|
||||
return False
|
||||
|
||||
return (hidden_states.is_contiguous() and w1.is_contiguous()
|
||||
and w2.is_contiguous())
|
||||
|
||||
|
||||
def _fp8_quantize(
|
||||
A: torch.Tensor,
|
||||
A_scale: Optional[torch.Tensor],
|
||||
block_shape: Optional[List[int]],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Perform fp8 quantization on the inputs. If a block_shape
|
||||
is provided, the output will be blocked.
|
||||
"""
|
||||
if block_shape is None:
|
||||
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
|
||||
else:
|
||||
assert len(block_shape) == 2
|
||||
_, block_k = block_shape[0], block_shape[1]
|
||||
A, A_scale = per_token_group_quant_fp8(A, block_k)
|
||||
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
|
||||
return A, A_scale
|
||||
|
||||
|
||||
def invoke_fused_moe_kernel(A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
C: torch.Tensor,
|
||||
@ -691,15 +753,11 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
||||
|
||||
if use_fp8_w8a8:
|
||||
assert B_scale is not None
|
||||
if block_shape is None:
|
||||
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
|
||||
else:
|
||||
assert len(block_shape) == 2
|
||||
block_n, block_k = block_shape[0], block_shape[1]
|
||||
A, A_scale = per_token_group_quant_fp8(A, block_k)
|
||||
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
|
||||
assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
|
||||
assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
|
||||
assert (block_shape is None or triton.cdiv(B.shape[-2], block_shape[0])
|
||||
== B_scale.shape[-2])
|
||||
assert (block_shape is None or triton.cdiv(B.shape[-1], block_shape[1])
|
||||
== B_scale.shape[-1])
|
||||
|
||||
elif use_int8_w8a16 or use_int4_w4a16:
|
||||
assert B_scale is not None
|
||||
assert block_shape is None or block_shape[0] == 0
|
||||
@ -1066,7 +1124,7 @@ def fused_topk(
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
):
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], (
|
||||
"Number of tokens mismatch")
|
||||
|
||||
@ -1098,14 +1156,16 @@ def fused_topk(
|
||||
|
||||
# This is used by the Deepseek-V2 and Deepseek-V3 model
|
||||
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
|
||||
def grouped_topk(hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
num_expert_group: int = 0,
|
||||
topk_group: int = 0,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None):
|
||||
def grouped_topk(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
num_expert_group: int = 0,
|
||||
topk_group: int = 0,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], (
|
||||
"Number of tokens mismatch")
|
||||
@ -1154,10 +1214,11 @@ def grouped_topk(hidden_states: torch.Tensor,
|
||||
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
||||
|
||||
|
||||
def get_config_dtype_str(dtype: torch.dtype,
|
||||
use_int4_w4a16: Optional[bool] = False,
|
||||
use_int8_w8a16: Optional[bool] = False,
|
||||
use_fp8_w8a8: Optional[bool] = False):
|
||||
def get_config_dtype_str(
|
||||
dtype: torch.dtype,
|
||||
use_int4_w4a16: Optional[bool] = False,
|
||||
use_int8_w8a16: Optional[bool] = False,
|
||||
use_fp8_w8a8: Optional[bool] = False) -> Optional[str]:
|
||||
if use_fp8_w8a8:
|
||||
return "fp8_w8a8"
|
||||
elif use_int8_w8a16:
|
||||
@ -1318,26 +1379,123 @@ def fused_experts(hidden_states: torch.Tensor,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None) -> torch.Tensor:
|
||||
return dispatch_fused_experts_func(inplace)(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=w1_zp,
|
||||
w2_zp=w2_zp,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_shape)
|
||||
block_shape: Optional[List[int]] = None,
|
||||
allow_deep_gemm: bool = False) -> torch.Tensor:
|
||||
if (allow_deep_gemm and use_fp8_w8a8
|
||||
and _valid_deep_gemm(hidden_states, w1, w2, expert_map)):
|
||||
return deep_gemm_moe_fp8(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
)
|
||||
else:
|
||||
return dispatch_fused_experts_func(inplace)(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=w1_zp,
|
||||
w2_zp=w2_zp,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_shape)
|
||||
|
||||
|
||||
def _fp8_perm(m: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
A permutation routine that works on fp8 types.
|
||||
"""
|
||||
if torch.is_floating_point(m) and torch.finfo(m.dtype).bits == 8:
|
||||
return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype)
|
||||
else:
|
||||
return m[idx, ...]
|
||||
|
||||
|
||||
def _moe_permute(
|
||||
curr_hidden_states: torch.Tensor,
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
curr_topk_ids: torch.Tensor,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
top_k_num: int,
|
||||
block_m: int,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor,
|
||||
torch.Tensor]:
|
||||
"""
|
||||
Determine the sorted_token_ids, expert_ids for the given problem size.
|
||||
Permute the hidden states and scales according to `sorted_token_ids`.
|
||||
"""
|
||||
tokens_in_chunk, _ = curr_hidden_states.shape
|
||||
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = (
|
||||
moe_align_block_size(curr_topk_ids,
|
||||
block_m,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
pad_sorted_ids=True))
|
||||
|
||||
inv_perm: Optional[torch.Tensor] = None
|
||||
|
||||
num_tokens = top_k_num * tokens_in_chunk
|
||||
sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1)
|
||||
expert_ids = torch.repeat_interleave(expert_ids, block_m, dim=0)
|
||||
inv_perm = torch.argsort(sorted_token_ids)[:num_tokens]
|
||||
|
||||
# Permute according to sorted token ids.
|
||||
curr_hidden_states = _fp8_perm(curr_hidden_states,
|
||||
sorted_token_ids // top_k_num)
|
||||
|
||||
if a1q_scale is not None:
|
||||
a1q_scale = a1q_scale[sorted_token_ids // top_k_num]
|
||||
|
||||
return (curr_hidden_states, a1q_scale, sorted_token_ids, expert_ids,
|
||||
inv_perm)
|
||||
|
||||
|
||||
def _moe_unpermute_and_reduce(
|
||||
out: torch.Tensor,
|
||||
curr_hidden: torch.Tensor,
|
||||
inv_perm: Optional[torch.Tensor],
|
||||
topk: int,
|
||||
K: int,
|
||||
topk_weight: torch.Tensor,
|
||||
) -> None:
|
||||
"""
|
||||
Unpermute the final result and apply topk_weights, then perform the final
|
||||
reduction on the hidden states.
|
||||
"""
|
||||
M = topk_weight.shape[0]
|
||||
curr_hidden = curr_hidden[inv_perm, ...]
|
||||
curr_hidden = curr_hidden.view(-1, topk, K)
|
||||
curr_hidden.mul_(topk_weight.view(M, -1, 1))
|
||||
ops.moe_sum(curr_hidden, out)
|
||||
|
||||
|
||||
def _resize_cache(x: torch.Tensor, v: Tuple[int, ...]) -> torch.Tensor:
|
||||
"""
|
||||
Shrink the given tensor and apply the given view to it. This is
|
||||
used to resize the intermediate fused_moe caches.
|
||||
"""
|
||||
assert prod(v) <= x.numel()
|
||||
return x.flatten()[:prod(v)].view(*v)
|
||||
|
||||
|
||||
def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
@ -1376,6 +1534,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
|
||||
num_tokens, _ = hidden_states.shape
|
||||
E, N, _ = w1.shape
|
||||
K = w2.shape[1]
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = E
|
||||
top_k_num = topk_ids.shape[1]
|
||||
@ -1401,13 +1560,11 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
|
||||
# We can reuse the memory between these because by the time we need
|
||||
# cache3, we're done with cache1
|
||||
cache13 = torch.empty(M * top_k_num * max(N, w2.shape[1]),
|
||||
cache13 = torch.empty(M * top_k_num * max(N, K),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype)
|
||||
intermediate_cache1 = cache13[:M * top_k_num * N].view(
|
||||
(M, topk_ids.shape[1], N))
|
||||
intermediate_cache3 = cache13[:M * top_k_num * w2.shape[1]].view(
|
||||
(M, topk_ids.shape[1], w2.shape[1]))
|
||||
intermediate_cache1 = cache13[:M * top_k_num * N].view(M, top_k_num, N)
|
||||
intermediate_cache3 = cache13[:M * top_k_num * K].view(M, top_k_num, K)
|
||||
|
||||
# This needs separate memory since it's used concurrently with cache1
|
||||
intermediate_cache2 = torch.empty((M * top_k_num, N // 2),
|
||||
@ -1452,14 +1609,23 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
|
||||
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
|
||||
|
||||
a1q_scale: Optional[torch.Tensor] = None
|
||||
|
||||
if use_fp8_w8a8:
|
||||
qcurr_hidden_states, a1q_scale = _fp8_quantize(
|
||||
curr_hidden_states, a1_scale, block_shape)
|
||||
else:
|
||||
qcurr_hidden_states = curr_hidden_states
|
||||
a1q_scale = a1_scale
|
||||
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = (
|
||||
moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'],
|
||||
global_num_experts, expert_map))
|
||||
|
||||
invoke_fused_moe_kernel(curr_hidden_states,
|
||||
invoke_fused_moe_kernel(qcurr_hidden_states,
|
||||
w1,
|
||||
intermediate_cache1,
|
||||
a1_scale,
|
||||
a1q_scale,
|
||||
w1_scale,
|
||||
w1_zp,
|
||||
curr_topk_weights,
|
||||
@ -1485,10 +1651,19 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
else:
|
||||
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
|
||||
|
||||
invoke_fused_moe_kernel(intermediate_cache2,
|
||||
a2q_scale: Optional[torch.Tensor] = None
|
||||
|
||||
if use_fp8_w8a8:
|
||||
qintermediate_cache2, a2q_scale = _fp8_quantize(
|
||||
intermediate_cache2, a2_scale, block_shape)
|
||||
else:
|
||||
qintermediate_cache2 = intermediate_cache2
|
||||
a2q_scale = a2_scale
|
||||
|
||||
invoke_fused_moe_kernel(qintermediate_cache2,
|
||||
w2,
|
||||
intermediate_cache3,
|
||||
a2_scale,
|
||||
a2q_scale,
|
||||
w2_scale,
|
||||
w2_zp,
|
||||
curr_topk_weights,
|
||||
@ -1617,6 +1792,193 @@ def fused_moe(
|
||||
block_shape=block_shape)
|
||||
|
||||
|
||||
def deep_gemm_moe_fp8(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
|
||||
using two sets of quantized weights, w1_q and w2_q, and top-k gating
|
||||
mechanism. The matrix multiplications are implemented with DeepGemm
|
||||
grouped gemm.
|
||||
|
||||
Parameters:
|
||||
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
|
||||
Shape: [M, K]
|
||||
- w1 (torch.Tensor): The first set of fp8 quantized expert weights.
|
||||
Shape: [num_experts, K, 2N] (the weights are passed transposed)
|
||||
- w2 (torch.Tensor): The second set of fp8 quantized expert weights.
|
||||
Shape: [num_experts, N, K] (the weights are passed transposed)
|
||||
- w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
|
||||
Shape: [num_experts] or [num_experts, 2N]
|
||||
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
|
||||
Shape: [num_experts] or [num_experts, K]
|
||||
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
|
||||
- topk_ids (torch.Tensor): The token->expert mapping for topk_weights.
|
||||
- inplace (bool): If True, perform the operation in-place.
|
||||
Defaults to False.
|
||||
- activation (str): The activation function to apply after the first
|
||||
MoE layer.
|
||||
- global_num_experts (int): The total number of experts in the global
|
||||
expert space.
|
||||
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
|
||||
from the global expert space to the local expert space of the expert
|
||||
parallel shard.
|
||||
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
|
||||
Shape: scalar or [M]
|
||||
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
|
||||
quantize the intermediate result between the gemms.
|
||||
Shape: scalar or [M]
|
||||
|
||||
Returns:
|
||||
- torch.Tensor: The bfloat16 output tensor after applying the MoE layer.
|
||||
"""
|
||||
# Lazy import to avoid CUDA initialization problems.
|
||||
import deep_gemm as dg
|
||||
|
||||
assert expert_map is None, "Expert maps not supported yet"
|
||||
|
||||
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
|
||||
|
||||
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
|
||||
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
|
||||
assert hidden_states.dtype in [
|
||||
torch.float32, torch.float16, torch.bfloat16
|
||||
]
|
||||
assert w1.dtype == torch.float8_e4m3fn
|
||||
assert w2.dtype == torch.float8_e4m3fn
|
||||
assert w1.shape[0] == w2.shape[0], "Expert number mismatch"
|
||||
assert w1.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
|
||||
assert w1.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch"
|
||||
assert a1_scale is None or a1_scale.dim(
|
||||
) == 0 or a1_scale.shape[0] == 1 or a1_scale.shape[
|
||||
0] == hidden_states.shape[0], "Input scale shape mismatch"
|
||||
assert a2_scale is None or a1_scale is None or a2_scale.shape == a1_scale.shape, "Intermediate scale shape mismatch" # noqa: E501
|
||||
|
||||
num_tokens, _ = hidden_states.shape
|
||||
E, N, _ = w1.shape
|
||||
K = w2.shape[1]
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = E
|
||||
top_k_num = topk_ids.shape[1]
|
||||
# We execute the fused_moe kernel in chunks to circumvent this issue:
|
||||
# https://github.com/vllm-project/vllm/issues/5938
|
||||
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||
|
||||
assert _valid_deep_gemm(hidden_states, w1, w2, expert_map)
|
||||
|
||||
if inplace:
|
||||
out_hidden_states = hidden_states
|
||||
else:
|
||||
out_hidden_states = torch.empty_like(hidden_states)
|
||||
|
||||
block_m = dg.get_m_alignment_for_contiguous_layout()
|
||||
block_shape = [block_m, block_m]
|
||||
|
||||
assert w1_scale is not None
|
||||
assert w2_scale is not None
|
||||
|
||||
# We attempt to transpose and align offline in Fp8MoEMethod, in which
|
||||
# case these calls will be nops. Otherwise, they'll be performed every
|
||||
# time the layer is executed.
|
||||
w1_scale = dg.get_col_major_tma_aligned_tensor(w1_scale).contiguous()
|
||||
w2_scale = dg.get_col_major_tma_aligned_tensor(w2_scale).contiguous()
|
||||
|
||||
M_sum = topk_ids.numel() + global_num_experts * (block_m - 1)
|
||||
M_sum = round_up(M_sum, block_m)
|
||||
|
||||
num_chunks = (num_tokens // CHUNK_SIZE) + 1
|
||||
|
||||
# We can reuse the memory between cache1 and cache3 because by the time
|
||||
# we need cache3, we're done with cache1
|
||||
cache13 = torch.empty(M_sum * max(N, K),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype)
|
||||
|
||||
intermediate_cache1 = cache13[:M_sum * N].view(M_sum, N)
|
||||
intermediate_cache2 = torch.empty((M_sum, N // 2),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype)
|
||||
intermediate_cache3 = cache13[:M_sum * K].view(M_sum, K)
|
||||
|
||||
for chunk in range(num_chunks):
|
||||
begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE,
|
||||
min((chunk + 1) * CHUNK_SIZE,
|
||||
num_tokens))
|
||||
curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
|
||||
tokens_in_chunk, _ = curr_hidden_states.shape
|
||||
|
||||
if tokens_in_chunk == 0:
|
||||
break
|
||||
|
||||
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
|
||||
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
|
||||
|
||||
a1q_scale: Optional[torch.Tensor] = None
|
||||
|
||||
qcurr_hidden_states, a1q_scale = _fp8_quantize(curr_hidden_states,
|
||||
a1_scale, block_shape)
|
||||
|
||||
(qcurr_hidden_states, a1q_scale, sorted_token_ids, expert_ids,
|
||||
inv_perm) = _moe_permute(qcurr_hidden_states, a1q_scale,
|
||||
curr_topk_ids, global_num_experts,
|
||||
expert_map, top_k_num, block_m)
|
||||
|
||||
# Adjust the intermediate cache size and config for the last chunk.
|
||||
# Note that in most cases we only have one chunk so the cache size
|
||||
# and config are already set correctly and do not need to be adjusted.
|
||||
if tokens_in_chunk < CHUNK_SIZE and chunk > 0:
|
||||
curr_M = sorted_token_ids.numel()
|
||||
intermediate_cache1 = _resize_cache(intermediate_cache1,
|
||||
(curr_M, N))
|
||||
intermediate_cache2 = _resize_cache(intermediate_cache2,
|
||||
(curr_M, N // 2))
|
||||
intermediate_cache3 = _resize_cache(intermediate_cache3,
|
||||
(curr_M, K))
|
||||
|
||||
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||
(qcurr_hidden_states, a1q_scale), (w1, w1_scale),
|
||||
intermediate_cache1, expert_ids)
|
||||
|
||||
if activation == "silu":
|
||||
torch.ops._C.silu_and_mul(intermediate_cache2,
|
||||
intermediate_cache1.view(-1, N))
|
||||
elif activation == "gelu":
|
||||
torch.ops._C.gelu_and_mul(intermediate_cache2,
|
||||
intermediate_cache1.view(-1, N))
|
||||
else:
|
||||
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
|
||||
|
||||
a2q_scale: Optional[torch.Tensor] = None
|
||||
|
||||
qintermediate_cache2, a2q_scale = _fp8_quantize(
|
||||
intermediate_cache2, a2_scale, block_shape)
|
||||
|
||||
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||
(qintermediate_cache2, a2q_scale), (w2, w2_scale),
|
||||
intermediate_cache3, expert_ids)
|
||||
|
||||
_moe_unpermute_and_reduce(
|
||||
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
||||
intermediate_cache3.view(*intermediate_cache3.shape), inv_perm,
|
||||
top_k_num, K, curr_topk_weights)
|
||||
|
||||
return out_hidden_states
|
||||
|
||||
|
||||
#TODO make the grouped gemm kernel consistent with scaled gemm kernel
|
||||
def cutlass_moe_fp8(
|
||||
a: torch.Tensor,
|
||||
|
@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import importlib.util
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
@ -37,6 +38,14 @@ ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
|
||||
|
||||
|
||||
def _is_col_major(x: torch.Tensor) -> bool:
|
||||
assert x.dim() == 3
|
||||
b, m, n = x.shape
|
||||
return x.stride(0) == m * n and x.stride(1) == 1 and x.stride(2) == m
|
||||
|
||||
|
||||
class Fp8Config(QuantizationConfig):
|
||||
"""Config class for FP8."""
|
||||
@ -424,6 +433,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
self.quant_config = quant_config
|
||||
self.block_quant = self.quant_config.weight_block_size is not None
|
||||
|
||||
# Check for DeepGemm support.
|
||||
self.allow_deep_gemm = False
|
||||
if envs.VLLM_USE_DEEP_GEMM:
|
||||
if not has_deep_gemm:
|
||||
logger.warning_once("Failed to import DeepGemm kernels.")
|
||||
elif (current_platform.is_cuda()
|
||||
and current_platform.has_device_capability(90)):
|
||||
logger.info_once("Using DeepGemm kernels for Fp8MoEMethod.")
|
||||
self.allow_deep_gemm = True
|
||||
else:
|
||||
logger.warning_once(
|
||||
"DeepGemm not supported on the current platform.")
|
||||
|
||||
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
@ -585,6 +607,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(shuffled_w2,
|
||||
requires_grad=False)
|
||||
|
||||
# DeepGemm scales need to be transposed and aligned. We try to do
|
||||
# it ahead of time for performance reasons.
|
||||
if self.allow_deep_gemm:
|
||||
# Lazy import to avoid CUDA initialization problems.
|
||||
import deep_gemm as dg
|
||||
if _is_col_major(layer.w13_weight_scale_inv):
|
||||
layer.w13_weight_scale_inv = \
|
||||
dg.get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv).contiguous()
|
||||
if _is_col_major(layer.w2_weight_scale_inv):
|
||||
layer.w2_weight_scale_inv = \
|
||||
dg.get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv).contiguous()
|
||||
|
||||
return
|
||||
|
||||
# If checkpoint is fp16, quantize in place.
|
||||
@ -773,6 +808,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
allow_deep_gemm=self.allow_deep_gemm,
|
||||
)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user