[NVIDIA] Support nvfp4 cutlass gemm (#13571)
This commit is contained in:
parent
8db1b9d0a1
commit
e109e598c7
@ -229,7 +229,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
|
||||
# Set CUTLASS_REVISION manually -- its revision detection doesn't work in this case.
|
||||
# Please keep this in sync with FetchContent_Declare line below.
|
||||
set(CUTLASS_REVISION "v3.7.0" CACHE STRING "CUTLASS revision to use")
|
||||
set(CUTLASS_REVISION "v3.8.0" CACHE STRING "CUTLASS revision to use")
|
||||
|
||||
# Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided
|
||||
if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR})
|
||||
@ -267,6 +267,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
"csrc/permute_cols.cu"
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
|
||||
"csrc/quantization/fp4/nvfp4_quant_entry.cu"
|
||||
"csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu"
|
||||
"csrc/sparse/cutlass/sparse_scaled_mm_entry.cu"
|
||||
"csrc/cutlass_extensions/common.cpp")
|
||||
|
||||
@ -383,6 +384,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND FP4_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/quantization/fp4/nvfp4_quant_kernels.cu"
|
||||
"csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu"
|
||||
)
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
|
@ -152,6 +152,11 @@ torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type,
|
||||
int64_t row);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A,
|
||||
torch::Tensor const& B, torch::Tensor const& A_sf,
|
||||
torch::Tensor const& B_sf,
|
||||
torch::Tensor const& alpha);
|
||||
|
||||
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
|
||||
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability);
|
||||
|
||||
|
37
csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu
Normal file
37
csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu
Normal file
@ -0,0 +1,37 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
|
||||
void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A,
|
||||
torch::Tensor const& B,
|
||||
torch::Tensor const& A_sf,
|
||||
torch::Tensor const& B_sf,
|
||||
torch::Tensor const& alpha);
|
||||
#endif
|
||||
|
||||
void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A,
|
||||
torch::Tensor const& B, torch::Tensor const& A_sf,
|
||||
torch::Tensor const& B_sf,
|
||||
torch::Tensor const& alpha) {
|
||||
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
|
||||
return cutlass_scaled_fp4_mm_sm100a(D, A, B, A_sf, B_sf, alpha);
|
||||
#endif
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 mm kernel, vLLM should "
|
||||
"be compiled using CUDA 12.8 and target "
|
||||
"compute capability 100 or above.");
|
||||
}
|
280
csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu
Normal file
280
csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu
Normal file
@ -0,0 +1,280 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
// Kernel Perf config
|
||||
template <typename T>
|
||||
struct KernelTraits;
|
||||
|
||||
template <>
|
||||
struct KernelTraits<float> {
|
||||
using MmaTileShape = Shape<_128, _128, _256>;
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
using PerSmTileShape_MNK = Shape<_128, _128, _256>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct KernelTraits<cutlass::half_t> {
|
||||
using MmaTileShape = Shape<_256, _256, _256>;
|
||||
using ClusterShape = Shape<_4, _4, _1>;
|
||||
using PerSmTileShape_MNK = Shape<_128, _256, _256>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct KernelTraits<cutlass::bfloat16_t> {
|
||||
using MmaTileShape = Shape<_256, _256, _256>;
|
||||
using ClusterShape = Shape<_4, _4, _1>;
|
||||
using PerSmTileShape_MNK = Shape<_128, _256, _256>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct Fp4GemmSm100 {
|
||||
// A matrix configuration
|
||||
using ElementA = cutlass::nv_float4_t<cutlass::float_e2m1_t>;
|
||||
using LayoutATag = cutlass::layout::RowMajor;
|
||||
static constexpr int AlignmentA = 32;
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = cutlass::nv_float4_t<cutlass::float_e2m1_t>;
|
||||
using LayoutBTag = cutlass::layout::ColumnMajor;
|
||||
static constexpr int AlignmentB = 32;
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementD = T;
|
||||
using ElementC = T;
|
||||
using LayoutCTag = cutlass::layout::RowMajor;
|
||||
using LayoutDTag = cutlass::layout::RowMajor;
|
||||
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
// Kernel functional config
|
||||
using ElementAccumulator = float;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp;
|
||||
|
||||
// Kernel Perf config
|
||||
using MmaTileShape = typename KernelTraits<T>::MmaTileShape;
|
||||
using ClusterShape = typename KernelTraits<T>::ClusterShape;
|
||||
using PerSmTileShape_MNK = typename KernelTraits<T>::PerSmTileShape_MNK;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass, PerSmTileShape_MNK, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator,
|
||||
ElementAccumulator, ElementC, LayoutCTag, AlignmentC, ElementD,
|
||||
LayoutDTag, AlignmentD,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass, ElementA, LayoutATag, AlignmentA, ElementB,
|
||||
LayoutBTag, AlignmentB, ElementAccumulator, MmaTileShape,
|
||||
ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>;
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using LayoutA = decltype(cute::make_layout(make_shape(0, 0, 0), StrideA{}));
|
||||
using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using LayoutB = decltype(cute::make_layout(make_shape(0, 0, 0), StrideB{}));
|
||||
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB;
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using LayoutC = decltype(cute::make_layout(make_shape(0, 0, 0), StrideC{}));
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
using LayoutD = decltype(cute::make_layout(make_shape(0, 0, 0), StrideD{}));
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
typename T::Gemm::Arguments args_from_options(
|
||||
at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
|
||||
at::Tensor const& A_sf, at::Tensor const& B_sf, at::Tensor const& alpha,
|
||||
int64_t M, int64_t N, int64_t K) {
|
||||
using ElementA = typename T::Gemm::ElementA;
|
||||
using ElementB = typename T::Gemm::ElementB;
|
||||
using ElementSFA = cutlass::float_ue4m3_t;
|
||||
using ElementSFB = cutlass::float_ue4m3_t;
|
||||
using ElementD = typename T::Gemm::ElementD;
|
||||
using ElementCompute = float;
|
||||
using StrideA = typename T::StrideA;
|
||||
using StrideB = typename T::StrideB;
|
||||
using StrideD = typename T::StrideD;
|
||||
using Sm100BlkScaledConfig =
|
||||
typename T::Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig;
|
||||
|
||||
int m = static_cast<int>(M);
|
||||
int n = static_cast<int>(N);
|
||||
int k = static_cast<int>(K);
|
||||
auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, {m, k, 1});
|
||||
auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {n, k, 1});
|
||||
auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, {m, n, 1});
|
||||
|
||||
auto layout_SFA = Sm100BlkScaledConfig::tile_atom_to_shape_SFA(
|
||||
cute::make_shape(m, n, k, 1));
|
||||
auto layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(
|
||||
cute::make_shape(m, n, k, 1));
|
||||
|
||||
typename T::Gemm::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{m, n, k, 1},
|
||||
{// Mainloop arguments
|
||||
static_cast<ElementA const*>(A.data_ptr()), stride_A,
|
||||
static_cast<ElementB const*>(B.data_ptr()), stride_B,
|
||||
static_cast<ElementSFA const*>(A_sf.data_ptr()), layout_SFA,
|
||||
static_cast<ElementSFB const*>(B_sf.data_ptr()), layout_SFB},
|
||||
{ // Epilogue arguments
|
||||
{}, // epilogue.thread
|
||||
static_cast<ElementD const*>(D.data_ptr()),
|
||||
stride_D,
|
||||
static_cast<ElementD*>(D.data_ptr()),
|
||||
stride_D}};
|
||||
auto& fusion_args = arguments.epilogue.thread;
|
||||
fusion_args.alpha_ptr = static_cast<ElementCompute const*>(alpha.data_ptr());
|
||||
return arguments;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
|
||||
at::Tensor const& A_sf, at::Tensor const& B_sf,
|
||||
at::Tensor const& alpha, int64_t m, int64_t n, int64_t k,
|
||||
cudaStream_t stream) {
|
||||
typename Fp4GemmSm100<T>::Gemm gemm;
|
||||
|
||||
auto arguments =
|
||||
args_from_options<Fp4GemmSm100<T>>(D, A, B, A_sf, B_sf, alpha, m, n, k);
|
||||
|
||||
size_t workspace_size = Fp4GemmSm100<T>::Gemm::get_workspace_size(arguments);
|
||||
auto const workspace_options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(A.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.data_ptr(), stream));
|
||||
|
||||
CUTLASS_CHECK(gemm.run(arguments, workspace.data_ptr(), stream));
|
||||
}
|
||||
#else
|
||||
template <typename T>
|
||||
void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
|
||||
at::Tensor const& A_sf, at::Tensor const& B_sf,
|
||||
at::Tensor const& alpha, int64_t m, int64_t n, int64_t k,
|
||||
cudaStream_t stream) {
|
||||
TORCH_CHECK(false, "Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to "
|
||||
"a CUTLASS 3.8 source directory to enable support.");
|
||||
}
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
#define CHECK_TYPE(x, st, m) \
|
||||
TORCH_CHECK(x.scalar_type() == st, "Inconsistency of Tensor type:", m)
|
||||
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x, m) \
|
||||
TORCH_CHECK(x.is_contiguous(), m, "must be contiguous")
|
||||
#define CHECK_INPUT(x, st, m) \
|
||||
CHECK_TH_CUDA(x, m); \
|
||||
CHECK_CONTIGUOUS(x, m); \
|
||||
CHECK_TYPE(x, st, m)
|
||||
|
||||
constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte;
|
||||
constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn;
|
||||
|
||||
void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A,
|
||||
torch::Tensor const& B,
|
||||
torch::Tensor const& A_sf,
|
||||
torch::Tensor const& B_sf,
|
||||
torch::Tensor const& alpha) {
|
||||
CHECK_INPUT(A, FLOAT4_E2M1X2, "a");
|
||||
CHECK_INPUT(B, FLOAT4_E2M1X2, "b");
|
||||
|
||||
CHECK_INPUT(A_sf, SF_DTYPE, "scale_a");
|
||||
CHECK_INPUT(B_sf, SF_DTYPE, "scale_b");
|
||||
|
||||
CHECK_INPUT(alpha, at::ScalarType::Float, "alpha");
|
||||
|
||||
TORCH_CHECK(A.dim() == 2, "a must be a matrix");
|
||||
TORCH_CHECK(B.dim() == 2, "b must be a matrix");
|
||||
TORCH_CHECK(A.sizes()[1] == B.sizes()[1],
|
||||
"a and b shapes cannot be multiplied (", A.sizes()[0], "x",
|
||||
A.sizes()[1], " and ", B.sizes()[0], "x", B.sizes()[1], ")");
|
||||
|
||||
auto const m = A.sizes()[0];
|
||||
auto const n = B.sizes()[0];
|
||||
auto const k = A.sizes()[1] * 2;
|
||||
|
||||
constexpr int alignment = 32;
|
||||
TORCH_CHECK(k % alignment == 0, "Expected k to be divisible by ", alignment,
|
||||
", but got a shape: (", A.sizes()[0], "x", A.sizes()[1],
|
||||
"), k: ", k, ".");
|
||||
TORCH_CHECK(n % alignment == 0, "Expected n to be divisible by ", alignment,
|
||||
", but got b shape: (", B.sizes()[0], "x", B.sizes()[1], ").");
|
||||
|
||||
auto round_up = [](int x, int y) { return (x + y - 1) / y * y; };
|
||||
int rounded_m = round_up(m, 128);
|
||||
int rounded_n = round_up(n, 128);
|
||||
// Since k is divisible by 32 (alignment), k / 16 is guaranteed to be an
|
||||
// integer.
|
||||
int rounded_k = round_up(k / 16, 4);
|
||||
|
||||
TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix");
|
||||
TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix");
|
||||
TORCH_CHECK(A_sf.sizes()[1] == B_sf.sizes()[1],
|
||||
"scale_a and scale_b shapes cannot be multiplied (",
|
||||
A_sf.sizes()[0], "x", A_sf.sizes()[1], " and ", B_sf.sizes()[0],
|
||||
"x", B_sf.sizes()[1], ")");
|
||||
TORCH_CHECK(A_sf.sizes()[0] == rounded_m && A_sf.sizes()[1] == rounded_k,
|
||||
"scale_a must be padded and swizzled to a shape (", rounded_m,
|
||||
"x", rounded_k, "), but got a shape (", A_sf.sizes()[0], "x",
|
||||
A_sf.sizes()[1], ")");
|
||||
TORCH_CHECK(B_sf.sizes()[0] == rounded_n && B_sf.sizes()[1] == rounded_k,
|
||||
"scale_b must be padded and swizzled to a shape (", rounded_n,
|
||||
"x", rounded_k, "), but got a shape (", B_sf.sizes()[0], "x",
|
||||
B_sf.sizes()[1], ")");
|
||||
|
||||
auto out_dtype = D.dtype();
|
||||
at::cuda::CUDAGuard device_guard{(char)A.get_device()};
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device());
|
||||
|
||||
if (out_dtype == at::ScalarType::Half) {
|
||||
runGemm<cutlass::half_t>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
||||
} else if (out_dtype == at::ScalarType::BFloat16) {
|
||||
runGemm<cutlass::bfloat16_t>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
||||
} else if (out_dtype == at::ScalarType::Float) {
|
||||
runGemm<float>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm");
|
||||
}
|
||||
}
|
@ -302,6 +302,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"SymInt size_k) -> Tensor");
|
||||
// conditionally compiled so impl registration is in source file
|
||||
|
||||
// CUTLASS nvfp4 block scaled GEMM
|
||||
ops.def(
|
||||
"cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b,"
|
||||
" Tensor block_scale_a, Tensor block_scale_b,"
|
||||
" Tensor alpha) -> ()");
|
||||
ops.impl("cutlass_scaled_fp4_mm", torch::kCUDA, &cutlass_scaled_fp4_mm);
|
||||
|
||||
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
|
||||
// quantization, as well as bias
|
||||
ops.def(
|
||||
|
150
tests/kernels/test_nvfp4_scaled_mm.py
Normal file
150
tests/kernels/test_nvfp4_scaled_mm.py
Normal file
@ -0,0 +1,150 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
if not current_platform.has_device_capability(100):
|
||||
pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.",
|
||||
allow_module_level=True)
|
||||
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
# m, n, k
|
||||
SHAPES = [(128, 128, 64), (128, 128, 128), (256, 128, 64), (128, 256, 128)]
|
||||
PAD_SHAPES = [(150, 128, 64), (128, 128, 96)]
|
||||
SHAPES.extend(PAD_SHAPES)
|
||||
|
||||
SEEDS = [42]
|
||||
CUDA_DEVICES = ['cuda:0']
|
||||
|
||||
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1fn.max()
|
||||
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
|
||||
|
||||
kE2M1ToFloatArray = [
|
||||
0.,
|
||||
0.5,
|
||||
1.,
|
||||
1.5,
|
||||
2.,
|
||||
3.,
|
||||
4.,
|
||||
6.,
|
||||
]
|
||||
|
||||
|
||||
def e2m1_to_fp32(int4_value):
|
||||
signBit = (int4_value & 0x8)
|
||||
int4_absValue = int4_value & 0x7
|
||||
float_result = kE2M1ToFloatArray[int4_absValue]
|
||||
if (signBit):
|
||||
float_result = -float_result
|
||||
return float_result
|
||||
|
||||
|
||||
def break_fp4_bytes(a, dtype):
|
||||
assert (a.dtype == torch.uint8)
|
||||
m, n = a.shape
|
||||
a = a.flatten()
|
||||
# Get upper 4 bits
|
||||
highHalfByte = (a & 0xF0) >> 4
|
||||
# Get lower 4 bits
|
||||
lowHalfByte = a & 0x0F
|
||||
fH = torch.tensor([e2m1_to_fp32(x) for x in highHalfByte]).to(a.device)
|
||||
fL = torch.tensor([e2m1_to_fp32(x) for x in lowHalfByte]).to(a.device)
|
||||
# [0xAB, 0xCD] -> [0xB, 0xA, 0xD, 0xC]
|
||||
out = torch.stack((fL, fH), dim=-1).reshape(m, n * 2)
|
||||
return out
|
||||
|
||||
|
||||
def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size):
|
||||
sf_m, sf_k = a_sf_swizzled.shape
|
||||
m_tiles = (m + 128 - 1) // 128
|
||||
f = block_size * 4
|
||||
k_tiles = (k + f - 1) // f
|
||||
tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 32, 4, 4))
|
||||
tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5))
|
||||
out = tmp.reshape(m_tiles * 128, k_tiles * f // block_size)
|
||||
return out[0:m, 0:k]
|
||||
|
||||
|
||||
def dequantize_to_dtype(tensor_fp4,
|
||||
tensor_sf,
|
||||
global_scale,
|
||||
dtype,
|
||||
device,
|
||||
block_size=16):
|
||||
"""Dequantize the fp4 tensor back to high precision."""
|
||||
# Two fp4 values are packed into one uint8.
|
||||
assert tensor_fp4.dtype == torch.uint8
|
||||
m, packed_k = tensor_fp4.shape
|
||||
k = packed_k * 2
|
||||
tensor_f32 = break_fp4_bytes(tensor_fp4, dtype)
|
||||
tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size)
|
||||
tensor_sf = tensor_sf.view(torch.float8_e4m3fn)
|
||||
tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size)
|
||||
tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale
|
||||
|
||||
# scale the tensor
|
||||
out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k)
|
||||
return out
|
||||
|
||||
|
||||
def get_ref_results(a_fp4, b_fp4, a_sf, b_sf, a_global_scale, b_global_scale,
|
||||
m, n, dtype, block_size, device):
|
||||
_, m_k = a_fp4.shape
|
||||
_, n_k = b_fp4.shape
|
||||
assert (m_k == n_k)
|
||||
a_in_dtype = dequantize_to_dtype(a_fp4,
|
||||
a_sf,
|
||||
a_global_scale,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
block_size=block_size)
|
||||
b_in_dtype = dequantize_to_dtype(b_fp4,
|
||||
b_sf,
|
||||
b_global_scale,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
block_size=block_size)
|
||||
return torch.matmul(a_in_dtype, b_in_dtype.t())
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("shape", SHAPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_nvfp4_gemm(
|
||||
dtype: torch.dtype,
|
||||
shape: tuple[int, int, int],
|
||||
seed: int,
|
||||
device: str,
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
m, n, packed_k = shape
|
||||
k = packed_k * 2
|
||||
block_size = 16
|
||||
a_dtype = torch.randn((m, k), dtype=dtype, device=device)
|
||||
b_dtype = torch.randn((n, k), dtype=dtype, device=device)
|
||||
|
||||
a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
|
||||
torch.amax(a_dtype.flatten(), dim=-1)).to(torch.float32)
|
||||
b_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
|
||||
torch.amax(b_dtype.flatten(), dim=-1)).to(torch.float32)
|
||||
alpha = 1. / (a_global_scale * b_global_scale)
|
||||
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a_dtype, a_global_scale)
|
||||
b_fp4, b_scale_interleaved = ops.scaled_fp4_quant(b_dtype, b_global_scale)
|
||||
|
||||
expected_out = get_ref_results(a_fp4, b_fp4, a_scale_interleaved,
|
||||
b_scale_interleaved, a_global_scale,
|
||||
b_global_scale, m, n, dtype, block_size,
|
||||
device)
|
||||
out = ops.cutlass_scaled_fp4_mm(a_fp4, b_fp4, a_scale_interleaved,
|
||||
b_scale_interleaved, alpha, dtype)
|
||||
|
||||
torch.testing.assert_close(out,
|
||||
expected_out.to(dtype=dtype),
|
||||
atol=1e-1,
|
||||
rtol=1e-1)
|
@ -433,6 +433,18 @@ if hasattr(torch.ops._C, "ggml_dequantize"):
|
||||
|
||||
|
||||
# cutlass
|
||||
def cutlass_scaled_fp4_mm(a: torch.Tensor, b: torch.Tensor,
|
||||
block_scale_a: torch.Tensor,
|
||||
block_scale_b: torch.Tensor, alpha: torch.Tensor,
|
||||
out_dtype: torch.dtype) -> torch.Tensor:
|
||||
assert a.ndim == 2 and b.ndim == 2
|
||||
m, n = a.shape[0], b.shape[0]
|
||||
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
|
||||
torch.ops._C.cutlass_scaled_fp4_mm(out, a, b, block_scale_a, block_scale_b,
|
||||
alpha)
|
||||
return out
|
||||
|
||||
|
||||
def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
|
||||
return torch.ops._C.cutlass_scaled_mm_supports_fp8(cuda_device_capability)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user