[ROCm][Quantization][Kernel] Using HIP FP8 header (#12593)

This commit is contained in:
Gregory Shtrasberg 2025-02-25 03:39:59 -05:00 committed by GitHub
parent 2f42a4888c
commit aabeb2688f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 267 additions and 634 deletions

View File

@ -174,6 +174,25 @@ include(FetchContent)
file(MAKE_DIRECTORY ${FETCHCONTENT_BASE_DIR}) # Ensure the directory exists
message(STATUS "FetchContent base directory: ${FETCHCONTENT_BASE_DIR}")
#
# Set rocm version dev int.
#
if(VLLM_GPU_LANG STREQUAL "HIP")
#
# Overriding the default -O set up by cmake, adding ggdb3 for the most verbose devug info
#
set(CMAKE_${VLLM_GPU_LANG}_FLAGS_DEBUG "${CMAKE_${VLLM_GPU_LANG}_FLAGS_DEBUG} -O0 -ggdb3")
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -O0 -ggdb3")
#
# Certain HIP functions are marked as [[nodiscard]], yet vllm ignores the result which generates
# a lot of warnings that always mask real issues. Suppressing until this is properly addressed.
#
set(CMAKE_${VLLM_GPU_LANG}_FLAGS "${CMAKE_${VLLM_GPU_LANG}_FLAGS} -Wno-unused-result")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-result")
endif()
#
# Define other extension targets
#

View File

@ -1,137 +0,0 @@
#pragma once
#ifdef __HIPCC__
#include <hip/hip_runtime.h>
#else
#include <type_traits>
#include <stdint.h>
#include <math.h>
#include <iostream>
#endif
#include "hip_float8_impl.h"
struct alignas(1) hip_fp8 {
struct from_bits_t {};
HIP_FP8_HOST_DEVICE static constexpr from_bits_t from_bits() {
return from_bits_t();
}
uint8_t data;
hip_fp8() = default;
HIP_FP8_HOST_DEVICE constexpr hip_fp8(const hip_fp8&) = default;
HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v) = delete;
explicit HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v, from_bits_t)
: data(v) {}
#ifdef __HIP__MI300__
// NOTE: ON-DEVICE... always optimal bias
explicit HIP_FP8_DEVICE hip_fp8(float v)
: data(hip_fp8_impl::to_fp8_from_fp32(v)) {}
explicit HIP_FP8_DEVICE hip_fp8(_Float16 v)
: hip_fp8(static_cast<float>(v)) {}
// Host only implementation using s/w simulation
explicit HIP_FP8_HOST
#else // __HIP__MI300__
// both Host and DEVICE for non-MI300 using s/w simulation
explicit HIP_FP8_HOST_DEVICE
#endif // __HIP__MI300__
hip_fp8(float v) {
data = hip_fp8_impl::to_float8<4, 3, float, true /*negative_zero_nan*/,
true /*clip*/>(v);
}
explicit HIP_FP8_HOST_DEVICE hip_fp8(double v)
: hip_fp8(static_cast<float>(v)) {}
#ifdef __HIP__MI300__
// upcast using device specific intrinsic
explicit inline HIP_FP8_DEVICE operator float() const {
float fval;
uint32_t i32val = static_cast<uint32_t>(data);
// upcast
asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0"
: "=v"(fval)
: "v"(i32val));
return fval;
}
explicit inline HIP_FP8_HOST operator float() const
#else // __HIP__MI300__
explicit inline HIP_FP8_HOST_DEVICE operator float() const
#endif // __HIP__MI300__
{
return hip_fp8_impl::from_float8<4, 3, float, true /*negative_zero_nan*/>(
data);
}
};
namespace std {
inline hip_fp8 sin(hip_fp8 a) { return hip_fp8(sinf(float(a))); }
inline hip_fp8 cos(hip_fp8 a) { return hip_fp8(cosf(float(a))); }
HIP_FP8_HOST_DEVICE constexpr hip_fp8 real(const hip_fp8& a) { return a; }
} // namespace std
// Special operator overloading
inline std::ostream& operator<<(std::ostream& os, const hip_fp8& f8) {
return os << float(f8);
}
// all + operator overloading with mixed types
// mixed types, always converts to f32, does computation in f32, and returns
// float
inline HIP_FP8_HOST_DEVICE float operator+(const float fa, hip_fp8 b) {
return (fa + float(b));
}
inline HIP_FP8_HOST_DEVICE float operator+(hip_fp8 a, const float fb) {
return (float(a) + fb);
}
inline HIP_FP8_HOST_DEVICE hip_fp8 operator+(hip_fp8 a, hip_fp8 b) {
return hip_fp8(float(a) + float(b));
}
inline HIP_FP8_HOST_DEVICE hip_fp8& operator+=(hip_fp8& a, hip_fp8 b) {
return a = hip_fp8(float(a) + float(b));
}
// overloading multiplication, always returns float,
inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, hip_fp8 b) {
return float(a) * float(b);
}
inline HIP_FP8_HOST_DEVICE float operator*(float a, hip_fp8 b) {
return (a * float(b));
}
inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, float b) {
return (float(a) * b);
}
inline HIP_FP8_HOST_DEVICE float operator*(int32_t a, hip_fp8 b) {
return ((float)a * float(b));
}
inline HIP_FP8_HOST_DEVICE float operator*(double a, hip_fp8 b) {
return ((float)a * float(b));
}
// overloading for compare
inline HIP_FP8_HOST_DEVICE bool operator==(hip_fp8 a, hip_fp8 b) {
return (a.data == b.data);
}
inline HIP_FP8_HOST_DEVICE bool operator!=(hip_fp8 a, hip_fp8 b) {
return (a.data != b.data);
}
inline HIP_FP8_HOST_DEVICE bool operator>=(hip_fp8 a, hip_fp8 b) {
return static_cast<float>(a) >= static_cast<float>(b);
}
inline HIP_FP8_HOST_DEVICE bool operator>(hip_fp8 a, hip_fp8 b) {
return static_cast<float>(a) > static_cast<float>(b);
}

View File

@ -1,315 +0,0 @@
#pragma once
#if defined(__HIPCC__) && defined(__gfx942__)
#define __HIP__MI300__
#endif
#ifdef __HIPCC__
#define HIP_FP8_HOST_DEVICE __host__ __device__
#define HIP_FP8_HOST __host__
#define HIP_FP8_DEVICE __device__
#else
#define HIP_FP8_HOST_DEVICE
#define HIP_FP8_HOST
#define HIP_FP8_DEVICE
#endif
namespace hip_fp8_impl {
#ifdef __HIP__MI300__
HIP_FP8_DEVICE uint8_t to_fp8_from_fp32(float v) {
uint8_t i8data;
union {
float fval;
uint32_t i32val;
uint8_t i8val[4]; // NOTE: not endian independent
} val;
uint32_t ival = 0;
val.fval = v;
if ((val.i32val & 0x7F800000) !=
0x7F800000) { /// propagate NAN/INF, no clipping
val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0);
}
ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival,
false); // false -> WORD0
val.i32val = ival;
i8data = val.i8val[0];
return i8data;
}
#endif // __HIP__MI300__
HIP_FP8_HOST inline int clz(uint32_t x) { return __builtin_clz(x); }
#if defined(__HIPCC__) || defined(__CUDA_ARCH__)
HIP_FP8_DEVICE inline int clz(uint32_t x) { return __clz(x); }
#endif
template <int we, int wm, typename T, bool negative_zero_nan, bool clip>
HIP_FP8_HOST_DEVICE uint8_t to_float8(T _x, bool stoch = false,
uint32_t rng = 0) {
#ifdef __HIPCC__
constexpr bool is_half = std::is_same<T, _Float16>::value;
#else
constexpr bool is_half = false;
#endif
constexpr bool is_float = std::is_same<T, float>::value;
static_assert(wm + we == 7, "wm+we==7");
static_assert(is_half || is_float, "Only half and float can be cast to f8");
const int mfmt = (sizeof(T) == 4) ? 23 : 10;
uint32_t x;
if (sizeof(T) == 4) {
x = reinterpret_cast<uint32_t&>(_x);
} else {
x = reinterpret_cast<uint16_t&>(_x);
}
uint32_t head, mantissa;
int exponent, bias;
uint32_t sign;
if (sizeof(T) == 4) {
head = x & 0xFF800000;
mantissa = x & 0x7FFFFF;
exponent = (head >> 23) & 0xFF;
sign = head >> 31;
bias = 127;
} else {
head = x & 0xFC00;
mantissa = x & 0x3FF;
exponent = (head >> 10) & 0x1F;
sign = head >> 15;
bias = 15;
}
uint32_t signed_inf = (sign << 7) + (((1 << we) - 1) << wm);
// Deal with inf and NaNs
if (negative_zero_nan) {
if (sizeof(T) == 4) {
if ((x & 0x7F800000) == 0x7F800000) {
return 0x80;
}
} else {
// if(__hisinf(x) || __hisnan(x))
if ((x & 0x7C00) == 0x7C00) {
return 0x80;
}
}
} else {
if (sizeof(T) == 4) {
if ((x & 0x7F800000) == 0x7F800000) {
return signed_inf + (mantissa != 0 ? 1 : 0);
}
} else {
if ((x & 0x7C00) == 0x7C00) {
return signed_inf + (mantissa != 0 ? 1 : 0);
}
}
}
if (x == 0) {
return 0;
}
// First need to check if it is normal or denorm as there is a difference of
// implicit 1 Then need to adjust the exponent to align with the F8 exponent,
// in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng
// to mantissa and truncate. And for RNE, no need to add rng. Then probably
// need to check whether there is carry and adjust exponent and mantissa again
// For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent
// bits
const int f8_bias = (1 << (we - 1)) - 1 + (negative_zero_nan ? 1 : 0);
const int f8_denormal_act_exponent =
1 - f8_bias; // actual exponent of f8 denormal
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
// f8_exponent is the converted f8 exponent with bias encoding
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
// the difference needs to be adjusted and mantissa shifted
int act_exponent, f8_exponent, exponent_diff;
if (exponent == 0) { // fp32/fp16 is in denormal.
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we
mostly concern fp16 here. In this case, f8 is usually in denormal. But there
could be exceptions. fp16 denormal has exponent bias 15 while bf8 with NANOO has
exponent bias 16. It means that there are some numbers in fp16 denormal but they
are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8
(NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */
act_exponent = exponent - bias + 1;
exponent_diff =
f8_denormal_act_exponent -
act_exponent; // actual exponent is exponent-bias+1 as it is denormal
} else { // fp32/fp16 is normal with implicit 1
act_exponent = exponent - bias;
if (act_exponent <= f8_denormal_act_exponent) {
/* This is the case where fp32/fp16 is normal but it is in f8 denormal
range. For example fp8 nanoo mode, denormal exponent is -7, but if the
fp32/fp16 actual exponent is -7, it is actually larger due to the implicit 1,
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
exponent_diff = f8_denormal_act_exponent - act_exponent;
} else { // both fp32/fp16 and f8 are in normal range
exponent_diff = 0; // exponent_diff=0 does not mean there is no
// difference for this case, act_exponent could be
// larger. Just that it does not need shift mantissa
}
mantissa += (1 << mfmt); // Add the implicit 1 into mantissa
}
bool midpoint = (mantissa & ((1 << (mfmt - wm + exponent_diff)) - 1)) ==
static_cast<uint32_t>(1 << (mfmt - wm + exponent_diff - 1));
/* This part is a bit tricky. The judgment of whether it is a tie needs to be
done before we shift right as shift right could rip off some residual part
and make something not midpoint look like midpoint. For example, the fp16
number 0x1002 (0 00100 0000000010), it is larger than midpoint, but after
shift right by 4 bits, it would look like midpoint.
*/
if (exponent_diff > 0) {
mantissa >>= exponent_diff;
} else if (exponent_diff == -1) {
mantissa <<= -exponent_diff;
}
bool implicit_one = mantissa & (1 << mfmt);
// if there is no implicit 1, it means the f8 is denormal and need to adjust
// to denorm exponent
f8_exponent = (act_exponent + exponent_diff) /*actual f8 exponent*/ +
f8_bias - (implicit_one ? 0 : 1);
// Now we have the exponent and mantissa adjusted
uint32_t drop_mask = (1 << (mfmt - wm)) - 1;
bool odd = mantissa & (1 << (mfmt - wm)); // if the least significant bit
// that is not truncated is 1
mantissa +=
(stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) &
drop_mask;
// Now we deal with overflow
if (f8_exponent == 0) {
if ((1 << mfmt) & mantissa) {
f8_exponent = 1; // denormal overflow to become normal, promote exponent
}
} else {
if ((1 << (mfmt + 1)) & mantissa) {
mantissa >>= 1;
f8_exponent++;
}
}
mantissa >>= (mfmt - wm);
// above range: quantize to maximum possible float of the same sign
const int max_exp = (1 << we) - (negative_zero_nan ? 1 : 2);
if (f8_exponent > max_exp) {
if (clip) {
mantissa = (1 << wm) - 1;
f8_exponent = max_exp;
} else {
return signed_inf;
}
}
if (f8_exponent == 0 && mantissa == 0) {
return negative_zero_nan ? 0 : (sign << 7);
}
mantissa &= (1 << wm) - 1;
return (sign << 7) | (f8_exponent << wm) | mantissa;
}
template <int we, int wm, typename T = float, bool negative_zero_nan = true>
inline HIP_FP8_HOST_DEVICE T from_float8(uint8_t x) {
#ifdef __HIPCC__
constexpr bool is_half = std::is_same<T, _Float16>::value;
#else
constexpr bool is_half = false;
#endif
constexpr bool is_float = std::is_same<T, float>::value;
static_assert(is_half || is_float, "only half and float are supported");
constexpr int weo = is_half ? 5 : 8;
constexpr int wmo = is_half ? 10 : (is_float ? 23 : 7);
T fInf, fNegInf, fNaN, fNeg0;
#ifdef __HIPCC__
if (is_half) {
const uint16_t ihInf = 0x7C00;
const uint16_t ihNegInf = 0xFC00;
const uint16_t ihNaN = 0x7C01;
const uint16_t ihNeg0 = 0x8000;
fInf = reinterpret_cast<const _Float16&>(ihInf);
fNegInf = reinterpret_cast<const _Float16&>(ihNegInf);
fNaN = reinterpret_cast<const _Float16&>(ihNaN);
fNeg0 = reinterpret_cast<const _Float16&>(ihNeg0);
} else
#endif
if (is_float) {
const uint32_t ifInf = 0x7F800000;
const uint32_t ifNegInf = 0xFF800000;
const uint32_t ifNaN = 0x7F800001;
const uint32_t ifNeg0 = 0x80000000;
fInf = reinterpret_cast<const float&>(ifInf);
fNegInf = reinterpret_cast<const float&>(ifNegInf);
fNaN = reinterpret_cast<const float&>(ifNaN);
fNeg0 = reinterpret_cast<const float&>(ifNeg0);
}
if (x == 0) {
return 0;
}
uint32_t sign = x >> 7;
uint32_t mantissa = x & ((1 << wm) - 1);
int exponent = (x & 0x7F) >> wm;
if (negative_zero_nan) {
if (x == 0x80) {
return fNaN;
}
} else {
if (x == 0x80) {
return fNeg0;
}
if (exponent == ((1 << we) - 1)) {
return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN;
}
}
typename std::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type retval;
if (we == 5 && is_half && !negative_zero_nan) {
retval = x << 8;
return reinterpret_cast<const T&>(retval);
}
const int exp_low_cutoff =
(1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0);
// subnormal input
if (exponent == 0) {
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
int sh = 1 + clz(mantissa) - (32 - wm);
mantissa <<= sh;
exponent += 1 - sh;
mantissa &= ((1 << wm) - 1);
}
exponent += exp_low_cutoff - 1;
mantissa <<= wmo - wm;
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
if (exponent <= 0) {
mantissa |= 1 << wmo;
mantissa >>= 1 - exponent;
exponent = 0;
}
if (sizeof(T) == 2) {
retval = (sign << 15) | (exponent << 10) | mantissa;
} else {
retval = (sign << 31) | (exponent << 23) | mantissa;
}
return reinterpret_cast<const T&>(retval);
}
} // namespace hip_fp8_impl

View File

@ -1,13 +1,11 @@
#pragma once
#include "hip_float8.h"
#include <hip/hip_fp8.h>
#include <hip/hip_fp16.h>
#include <hip/hip_bf16.h>
#include <hip/hip_bfloat16.h>
#include "../../../attention/dtype_fp8.cuh"
#include "../../../attention/dtype_float32.cuh"
#include "../../../attention/dtype_bfloat16.cuh"
#include "../../../attention/attention_dtypes.h"
namespace vllm {
#ifdef USE_ROCM
@ -26,40 +24,31 @@ __inline__ __device__ Tout scaled_vec_conversion(const Tin& x,
return x;
}
#if HIP_FP8_TYPE_FNUZ
using fp8_type = __hip_fp8_e4m3_fnuz;
using fp8x2_type = __hip_fp8x2_e4m3_fnuz;
#elif HIP_FP8_TYPE_OCP
using fp8_type = __hip_fp8_e4m3;
using fp8x2_type = __hip_fp8x2_e4m3;
#endif
// fp8 -> half
template <>
__inline__ __device__ uint16_t
vec_conversion<uint16_t, uint8_t>(const uint8_t& a) {
hip_fp8 f8{a, hip_fp8::from_bits()};
__half_raw res;
res.data = static_cast<float>(f8);
return res.x;
return __hip_cvt_fp8_to_halfraw(a, fp8_type::__default_interpret).x;
}
// fp8x2 -> half2
template <>
__inline__ __device__ uint32_t
vec_conversion<uint32_t, uint16_t>(const uint16_t& a) {
#if defined(__HIP__MI300__) && \
defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
union {
__half2_raw h2r;
uint32_t ui32;
} tmp;
tmp.h2r.x.data = f2[0];
tmp.h2r.y.data = f2[1];
tmp.h2r = __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
return tmp.ui32;
#else
union {
uint16_t u16[2];
uint32_t u32;
} tmp;
tmp.u16[0] = vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a));
tmp.u16[1] = vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a >> 8U));
return tmp.u32;
#endif
}
// fp8x4 -> half2x2
@ -92,9 +81,9 @@ using __nv_bfloat16 = __hip_bfloat16;
template <>
__inline__ __device__ __nv_bfloat16
vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) {
hip_fp8 f8{a, hip_fp8::from_bits()};
float f{f8};
return __float2bfloat16(f);
fp8_type f8;
f8.__x = a;
return __float2bfloat16(static_cast<float>(f8));
}
using __nv_bfloat162 = __hip_bfloat162;
@ -136,27 +125,18 @@ __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a) {
// fp8 -> float
template <>
__inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a) {
hip_fp8 fp8{a, hip_fp8::from_bits()};
return static_cast<float>(fp8);
fp8_type f8;
f8.__x = a;
return static_cast<float>(f8);
}
// fp8x2 -> float2
template <>
__inline__ __device__ float2
vec_conversion<float2, uint16_t>(const uint16_t& a) {
#if defined(__HIP__MI300__) && \
defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
float2 res;
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
res.x = f2[0];
res.y = f2[1];
return res;
#else
float2 res;
res.x = vec_conversion<float, uint8_t>(static_cast<uint8_t>(a));
res.y = vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U));
return res;
#endif
fp8x2_type f8x2;
f8x2.__x = a;
return static_cast<float2>(f8x2);
}
// fp8x4 -> float4
@ -169,6 +149,15 @@ vec_conversion<Float4_, uint32_t>(const uint32_t& a) {
return res;
}
// fp8x4 -> float4
template <>
__inline__ __device__ float4
vec_conversion<float4, uint32_t>(const uint32_t& a) {
Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
return res;
}
// fp8x8 -> float8
template <>
__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a) {
@ -189,33 +178,36 @@ __inline__ __device__ uint8_t
vec_conversion<uint8_t, uint16_t>(const uint16_t& a) {
__half_raw tmp;
tmp.x = a;
return __hip_cvt_halfraw_to_fp8(tmp, fp8_type::__default_saturation,
fp8_type::__default_interpret);
}
hip_fp8 f8{static_cast<float>(tmp.data)};
return f8.data;
template <>
__inline__ __device__ uint16_t
vec_conversion<uint16_t, uint32_t>(const uint32_t& a) {
union {
uint32_t ui32;
__half2_raw h2r;
} tmp;
tmp.ui32 = a;
return __hip_cvt_halfraw2_to_fp8x2(tmp.h2r, fp8_type::__default_saturation,
fp8_type::__default_interpret);
}
// bf16 -> fp8
template <>
__inline__ __device__ uint8_t
vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a) {
hip_fp8 res{__bfloat162float(a)};
return res.data;
return __hip_cvt_float_to_fp8(__bfloat162float(a),
fp8_type::__default_saturation,
fp8_type::__default_interpret);
}
// float -> fp8
template <>
__inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a) {
hip_fp8 f8(a);
return f8.data;
}
// fp8x4 -> float4
template <>
__inline__ __device__ float4
vec_conversion<float4, uint32_t>(const uint32_t& a) {
Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
return res;
return __hip_cvt_float_to_fp8(a, fp8_type::__default_saturation,
fp8_type::__default_interpret);
}
// float2 -> half2
@ -307,90 +299,22 @@ vec_conversion<bf16_8_t, Float8_>(const Float8_& a) {
*/
// fp8 -> half
template <>
__inline__ __device__ uint16_t
scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, const float scale) {
hip_fp8 f8{a, hip_fp8::from_bits()};
__half_raw res;
res.data = static_cast<float>(f8) * scale;
return res.x;
}
// fp8x2 -> half2
template <>
__inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>(
const uint16_t& a, const float scale) {
#if defined(__HIP__MI300__) && \
defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
union {
__half2_raw h2r;
uint32_t ui32;
} tmp;
tmp.h2r.x.data = f2[0] * scale;
tmp.h2r.y.data = f2[1] * scale;
return tmp.ui32;
#else
union {
uint16_t u16[2];
uint32_t u32;
} tmp;
tmp.u16[0] =
scaled_vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a), scale);
tmp.u16[1] = scaled_vec_conversion<uint16_t, uint8_t>(
static_cast<uint8_t>(a >> 8U), scale);
return tmp.u32;
#endif
}
// fp8x4 -> half2x2
template <>
__inline__ __device__ uint2
scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, const float scale) {
union {
uint2 u32x2;
uint32_t u32[2];
} tmp;
tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale);
tmp.u32[1] =
scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale);
return tmp.u32x2;
}
// fp8x8 -> half2x4
template <>
__inline__ __device__ uint4
scaled_vec_conversion<uint4, uint2>(const uint2& a, const float scale) {
union {
uint4 u64x2;
uint2 u64[2];
} tmp;
tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale);
tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale);
return tmp.u64x2;
}
using __nv_bfloat16 = __hip_bfloat16;
// fp8 -> __nv_bfloat16
template <>
__inline__ __device__ __nv_bfloat16
scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a,
const float scale) {
hip_fp8 f8{a, hip_fp8::from_bits()};
float f{f8};
return __float2bfloat16(f * scale);
scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, float scale) {
fp8_type f8;
f8.__x = a;
return __float2bfloat16(static_cast<float>(f8) * scale);
}
using __nv_bfloat162 = __hip_bfloat162;
// fp8x2 -> __nv_bfloat162
template <>
__inline__ __device__ __nv_bfloat162
scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a,
const float scale) {
float scale) {
__nv_bfloat162 res;
res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale);
res.y =
@ -400,8 +324,8 @@ scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a,
// fp8x4 -> bf16_4_t
template <>
__inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(
const uint32_t& a, const float scale) {
__inline__ __device__ bf16_4_t
scaled_vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a, float scale) {
bf16_4_t res;
res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale);
res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U),
@ -412,7 +336,7 @@ __inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(
// fp8x8 -> bf16_8_t
template <>
__inline__ __device__ bf16_8_t
scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, const float scale) {
scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, float scale) {
bf16_4_t tmp1, tmp2;
tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale);
tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale);
@ -427,29 +351,19 @@ scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, const float scale) {
// fp8 -> float
template <>
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
const uint8_t& a, const float scale) {
hip_fp8 fp8{a, hip_fp8::from_bits()};
return static_cast<float>(fp8) * scale;
const uint8_t& a, float scale) {
fp8_type f8;
f8.__x = a;
return static_cast<float>(f8) * scale;
}
// fp8x2 -> float2
template <>
__inline__ __device__ float2
scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, const float scale) {
#if defined(__HIP__MI300__) && \
defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
float2 res;
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
res.x = f2[0] * scale;
res.y = f2[1] * scale;
return res;
#else
float2 res;
res.x = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a), scale);
res.y = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U),
scale);
return res;
#endif
scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, float scale) {
fp8x2_type f8x2;
f8x2.__x = a;
return static_cast<float2>(f8x2) * scale;
}
// fp8x4 -> float4
@ -462,10 +376,18 @@ scaled_vec_conversion<Float4_, uint32_t>(const uint32_t& a, const float scale) {
return res;
}
// fp8x4 -> float4
template <>
__inline__ __device__ float4
scaled_vec_conversion<float4, uint32_t>(const uint32_t& a, float scale) {
Float4_ res = scaled_vec_conversion<Float4_, uint32_t>(a, scale);
return {res.x.x, res.x.y, res.y.x, res.y.y};
}
// fp8x8 -> float8
template <>
__inline__ __device__ Float8_
scaled_vec_conversion<Float8_, uint2>(const uint2& a, const float scale) {
scaled_vec_conversion<Float8_, uint2>(const uint2& a, float scale) {
Float4_ tmp1, tmp2;
tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale);
tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale);
@ -477,44 +399,184 @@ scaled_vec_conversion<Float8_, uint2>(const uint2& a, const float scale) {
return res;
}
/* Quantize(HP / scale) => FP8 */
// fp8 -> half
template <>
__inline__ __device__ uint16_t
scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, float scale) {
__half_raw res;
res.data = scaled_vec_conversion<float, uint8_t>(a, scale);
return res.x;
}
// TODO(Hai): vectorized to add
// fp8x2 -> half2
template <>
__inline__ __device__ uint32_t
scaled_vec_conversion<uint32_t, uint16_t>(const uint16_t& a, float scale) {
__half2_raw h2r =
__hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
union {
__half2_raw h2r;
uint32_t ui32;
} tmp;
tmp.h2r = __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
tmp.h2r.x.data *= scale;
tmp.h2r.y.data *= scale;
return tmp.ui32;
}
// fp8x4 -> half2x2
template <>
__inline__ __device__ uint2
scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, float scale) {
union {
uint2 u32x2;
uint32_t u32[2];
} tmp;
tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale);
tmp.u32[1] =
scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale);
return tmp.u32x2;
}
// fp8x8 -> half2x4
template <>
__inline__ __device__ uint4 scaled_vec_conversion<uint4, uint2>(const uint2& a,
float scale) {
union {
uint4 u64x2;
uint2 u64[2];
} tmp;
tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale);
tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale);
return tmp.u64x2;
}
// half -> fp8
template <>
__inline__ __device__ uint8_t
scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, const float scale) {
scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, float scale) {
__half_raw tmp;
tmp.x = a;
tmp.data /= scale;
return __hip_cvt_halfraw_to_fp8(tmp, fp8_type::__default_saturation,
fp8_type::__default_interpret);
}
hip_fp8 f8{static_cast<float>(tmp.data) / scale};
return f8.data;
// halfx2 -> fp8x2
template <>
__inline__ __device__ uint16_t
scaled_vec_conversion<uint16_t, uint32_t>(const uint32_t& a, float scale) {
union {
uint32_t ui32;
__half2_raw h2r;
} tmp;
tmp.ui32 = a;
tmp.h2r.x.data /= scale;
tmp.h2r.y.data /= scale;
return __hip_cvt_halfraw2_to_fp8x2(tmp.h2r, fp8_type::__default_saturation,
fp8_type::__default_interpret);
}
// half2x2 -> fp8x4
template <>
__inline__ __device__ uint32_t
scaled_vec_conversion<uint32_t, uint2>(const uint2& a, float scale) {
union {
uint16_t ui16[2];
uint32_t ui32;
} tmp;
tmp.ui16[0] = scaled_vec_conversion<uint16_t, uint32_t>(a.x, scale);
tmp.ui16[1] = scaled_vec_conversion<uint16_t, uint32_t>(a.y, scale);
return tmp.ui32;
}
// half2x4 -> fp8x8
template <>
__inline__ __device__ uint2 scaled_vec_conversion<uint2, uint4>(const uint4& a,
float scale) {
union {
uint2 ui2[2];
uint4 ui4;
} tmp;
tmp.ui4 = a;
uint2 res;
res.x = scaled_vec_conversion<uint32_t, uint2>(tmp.ui2[0], scale);
res.y = scaled_vec_conversion<uint32_t, uint2>(tmp.ui2[1], scale);
return res;
}
// bf16 -> fp8
template <>
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
const __nv_bfloat16& a, const float scale) {
hip_fp8 res{__bfloat162float(a) / scale};
return res.data;
const __nv_bfloat16& a, float scale) {
return __hip_cvt_float_to_fp8(__bfloat162float(a) / scale,
fp8_type::__default_saturation,
fp8_type::__default_interpret);
}
// bf16x2 -> fp8x2
template <>
__inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, __nv_bfloat162>(
const __nv_bfloat162& a, float scale) {
union {
uint8_t ui8[2];
uint16_t ui16;
} tmp;
tmp.ui8[0] = scaled_vec_conversion<uint8_t, __nv_bfloat16>(a.x, scale);
tmp.ui8[1] = scaled_vec_conversion<uint8_t, __nv_bfloat16>(a.y, scale);
return tmp.ui16;
}
// bf16x4 -> fp8x4
template <>
__inline__ __device__ uint32_t
scaled_vec_conversion<uint32_t, bf16_4_t>(const bf16_4_t& a, float scale) {
union {
uint16_t ui16[2];
uint32_t ui32;
} tmp;
tmp.ui16[0] = scaled_vec_conversion<uint16_t, __nv_bfloat162>(a.x, scale);
tmp.ui16[1] = scaled_vec_conversion<uint16_t, __nv_bfloat162>(a.y, scale);
return tmp.ui32;
}
// bf16x8 -> fp8x8
template <>
__inline__ __device__ uint2
scaled_vec_conversion<uint2, bf16_8_t>(const bf16_8_t& a, float scale) {
uint2 res;
res.x = scaled_vec_conversion<uint32_t, bf16_4_t>({a.x, a.y}, scale);
res.y = scaled_vec_conversion<uint32_t, bf16_4_t>({a.z, a.w}, scale);
return res;
}
// float -> fp8
template <>
__inline__ __device__ uint8_t
scaled_vec_conversion<uint8_t, float>(const float& a, const float scale) {
hip_fp8 f8(a / scale);
return f8.data;
scaled_vec_conversion<uint8_t, float>(const float& a, float scale) {
return __hip_cvt_float_to_fp8(a / scale, fp8_type::__default_saturation,
fp8_type::__default_interpret);
}
// fp8x4 -> float4
// floatx2 -> fp8x2
template <>
__inline__ __device__ float4
scaled_vec_conversion<float4, uint32_t>(const uint32_t& a, const float scale) {
Float4_ tmp = scaled_vec_conversion<Float4_, uint32_t>(a, scale);
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
return res;
__inline__ __device__ uint16_t
scaled_vec_conversion<uint16_t, float2>(const float2& a, float scale) {
return __hip_cvt_float2_to_fp8x2(a / scale, fp8_type::__default_saturation,
fp8_type::__default_interpret);
}
// floatx4 -> fp8x4
template <>
__inline__ __device__ uint32_t
scaled_vec_conversion<uint32_t, float4>(const float4& a, float scale) {
union {
uint16_t ui16[2];
uint32_t ui32;
} tmp;
tmp.ui16[0] = scaled_vec_conversion<uint16_t, float2>({a.x, a.y}, scale);
tmp.ui16[1] = scaled_vec_conversion<uint16_t, float2>({a.z, a.w}, scale);
return tmp.ui32;
}
#endif // ENABLE_FP8

View File

@ -12,7 +12,7 @@ C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX =
std::numeric_limits<FP8_TYPE>::max();
#else
#include <c10/util/Float8_e4m3fnuz.h>
#include "amd/hip_float8.h"
#include "amd/quant_utils.cuh"
using FP8_TYPE = c10::Float8_e4m3fnuz;
// Using the default max value from pytorch (240.0) will cause accuracy
// issue when running dynamic quantization. Here use 224.0f for rocm.
@ -47,8 +47,10 @@ __device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val,
return static_cast<c10::Float8_e4m3fn>(r);
#else
// Use hardware cvt instruction for fp8 on rocm
return c10::Float8_e4m3fnuz(hip_fp8(r).data,
c10::Float8_e4m3fnuz::from_bits());
return c10::Float8_e4m3fnuz(
__hip_cvt_float_to_fp8(r, fp8::fp8_type::__default_saturation,
fp8::fp8_type::__default_interpret),
c10::Float8_e4m3fnuz::from_bits());
#endif
}

View File

@ -159,19 +159,20 @@ def test_reshape_and_cache(
device)
key_cache, value_cache = key_caches[0], value_caches[0]
# Using default kv_scale
k_scale = (key.amax() / 64.0).to(torch.float32)
v_scale = (value.amax() / 64.0).to(torch.float32)
# Clone the KV caches.
if kv_cache_dtype == "fp8":
cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
ops.convert_fp8(cloned_key_cache, key_cache)
ops.convert_fp8(cloned_key_cache, key_cache, k_scale.item())
cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
ops.convert_fp8(cloned_value_cache, value_cache)
ops.convert_fp8(cloned_value_cache, value_cache, v_scale.item())
else:
cloned_key_cache = key_cache.clone()
cloned_value_cache = value_cache.clone()
# Using default kv_scale
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
# Call the reshape_and_cache kernel.
opcheck(torch.ops._C_cache_ops.reshape_and_cache,
(key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype,
@ -182,9 +183,9 @@ def test_reshape_and_cache(
if kv_cache_dtype == "fp8":
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
ops.convert_fp8(result_key_cache, key_cache)
ops.convert_fp8(result_key_cache, key_cache, k_scale.item())
result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
ops.convert_fp8(result_value_cache, value_cache)
ops.convert_fp8(result_value_cache, value_cache, v_scale.item())
# Run the reference implementation.
reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
@ -268,15 +269,16 @@ def test_reshape_and_cache_flash(
del key_caches
del value_caches
k_scale = (key.amax() / 256.0).to(torch.float32)
v_scale = (value.amax() / 256.0).to(torch.float32)
k_scale = (key.amax() / 64.0).to(torch.float32)
v_scale = (value.amax() / 64.0).to(torch.float32)
# Clone the KV caches.
if kv_cache_dtype == "fp8":
cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
ops.convert_fp8(cloned_key_cache, key_cache, k_scale, kv_cache_dtype)
ops.convert_fp8(cloned_key_cache, key_cache, k_scale.item(),
kv_cache_dtype)
cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
ops.convert_fp8(cloned_value_cache, value_cache, v_scale,
ops.convert_fp8(cloned_value_cache, value_cache, v_scale.item(),
kv_cache_dtype)
else:
cloned_key_cache = key_cache.clone()