Add option to use DeepGemm contiguous grouped gemm kernel for fused MoE operations. (#13932)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
parent
a57a3044aa
commit
e59ca942f5
@ -30,8 +30,7 @@ class BenchmarkConfig(TypedDict):
|
|||||||
num_stages: int
|
num_stages: int
|
||||||
|
|
||||||
|
|
||||||
def benchmark_config(
|
def benchmark_config(config: BenchmarkConfig,
|
||||||
config: BenchmarkConfig,
|
|
||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
num_experts: int,
|
num_experts: int,
|
||||||
shard_intermediate_size: int,
|
shard_intermediate_size: int,
|
||||||
@ -42,7 +41,7 @@ def benchmark_config(
|
|||||||
use_int8_w8a16: bool,
|
use_int8_w8a16: bool,
|
||||||
num_iters: int = 100,
|
num_iters: int = 100,
|
||||||
block_quant_shape: List[int] = None,
|
block_quant_shape: List[int] = None,
|
||||||
) -> float:
|
use_deep_gemm: bool = False) -> float:
|
||||||
init_dtype = torch.float16 if use_fp8_w8a8 else dtype
|
init_dtype = torch.float16 if use_fp8_w8a8 else dtype
|
||||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||||
if use_int8_w8a16:
|
if use_int8_w8a16:
|
||||||
@ -115,6 +114,25 @@ def benchmark_config(
|
|||||||
def run():
|
def run():
|
||||||
from vllm.model_executor.layers.fused_moe import override_config
|
from vllm.model_executor.layers.fused_moe import override_config
|
||||||
with override_config(config):
|
with override_config(config):
|
||||||
|
if use_deep_gemm:
|
||||||
|
topk_weights, topk_ids = fused_topk(x, input_gating, topk,
|
||||||
|
False)
|
||||||
|
return fused_experts(
|
||||||
|
x,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
inplace=True,
|
||||||
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a1_scale,
|
||||||
|
a2_scale=a2_scale,
|
||||||
|
block_shape=block_quant_shape,
|
||||||
|
allow_deep_gemm=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
fused_moe(
|
fused_moe(
|
||||||
x,
|
x,
|
||||||
w1,
|
w1,
|
||||||
@ -366,6 +384,7 @@ class BenchmarkWorker:
|
|||||||
use_fp8_w8a8: bool,
|
use_fp8_w8a8: bool,
|
||||||
use_int8_w8a16: bool,
|
use_int8_w8a16: bool,
|
||||||
block_quant_shape: List[int] = None,
|
block_quant_shape: List[int] = None,
|
||||||
|
use_deep_gemm: bool = False,
|
||||||
) -> tuple[dict[str, int], float]:
|
) -> tuple[dict[str, int], float]:
|
||||||
current_platform.seed_everything(self.seed)
|
current_platform.seed_everything(self.seed)
|
||||||
dtype_str = get_config_dtype_str(dtype,
|
dtype_str = get_config_dtype_str(dtype,
|
||||||
@ -396,7 +415,8 @@ class BenchmarkWorker:
|
|||||||
use_fp8_w8a8,
|
use_fp8_w8a8,
|
||||||
use_int8_w8a16,
|
use_int8_w8a16,
|
||||||
num_iters=100,
|
num_iters=100,
|
||||||
block_quant_shape=block_quant_shape)
|
block_quant_shape=block_quant_shape,
|
||||||
|
use_deep_gemm=use_deep_gemm)
|
||||||
return config, kernel_time
|
return config, kernel_time
|
||||||
|
|
||||||
def tune(
|
def tune(
|
||||||
@ -411,6 +431,7 @@ class BenchmarkWorker:
|
|||||||
use_int8_w8a16: bool,
|
use_int8_w8a16: bool,
|
||||||
search_space: list[dict[str, int]],
|
search_space: list[dict[str, int]],
|
||||||
block_quant_shape: list[int],
|
block_quant_shape: list[int],
|
||||||
|
use_deep_gemm: bool,
|
||||||
) -> dict[str, int]:
|
) -> dict[str, int]:
|
||||||
best_config = None
|
best_config = None
|
||||||
best_time = float("inf")
|
best_time = float("inf")
|
||||||
@ -436,7 +457,8 @@ class BenchmarkWorker:
|
|||||||
use_fp8_w8a8,
|
use_fp8_w8a8,
|
||||||
use_int8_w8a16,
|
use_int8_w8a16,
|
||||||
num_iters=20,
|
num_iters=20,
|
||||||
block_quant_shape=block_quant_shape)
|
block_quant_shape=block_quant_shape,
|
||||||
|
use_deep_gemm=use_deep_gemm)
|
||||||
except triton.runtime.autotuner.OutOfResources:
|
except triton.runtime.autotuner.OutOfResources:
|
||||||
# Some configurations may be invalid and fail to compile.
|
# Some configurations may be invalid and fail to compile.
|
||||||
continue
|
continue
|
||||||
@ -550,6 +572,8 @@ def main(args: argparse.Namespace):
|
|||||||
else:
|
else:
|
||||||
batch_sizes = [args.batch_size]
|
batch_sizes = [args.batch_size]
|
||||||
|
|
||||||
|
use_deep_gemm = bool(args.use_deep_gemm)
|
||||||
|
|
||||||
ray.init()
|
ray.init()
|
||||||
num_gpus = int(ray.available_resources()["GPU"])
|
num_gpus = int(ray.available_resources()["GPU"])
|
||||||
workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)]
|
workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)]
|
||||||
@ -572,9 +596,9 @@ def main(args: argparse.Namespace):
|
|||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
configs = _distribute(
|
configs = _distribute(
|
||||||
"tune",
|
"tune", [(batch_size, E, shard_intermediate_size, hidden_size,
|
||||||
[(batch_size, E, shard_intermediate_size, hidden_size, topk, dtype,
|
topk, dtype, use_fp8_w8a8, use_int8_w8a16, search_space,
|
||||||
use_fp8_w8a8, use_int8_w8a16, search_space, block_quant_shape)
|
block_quant_shape, use_deep_gemm)
|
||||||
for batch_size in batch_sizes])
|
for batch_size in batch_sizes])
|
||||||
best_configs = {
|
best_configs = {
|
||||||
M: sort_config(config)
|
M: sort_config(config)
|
||||||
@ -589,7 +613,7 @@ def main(args: argparse.Namespace):
|
|||||||
outputs = _distribute(
|
outputs = _distribute(
|
||||||
"benchmark",
|
"benchmark",
|
||||||
[(batch_size, E, shard_intermediate_size, hidden_size, topk, dtype,
|
[(batch_size, E, shard_intermediate_size, hidden_size, topk, dtype,
|
||||||
use_fp8_w8a8, use_int8_w8a16, block_quant_shape)
|
use_fp8_w8a8, use_int8_w8a16, block_quant_shape, use_deep_gemm)
|
||||||
for batch_size in batch_sizes])
|
for batch_size in batch_sizes])
|
||||||
|
|
||||||
for batch_size, (config, kernel_time) in zip(batch_sizes, outputs):
|
for batch_size, (config, kernel_time) in zip(batch_sizes, outputs):
|
||||||
@ -611,6 +635,7 @@ if __name__ == "__main__":
|
|||||||
type=str,
|
type=str,
|
||||||
choices=["auto", "fp8_w8a8", "int8_w8a16"],
|
choices=["auto", "fp8_w8a8", "int8_w8a16"],
|
||||||
default="auto")
|
default="auto")
|
||||||
|
parser.add_argument("--use-deep-gemm", action="store_true")
|
||||||
parser.add_argument("--seed", type=int, default=0)
|
parser.add_argument("--seed", type=int, default=0)
|
||||||
parser.add_argument("--batch-size", type=int, required=False)
|
parser.add_argument("--batch-size", type=int, required=False)
|
||||||
parser.add_argument("--tune", action="store_true")
|
parser.add_argument("--tune", action="store_true")
|
||||||
|
@ -6,12 +6,22 @@ import itertools
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.config import VllmConfig, set_current_vllm_config
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||||
|
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||||
|
deep_gemm_moe_fp8, fused_topk, moe_align_block_size)
|
||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
per_token_group_quant_fp8, w8a8_block_fp8_matmul)
|
per_token_group_quant_fp8, w8a8_block_fp8_matmul)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
dg_available = False
|
||||||
|
try:
|
||||||
|
import deep_gemm
|
||||||
|
dg_available = True
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
if current_platform.get_device_capability() < (9, 0):
|
if current_platform.get_device_capability() < (9, 0):
|
||||||
pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
|
pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
|
||||||
allow_module_level=True)
|
allow_module_level=True)
|
||||||
@ -21,17 +31,18 @@ DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32]
|
|||||||
NUM_TOKENS = [7, 83, 2048]
|
NUM_TOKENS = [7, 83, 2048]
|
||||||
D = [512, 4096, 5120, 13824]
|
D = [512, 4096, 5120, 13824]
|
||||||
GROUP_SIZE = [64, 128, 256, 512]
|
GROUP_SIZE = [64, 128, 256, 512]
|
||||||
M = [1, 7, 83, 512, 2048]
|
M = [1, 7, 8, 83, 84, 512, 2048, 4096]
|
||||||
N = [128, 512, 1024, 4096, 7748, 13824]
|
N = [128, 512, 1024, 4096, 7168, 7748, 13824]
|
||||||
K = [256, 4096, 5120, 3884, 13824]
|
K = [256, 4096, 5120, 3884, 13824, 16384]
|
||||||
# Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8
|
# Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8
|
||||||
# and its hidden size is 7168.
|
# and its hidden size is 7168.
|
||||||
M_moe = [1, 7, 83, 512, 2048]
|
M_moe = [1, 2, 7, 83, 128, 512, 2048]
|
||||||
N_moe = [4608] # [128, 4608, 13824]
|
M_moe_dg = [128, 192, 512, 1335, 2048]
|
||||||
K_moe = [7168] # [256, 7168, 13824]
|
N_moe = [128, 256, 1024, 4608] # [13824]
|
||||||
|
K_moe = [256, 512, 7168] # [13824]
|
||||||
BLOCK_SIZE = [[128, 128]]
|
BLOCK_SIZE = [[128, 128]]
|
||||||
E = [8, 24] # [8, 24, 128, 256]
|
E = [2, 8, 16, 24] # [128, 256]
|
||||||
TOP_KS = [2] # [1, 2, 6]
|
TOP_KS = [1, 2, 6]
|
||||||
OUT_DTYPES = [torch.bfloat16] # [torch.float32, torch.half, torch.bfloat16]
|
OUT_DTYPES = [torch.bfloat16] # [torch.float32, torch.half, torch.bfloat16]
|
||||||
SEEDS = [0]
|
SEEDS = [0]
|
||||||
|
|
||||||
@ -217,11 +228,16 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
|
|||||||
SEEDS))
|
SEEDS))
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
|
def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
|
||||||
|
if topk > E:
|
||||||
|
pytest.skip(f"Skipping test; topk={topk} > E={E}")
|
||||||
|
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
factor_for_scale = 1e-2
|
factor_for_scale = 1e-2
|
||||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||||
|
|
||||||
|
vllm_config = VllmConfig()
|
||||||
|
|
||||||
a = torch.randn((M, K), dtype=dtype) / 10
|
a = torch.randn((M, K), dtype=dtype) / 10
|
||||||
|
|
||||||
w1_bf16 = (torch.rand(
|
w1_bf16 = (torch.rand(
|
||||||
@ -246,6 +262,8 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
|
|||||||
|
|
||||||
score = torch.randn((M, E), dtype=dtype)
|
score = torch.randn((M, E), dtype=dtype)
|
||||||
|
|
||||||
|
# Set the context to avoid lots of warning spam.
|
||||||
|
with set_current_vllm_config(vllm_config):
|
||||||
out = fused_moe(
|
out = fused_moe(
|
||||||
a,
|
a,
|
||||||
w1,
|
w1,
|
||||||
@ -261,10 +279,223 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
|
|||||||
ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk,
|
ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk,
|
||||||
block_size)
|
block_size)
|
||||||
|
|
||||||
print(f"{out.sum()=}")
|
#print(f"{out.sum()=}")
|
||||||
print(f"{ref_out.sum()=}")
|
#print(f"{ref_out.sum()=}")
|
||||||
|
|
||||||
rel_diff = (torch.mean(
|
rel_diff = (torch.mean(
|
||||||
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
|
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
|
||||||
torch.mean(torch.abs(ref_out.to(torch.float32))))
|
torch.mean(torch.abs(ref_out.to(torch.float32))))
|
||||||
assert rel_diff < 0.03
|
assert rel_diff < 0.03
|
||||||
|
|
||||||
|
|
||||||
|
def per_block_cast_to_fp8(
|
||||||
|
x: torch.Tensor,
|
||||||
|
block_size_n: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
assert x.dim() == 2
|
||||||
|
m, n = x.shape
|
||||||
|
x_padded = torch.zeros(
|
||||||
|
(deep_gemm.ceil_div(m, 128) * 128,
|
||||||
|
deep_gemm.ceil_div(n, block_size_n) * block_size_n),
|
||||||
|
dtype=x.dtype,
|
||||||
|
device=x.device)
|
||||||
|
x_padded[:m, :n] = x
|
||||||
|
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n)
|
||||||
|
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
||||||
|
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
|
||||||
|
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
|
||||||
|
scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
|
||||||
|
return x_scaled_sub, scales
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"M,N,K,block_size,out_dtype,seed",
|
||||||
|
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS))
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
|
||||||
|
# only aligned sizes
|
||||||
|
if M % 4 != 0 or K % 128 != 0 or N % 64 != 0:
|
||||||
|
pytest.skip(f"Skipping test; invalid size {M}, {N}, {K}")
|
||||||
|
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||||
|
fp8_max = fp8_info.max
|
||||||
|
|
||||||
|
A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
|
||||||
|
B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
|
||||||
|
|
||||||
|
_, block_k = block_size[0], block_size[1]
|
||||||
|
|
||||||
|
A_fp8, As_fp8 = per_token_group_quant_fp8(A_fp32, block_k)
|
||||||
|
B_fp8, Bs_fp8 = per_block_cast_to_fp8(B_fp32)
|
||||||
|
|
||||||
|
As = As_fp8.to(torch.float32)
|
||||||
|
Bs = Bs_fp8.to(torch.float32)
|
||||||
|
|
||||||
|
ref_out = native_w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size,
|
||||||
|
out_dtype)
|
||||||
|
|
||||||
|
# Transpose earlier so that the testing will not trigger transposing kernels
|
||||||
|
As_fp8 = deep_gemm.get_col_major_tma_aligned_tensor(As_fp8)
|
||||||
|
|
||||||
|
out = torch.zeros((M, N), device='cuda', dtype=out_dtype)
|
||||||
|
|
||||||
|
assert As_fp8.shape == (M, (K + 127) //
|
||||||
|
128), f"{As_fp8.shape} != {(M, (K + 127) // 128)}"
|
||||||
|
|
||||||
|
deep_gemm.gemm_fp8_fp8_bf16_nt((A_fp8, As_fp8), (B_fp8, Bs_fp8), out)
|
||||||
|
|
||||||
|
rel_diff = (torch.mean(
|
||||||
|
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
|
||||||
|
torch.mean(torch.abs(ref_out.to(torch.float32))))
|
||||||
|
assert rel_diff < 0.001
|
||||||
|
|
||||||
|
|
||||||
|
def fp8_perm(m, idx):
|
||||||
|
if torch.is_floating_point(m) and torch.finfo(m.dtype).bits == 8:
|
||||||
|
return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype)
|
||||||
|
else:
|
||||||
|
return m[idx, ...]
|
||||||
|
|
||||||
|
|
||||||
|
def test_moe_permute(a, a_s, topk_ids, num_groups, topk, block_m):
|
||||||
|
M, K = a.shape
|
||||||
|
|
||||||
|
sorted_token_ids, m_indices, num_pad = moe_align_block_size(
|
||||||
|
topk_ids, block_m, num_groups, None, pad_sorted_ids=True)
|
||||||
|
|
||||||
|
num_tokens = topk * M
|
||||||
|
|
||||||
|
sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1)
|
||||||
|
m_indices = torch.repeat_interleave(m_indices, block_m, dim=0)
|
||||||
|
inv_perm = torch.argsort(sorted_token_ids)[:M * topk]
|
||||||
|
|
||||||
|
a = fp8_perm(a, sorted_token_ids // topk)
|
||||||
|
if a_s is not None:
|
||||||
|
a_s = a_s[sorted_token_ids // topk]
|
||||||
|
|
||||||
|
return a, a_s, m_indices, inv_perm
|
||||||
|
|
||||||
|
|
||||||
|
def test_moe_unpermute(out, inv_perm, topk, K, topk_weight):
|
||||||
|
M = topk_weight.shape[0]
|
||||||
|
out = out[inv_perm, ...]
|
||||||
|
tmp_out = out.view(-1, topk, K)
|
||||||
|
return (tmp_out * topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk,
|
||||||
|
block_shape):
|
||||||
|
"""Fused moe with block-wise quantization using DeepGemm grouped gemm."""
|
||||||
|
num_groups = w1.shape[0]
|
||||||
|
M, K = a.shape
|
||||||
|
N = w2.shape[-1]
|
||||||
|
|
||||||
|
topk_weight, topk_ids = fused_topk(a, score.float(), topk, False)
|
||||||
|
|
||||||
|
block_m = deep_gemm.get_m_alignment_for_contiguous_layout()
|
||||||
|
|
||||||
|
_, block_k = block_shape[0], block_shape[1]
|
||||||
|
|
||||||
|
a_q, a_s = per_token_group_quant_fp8(a, block_m)
|
||||||
|
|
||||||
|
a_q, a_s, m_indices, inv_perm = test_moe_permute(a_q, a_s, topk_ids,
|
||||||
|
num_groups, topk, block_m)
|
||||||
|
|
||||||
|
inter_out = torch.zeros((a_q.shape[0], N * 2),
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
device=a.device)
|
||||||
|
|
||||||
|
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous((a_q, a_s), (w1, w1_s),
|
||||||
|
inter_out, m_indices)
|
||||||
|
|
||||||
|
act_out = SiluAndMul().forward_native(inter_out)
|
||||||
|
act_out_q, act_out_s = per_token_group_quant_fp8(act_out, block_k)
|
||||||
|
|
||||||
|
out = torch.zeros(a_q.shape[0], K, dtype=torch.bfloat16, device=a.device)
|
||||||
|
|
||||||
|
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||||
|
(act_out_q, act_out_s), (w2, w2_s), out, m_indices)
|
||||||
|
|
||||||
|
final_out = test_moe_unpermute(out, inv_perm, topk, K, topk_weight)
|
||||||
|
|
||||||
|
return final_out
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"M,N,K,E,topk,seed",
|
||||||
|
itertools.product(M_moe_dg, N_moe, K_moe, E, TOP_KS, SEEDS))
|
||||||
|
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed):
|
||||||
|
|
||||||
|
block_m = deep_gemm.get_m_alignment_for_contiguous_layout()
|
||||||
|
block_size = [block_m, block_m]
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
|
||||||
|
# only aligned sizes
|
||||||
|
if (N % block_m != 0 or K % block_m != 0 or topk > E):
|
||||||
|
pytest.skip(
|
||||||
|
f"Skipping test; bad size m={M}, n={N}, k={K}, topk={topk}, E={E}")
|
||||||
|
|
||||||
|
if (N <= 512):
|
||||||
|
pytest.skip("Skipping N <= 512 until performance issues solved.")
|
||||||
|
|
||||||
|
vllm_config = VllmConfig()
|
||||||
|
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||||
|
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||||
|
|
||||||
|
a = torch.randn((M, K), dtype=dtype) / 10
|
||||||
|
|
||||||
|
w1_bf16 = ((torch.rand((E, 2 * N, K), dtype=torch.bfloat16) - 0.5) * 2 *
|
||||||
|
fp8_max).clamp(min=fp8_min, max=fp8_max)
|
||||||
|
|
||||||
|
w2_bf16 = ((torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 *
|
||||||
|
fp8_max).clamp(min=fp8_min, max=fp8_max)
|
||||||
|
|
||||||
|
score = torch.randn((M, E), dtype=dtype)
|
||||||
|
|
||||||
|
block_n, block_k = block_size[0], block_size[1]
|
||||||
|
n_tiles_w1 = ((2 * N) + block_n - 1) // block_n
|
||||||
|
k_tiles_w1 = (K + block_k - 1) // block_k
|
||||||
|
n_tiles_w2 = (K + block_n - 1) // block_n
|
||||||
|
k_tiles_w2 = (N + block_k - 1) // block_k
|
||||||
|
|
||||||
|
w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn)
|
||||||
|
w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn)
|
||||||
|
|
||||||
|
w1_s = torch.empty((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32)
|
||||||
|
w2_s = torch.empty((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32)
|
||||||
|
|
||||||
|
w1_s = deep_gemm.get_col_major_tma_aligned_tensor(w1_s).contiguous()
|
||||||
|
w2_s = deep_gemm.get_col_major_tma_aligned_tensor(w2_s).contiguous()
|
||||||
|
|
||||||
|
assert w1_s.shape == (E, (2 * N + 127) // 128, (K + 127) // 128)
|
||||||
|
assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2]
|
||||||
|
|
||||||
|
for i in range(E):
|
||||||
|
w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i])
|
||||||
|
w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i])
|
||||||
|
|
||||||
|
# Set the context to avoid lots of warning spam.
|
||||||
|
with set_current_vllm_config(vllm_config):
|
||||||
|
if M >= 128:
|
||||||
|
ref_out = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s,
|
||||||
|
score, topk, block_size)
|
||||||
|
else:
|
||||||
|
ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score,
|
||||||
|
topk, block_size)
|
||||||
|
|
||||||
|
topk_weights, topk_ids = fused_topk(a, score.float(), topk, False)
|
||||||
|
|
||||||
|
out = deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids)
|
||||||
|
|
||||||
|
#print(f"{out.sum()=}")
|
||||||
|
#print(f"{ref_out.sum()=}")
|
||||||
|
|
||||||
|
rel_diff = (torch.mean(
|
||||||
|
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
|
||||||
|
torch.mean(torch.abs(ref_out.to(torch.float32))))
|
||||||
|
|
||||||
|
assert rel_diff < 0.03
|
||||||
|
@ -1224,7 +1224,7 @@ def moe_wna16_gemm(input: torch.Tensor, output: torch.Tensor,
|
|||||||
|
|
||||||
def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||||
token_expert_indicies: torch.Tensor,
|
token_expert_indicies: torch.Tensor,
|
||||||
gating_output: float) -> None:
|
gating_output: torch.Tensor) -> None:
|
||||||
torch.ops._moe_C.topk_softmax(topk_weights, topk_ids,
|
torch.ops._moe_C.topk_softmax(topk_weights, topk_ids,
|
||||||
token_expert_indicies, gating_output)
|
token_expert_indicies, gating_output)
|
||||||
|
|
||||||
|
@ -105,6 +105,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_V0_USE_OUTLINES_CACHE: bool = False
|
VLLM_V0_USE_OUTLINES_CACHE: bool = False
|
||||||
VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION: bool = False
|
VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION: bool = False
|
||||||
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
|
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
|
||||||
|
VLLM_USE_DEEP_GEMM: bool = False
|
||||||
|
|
||||||
|
|
||||||
def get_default_cache_root():
|
def get_default_cache_root():
|
||||||
@ -687,6 +688,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
"VLLM_TPU_BUCKET_PADDING_GAP":
|
"VLLM_TPU_BUCKET_PADDING_GAP":
|
||||||
lambda: int(os.environ["VLLM_TPU_BUCKET_PADDING_GAP"])
|
lambda: int(os.environ["VLLM_TPU_BUCKET_PADDING_GAP"])
|
||||||
if "VLLM_TPU_BUCKET_PADDING_GAP" in os.environ else 0,
|
if "VLLM_TPU_BUCKET_PADDING_GAP" in os.environ else 0,
|
||||||
|
|
||||||
|
# Allow use of DeepGemm kernels for fused moe ops.
|
||||||
|
"VLLM_USE_DEEP_GEMM":
|
||||||
|
lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))),
|
||||||
}
|
}
|
||||||
|
|
||||||
# end-env-vars-definition
|
# end-env-vars-definition
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
"""Fused MoE kernel."""
|
"""Fused MoE kernel."""
|
||||||
import functools
|
import functools
|
||||||
|
import importlib.util
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
from math import prod
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -15,7 +17,7 @@ from vllm.logger import init_logger
|
|||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
per_token_group_quant_fp8)
|
per_token_group_quant_fp8)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import direct_register_custom_op, round_up
|
||||||
|
|
||||||
from .rocm_aiter_fused_moe import (is_rocm_aiter_moe_enabled,
|
from .rocm_aiter_fused_moe import (is_rocm_aiter_moe_enabled,
|
||||||
rocm_aiter_fused_experts,
|
rocm_aiter_fused_experts,
|
||||||
@ -23,6 +25,8 @@ from .rocm_aiter_fused_moe import (is_rocm_aiter_moe_enabled,
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, offs_token,
|
def write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, offs_token,
|
||||||
@ -581,7 +585,8 @@ def moe_align_block_size(
|
|||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
block_size: int,
|
block_size: int,
|
||||||
num_experts: int,
|
num_experts: int,
|
||||||
expert_map: torch.Tensor = None
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
|
pad_sorted_ids: bool = False
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Aligns the token distribution across experts to be compatible with block
|
Aligns the token distribution across experts to be compatible with block
|
||||||
@ -596,6 +601,8 @@ def moe_align_block_size(
|
|||||||
from the global space to the local index space of the current
|
from the global space to the local index space of the current
|
||||||
expert parallel shard. If the expert is not in the current expert
|
expert parallel shard. If the expert is not in the current expert
|
||||||
parallel shard, the mapping is set to -1.
|
parallel shard, the mapping is set to -1.
|
||||||
|
- pad_sorted_ids: A flag indicating whether the sorted_token_ids length
|
||||||
|
should be padded to a multiple of block_size,
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- sorted_token_ids: A tensor containing the sorted token indices according
|
- sorted_token_ids: A tensor containing the sorted token indices according
|
||||||
@ -625,6 +632,8 @@ def moe_align_block_size(
|
|||||||
by block_size for proper block matrix operations.
|
by block_size for proper block matrix operations.
|
||||||
"""
|
"""
|
||||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
||||||
|
if pad_sorted_ids:
|
||||||
|
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
|
||||||
sorted_ids = torch.empty((max_num_tokens_padded, ),
|
sorted_ids = torch.empty((max_num_tokens_padded, ),
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=topk_ids.device)
|
device=topk_ids.device)
|
||||||
@ -667,6 +676,59 @@ def moe_align_block_size(
|
|||||||
return sorted_ids, expert_ids, num_tokens_post_pad
|
return sorted_ids, expert_ids, num_tokens_post_pad
|
||||||
|
|
||||||
|
|
||||||
|
def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
expert_map: Optional[torch.Tensor]) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the given problem size is supported by the DeepGemm grouped
|
||||||
|
gemm kernel. All of M, N, K and the quantization block_shape must be
|
||||||
|
aligned by `dg.get_m_alignment_for_contiguous_layout()`.
|
||||||
|
"""
|
||||||
|
if not has_deep_gemm:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Lazy import to avoid CUDA initialization problems.
|
||||||
|
import deep_gemm as dg
|
||||||
|
|
||||||
|
# Expert maps not supported yet.
|
||||||
|
if expert_map is not None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
align = dg.get_m_alignment_for_contiguous_layout()
|
||||||
|
M = hidden_states.shape[0]
|
||||||
|
_, K, N = w2.shape
|
||||||
|
|
||||||
|
# For now, disable DeepGemm for small N until better permute/unpermute
|
||||||
|
# ops are available.
|
||||||
|
if N <= 512:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if align > M or N % align != 0 or K % align != 0:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return (hidden_states.is_contiguous() and w1.is_contiguous()
|
||||||
|
and w2.is_contiguous())
|
||||||
|
|
||||||
|
|
||||||
|
def _fp8_quantize(
|
||||||
|
A: torch.Tensor,
|
||||||
|
A_scale: Optional[torch.Tensor],
|
||||||
|
block_shape: Optional[List[int]],
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Perform fp8 quantization on the inputs. If a block_shape
|
||||||
|
is provided, the output will be blocked.
|
||||||
|
"""
|
||||||
|
if block_shape is None:
|
||||||
|
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
|
||||||
|
else:
|
||||||
|
assert len(block_shape) == 2
|
||||||
|
_, block_k = block_shape[0], block_shape[1]
|
||||||
|
A, A_scale = per_token_group_quant_fp8(A, block_k)
|
||||||
|
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
|
||||||
|
return A, A_scale
|
||||||
|
|
||||||
|
|
||||||
def invoke_fused_moe_kernel(A: torch.Tensor,
|
def invoke_fused_moe_kernel(A: torch.Tensor,
|
||||||
B: torch.Tensor,
|
B: torch.Tensor,
|
||||||
C: torch.Tensor,
|
C: torch.Tensor,
|
||||||
@ -691,15 +753,11 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
|||||||
|
|
||||||
if use_fp8_w8a8:
|
if use_fp8_w8a8:
|
||||||
assert B_scale is not None
|
assert B_scale is not None
|
||||||
if block_shape is None:
|
assert (block_shape is None or triton.cdiv(B.shape[-2], block_shape[0])
|
||||||
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
|
== B_scale.shape[-2])
|
||||||
else:
|
assert (block_shape is None or triton.cdiv(B.shape[-1], block_shape[1])
|
||||||
assert len(block_shape) == 2
|
== B_scale.shape[-1])
|
||||||
block_n, block_k = block_shape[0], block_shape[1]
|
|
||||||
A, A_scale = per_token_group_quant_fp8(A, block_k)
|
|
||||||
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
|
|
||||||
assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
|
|
||||||
assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
|
|
||||||
elif use_int8_w8a16 or use_int4_w4a16:
|
elif use_int8_w8a16 or use_int4_w4a16:
|
||||||
assert B_scale is not None
|
assert B_scale is not None
|
||||||
assert block_shape is None or block_shape[0] == 0
|
assert block_shape is None or block_shape[0] == 0
|
||||||
@ -1066,7 +1124,7 @@ def fused_topk(
|
|||||||
gating_output: torch.Tensor,
|
gating_output: torch.Tensor,
|
||||||
topk: int,
|
topk: int,
|
||||||
renormalize: bool,
|
renormalize: bool,
|
||||||
):
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
assert hidden_states.shape[0] == gating_output.shape[0], (
|
assert hidden_states.shape[0] == gating_output.shape[0], (
|
||||||
"Number of tokens mismatch")
|
"Number of tokens mismatch")
|
||||||
|
|
||||||
@ -1098,14 +1156,16 @@ def fused_topk(
|
|||||||
|
|
||||||
# This is used by the Deepseek-V2 and Deepseek-V3 model
|
# This is used by the Deepseek-V2 and Deepseek-V3 model
|
||||||
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
|
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
|
||||||
def grouped_topk(hidden_states: torch.Tensor,
|
def grouped_topk(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
gating_output: torch.Tensor,
|
gating_output: torch.Tensor,
|
||||||
topk: int,
|
topk: int,
|
||||||
renormalize: bool,
|
renormalize: bool,
|
||||||
num_expert_group: int = 0,
|
num_expert_group: int = 0,
|
||||||
topk_group: int = 0,
|
topk_group: int = 0,
|
||||||
scoring_func: str = "softmax",
|
scoring_func: str = "softmax",
|
||||||
e_score_correction_bias: Optional[torch.Tensor] = None):
|
e_score_correction_bias: Optional[torch.Tensor] = None
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
|
||||||
assert hidden_states.shape[0] == gating_output.shape[0], (
|
assert hidden_states.shape[0] == gating_output.shape[0], (
|
||||||
"Number of tokens mismatch")
|
"Number of tokens mismatch")
|
||||||
@ -1154,10 +1214,11 @@ def grouped_topk(hidden_states: torch.Tensor,
|
|||||||
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
||||||
|
|
||||||
|
|
||||||
def get_config_dtype_str(dtype: torch.dtype,
|
def get_config_dtype_str(
|
||||||
|
dtype: torch.dtype,
|
||||||
use_int4_w4a16: Optional[bool] = False,
|
use_int4_w4a16: Optional[bool] = False,
|
||||||
use_int8_w8a16: Optional[bool] = False,
|
use_int8_w8a16: Optional[bool] = False,
|
||||||
use_fp8_w8a8: Optional[bool] = False):
|
use_fp8_w8a8: Optional[bool] = False) -> Optional[str]:
|
||||||
if use_fp8_w8a8:
|
if use_fp8_w8a8:
|
||||||
return "fp8_w8a8"
|
return "fp8_w8a8"
|
||||||
elif use_int8_w8a16:
|
elif use_int8_w8a16:
|
||||||
@ -1318,7 +1379,25 @@ def fused_experts(hidden_states: torch.Tensor,
|
|||||||
w2_zp: Optional[torch.Tensor] = None,
|
w2_zp: Optional[torch.Tensor] = None,
|
||||||
a1_scale: Optional[torch.Tensor] = None,
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
a2_scale: Optional[torch.Tensor] = None,
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
block_shape: Optional[List[int]] = None) -> torch.Tensor:
|
block_shape: Optional[List[int]] = None,
|
||||||
|
allow_deep_gemm: bool = False) -> torch.Tensor:
|
||||||
|
if (allow_deep_gemm and use_fp8_w8a8
|
||||||
|
and _valid_deep_gemm(hidden_states, w1, w2, expert_map)):
|
||||||
|
return deep_gemm_moe_fp8(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
w1=w1,
|
||||||
|
w2=w2,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
activation=activation,
|
||||||
|
global_num_experts=global_num_experts,
|
||||||
|
expert_map=expert_map,
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a1_scale,
|
||||||
|
a2_scale=a2_scale,
|
||||||
|
)
|
||||||
|
else:
|
||||||
return dispatch_fused_experts_func(inplace)(
|
return dispatch_fused_experts_func(inplace)(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
w1=w1,
|
w1=w1,
|
||||||
@ -1340,6 +1419,85 @@ def fused_experts(hidden_states: torch.Tensor,
|
|||||||
block_shape=block_shape)
|
block_shape=block_shape)
|
||||||
|
|
||||||
|
|
||||||
|
def _fp8_perm(m: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
A permutation routine that works on fp8 types.
|
||||||
|
"""
|
||||||
|
if torch.is_floating_point(m) and torch.finfo(m.dtype).bits == 8:
|
||||||
|
return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype)
|
||||||
|
else:
|
||||||
|
return m[idx, ...]
|
||||||
|
|
||||||
|
|
||||||
|
def _moe_permute(
|
||||||
|
curr_hidden_states: torch.Tensor,
|
||||||
|
a1q_scale: Optional[torch.Tensor],
|
||||||
|
curr_topk_ids: torch.Tensor,
|
||||||
|
global_num_experts: int,
|
||||||
|
expert_map: Optional[torch.Tensor],
|
||||||
|
top_k_num: int,
|
||||||
|
block_m: int,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor,
|
||||||
|
torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Determine the sorted_token_ids, expert_ids for the given problem size.
|
||||||
|
Permute the hidden states and scales according to `sorted_token_ids`.
|
||||||
|
"""
|
||||||
|
tokens_in_chunk, _ = curr_hidden_states.shape
|
||||||
|
|
||||||
|
sorted_token_ids, expert_ids, num_tokens_post_padded = (
|
||||||
|
moe_align_block_size(curr_topk_ids,
|
||||||
|
block_m,
|
||||||
|
global_num_experts,
|
||||||
|
expert_map,
|
||||||
|
pad_sorted_ids=True))
|
||||||
|
|
||||||
|
inv_perm: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
num_tokens = top_k_num * tokens_in_chunk
|
||||||
|
sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1)
|
||||||
|
expert_ids = torch.repeat_interleave(expert_ids, block_m, dim=0)
|
||||||
|
inv_perm = torch.argsort(sorted_token_ids)[:num_tokens]
|
||||||
|
|
||||||
|
# Permute according to sorted token ids.
|
||||||
|
curr_hidden_states = _fp8_perm(curr_hidden_states,
|
||||||
|
sorted_token_ids // top_k_num)
|
||||||
|
|
||||||
|
if a1q_scale is not None:
|
||||||
|
a1q_scale = a1q_scale[sorted_token_ids // top_k_num]
|
||||||
|
|
||||||
|
return (curr_hidden_states, a1q_scale, sorted_token_ids, expert_ids,
|
||||||
|
inv_perm)
|
||||||
|
|
||||||
|
|
||||||
|
def _moe_unpermute_and_reduce(
|
||||||
|
out: torch.Tensor,
|
||||||
|
curr_hidden: torch.Tensor,
|
||||||
|
inv_perm: Optional[torch.Tensor],
|
||||||
|
topk: int,
|
||||||
|
K: int,
|
||||||
|
topk_weight: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Unpermute the final result and apply topk_weights, then perform the final
|
||||||
|
reduction on the hidden states.
|
||||||
|
"""
|
||||||
|
M = topk_weight.shape[0]
|
||||||
|
curr_hidden = curr_hidden[inv_perm, ...]
|
||||||
|
curr_hidden = curr_hidden.view(-1, topk, K)
|
||||||
|
curr_hidden.mul_(topk_weight.view(M, -1, 1))
|
||||||
|
ops.moe_sum(curr_hidden, out)
|
||||||
|
|
||||||
|
|
||||||
|
def _resize_cache(x: torch.Tensor, v: Tuple[int, ...]) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Shrink the given tensor and apply the given view to it. This is
|
||||||
|
used to resize the intermediate fused_moe caches.
|
||||||
|
"""
|
||||||
|
assert prod(v) <= x.numel()
|
||||||
|
return x.flatten()[:prod(v)].view(*v)
|
||||||
|
|
||||||
|
|
||||||
def fused_experts_impl(hidden_states: torch.Tensor,
|
def fused_experts_impl(hidden_states: torch.Tensor,
|
||||||
w1: torch.Tensor,
|
w1: torch.Tensor,
|
||||||
w2: torch.Tensor,
|
w2: torch.Tensor,
|
||||||
@ -1376,6 +1534,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
|||||||
|
|
||||||
num_tokens, _ = hidden_states.shape
|
num_tokens, _ = hidden_states.shape
|
||||||
E, N, _ = w1.shape
|
E, N, _ = w1.shape
|
||||||
|
K = w2.shape[1]
|
||||||
if global_num_experts == -1:
|
if global_num_experts == -1:
|
||||||
global_num_experts = E
|
global_num_experts = E
|
||||||
top_k_num = topk_ids.shape[1]
|
top_k_num = topk_ids.shape[1]
|
||||||
@ -1401,13 +1560,11 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
|||||||
|
|
||||||
# We can reuse the memory between these because by the time we need
|
# We can reuse the memory between these because by the time we need
|
||||||
# cache3, we're done with cache1
|
# cache3, we're done with cache1
|
||||||
cache13 = torch.empty(M * top_k_num * max(N, w2.shape[1]),
|
cache13 = torch.empty(M * top_k_num * max(N, K),
|
||||||
device=hidden_states.device,
|
device=hidden_states.device,
|
||||||
dtype=hidden_states.dtype)
|
dtype=hidden_states.dtype)
|
||||||
intermediate_cache1 = cache13[:M * top_k_num * N].view(
|
intermediate_cache1 = cache13[:M * top_k_num * N].view(M, top_k_num, N)
|
||||||
(M, topk_ids.shape[1], N))
|
intermediate_cache3 = cache13[:M * top_k_num * K].view(M, top_k_num, K)
|
||||||
intermediate_cache3 = cache13[:M * top_k_num * w2.shape[1]].view(
|
|
||||||
(M, topk_ids.shape[1], w2.shape[1]))
|
|
||||||
|
|
||||||
# This needs separate memory since it's used concurrently with cache1
|
# This needs separate memory since it's used concurrently with cache1
|
||||||
intermediate_cache2 = torch.empty((M * top_k_num, N // 2),
|
intermediate_cache2 = torch.empty((M * top_k_num, N // 2),
|
||||||
@ -1452,14 +1609,23 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
|||||||
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
|
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
|
||||||
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
|
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
|
||||||
|
|
||||||
|
a1q_scale: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
if use_fp8_w8a8:
|
||||||
|
qcurr_hidden_states, a1q_scale = _fp8_quantize(
|
||||||
|
curr_hidden_states, a1_scale, block_shape)
|
||||||
|
else:
|
||||||
|
qcurr_hidden_states = curr_hidden_states
|
||||||
|
a1q_scale = a1_scale
|
||||||
|
|
||||||
sorted_token_ids, expert_ids, num_tokens_post_padded = (
|
sorted_token_ids, expert_ids, num_tokens_post_padded = (
|
||||||
moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'],
|
moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'],
|
||||||
global_num_experts, expert_map))
|
global_num_experts, expert_map))
|
||||||
|
|
||||||
invoke_fused_moe_kernel(curr_hidden_states,
|
invoke_fused_moe_kernel(qcurr_hidden_states,
|
||||||
w1,
|
w1,
|
||||||
intermediate_cache1,
|
intermediate_cache1,
|
||||||
a1_scale,
|
a1q_scale,
|
||||||
w1_scale,
|
w1_scale,
|
||||||
w1_zp,
|
w1_zp,
|
||||||
curr_topk_weights,
|
curr_topk_weights,
|
||||||
@ -1485,10 +1651,19 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
|
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
|
||||||
|
|
||||||
invoke_fused_moe_kernel(intermediate_cache2,
|
a2q_scale: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
if use_fp8_w8a8:
|
||||||
|
qintermediate_cache2, a2q_scale = _fp8_quantize(
|
||||||
|
intermediate_cache2, a2_scale, block_shape)
|
||||||
|
else:
|
||||||
|
qintermediate_cache2 = intermediate_cache2
|
||||||
|
a2q_scale = a2_scale
|
||||||
|
|
||||||
|
invoke_fused_moe_kernel(qintermediate_cache2,
|
||||||
w2,
|
w2,
|
||||||
intermediate_cache3,
|
intermediate_cache3,
|
||||||
a2_scale,
|
a2q_scale,
|
||||||
w2_scale,
|
w2_scale,
|
||||||
w2_zp,
|
w2_zp,
|
||||||
curr_topk_weights,
|
curr_topk_weights,
|
||||||
@ -1617,6 +1792,193 @@ def fused_moe(
|
|||||||
block_shape=block_shape)
|
block_shape=block_shape)
|
||||||
|
|
||||||
|
|
||||||
|
def deep_gemm_moe_fp8(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
w1_scale: torch.Tensor,
|
||||||
|
w2_scale: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
inplace: bool = False,
|
||||||
|
activation: str = "silu",
|
||||||
|
global_num_experts: int = -1,
|
||||||
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
|
||||||
|
using two sets of quantized weights, w1_q and w2_q, and top-k gating
|
||||||
|
mechanism. The matrix multiplications are implemented with DeepGemm
|
||||||
|
grouped gemm.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
|
||||||
|
Shape: [M, K]
|
||||||
|
- w1 (torch.Tensor): The first set of fp8 quantized expert weights.
|
||||||
|
Shape: [num_experts, K, 2N] (the weights are passed transposed)
|
||||||
|
- w2 (torch.Tensor): The second set of fp8 quantized expert weights.
|
||||||
|
Shape: [num_experts, N, K] (the weights are passed transposed)
|
||||||
|
- w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
|
||||||
|
Shape: [num_experts] or [num_experts, 2N]
|
||||||
|
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
|
||||||
|
Shape: [num_experts] or [num_experts, K]
|
||||||
|
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
|
||||||
|
- topk_ids (torch.Tensor): The token->expert mapping for topk_weights.
|
||||||
|
- inplace (bool): If True, perform the operation in-place.
|
||||||
|
Defaults to False.
|
||||||
|
- activation (str): The activation function to apply after the first
|
||||||
|
MoE layer.
|
||||||
|
- global_num_experts (int): The total number of experts in the global
|
||||||
|
expert space.
|
||||||
|
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
|
||||||
|
from the global expert space to the local expert space of the expert
|
||||||
|
parallel shard.
|
||||||
|
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
|
||||||
|
Shape: scalar or [M]
|
||||||
|
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
|
||||||
|
quantize the intermediate result between the gemms.
|
||||||
|
Shape: scalar or [M]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- torch.Tensor: The bfloat16 output tensor after applying the MoE layer.
|
||||||
|
"""
|
||||||
|
# Lazy import to avoid CUDA initialization problems.
|
||||||
|
import deep_gemm as dg
|
||||||
|
|
||||||
|
assert expert_map is None, "Expert maps not supported yet"
|
||||||
|
|
||||||
|
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
|
||||||
|
|
||||||
|
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
||||||
|
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||||
|
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
|
||||||
|
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
|
||||||
|
assert hidden_states.dtype in [
|
||||||
|
torch.float32, torch.float16, torch.bfloat16
|
||||||
|
]
|
||||||
|
assert w1.dtype == torch.float8_e4m3fn
|
||||||
|
assert w2.dtype == torch.float8_e4m3fn
|
||||||
|
assert w1.shape[0] == w2.shape[0], "Expert number mismatch"
|
||||||
|
assert w1.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
|
||||||
|
assert w1.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch"
|
||||||
|
assert a1_scale is None or a1_scale.dim(
|
||||||
|
) == 0 or a1_scale.shape[0] == 1 or a1_scale.shape[
|
||||||
|
0] == hidden_states.shape[0], "Input scale shape mismatch"
|
||||||
|
assert a2_scale is None or a1_scale is None or a2_scale.shape == a1_scale.shape, "Intermediate scale shape mismatch" # noqa: E501
|
||||||
|
|
||||||
|
num_tokens, _ = hidden_states.shape
|
||||||
|
E, N, _ = w1.shape
|
||||||
|
K = w2.shape[1]
|
||||||
|
if global_num_experts == -1:
|
||||||
|
global_num_experts = E
|
||||||
|
top_k_num = topk_ids.shape[1]
|
||||||
|
# We execute the fused_moe kernel in chunks to circumvent this issue:
|
||||||
|
# https://github.com/vllm-project/vllm/issues/5938
|
||||||
|
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||||
|
|
||||||
|
assert _valid_deep_gemm(hidden_states, w1, w2, expert_map)
|
||||||
|
|
||||||
|
if inplace:
|
||||||
|
out_hidden_states = hidden_states
|
||||||
|
else:
|
||||||
|
out_hidden_states = torch.empty_like(hidden_states)
|
||||||
|
|
||||||
|
block_m = dg.get_m_alignment_for_contiguous_layout()
|
||||||
|
block_shape = [block_m, block_m]
|
||||||
|
|
||||||
|
assert w1_scale is not None
|
||||||
|
assert w2_scale is not None
|
||||||
|
|
||||||
|
# We attempt to transpose and align offline in Fp8MoEMethod, in which
|
||||||
|
# case these calls will be nops. Otherwise, they'll be performed every
|
||||||
|
# time the layer is executed.
|
||||||
|
w1_scale = dg.get_col_major_tma_aligned_tensor(w1_scale).contiguous()
|
||||||
|
w2_scale = dg.get_col_major_tma_aligned_tensor(w2_scale).contiguous()
|
||||||
|
|
||||||
|
M_sum = topk_ids.numel() + global_num_experts * (block_m - 1)
|
||||||
|
M_sum = round_up(M_sum, block_m)
|
||||||
|
|
||||||
|
num_chunks = (num_tokens // CHUNK_SIZE) + 1
|
||||||
|
|
||||||
|
# We can reuse the memory between cache1 and cache3 because by the time
|
||||||
|
# we need cache3, we're done with cache1
|
||||||
|
cache13 = torch.empty(M_sum * max(N, K),
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=hidden_states.dtype)
|
||||||
|
|
||||||
|
intermediate_cache1 = cache13[:M_sum * N].view(M_sum, N)
|
||||||
|
intermediate_cache2 = torch.empty((M_sum, N // 2),
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=hidden_states.dtype)
|
||||||
|
intermediate_cache3 = cache13[:M_sum * K].view(M_sum, K)
|
||||||
|
|
||||||
|
for chunk in range(num_chunks):
|
||||||
|
begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE,
|
||||||
|
min((chunk + 1) * CHUNK_SIZE,
|
||||||
|
num_tokens))
|
||||||
|
curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
|
||||||
|
tokens_in_chunk, _ = curr_hidden_states.shape
|
||||||
|
|
||||||
|
if tokens_in_chunk == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
|
||||||
|
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
|
||||||
|
|
||||||
|
a1q_scale: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
qcurr_hidden_states, a1q_scale = _fp8_quantize(curr_hidden_states,
|
||||||
|
a1_scale, block_shape)
|
||||||
|
|
||||||
|
(qcurr_hidden_states, a1q_scale, sorted_token_ids, expert_ids,
|
||||||
|
inv_perm) = _moe_permute(qcurr_hidden_states, a1q_scale,
|
||||||
|
curr_topk_ids, global_num_experts,
|
||||||
|
expert_map, top_k_num, block_m)
|
||||||
|
|
||||||
|
# Adjust the intermediate cache size and config for the last chunk.
|
||||||
|
# Note that in most cases we only have one chunk so the cache size
|
||||||
|
# and config are already set correctly and do not need to be adjusted.
|
||||||
|
if tokens_in_chunk < CHUNK_SIZE and chunk > 0:
|
||||||
|
curr_M = sorted_token_ids.numel()
|
||||||
|
intermediate_cache1 = _resize_cache(intermediate_cache1,
|
||||||
|
(curr_M, N))
|
||||||
|
intermediate_cache2 = _resize_cache(intermediate_cache2,
|
||||||
|
(curr_M, N // 2))
|
||||||
|
intermediate_cache3 = _resize_cache(intermediate_cache3,
|
||||||
|
(curr_M, K))
|
||||||
|
|
||||||
|
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||||
|
(qcurr_hidden_states, a1q_scale), (w1, w1_scale),
|
||||||
|
intermediate_cache1, expert_ids)
|
||||||
|
|
||||||
|
if activation == "silu":
|
||||||
|
torch.ops._C.silu_and_mul(intermediate_cache2,
|
||||||
|
intermediate_cache1.view(-1, N))
|
||||||
|
elif activation == "gelu":
|
||||||
|
torch.ops._C.gelu_and_mul(intermediate_cache2,
|
||||||
|
intermediate_cache1.view(-1, N))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
|
||||||
|
|
||||||
|
a2q_scale: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
qintermediate_cache2, a2q_scale = _fp8_quantize(
|
||||||
|
intermediate_cache2, a2_scale, block_shape)
|
||||||
|
|
||||||
|
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||||
|
(qintermediate_cache2, a2q_scale), (w2, w2_scale),
|
||||||
|
intermediate_cache3, expert_ids)
|
||||||
|
|
||||||
|
_moe_unpermute_and_reduce(
|
||||||
|
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
||||||
|
intermediate_cache3.view(*intermediate_cache3.shape), inv_perm,
|
||||||
|
top_k_num, K, curr_topk_weights)
|
||||||
|
|
||||||
|
return out_hidden_states
|
||||||
|
|
||||||
|
|
||||||
#TODO make the grouped gemm kernel consistent with scaled gemm kernel
|
#TODO make the grouped gemm kernel consistent with scaled gemm kernel
|
||||||
def cutlass_moe_fp8(
|
def cutlass_moe_fp8(
|
||||||
a: torch.Tensor,
|
a: torch.Tensor,
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import importlib.util
|
||||||
from typing import Any, Callable, Dict, List, Optional
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -37,6 +38,14 @@ ACTIVATION_SCHEMES = ["static", "dynamic"]
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
|
||||||
|
|
||||||
|
|
||||||
|
def _is_col_major(x: torch.Tensor) -> bool:
|
||||||
|
assert x.dim() == 3
|
||||||
|
b, m, n = x.shape
|
||||||
|
return x.stride(0) == m * n and x.stride(1) == 1 and x.stride(2) == m
|
||||||
|
|
||||||
|
|
||||||
class Fp8Config(QuantizationConfig):
|
class Fp8Config(QuantizationConfig):
|
||||||
"""Config class for FP8."""
|
"""Config class for FP8."""
|
||||||
@ -424,6 +433,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.block_quant = self.quant_config.weight_block_size is not None
|
self.block_quant = self.quant_config.weight_block_size is not None
|
||||||
|
|
||||||
|
# Check for DeepGemm support.
|
||||||
|
self.allow_deep_gemm = False
|
||||||
|
if envs.VLLM_USE_DEEP_GEMM:
|
||||||
|
if not has_deep_gemm:
|
||||||
|
logger.warning_once("Failed to import DeepGemm kernels.")
|
||||||
|
elif (current_platform.is_cuda()
|
||||||
|
and current_platform.has_device_capability(90)):
|
||||||
|
logger.info_once("Using DeepGemm kernels for Fp8MoEMethod.")
|
||||||
|
self.allow_deep_gemm = True
|
||||||
|
else:
|
||||||
|
logger.warning_once(
|
||||||
|
"DeepGemm not supported on the current platform.")
|
||||||
|
|
||||||
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
|
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
|
||||||
intermediate_size_per_partition: int,
|
intermediate_size_per_partition: int,
|
||||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||||
@ -585,6 +607,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
layer.w2_weight = torch.nn.Parameter(shuffled_w2,
|
layer.w2_weight = torch.nn.Parameter(shuffled_w2,
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
|
|
||||||
|
# DeepGemm scales need to be transposed and aligned. We try to do
|
||||||
|
# it ahead of time for performance reasons.
|
||||||
|
if self.allow_deep_gemm:
|
||||||
|
# Lazy import to avoid CUDA initialization problems.
|
||||||
|
import deep_gemm as dg
|
||||||
|
if _is_col_major(layer.w13_weight_scale_inv):
|
||||||
|
layer.w13_weight_scale_inv = \
|
||||||
|
dg.get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv).contiguous()
|
||||||
|
if _is_col_major(layer.w2_weight_scale_inv):
|
||||||
|
layer.w2_weight_scale_inv = \
|
||||||
|
dg.get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv).contiguous()
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# If checkpoint is fp16, quantize in place.
|
# If checkpoint is fp16, quantize in place.
|
||||||
@ -773,6 +808,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
a1_scale=layer.w13_input_scale,
|
a1_scale=layer.w13_input_scale,
|
||||||
a2_scale=layer.w2_input_scale,
|
a2_scale=layer.w2_input_scale,
|
||||||
block_shape=self.quant_config.weight_block_size,
|
block_shape=self.quant_config.weight_block_size,
|
||||||
|
allow_deep_gemm=self.allow_deep_gemm,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user