[ROCm][Quantization][Kernel] Using HIP FP8 header (#12593)
This commit is contained in:
parent
2f42a4888c
commit
aabeb2688f
@ -174,6 +174,25 @@ include(FetchContent)
|
||||
file(MAKE_DIRECTORY ${FETCHCONTENT_BASE_DIR}) # Ensure the directory exists
|
||||
message(STATUS "FetchContent base directory: ${FETCHCONTENT_BASE_DIR}")
|
||||
|
||||
#
|
||||
# Set rocm version dev int.
|
||||
#
|
||||
if(VLLM_GPU_LANG STREQUAL "HIP")
|
||||
#
|
||||
# Overriding the default -O set up by cmake, adding ggdb3 for the most verbose devug info
|
||||
#
|
||||
set(CMAKE_${VLLM_GPU_LANG}_FLAGS_DEBUG "${CMAKE_${VLLM_GPU_LANG}_FLAGS_DEBUG} -O0 -ggdb3")
|
||||
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -O0 -ggdb3")
|
||||
|
||||
|
||||
#
|
||||
# Certain HIP functions are marked as [[nodiscard]], yet vllm ignores the result which generates
|
||||
# a lot of warnings that always mask real issues. Suppressing until this is properly addressed.
|
||||
#
|
||||
set(CMAKE_${VLLM_GPU_LANG}_FLAGS "${CMAKE_${VLLM_GPU_LANG}_FLAGS} -Wno-unused-result")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-result")
|
||||
endif()
|
||||
|
||||
#
|
||||
# Define other extension targets
|
||||
#
|
||||
|
@ -1,137 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifdef __HIPCC__
|
||||
#include <hip/hip_runtime.h>
|
||||
#else
|
||||
#include <type_traits>
|
||||
#include <stdint.h>
|
||||
#include <math.h>
|
||||
#include <iostream>
|
||||
#endif
|
||||
|
||||
#include "hip_float8_impl.h"
|
||||
|
||||
struct alignas(1) hip_fp8 {
|
||||
struct from_bits_t {};
|
||||
HIP_FP8_HOST_DEVICE static constexpr from_bits_t from_bits() {
|
||||
return from_bits_t();
|
||||
}
|
||||
uint8_t data;
|
||||
|
||||
hip_fp8() = default;
|
||||
HIP_FP8_HOST_DEVICE constexpr hip_fp8(const hip_fp8&) = default;
|
||||
HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v) = delete;
|
||||
explicit HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v, from_bits_t)
|
||||
: data(v) {}
|
||||
|
||||
#ifdef __HIP__MI300__
|
||||
// NOTE: ON-DEVICE... always optimal bias
|
||||
explicit HIP_FP8_DEVICE hip_fp8(float v)
|
||||
: data(hip_fp8_impl::to_fp8_from_fp32(v)) {}
|
||||
|
||||
explicit HIP_FP8_DEVICE hip_fp8(_Float16 v)
|
||||
: hip_fp8(static_cast<float>(v)) {}
|
||||
|
||||
// Host only implementation using s/w simulation
|
||||
explicit HIP_FP8_HOST
|
||||
#else // __HIP__MI300__
|
||||
// both Host and DEVICE for non-MI300 using s/w simulation
|
||||
explicit HIP_FP8_HOST_DEVICE
|
||||
#endif // __HIP__MI300__
|
||||
hip_fp8(float v) {
|
||||
data = hip_fp8_impl::to_float8<4, 3, float, true /*negative_zero_nan*/,
|
||||
true /*clip*/>(v);
|
||||
}
|
||||
|
||||
explicit HIP_FP8_HOST_DEVICE hip_fp8(double v)
|
||||
: hip_fp8(static_cast<float>(v)) {}
|
||||
|
||||
#ifdef __HIP__MI300__
|
||||
// upcast using device specific intrinsic
|
||||
explicit inline HIP_FP8_DEVICE operator float() const {
|
||||
float fval;
|
||||
uint32_t i32val = static_cast<uint32_t>(data);
|
||||
|
||||
// upcast
|
||||
asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0"
|
||||
: "=v"(fval)
|
||||
: "v"(i32val));
|
||||
|
||||
return fval;
|
||||
}
|
||||
|
||||
explicit inline HIP_FP8_HOST operator float() const
|
||||
#else // __HIP__MI300__
|
||||
explicit inline HIP_FP8_HOST_DEVICE operator float() const
|
||||
#endif // __HIP__MI300__
|
||||
{
|
||||
return hip_fp8_impl::from_float8<4, 3, float, true /*negative_zero_nan*/>(
|
||||
data);
|
||||
}
|
||||
};
|
||||
|
||||
namespace std {
|
||||
inline hip_fp8 sin(hip_fp8 a) { return hip_fp8(sinf(float(a))); }
|
||||
inline hip_fp8 cos(hip_fp8 a) { return hip_fp8(cosf(float(a))); }
|
||||
HIP_FP8_HOST_DEVICE constexpr hip_fp8 real(const hip_fp8& a) { return a; }
|
||||
} // namespace std
|
||||
|
||||
// Special operator overloading
|
||||
inline std::ostream& operator<<(std::ostream& os, const hip_fp8& f8) {
|
||||
return os << float(f8);
|
||||
}
|
||||
|
||||
// all + operator overloading with mixed types
|
||||
// mixed types, always converts to f32, does computation in f32, and returns
|
||||
// float
|
||||
inline HIP_FP8_HOST_DEVICE float operator+(const float fa, hip_fp8 b) {
|
||||
return (fa + float(b));
|
||||
}
|
||||
|
||||
inline HIP_FP8_HOST_DEVICE float operator+(hip_fp8 a, const float fb) {
|
||||
return (float(a) + fb);
|
||||
}
|
||||
|
||||
inline HIP_FP8_HOST_DEVICE hip_fp8 operator+(hip_fp8 a, hip_fp8 b) {
|
||||
return hip_fp8(float(a) + float(b));
|
||||
}
|
||||
|
||||
inline HIP_FP8_HOST_DEVICE hip_fp8& operator+=(hip_fp8& a, hip_fp8 b) {
|
||||
return a = hip_fp8(float(a) + float(b));
|
||||
}
|
||||
|
||||
// overloading multiplication, always returns float,
|
||||
inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, hip_fp8 b) {
|
||||
return float(a) * float(b);
|
||||
}
|
||||
|
||||
inline HIP_FP8_HOST_DEVICE float operator*(float a, hip_fp8 b) {
|
||||
return (a * float(b));
|
||||
}
|
||||
|
||||
inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, float b) {
|
||||
return (float(a) * b);
|
||||
}
|
||||
|
||||
inline HIP_FP8_HOST_DEVICE float operator*(int32_t a, hip_fp8 b) {
|
||||
return ((float)a * float(b));
|
||||
}
|
||||
|
||||
inline HIP_FP8_HOST_DEVICE float operator*(double a, hip_fp8 b) {
|
||||
return ((float)a * float(b));
|
||||
}
|
||||
|
||||
// overloading for compare
|
||||
inline HIP_FP8_HOST_DEVICE bool operator==(hip_fp8 a, hip_fp8 b) {
|
||||
return (a.data == b.data);
|
||||
}
|
||||
inline HIP_FP8_HOST_DEVICE bool operator!=(hip_fp8 a, hip_fp8 b) {
|
||||
return (a.data != b.data);
|
||||
}
|
||||
|
||||
inline HIP_FP8_HOST_DEVICE bool operator>=(hip_fp8 a, hip_fp8 b) {
|
||||
return static_cast<float>(a) >= static_cast<float>(b);
|
||||
}
|
||||
inline HIP_FP8_HOST_DEVICE bool operator>(hip_fp8 a, hip_fp8 b) {
|
||||
return static_cast<float>(a) > static_cast<float>(b);
|
||||
}
|
@ -1,315 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#if defined(__HIPCC__) && defined(__gfx942__)
|
||||
#define __HIP__MI300__
|
||||
#endif
|
||||
|
||||
#ifdef __HIPCC__
|
||||
#define HIP_FP8_HOST_DEVICE __host__ __device__
|
||||
#define HIP_FP8_HOST __host__
|
||||
#define HIP_FP8_DEVICE __device__
|
||||
#else
|
||||
#define HIP_FP8_HOST_DEVICE
|
||||
#define HIP_FP8_HOST
|
||||
#define HIP_FP8_DEVICE
|
||||
#endif
|
||||
|
||||
namespace hip_fp8_impl {
|
||||
|
||||
#ifdef __HIP__MI300__
|
||||
HIP_FP8_DEVICE uint8_t to_fp8_from_fp32(float v) {
|
||||
uint8_t i8data;
|
||||
union {
|
||||
float fval;
|
||||
uint32_t i32val;
|
||||
uint8_t i8val[4]; // NOTE: not endian independent
|
||||
} val;
|
||||
|
||||
uint32_t ival = 0;
|
||||
val.fval = v;
|
||||
|
||||
if ((val.i32val & 0x7F800000) !=
|
||||
0x7F800000) { /// propagate NAN/INF, no clipping
|
||||
val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0);
|
||||
}
|
||||
|
||||
ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival,
|
||||
false); // false -> WORD0
|
||||
val.i32val = ival;
|
||||
i8data = val.i8val[0];
|
||||
|
||||
return i8data;
|
||||
}
|
||||
#endif // __HIP__MI300__
|
||||
|
||||
HIP_FP8_HOST inline int clz(uint32_t x) { return __builtin_clz(x); }
|
||||
#if defined(__HIPCC__) || defined(__CUDA_ARCH__)
|
||||
HIP_FP8_DEVICE inline int clz(uint32_t x) { return __clz(x); }
|
||||
#endif
|
||||
|
||||
template <int we, int wm, typename T, bool negative_zero_nan, bool clip>
|
||||
HIP_FP8_HOST_DEVICE uint8_t to_float8(T _x, bool stoch = false,
|
||||
uint32_t rng = 0) {
|
||||
#ifdef __HIPCC__
|
||||
constexpr bool is_half = std::is_same<T, _Float16>::value;
|
||||
#else
|
||||
constexpr bool is_half = false;
|
||||
#endif
|
||||
constexpr bool is_float = std::is_same<T, float>::value;
|
||||
static_assert(wm + we == 7, "wm+we==7");
|
||||
static_assert(is_half || is_float, "Only half and float can be cast to f8");
|
||||
|
||||
const int mfmt = (sizeof(T) == 4) ? 23 : 10;
|
||||
uint32_t x;
|
||||
if (sizeof(T) == 4) {
|
||||
x = reinterpret_cast<uint32_t&>(_x);
|
||||
} else {
|
||||
x = reinterpret_cast<uint16_t&>(_x);
|
||||
}
|
||||
|
||||
uint32_t head, mantissa;
|
||||
int exponent, bias;
|
||||
uint32_t sign;
|
||||
|
||||
if (sizeof(T) == 4) {
|
||||
head = x & 0xFF800000;
|
||||
mantissa = x & 0x7FFFFF;
|
||||
exponent = (head >> 23) & 0xFF;
|
||||
sign = head >> 31;
|
||||
bias = 127;
|
||||
} else {
|
||||
head = x & 0xFC00;
|
||||
mantissa = x & 0x3FF;
|
||||
exponent = (head >> 10) & 0x1F;
|
||||
sign = head >> 15;
|
||||
bias = 15;
|
||||
}
|
||||
|
||||
uint32_t signed_inf = (sign << 7) + (((1 << we) - 1) << wm);
|
||||
|
||||
// Deal with inf and NaNs
|
||||
if (negative_zero_nan) {
|
||||
if (sizeof(T) == 4) {
|
||||
if ((x & 0x7F800000) == 0x7F800000) {
|
||||
return 0x80;
|
||||
}
|
||||
} else {
|
||||
// if(__hisinf(x) || __hisnan(x))
|
||||
if ((x & 0x7C00) == 0x7C00) {
|
||||
return 0x80;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (sizeof(T) == 4) {
|
||||
if ((x & 0x7F800000) == 0x7F800000) {
|
||||
return signed_inf + (mantissa != 0 ? 1 : 0);
|
||||
}
|
||||
} else {
|
||||
if ((x & 0x7C00) == 0x7C00) {
|
||||
return signed_inf + (mantissa != 0 ? 1 : 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (x == 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// First need to check if it is normal or denorm as there is a difference of
|
||||
// implicit 1 Then need to adjust the exponent to align with the F8 exponent,
|
||||
// in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng
|
||||
// to mantissa and truncate. And for RNE, no need to add rng. Then probably
|
||||
// need to check whether there is carry and adjust exponent and mantissa again
|
||||
|
||||
// For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent
|
||||
// bits
|
||||
const int f8_bias = (1 << (we - 1)) - 1 + (negative_zero_nan ? 1 : 0);
|
||||
const int f8_denormal_act_exponent =
|
||||
1 - f8_bias; // actual exponent of f8 denormal
|
||||
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
|
||||
// f8_exponent is the converted f8 exponent with bias encoding
|
||||
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
|
||||
// the difference needs to be adjusted and mantissa shifted
|
||||
int act_exponent, f8_exponent, exponent_diff;
|
||||
|
||||
if (exponent == 0) { // fp32/fp16 is in denormal.
|
||||
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we
|
||||
mostly concern fp16 here. In this case, f8 is usually in denormal. But there
|
||||
could be exceptions. fp16 denormal has exponent bias 15 while bf8 with NANOO has
|
||||
exponent bias 16. It means that there are some numbers in fp16 denormal but they
|
||||
are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
|
||||
where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8
|
||||
(NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */
|
||||
act_exponent = exponent - bias + 1;
|
||||
exponent_diff =
|
||||
f8_denormal_act_exponent -
|
||||
act_exponent; // actual exponent is exponent-bias+1 as it is denormal
|
||||
} else { // fp32/fp16 is normal with implicit 1
|
||||
act_exponent = exponent - bias;
|
||||
if (act_exponent <= f8_denormal_act_exponent) {
|
||||
/* This is the case where fp32/fp16 is normal but it is in f8 denormal
|
||||
range. For example fp8 nanoo mode, denormal exponent is -7, but if the
|
||||
fp32/fp16 actual exponent is -7, it is actually larger due to the implicit 1,
|
||||
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
|
||||
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
|
||||
exponent_diff = f8_denormal_act_exponent - act_exponent;
|
||||
} else { // both fp32/fp16 and f8 are in normal range
|
||||
exponent_diff = 0; // exponent_diff=0 does not mean there is no
|
||||
// difference for this case, act_exponent could be
|
||||
// larger. Just that it does not need shift mantissa
|
||||
}
|
||||
mantissa += (1 << mfmt); // Add the implicit 1 into mantissa
|
||||
}
|
||||
|
||||
bool midpoint = (mantissa & ((1 << (mfmt - wm + exponent_diff)) - 1)) ==
|
||||
static_cast<uint32_t>(1 << (mfmt - wm + exponent_diff - 1));
|
||||
/* This part is a bit tricky. The judgment of whether it is a tie needs to be
|
||||
done before we shift right as shift right could rip off some residual part
|
||||
and make something not midpoint look like midpoint. For example, the fp16
|
||||
number 0x1002 (0 00100 0000000010), it is larger than midpoint, but after
|
||||
shift right by 4 bits, it would look like midpoint.
|
||||
*/
|
||||
|
||||
if (exponent_diff > 0) {
|
||||
mantissa >>= exponent_diff;
|
||||
} else if (exponent_diff == -1) {
|
||||
mantissa <<= -exponent_diff;
|
||||
}
|
||||
bool implicit_one = mantissa & (1 << mfmt);
|
||||
// if there is no implicit 1, it means the f8 is denormal and need to adjust
|
||||
// to denorm exponent
|
||||
f8_exponent = (act_exponent + exponent_diff) /*actual f8 exponent*/ +
|
||||
f8_bias - (implicit_one ? 0 : 1);
|
||||
|
||||
// Now we have the exponent and mantissa adjusted
|
||||
uint32_t drop_mask = (1 << (mfmt - wm)) - 1;
|
||||
bool odd = mantissa & (1 << (mfmt - wm)); // if the least significant bit
|
||||
// that is not truncated is 1
|
||||
mantissa +=
|
||||
(stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) &
|
||||
drop_mask;
|
||||
|
||||
// Now we deal with overflow
|
||||
if (f8_exponent == 0) {
|
||||
if ((1 << mfmt) & mantissa) {
|
||||
f8_exponent = 1; // denormal overflow to become normal, promote exponent
|
||||
}
|
||||
} else {
|
||||
if ((1 << (mfmt + 1)) & mantissa) {
|
||||
mantissa >>= 1;
|
||||
f8_exponent++;
|
||||
}
|
||||
}
|
||||
|
||||
mantissa >>= (mfmt - wm);
|
||||
|
||||
// above range: quantize to maximum possible float of the same sign
|
||||
const int max_exp = (1 << we) - (negative_zero_nan ? 1 : 2);
|
||||
if (f8_exponent > max_exp) {
|
||||
if (clip) {
|
||||
mantissa = (1 << wm) - 1;
|
||||
f8_exponent = max_exp;
|
||||
} else {
|
||||
return signed_inf;
|
||||
}
|
||||
}
|
||||
|
||||
if (f8_exponent == 0 && mantissa == 0) {
|
||||
return negative_zero_nan ? 0 : (sign << 7);
|
||||
}
|
||||
mantissa &= (1 << wm) - 1;
|
||||
return (sign << 7) | (f8_exponent << wm) | mantissa;
|
||||
}
|
||||
|
||||
template <int we, int wm, typename T = float, bool negative_zero_nan = true>
|
||||
inline HIP_FP8_HOST_DEVICE T from_float8(uint8_t x) {
|
||||
#ifdef __HIPCC__
|
||||
constexpr bool is_half = std::is_same<T, _Float16>::value;
|
||||
#else
|
||||
constexpr bool is_half = false;
|
||||
#endif
|
||||
constexpr bool is_float = std::is_same<T, float>::value;
|
||||
static_assert(is_half || is_float, "only half and float are supported");
|
||||
|
||||
constexpr int weo = is_half ? 5 : 8;
|
||||
constexpr int wmo = is_half ? 10 : (is_float ? 23 : 7);
|
||||
|
||||
T fInf, fNegInf, fNaN, fNeg0;
|
||||
|
||||
#ifdef __HIPCC__
|
||||
if (is_half) {
|
||||
const uint16_t ihInf = 0x7C00;
|
||||
const uint16_t ihNegInf = 0xFC00;
|
||||
const uint16_t ihNaN = 0x7C01;
|
||||
const uint16_t ihNeg0 = 0x8000;
|
||||
fInf = reinterpret_cast<const _Float16&>(ihInf);
|
||||
fNegInf = reinterpret_cast<const _Float16&>(ihNegInf);
|
||||
fNaN = reinterpret_cast<const _Float16&>(ihNaN);
|
||||
fNeg0 = reinterpret_cast<const _Float16&>(ihNeg0);
|
||||
} else
|
||||
#endif
|
||||
if (is_float) {
|
||||
const uint32_t ifInf = 0x7F800000;
|
||||
const uint32_t ifNegInf = 0xFF800000;
|
||||
const uint32_t ifNaN = 0x7F800001;
|
||||
const uint32_t ifNeg0 = 0x80000000;
|
||||
fInf = reinterpret_cast<const float&>(ifInf);
|
||||
fNegInf = reinterpret_cast<const float&>(ifNegInf);
|
||||
fNaN = reinterpret_cast<const float&>(ifNaN);
|
||||
fNeg0 = reinterpret_cast<const float&>(ifNeg0);
|
||||
}
|
||||
|
||||
if (x == 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
uint32_t sign = x >> 7;
|
||||
uint32_t mantissa = x & ((1 << wm) - 1);
|
||||
int exponent = (x & 0x7F) >> wm;
|
||||
if (negative_zero_nan) {
|
||||
if (x == 0x80) {
|
||||
return fNaN;
|
||||
}
|
||||
} else {
|
||||
if (x == 0x80) {
|
||||
return fNeg0;
|
||||
}
|
||||
if (exponent == ((1 << we) - 1)) {
|
||||
return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN;
|
||||
}
|
||||
}
|
||||
typename std::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type retval;
|
||||
if (we == 5 && is_half && !negative_zero_nan) {
|
||||
retval = x << 8;
|
||||
return reinterpret_cast<const T&>(retval);
|
||||
}
|
||||
|
||||
const int exp_low_cutoff =
|
||||
(1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0);
|
||||
|
||||
// subnormal input
|
||||
if (exponent == 0) {
|
||||
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
|
||||
int sh = 1 + clz(mantissa) - (32 - wm);
|
||||
mantissa <<= sh;
|
||||
exponent += 1 - sh;
|
||||
mantissa &= ((1 << wm) - 1);
|
||||
}
|
||||
exponent += exp_low_cutoff - 1;
|
||||
mantissa <<= wmo - wm;
|
||||
|
||||
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
|
||||
if (exponent <= 0) {
|
||||
mantissa |= 1 << wmo;
|
||||
mantissa >>= 1 - exponent;
|
||||
exponent = 0;
|
||||
}
|
||||
|
||||
if (sizeof(T) == 2) {
|
||||
retval = (sign << 15) | (exponent << 10) | mantissa;
|
||||
} else {
|
||||
retval = (sign << 31) | (exponent << 23) | mantissa;
|
||||
}
|
||||
return reinterpret_cast<const T&>(retval);
|
||||
}
|
||||
|
||||
} // namespace hip_fp8_impl
|
@ -1,13 +1,11 @@
|
||||
#pragma once
|
||||
#include "hip_float8.h"
|
||||
#include <hip/hip_fp8.h>
|
||||
|
||||
#include <hip/hip_fp16.h>
|
||||
#include <hip/hip_bf16.h>
|
||||
#include <hip/hip_bfloat16.h>
|
||||
|
||||
#include "../../../attention/dtype_fp8.cuh"
|
||||
#include "../../../attention/dtype_float32.cuh"
|
||||
#include "../../../attention/dtype_bfloat16.cuh"
|
||||
#include "../../../attention/attention_dtypes.h"
|
||||
|
||||
namespace vllm {
|
||||
#ifdef USE_ROCM
|
||||
@ -26,40 +24,31 @@ __inline__ __device__ Tout scaled_vec_conversion(const Tin& x,
|
||||
return x;
|
||||
}
|
||||
|
||||
#if HIP_FP8_TYPE_FNUZ
|
||||
using fp8_type = __hip_fp8_e4m3_fnuz;
|
||||
using fp8x2_type = __hip_fp8x2_e4m3_fnuz;
|
||||
#elif HIP_FP8_TYPE_OCP
|
||||
using fp8_type = __hip_fp8_e4m3;
|
||||
using fp8x2_type = __hip_fp8x2_e4m3;
|
||||
#endif
|
||||
|
||||
// fp8 -> half
|
||||
template <>
|
||||
__inline__ __device__ uint16_t
|
||||
vec_conversion<uint16_t, uint8_t>(const uint8_t& a) {
|
||||
hip_fp8 f8{a, hip_fp8::from_bits()};
|
||||
__half_raw res;
|
||||
res.data = static_cast<float>(f8);
|
||||
return res.x;
|
||||
return __hip_cvt_fp8_to_halfraw(a, fp8_type::__default_interpret).x;
|
||||
}
|
||||
|
||||
// fp8x2 -> half2
|
||||
template <>
|
||||
__inline__ __device__ uint32_t
|
||||
vec_conversion<uint32_t, uint16_t>(const uint16_t& a) {
|
||||
#if defined(__HIP__MI300__) && \
|
||||
defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
||||
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
||||
union {
|
||||
__half2_raw h2r;
|
||||
uint32_t ui32;
|
||||
} tmp;
|
||||
tmp.h2r.x.data = f2[0];
|
||||
tmp.h2r.y.data = f2[1];
|
||||
tmp.h2r = __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
|
||||
return tmp.ui32;
|
||||
#else
|
||||
union {
|
||||
uint16_t u16[2];
|
||||
uint32_t u32;
|
||||
} tmp;
|
||||
|
||||
tmp.u16[0] = vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a));
|
||||
tmp.u16[1] = vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a >> 8U));
|
||||
return tmp.u32;
|
||||
#endif
|
||||
}
|
||||
|
||||
// fp8x4 -> half2x2
|
||||
@ -92,9 +81,9 @@ using __nv_bfloat16 = __hip_bfloat16;
|
||||
template <>
|
||||
__inline__ __device__ __nv_bfloat16
|
||||
vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) {
|
||||
hip_fp8 f8{a, hip_fp8::from_bits()};
|
||||
float f{f8};
|
||||
return __float2bfloat16(f);
|
||||
fp8_type f8;
|
||||
f8.__x = a;
|
||||
return __float2bfloat16(static_cast<float>(f8));
|
||||
}
|
||||
|
||||
using __nv_bfloat162 = __hip_bfloat162;
|
||||
@ -136,27 +125,18 @@ __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a) {
|
||||
// fp8 -> float
|
||||
template <>
|
||||
__inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a) {
|
||||
hip_fp8 fp8{a, hip_fp8::from_bits()};
|
||||
return static_cast<float>(fp8);
|
||||
fp8_type f8;
|
||||
f8.__x = a;
|
||||
return static_cast<float>(f8);
|
||||
}
|
||||
|
||||
// fp8x2 -> float2
|
||||
template <>
|
||||
__inline__ __device__ float2
|
||||
vec_conversion<float2, uint16_t>(const uint16_t& a) {
|
||||
#if defined(__HIP__MI300__) && \
|
||||
defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
||||
float2 res;
|
||||
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
||||
res.x = f2[0];
|
||||
res.y = f2[1];
|
||||
return res;
|
||||
#else
|
||||
float2 res;
|
||||
res.x = vec_conversion<float, uint8_t>(static_cast<uint8_t>(a));
|
||||
res.y = vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U));
|
||||
return res;
|
||||
#endif
|
||||
fp8x2_type f8x2;
|
||||
f8x2.__x = a;
|
||||
return static_cast<float2>(f8x2);
|
||||
}
|
||||
|
||||
// fp8x4 -> float4
|
||||
@ -169,6 +149,15 @@ vec_conversion<Float4_, uint32_t>(const uint32_t& a) {
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x4 -> float4
|
||||
template <>
|
||||
__inline__ __device__ float4
|
||||
vec_conversion<float4, uint32_t>(const uint32_t& a) {
|
||||
Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
|
||||
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x8 -> float8
|
||||
template <>
|
||||
__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a) {
|
||||
@ -189,33 +178,36 @@ __inline__ __device__ uint8_t
|
||||
vec_conversion<uint8_t, uint16_t>(const uint16_t& a) {
|
||||
__half_raw tmp;
|
||||
tmp.x = a;
|
||||
return __hip_cvt_halfraw_to_fp8(tmp, fp8_type::__default_saturation,
|
||||
fp8_type::__default_interpret);
|
||||
}
|
||||
|
||||
hip_fp8 f8{static_cast<float>(tmp.data)};
|
||||
return f8.data;
|
||||
template <>
|
||||
__inline__ __device__ uint16_t
|
||||
vec_conversion<uint16_t, uint32_t>(const uint32_t& a) {
|
||||
union {
|
||||
uint32_t ui32;
|
||||
__half2_raw h2r;
|
||||
} tmp;
|
||||
tmp.ui32 = a;
|
||||
return __hip_cvt_halfraw2_to_fp8x2(tmp.h2r, fp8_type::__default_saturation,
|
||||
fp8_type::__default_interpret);
|
||||
}
|
||||
|
||||
// bf16 -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t
|
||||
vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a) {
|
||||
hip_fp8 res{__bfloat162float(a)};
|
||||
return res.data;
|
||||
return __hip_cvt_float_to_fp8(__bfloat162float(a),
|
||||
fp8_type::__default_saturation,
|
||||
fp8_type::__default_interpret);
|
||||
}
|
||||
|
||||
// float -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a) {
|
||||
hip_fp8 f8(a);
|
||||
return f8.data;
|
||||
}
|
||||
|
||||
// fp8x4 -> float4
|
||||
template <>
|
||||
__inline__ __device__ float4
|
||||
vec_conversion<float4, uint32_t>(const uint32_t& a) {
|
||||
Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
|
||||
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
||||
return res;
|
||||
return __hip_cvt_float_to_fp8(a, fp8_type::__default_saturation,
|
||||
fp8_type::__default_interpret);
|
||||
}
|
||||
|
||||
// float2 -> half2
|
||||
@ -307,90 +299,22 @@ vec_conversion<bf16_8_t, Float8_>(const Float8_& a) {
|
||||
|
||||
*/
|
||||
|
||||
// fp8 -> half
|
||||
template <>
|
||||
__inline__ __device__ uint16_t
|
||||
scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, const float scale) {
|
||||
hip_fp8 f8{a, hip_fp8::from_bits()};
|
||||
__half_raw res;
|
||||
res.data = static_cast<float>(f8) * scale;
|
||||
return res.x;
|
||||
}
|
||||
|
||||
// fp8x2 -> half2
|
||||
template <>
|
||||
__inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>(
|
||||
const uint16_t& a, const float scale) {
|
||||
#if defined(__HIP__MI300__) && \
|
||||
defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
||||
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
||||
union {
|
||||
__half2_raw h2r;
|
||||
uint32_t ui32;
|
||||
} tmp;
|
||||
tmp.h2r.x.data = f2[0] * scale;
|
||||
tmp.h2r.y.data = f2[1] * scale;
|
||||
return tmp.ui32;
|
||||
#else
|
||||
union {
|
||||
uint16_t u16[2];
|
||||
uint32_t u32;
|
||||
} tmp;
|
||||
|
||||
tmp.u16[0] =
|
||||
scaled_vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a), scale);
|
||||
tmp.u16[1] = scaled_vec_conversion<uint16_t, uint8_t>(
|
||||
static_cast<uint8_t>(a >> 8U), scale);
|
||||
return tmp.u32;
|
||||
#endif
|
||||
}
|
||||
|
||||
// fp8x4 -> half2x2
|
||||
template <>
|
||||
__inline__ __device__ uint2
|
||||
scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, const float scale) {
|
||||
union {
|
||||
uint2 u32x2;
|
||||
uint32_t u32[2];
|
||||
} tmp;
|
||||
tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale);
|
||||
tmp.u32[1] =
|
||||
scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale);
|
||||
return tmp.u32x2;
|
||||
}
|
||||
|
||||
// fp8x8 -> half2x4
|
||||
template <>
|
||||
__inline__ __device__ uint4
|
||||
scaled_vec_conversion<uint4, uint2>(const uint2& a, const float scale) {
|
||||
union {
|
||||
uint4 u64x2;
|
||||
uint2 u64[2];
|
||||
} tmp;
|
||||
tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale);
|
||||
tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale);
|
||||
return tmp.u64x2;
|
||||
}
|
||||
|
||||
using __nv_bfloat16 = __hip_bfloat16;
|
||||
|
||||
// fp8 -> __nv_bfloat16
|
||||
template <>
|
||||
__inline__ __device__ __nv_bfloat16
|
||||
scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a,
|
||||
const float scale) {
|
||||
hip_fp8 f8{a, hip_fp8::from_bits()};
|
||||
float f{f8};
|
||||
return __float2bfloat16(f * scale);
|
||||
scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, float scale) {
|
||||
fp8_type f8;
|
||||
f8.__x = a;
|
||||
return __float2bfloat16(static_cast<float>(f8) * scale);
|
||||
}
|
||||
|
||||
using __nv_bfloat162 = __hip_bfloat162;
|
||||
|
||||
// fp8x2 -> __nv_bfloat162
|
||||
template <>
|
||||
__inline__ __device__ __nv_bfloat162
|
||||
scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a,
|
||||
const float scale) {
|
||||
float scale) {
|
||||
__nv_bfloat162 res;
|
||||
res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale);
|
||||
res.y =
|
||||
@ -400,8 +324,8 @@ scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a,
|
||||
|
||||
// fp8x4 -> bf16_4_t
|
||||
template <>
|
||||
__inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(
|
||||
const uint32_t& a, const float scale) {
|
||||
__inline__ __device__ bf16_4_t
|
||||
scaled_vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a, float scale) {
|
||||
bf16_4_t res;
|
||||
res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale);
|
||||
res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U),
|
||||
@ -412,7 +336,7 @@ __inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(
|
||||
// fp8x8 -> bf16_8_t
|
||||
template <>
|
||||
__inline__ __device__ bf16_8_t
|
||||
scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, const float scale) {
|
||||
scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, float scale) {
|
||||
bf16_4_t tmp1, tmp2;
|
||||
tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale);
|
||||
tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale);
|
||||
@ -427,29 +351,19 @@ scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, const float scale) {
|
||||
// fp8 -> float
|
||||
template <>
|
||||
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
|
||||
const uint8_t& a, const float scale) {
|
||||
hip_fp8 fp8{a, hip_fp8::from_bits()};
|
||||
return static_cast<float>(fp8) * scale;
|
||||
const uint8_t& a, float scale) {
|
||||
fp8_type f8;
|
||||
f8.__x = a;
|
||||
return static_cast<float>(f8) * scale;
|
||||
}
|
||||
|
||||
// fp8x2 -> float2
|
||||
template <>
|
||||
__inline__ __device__ float2
|
||||
scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, const float scale) {
|
||||
#if defined(__HIP__MI300__) && \
|
||||
defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
||||
float2 res;
|
||||
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
||||
res.x = f2[0] * scale;
|
||||
res.y = f2[1] * scale;
|
||||
return res;
|
||||
#else
|
||||
float2 res;
|
||||
res.x = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a), scale);
|
||||
res.y = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U),
|
||||
scale);
|
||||
return res;
|
||||
#endif
|
||||
scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, float scale) {
|
||||
fp8x2_type f8x2;
|
||||
f8x2.__x = a;
|
||||
return static_cast<float2>(f8x2) * scale;
|
||||
}
|
||||
|
||||
// fp8x4 -> float4
|
||||
@ -462,10 +376,18 @@ scaled_vec_conversion<Float4_, uint32_t>(const uint32_t& a, const float scale) {
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x4 -> float4
|
||||
template <>
|
||||
__inline__ __device__ float4
|
||||
scaled_vec_conversion<float4, uint32_t>(const uint32_t& a, float scale) {
|
||||
Float4_ res = scaled_vec_conversion<Float4_, uint32_t>(a, scale);
|
||||
return {res.x.x, res.x.y, res.y.x, res.y.y};
|
||||
}
|
||||
|
||||
// fp8x8 -> float8
|
||||
template <>
|
||||
__inline__ __device__ Float8_
|
||||
scaled_vec_conversion<Float8_, uint2>(const uint2& a, const float scale) {
|
||||
scaled_vec_conversion<Float8_, uint2>(const uint2& a, float scale) {
|
||||
Float4_ tmp1, tmp2;
|
||||
tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale);
|
||||
tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale);
|
||||
@ -477,44 +399,184 @@ scaled_vec_conversion<Float8_, uint2>(const uint2& a, const float scale) {
|
||||
return res;
|
||||
}
|
||||
|
||||
/* Quantize(HP / scale) => FP8 */
|
||||
// fp8 -> half
|
||||
template <>
|
||||
__inline__ __device__ uint16_t
|
||||
scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, float scale) {
|
||||
__half_raw res;
|
||||
res.data = scaled_vec_conversion<float, uint8_t>(a, scale);
|
||||
return res.x;
|
||||
}
|
||||
|
||||
// TODO(Hai): vectorized to add
|
||||
// fp8x2 -> half2
|
||||
template <>
|
||||
__inline__ __device__ uint32_t
|
||||
scaled_vec_conversion<uint32_t, uint16_t>(const uint16_t& a, float scale) {
|
||||
__half2_raw h2r =
|
||||
__hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
|
||||
union {
|
||||
__half2_raw h2r;
|
||||
uint32_t ui32;
|
||||
} tmp;
|
||||
tmp.h2r = __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
|
||||
tmp.h2r.x.data *= scale;
|
||||
tmp.h2r.y.data *= scale;
|
||||
return tmp.ui32;
|
||||
}
|
||||
|
||||
// fp8x4 -> half2x2
|
||||
template <>
|
||||
__inline__ __device__ uint2
|
||||
scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, float scale) {
|
||||
union {
|
||||
uint2 u32x2;
|
||||
uint32_t u32[2];
|
||||
} tmp;
|
||||
tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale);
|
||||
tmp.u32[1] =
|
||||
scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale);
|
||||
return tmp.u32x2;
|
||||
}
|
||||
|
||||
// fp8x8 -> half2x4
|
||||
template <>
|
||||
__inline__ __device__ uint4 scaled_vec_conversion<uint4, uint2>(const uint2& a,
|
||||
float scale) {
|
||||
union {
|
||||
uint4 u64x2;
|
||||
uint2 u64[2];
|
||||
} tmp;
|
||||
tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale);
|
||||
tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale);
|
||||
return tmp.u64x2;
|
||||
}
|
||||
|
||||
// half -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t
|
||||
scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, const float scale) {
|
||||
scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, float scale) {
|
||||
__half_raw tmp;
|
||||
tmp.x = a;
|
||||
tmp.data /= scale;
|
||||
return __hip_cvt_halfraw_to_fp8(tmp, fp8_type::__default_saturation,
|
||||
fp8_type::__default_interpret);
|
||||
}
|
||||
|
||||
hip_fp8 f8{static_cast<float>(tmp.data) / scale};
|
||||
return f8.data;
|
||||
// halfx2 -> fp8x2
|
||||
template <>
|
||||
__inline__ __device__ uint16_t
|
||||
scaled_vec_conversion<uint16_t, uint32_t>(const uint32_t& a, float scale) {
|
||||
union {
|
||||
uint32_t ui32;
|
||||
__half2_raw h2r;
|
||||
} tmp;
|
||||
tmp.ui32 = a;
|
||||
tmp.h2r.x.data /= scale;
|
||||
tmp.h2r.y.data /= scale;
|
||||
return __hip_cvt_halfraw2_to_fp8x2(tmp.h2r, fp8_type::__default_saturation,
|
||||
fp8_type::__default_interpret);
|
||||
}
|
||||
|
||||
// half2x2 -> fp8x4
|
||||
template <>
|
||||
__inline__ __device__ uint32_t
|
||||
scaled_vec_conversion<uint32_t, uint2>(const uint2& a, float scale) {
|
||||
union {
|
||||
uint16_t ui16[2];
|
||||
uint32_t ui32;
|
||||
} tmp;
|
||||
tmp.ui16[0] = scaled_vec_conversion<uint16_t, uint32_t>(a.x, scale);
|
||||
tmp.ui16[1] = scaled_vec_conversion<uint16_t, uint32_t>(a.y, scale);
|
||||
return tmp.ui32;
|
||||
}
|
||||
|
||||
// half2x4 -> fp8x8
|
||||
template <>
|
||||
__inline__ __device__ uint2 scaled_vec_conversion<uint2, uint4>(const uint4& a,
|
||||
float scale) {
|
||||
union {
|
||||
uint2 ui2[2];
|
||||
uint4 ui4;
|
||||
} tmp;
|
||||
tmp.ui4 = a;
|
||||
uint2 res;
|
||||
res.x = scaled_vec_conversion<uint32_t, uint2>(tmp.ui2[0], scale);
|
||||
res.y = scaled_vec_conversion<uint32_t, uint2>(tmp.ui2[1], scale);
|
||||
return res;
|
||||
}
|
||||
|
||||
// bf16 -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
|
||||
const __nv_bfloat16& a, const float scale) {
|
||||
hip_fp8 res{__bfloat162float(a) / scale};
|
||||
return res.data;
|
||||
const __nv_bfloat16& a, float scale) {
|
||||
return __hip_cvt_float_to_fp8(__bfloat162float(a) / scale,
|
||||
fp8_type::__default_saturation,
|
||||
fp8_type::__default_interpret);
|
||||
}
|
||||
|
||||
// bf16x2 -> fp8x2
|
||||
template <>
|
||||
__inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, __nv_bfloat162>(
|
||||
const __nv_bfloat162& a, float scale) {
|
||||
union {
|
||||
uint8_t ui8[2];
|
||||
uint16_t ui16;
|
||||
} tmp;
|
||||
tmp.ui8[0] = scaled_vec_conversion<uint8_t, __nv_bfloat16>(a.x, scale);
|
||||
tmp.ui8[1] = scaled_vec_conversion<uint8_t, __nv_bfloat16>(a.y, scale);
|
||||
return tmp.ui16;
|
||||
}
|
||||
|
||||
// bf16x4 -> fp8x4
|
||||
template <>
|
||||
__inline__ __device__ uint32_t
|
||||
scaled_vec_conversion<uint32_t, bf16_4_t>(const bf16_4_t& a, float scale) {
|
||||
union {
|
||||
uint16_t ui16[2];
|
||||
uint32_t ui32;
|
||||
} tmp;
|
||||
tmp.ui16[0] = scaled_vec_conversion<uint16_t, __nv_bfloat162>(a.x, scale);
|
||||
tmp.ui16[1] = scaled_vec_conversion<uint16_t, __nv_bfloat162>(a.y, scale);
|
||||
return tmp.ui32;
|
||||
}
|
||||
|
||||
// bf16x8 -> fp8x8
|
||||
template <>
|
||||
__inline__ __device__ uint2
|
||||
scaled_vec_conversion<uint2, bf16_8_t>(const bf16_8_t& a, float scale) {
|
||||
uint2 res;
|
||||
res.x = scaled_vec_conversion<uint32_t, bf16_4_t>({a.x, a.y}, scale);
|
||||
res.y = scaled_vec_conversion<uint32_t, bf16_4_t>({a.z, a.w}, scale);
|
||||
return res;
|
||||
}
|
||||
|
||||
// float -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t
|
||||
scaled_vec_conversion<uint8_t, float>(const float& a, const float scale) {
|
||||
hip_fp8 f8(a / scale);
|
||||
return f8.data;
|
||||
scaled_vec_conversion<uint8_t, float>(const float& a, float scale) {
|
||||
return __hip_cvt_float_to_fp8(a / scale, fp8_type::__default_saturation,
|
||||
fp8_type::__default_interpret);
|
||||
}
|
||||
|
||||
// fp8x4 -> float4
|
||||
// floatx2 -> fp8x2
|
||||
template <>
|
||||
__inline__ __device__ float4
|
||||
scaled_vec_conversion<float4, uint32_t>(const uint32_t& a, const float scale) {
|
||||
Float4_ tmp = scaled_vec_conversion<Float4_, uint32_t>(a, scale);
|
||||
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
||||
return res;
|
||||
__inline__ __device__ uint16_t
|
||||
scaled_vec_conversion<uint16_t, float2>(const float2& a, float scale) {
|
||||
return __hip_cvt_float2_to_fp8x2(a / scale, fp8_type::__default_saturation,
|
||||
fp8_type::__default_interpret);
|
||||
}
|
||||
|
||||
// floatx4 -> fp8x4
|
||||
template <>
|
||||
__inline__ __device__ uint32_t
|
||||
scaled_vec_conversion<uint32_t, float4>(const float4& a, float scale) {
|
||||
union {
|
||||
uint16_t ui16[2];
|
||||
uint32_t ui32;
|
||||
} tmp;
|
||||
tmp.ui16[0] = scaled_vec_conversion<uint16_t, float2>({a.x, a.y}, scale);
|
||||
tmp.ui16[1] = scaled_vec_conversion<uint16_t, float2>({a.z, a.w}, scale);
|
||||
return tmp.ui32;
|
||||
}
|
||||
#endif // ENABLE_FP8
|
||||
|
||||
|
@ -12,7 +12,7 @@ C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX =
|
||||
std::numeric_limits<FP8_TYPE>::max();
|
||||
#else
|
||||
#include <c10/util/Float8_e4m3fnuz.h>
|
||||
#include "amd/hip_float8.h"
|
||||
#include "amd/quant_utils.cuh"
|
||||
using FP8_TYPE = c10::Float8_e4m3fnuz;
|
||||
// Using the default max value from pytorch (240.0) will cause accuracy
|
||||
// issue when running dynamic quantization. Here use 224.0f for rocm.
|
||||
@ -47,8 +47,10 @@ __device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val,
|
||||
return static_cast<c10::Float8_e4m3fn>(r);
|
||||
#else
|
||||
// Use hardware cvt instruction for fp8 on rocm
|
||||
return c10::Float8_e4m3fnuz(hip_fp8(r).data,
|
||||
c10::Float8_e4m3fnuz::from_bits());
|
||||
return c10::Float8_e4m3fnuz(
|
||||
__hip_cvt_float_to_fp8(r, fp8::fp8_type::__default_saturation,
|
||||
fp8::fp8_type::__default_interpret),
|
||||
c10::Float8_e4m3fnuz::from_bits());
|
||||
#endif
|
||||
}
|
||||
|
||||
|
@ -159,19 +159,20 @@ def test_reshape_and_cache(
|
||||
device)
|
||||
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||
|
||||
# Using default kv_scale
|
||||
k_scale = (key.amax() / 64.0).to(torch.float32)
|
||||
v_scale = (value.amax() / 64.0).to(torch.float32)
|
||||
|
||||
# Clone the KV caches.
|
||||
if kv_cache_dtype == "fp8":
|
||||
cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
|
||||
ops.convert_fp8(cloned_key_cache, key_cache)
|
||||
ops.convert_fp8(cloned_key_cache, key_cache, k_scale.item())
|
||||
cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
|
||||
ops.convert_fp8(cloned_value_cache, value_cache)
|
||||
ops.convert_fp8(cloned_value_cache, value_cache, v_scale.item())
|
||||
else:
|
||||
cloned_key_cache = key_cache.clone()
|
||||
cloned_value_cache = value_cache.clone()
|
||||
|
||||
# Using default kv_scale
|
||||
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
|
||||
# Call the reshape_and_cache kernel.
|
||||
opcheck(torch.ops._C_cache_ops.reshape_and_cache,
|
||||
(key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype,
|
||||
@ -182,9 +183,9 @@ def test_reshape_and_cache(
|
||||
|
||||
if kv_cache_dtype == "fp8":
|
||||
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
|
||||
ops.convert_fp8(result_key_cache, key_cache)
|
||||
ops.convert_fp8(result_key_cache, key_cache, k_scale.item())
|
||||
result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
|
||||
ops.convert_fp8(result_value_cache, value_cache)
|
||||
ops.convert_fp8(result_value_cache, value_cache, v_scale.item())
|
||||
|
||||
# Run the reference implementation.
|
||||
reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
|
||||
@ -268,15 +269,16 @@ def test_reshape_and_cache_flash(
|
||||
del key_caches
|
||||
del value_caches
|
||||
|
||||
k_scale = (key.amax() / 256.0).to(torch.float32)
|
||||
v_scale = (value.amax() / 256.0).to(torch.float32)
|
||||
k_scale = (key.amax() / 64.0).to(torch.float32)
|
||||
v_scale = (value.amax() / 64.0).to(torch.float32)
|
||||
|
||||
# Clone the KV caches.
|
||||
if kv_cache_dtype == "fp8":
|
||||
cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
|
||||
ops.convert_fp8(cloned_key_cache, key_cache, k_scale, kv_cache_dtype)
|
||||
ops.convert_fp8(cloned_key_cache, key_cache, k_scale.item(),
|
||||
kv_cache_dtype)
|
||||
cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
|
||||
ops.convert_fp8(cloned_value_cache, value_cache, v_scale,
|
||||
ops.convert_fp8(cloned_value_cache, value_cache, v_scale.item(),
|
||||
kv_cache_dtype)
|
||||
else:
|
||||
cloned_key_cache = key_cache.clone()
|
||||
|
Loading…
x
Reference in New Issue
Block a user