From af647fb8b3ea9d910f7d1bc104af8986d048a8e2 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Mon, 29 Jul 2024 22:24:58 -0400 Subject: [PATCH] [Kernel] Tuned int8 kernels for Ada Lovelace (#6848) Co-authored-by: Varun Sundar Rabindranath --- .../cutlass_w8a8/scaled_mm_c2x.cu | 27 +- ...uh => scaled_mm_c2x_sm89_fp8_dispatch.cuh} | 54 +-- .../scaled_mm_c2x_sm89_int8_dispatch.cuh | 353 ++++++++++++++++++ tests/kernels/test_cutlass.py | 4 +- 4 files changed, 395 insertions(+), 43 deletions(-) rename csrc/quantization/cutlass_w8a8/{scaled_mm_c2x_sm89_dispatch.cuh => scaled_mm_c2x_sm89_fp8_dispatch.cuh} (89%) create mode 100644 csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu index d26c43de..aac4900f 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu @@ -4,7 +4,8 @@ #include "scaled_mm_c2x.cuh" #include "scaled_mm_c2x_sm80_dispatch.cuh" -#include "scaled_mm_c2x_sm89_dispatch.cuh" +#include "scaled_mm_c2x_sm89_fp8_dispatch.cuh" +#include "scaled_mm_c2x_sm89_int8_dispatch.cuh" /* This file defines quantized GEMM operations using the CUTLASS 2.x API, for @@ -98,25 +99,17 @@ template