#pragma once /** * Quantization utilities including: * Adjusted maximum values for qtypes. * Minimum scaling factors for qtypes. */ #include #include #ifndef USE_ROCM #include #define MAYBE_HOST_DEVICE C10_HOST_DEVICE #else #include #include #include // ROCm doesn't seem to need C10_HOST_DEVICE for static constexpr #define MAYBE_HOST_DEVICE #endif template || std::is_same_v || std::is_same_v>> struct quant_type_max { static constexpr T val() { return std::numeric_limits::max(); } }; // Using the default max value from pytorch (240.0 0x7F) will cause accuracy // issues when running dynamic quantization. Here use 224.0 0x7E for rocm. template <> struct quant_type_max { static constexpr c10::Float8_e4m3fnuz val() { return c10::Float8_e4m3fnuz(0x7E, c10::Float8_e4m3fnuz::from_bits()); } }; template MAYBE_HOST_DEVICE static constexpr T quant_type_max_v = quant_type_max::val(); template || std::is_same_v || std::is_same_v>> struct min_scaling_factor { C10_DEVICE C10_ALWAYS_INLINE static float val() { return 1.0f / (quant_type_max_v * 512.0f); } }; template <> struct min_scaling_factor { C10_DEVICE C10_ALWAYS_INLINE static float val() { return std::numeric_limits::epsilon(); } };