2025-02-12 19:51:51 -08:00
|
|
|
/*
|
|
|
|
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
|
|
|
*
|
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
* you may not use this file except in compliance with the License.
|
|
|
|
* You may obtain a copy of the License at
|
|
|
|
*
|
|
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
*
|
|
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
* See the License for the specific language governing permissions and
|
|
|
|
* limitations under the License.
|
|
|
|
*/
|
|
|
|
|
|
|
|
#include <torch/all.h>
|
|
|
|
|
|
|
|
#include <cuda_runtime_api.h>
|
|
|
|
#include <cuda_runtime.h>
|
|
|
|
|
|
|
|
#include <ATen/cuda/CUDAContext.h>
|
|
|
|
#include <c10/cuda/CUDAGuard.h>
|
|
|
|
|
|
|
|
#include <cuda_fp8.h>
|
|
|
|
|
|
|
|
#include "cuda_utils.h"
|
|
|
|
|
|
|
|
// Get type2 from type or vice versa (applied to half and bfloat16)
|
|
|
|
template <typename T>
|
|
|
|
struct TypeConverter {
|
|
|
|
using Type = half2;
|
|
|
|
}; // keep for generality
|
|
|
|
|
|
|
|
template <>
|
|
|
|
struct TypeConverter<half2> {
|
|
|
|
using Type = half;
|
|
|
|
};
|
|
|
|
|
|
|
|
template <>
|
|
|
|
struct TypeConverter<half> {
|
|
|
|
using Type = half2;
|
|
|
|
};
|
|
|
|
|
|
|
|
template <>
|
|
|
|
struct TypeConverter<__nv_bfloat162> {
|
|
|
|
using Type = __nv_bfloat16;
|
|
|
|
};
|
|
|
|
|
|
|
|
template <>
|
|
|
|
struct TypeConverter<__nv_bfloat16> {
|
|
|
|
using Type = __nv_bfloat162;
|
|
|
|
};
|
|
|
|
|
|
|
|
#define ELTS_PER_THREAD 8
|
|
|
|
|
|
|
|
constexpr int CVT_FP4_ELTS_PER_THREAD = 8;
|
|
|
|
constexpr int CVT_FP4_SF_VEC_SIZE = 16;
|
|
|
|
|
|
|
|
// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
|
|
|
|
inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) {
|
|
|
|
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
|
|
|
uint32_t val;
|
|
|
|
asm volatile(
|
|
|
|
"{\n"
|
|
|
|
".reg .b8 byte0;\n"
|
|
|
|
".reg .b8 byte1;\n"
|
|
|
|
".reg .b8 byte2;\n"
|
|
|
|
".reg .b8 byte3;\n"
|
|
|
|
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n"
|
|
|
|
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n"
|
|
|
|
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n"
|
|
|
|
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n"
|
|
|
|
"mov.b32 %0, {byte0, byte1, byte2, byte3};\n"
|
|
|
|
"}"
|
|
|
|
: "=r"(val)
|
|
|
|
: "f"(array[0]), "f"(array[1]), "f"(array[2]), "f"(array[3]),
|
|
|
|
"f"(array[4]), "f"(array[5]), "f"(array[6]), "f"(array[7]));
|
|
|
|
return val;
|
|
|
|
#else
|
|
|
|
return 0;
|
|
|
|
#endif
|
|
|
|
}
|
|
|
|
|
|
|
|
// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t).
|
|
|
|
inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) {
|
|
|
|
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
|
|
|
uint32_t val;
|
|
|
|
asm volatile(
|
|
|
|
"{\n"
|
|
|
|
".reg .b8 byte0;\n"
|
|
|
|
".reg .b8 byte1;\n"
|
|
|
|
".reg .b8 byte2;\n"
|
|
|
|
".reg .b8 byte3;\n"
|
|
|
|
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n"
|
|
|
|
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n"
|
|
|
|
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n"
|
|
|
|
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n"
|
|
|
|
"mov.b32 %0, {byte0, byte1, byte2, byte3};\n"
|
|
|
|
"}"
|
|
|
|
: "=r"(val)
|
|
|
|
: "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y),
|
|
|
|
"f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y));
|
|
|
|
return val;
|
|
|
|
#else
|
|
|
|
return 0;
|
|
|
|
#endif
|
|
|
|
}
|
|
|
|
|
|
|
|
// Fast reciprocal.
|
|
|
|
inline __device__ float reciprocal_approximate_ftz(float a) {
|
|
|
|
float b;
|
|
|
|
asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a));
|
|
|
|
return b;
|
|
|
|
}
|
|
|
|
|
|
|
|
template <class SFType, int CVT_FP4_NUM_THREADS_PER_SF>
|
|
|
|
__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(int rowIdx, int colIdx,
|
|
|
|
int numCols,
|
|
|
|
SFType* SFout) {
|
|
|
|
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
|
|
|
static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 ||
|
|
|
|
CVT_FP4_NUM_THREADS_PER_SF == 2);
|
|
|
|
|
|
|
|
// One pair of threads write one SF to global memory.
|
|
|
|
// TODO: stage through smem for packed STG.32
|
|
|
|
// is it better than STG.8 from 4 threads ?
|
|
|
|
if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) {
|
|
|
|
// SF vector index (16 elements share one SF in the K dimension).
|
|
|
|
int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF;
|
|
|
|
int32_t mIdx = rowIdx;
|
|
|
|
|
|
|
|
// SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
|
|
|
|
// --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]
|
|
|
|
|
|
|
|
int32_t mTileIdx = mIdx / (32 * 4);
|
|
|
|
// SF vector size 16.
|
|
|
|
int factor = CVT_FP4_SF_VEC_SIZE * 4;
|
|
|
|
int32_t numKTiles = (numCols + factor - 1) / factor;
|
|
|
|
int64_t mTileStride = numKTiles * 32 * 4 * 4;
|
|
|
|
|
|
|
|
int32_t kTileIdx = (kIdx / 4);
|
|
|
|
int64_t kTileStride = 32 * 4 * 4;
|
|
|
|
|
|
|
|
// M tile layout [32, 4] is column-major.
|
|
|
|
int32_t outerMIdx = (mIdx % 32);
|
|
|
|
int64_t outerMStride = 4 * 4;
|
|
|
|
|
|
|
|
int32_t innerMIdx = (mIdx % (32 * 4)) / 32;
|
|
|
|
int64_t innerMStride = 4;
|
|
|
|
|
|
|
|
int32_t innerKIdx = (kIdx % 4);
|
|
|
|
int64_t innerKStride = 1;
|
|
|
|
|
|
|
|
// Compute the global offset.
|
|
|
|
int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride +
|
|
|
|
outerMIdx * outerMStride + innerMIdx * innerMStride +
|
|
|
|
innerKIdx * innerKStride;
|
|
|
|
|
|
|
|
return reinterpret_cast<uint8_t*>(SFout) + SFOffset;
|
|
|
|
}
|
|
|
|
#endif
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Define a 16 bytes packed data type.
|
|
|
|
template <class Type>
|
|
|
|
struct PackedVec {
|
|
|
|
typename TypeConverter<Type>::Type elts[4];
|
|
|
|
};
|
|
|
|
|
|
|
|
template <>
|
|
|
|
struct PackedVec<__nv_fp8_e4m3> {
|
|
|
|
__nv_fp8x2_e4m3 elts[8];
|
|
|
|
};
|
|
|
|
|
|
|
|
// Quantizes the provided PackedVec into the uint32_t output
|
|
|
|
template <class Type, bool UE8M0_SF = false>
|
|
|
|
__device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal,
|
|
|
|
uint8_t* SFout) {
|
|
|
|
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
|
|
|
// Get absolute maximum values among the local 8 values.
|
|
|
|
auto localMax = __habs2(vec.elts[0]);
|
|
|
|
|
|
|
|
// Local maximum value.
|
|
|
|
#pragma unroll
|
|
|
|
for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
|
|
|
|
localMax = __hmax2(localMax, __habs2(vec.elts[i]));
|
|
|
|
}
|
|
|
|
|
|
|
|
// Get the absolute maximum among all 16 values (two threads).
|
|
|
|
localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax);
|
|
|
|
// Get the final absolute maximum values.
|
|
|
|
float vecMax = float(__hmax(localMax.x, localMax.y));
|
|
|
|
|
|
|
|
// Get the SF (max value of the vector / max value of e2m1).
|
|
|
|
// maximum value of e2m1 = 6.0.
|
|
|
|
// TODO: use half as compute data type.
|
|
|
|
float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f));
|
|
|
|
// 8 bits representation of the SF.
|
|
|
|
uint8_t fp8SFVal;
|
|
|
|
// Write the SF to global memory (STG.8).
|
|
|
|
if constexpr (UE8M0_SF) {
|
|
|
|
// Extract the 8 exponent bits from float32.
|
|
|
|
// float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits.
|
|
|
|
uint32_t tmp = reinterpret_cast<uint32_t&>(SFValue) >> 23;
|
|
|
|
fp8SFVal = tmp & 0xff;
|
|
|
|
// Convert back to fp32.
|
|
|
|
reinterpret_cast<uint32_t&>(SFValue) = tmp << 23;
|
|
|
|
} else {
|
|
|
|
// Here SFValue is always positive, so E4M3 is the same as UE4M3.
|
|
|
|
__nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue);
|
|
|
|
reinterpret_cast<__nv_fp8_e4m3&>(fp8SFVal) = tmp;
|
|
|
|
// Convert back to fp32.
|
|
|
|
SFValue = float(tmp);
|
|
|
|
}
|
|
|
|
// Get the output scale.
|
|
|
|
// Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) *
|
|
|
|
// reciprocal(SFScaleVal))
|
|
|
|
float outputScale =
|
|
|
|
SFValue != 0 ? reciprocal_approximate_ftz(
|
|
|
|
SFValue * reciprocal_approximate_ftz(SFScaleVal))
|
|
|
|
: 0.0f;
|
|
|
|
|
|
|
|
if (SFout) {
|
|
|
|
// Write the SF to global memory (STG.8).
|
|
|
|
*SFout = fp8SFVal;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Convert the input to float.
|
|
|
|
float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2];
|
|
|
|
|
|
|
|
#pragma unroll
|
|
|
|
for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
|
|
|
|
if constexpr (std::is_same_v<Type, half>) {
|
|
|
|
fp2Vals[i] = __half22float2(vec.elts[i]);
|
|
|
|
} else {
|
|
|
|
fp2Vals[i] = __bfloat1622float2(vec.elts[i]);
|
|
|
|
}
|
|
|
|
fp2Vals[i].x *= outputScale;
|
|
|
|
fp2Vals[i].y *= outputScale;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Convert to e2m1 values.
|
|
|
|
uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals);
|
|
|
|
|
|
|
|
// Write the e2m1 values to global memory.
|
|
|
|
return e2m1Vec;
|
|
|
|
#else
|
|
|
|
return 0;
|
|
|
|
#endif
|
|
|
|
}
|
|
|
|
|
|
|
|
// Use UE4M3 by default.
|
|
|
|
template <class Type, bool UE8M0_SF = false>
|
|
|
|
__global__ void
|
|
|
|
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
|
|
|
__launch_bounds__(512, 4) cvt_fp16_to_fp4(
|
|
|
|
#else
|
|
|
|
cvt_fp16_to_fp4(
|
|
|
|
#endif
|
|
|
|
int32_t numRows, int32_t numCols, Type const* in, float const* SFScale,
|
|
|
|
uint32_t* out, uint32_t* SFout) {
|
|
|
|
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
|
|
|
using PackedVec = PackedVec<Type>;
|
|
|
|
static constexpr int CVT_FP4_NUM_THREADS_PER_SF =
|
|
|
|
(CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
|
|
|
|
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
|
|
|
|
"Vec size is not matched.");
|
|
|
|
|
|
|
|
// Get the global scaling factor, which will be applied to the SF.
|
|
|
|
// Note SFScale is the same as next GEMM's alpha, which is
|
|
|
|
// (448.f / (Alpha_A / 6.f)).
|
|
|
|
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[0];
|
|
|
|
|
|
|
|
// Input tensor row/col loops.
|
|
|
|
for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) {
|
|
|
|
for (int colIdx = threadIdx.x; colIdx < numCols / CVT_FP4_ELTS_PER_THREAD;
|
|
|
|
colIdx += blockDim.x) {
|
|
|
|
int64_t inOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx;
|
|
|
|
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
|
|
|
|
// Get the output tensor offset.
|
|
|
|
// Same as inOffset because 8 elements are packed into one uint32_t.
|
|
|
|
int64_t outOffset = inOffset;
|
|
|
|
auto& out_pos = out[outOffset];
|
|
|
|
|
|
|
|
auto sf_out =
|
|
|
|
cvt_quant_to_fp4_get_sf_out_offset<uint32_t,
|
|
|
|
CVT_FP4_NUM_THREADS_PER_SF>(
|
|
|
|
rowIdx, colIdx, numCols, SFout);
|
|
|
|
|
|
|
|
out_pos =
|
|
|
|
cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
#endif
|
|
|
|
}
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
void invokeFP4Quantization(int m, int n, T const* input, float const* SFScale,
|
|
|
|
int64_t* output, int32_t* SFOuput, bool useUE8M0,
|
|
|
|
int multiProcessorCount, cudaStream_t stream) {
|
|
|
|
// Grid, Block size.
|
|
|
|
// Each thread converts 8 values.
|
|
|
|
dim3 block(std::min(int(n / ELTS_PER_THREAD), 512));
|
|
|
|
// Get number of blocks per SM (assume we can fully utilize the SM).
|
|
|
|
int const numBlocksPerSM = 2048 / block.x;
|
|
|
|
dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM));
|
|
|
|
|
|
|
|
// Launch the cvt kernel.
|
|
|
|
if (useUE8M0) {
|
|
|
|
cvt_fp16_to_fp4<T, true><<<grid, block, 0, stream>>>(
|
|
|
|
m, n, input, SFScale, reinterpret_cast<uint32_t*>(output),
|
|
|
|
reinterpret_cast<uint32_t*>(SFOuput));
|
|
|
|
} else {
|
|
|
|
cvt_fp16_to_fp4<T, false><<<grid, block, 0, stream>>>(
|
|
|
|
m, n, input, SFScale, reinterpret_cast<uint32_t*>(output),
|
|
|
|
reinterpret_cast<uint32_t*>(SFOuput));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Instantiate the function.
|
|
|
|
template void invokeFP4Quantization(int m, int n, half const* input,
|
|
|
|
float const* SFScale, int64_t* output,
|
|
|
|
int32_t* SFOuput, bool useUE8M0,
|
|
|
|
int multiProcessorCount,
|
|
|
|
cudaStream_t stream);
|
|
|
|
|
|
|
|
template void invokeFP4Quantization(int m, int n, __nv_bfloat16 const* input,
|
|
|
|
float const* SFScale, int64_t* output,
|
|
|
|
int32_t* SFOuput, bool useUE8M0,
|
|
|
|
int multiProcessorCount,
|
|
|
|
cudaStream_t stream);
|
|
|
|
|
|
|
|
void scaled_fp4_quant_sm100a(torch::Tensor const& output,
|
|
|
|
torch::Tensor const& input,
|
|
|
|
torch::Tensor const& output_sf,
|
|
|
|
torch::Tensor const& input_sf) {
|
|
|
|
int32_t m = input.size(0);
|
|
|
|
int32_t n = input.size(1);
|
|
|
|
|
|
|
|
TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16.");
|
|
|
|
|
|
|
|
int multiProcessorCount =
|
|
|
|
get_device_attribute(cudaDevAttrMultiProcessorCount, -1);
|
|
|
|
|
|
|
|
auto input_sf_ptr = static_cast<float const*>(input_sf.data_ptr());
|
|
|
|
auto sf_out = static_cast<int32_t*>(output_sf.data_ptr());
|
|
|
|
auto output_ptr = static_cast<int64_t*>(output.data_ptr());
|
|
|
|
at::cuda::CUDAGuard device_guard{(char)input.get_device()};
|
2025-02-20 22:01:48 -08:00
|
|
|
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
|
2025-02-12 19:51:51 -08:00
|
|
|
|
|
|
|
// We don't support e8m0 scales at this moment.
|
|
|
|
bool useUE8M0 = false;
|
|
|
|
|
|
|
|
switch (input.scalar_type()) {
|
|
|
|
case torch::kHalf: {
|
|
|
|
auto input_ptr = reinterpret_cast<half const*>(input.data_ptr());
|
|
|
|
invokeFP4Quantization(m, n, input_ptr, input_sf_ptr, output_ptr, sf_out,
|
|
|
|
useUE8M0, multiProcessorCount, stream);
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
case torch::kBFloat16: {
|
|
|
|
auto input_ptr = reinterpret_cast<__nv_bfloat16 const*>(input.data_ptr());
|
|
|
|
invokeFP4Quantization(m, n, input_ptr, input_sf_ptr, output_ptr, sf_out,
|
|
|
|
useUE8M0, multiProcessorCount, stream);
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
default: {
|
|
|
|
std::cerr << "Observing: " << input.scalar_type()
|
|
|
|
<< " for the input datatype which is invalid";
|
|
|
|
throw std::runtime_error(
|
|
|
|
"Unsupported input data type for quantize_to_fp4.");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|