#pragma once /** * __device__ helper functions to deal with float -> quant datatype conversion */ #include "quantization/vectorization.cuh" // TODO(luka/varun):refactor common.cuh to use this file instead #include "quantization/fp8/common.cuh" namespace vllm { // TODO(luka/varun): combine into common utilities for int8 // (with int8_quant_kernels.cu) static __device__ __forceinline__ int8_t float_to_int8_rn(float const x) { #ifdef USE_ROCM static const float i8_min = static_cast(std::numeric_limits::min()); static const float i8_max = static_cast(std::numeric_limits::max()); // round float dst = std::nearbyint(x); // saturate dst = std::clamp(dst, i8_min, i8_max); return static_cast(dst); #else // CUDA path uint32_t dst; asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x)); return reinterpret_cast(dst); #endif } template static __device__ __forceinline__ fp8_type float_to_fp8(float const x) { float const r = fmax(-fp8_e4m3_adjusted_max_v, fmin(x, fp8_e4m3_adjusted_max_v)); return static_cast(r); } template struct ScaledQuant; template struct ScaledQuant< quant_type_t, is_scale_inverted, typename std::enable_if_t>> { static __device__ __forceinline__ quant_type_t quant_fn(float const x, float const scale) { if constexpr (is_scale_inverted) { return float_to_int8_rn(x * scale); } else { return float_to_int8_rn(x / scale); } } }; template struct ScaledQuant || std::is_same_v>> { static __device__ __forceinline__ quant_type_t quant_fn(float const x, float const scale) { if constexpr (is_scale_inverted) { return float_to_fp8(x * scale); } else { return float_to_fp8(x / scale); } } }; template __device__ void scaled_quant_conversion(quant_type_t* __restrict__ output, scalar_t const* __restrict__ input, float const scale, int const tid, int const num_elements, int const step) { for (int i = tid; i < num_elements; i += step) { output[i] = ScaledQuant(input[i], scale); } } } // namespace vllm