forward fix PR 14245, restore build on ROCm 6.2 (#14709)
Signed-off-by: Jeff Daily <jeff.daily@amd.com>
This commit is contained in:
parent
7888e1d0a3
commit
2a602b055a
@ -19,12 +19,24 @@ __device__ __forceinline__ fp8_type cvt_c10(float const r) {
|
|||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// __hip_fp8_e4m3 only exists starting in ROCm 6.3. The macro
|
||||||
|
// HIP_FP8_TYPE_OCP comes from the hip_fp8.h header and also makes
|
||||||
|
// its first appearance in ROCm 6.3. Since VLLM_DISPATCH_FP8_TYPES
|
||||||
|
// on ROCm instantiates both OCP and FNUZ kernels, we need to replace
|
||||||
|
// the new HW cvt with something reasonable that doesn't rely on the
|
||||||
|
// ROCm 6.3 feature. This allows compiling on ROCm 6.2 or newer.
|
||||||
template <>
|
template <>
|
||||||
__device__ __forceinline__ c10::Float8_e4m3fn cvt_c10(float const r) {
|
__device__ __forceinline__ c10::Float8_e4m3fn cvt_c10(float const r) {
|
||||||
|
#if HIP_FP8_TYPE_OCP
|
||||||
return c10::Float8_e4m3fn(
|
return c10::Float8_e4m3fn(
|
||||||
__hip_cvt_float_to_fp8(r, __hip_fp8_e4m3::__default_saturation,
|
__hip_cvt_float_to_fp8(r, __hip_fp8_e4m3::__default_saturation,
|
||||||
__hip_fp8_e4m3::__default_interpret),
|
__hip_fp8_e4m3::__default_interpret),
|
||||||
c10::Float8_e4m3fn::from_bits());
|
c10::Float8_e4m3fn::from_bits());
|
||||||
|
#else
|
||||||
|
// Cast implemented by pytorch. Uses bit manipulation instead of HW cvt.
|
||||||
|
// HW cvt above is faster when it is available (ROCm 6.3 or newer).
|
||||||
|
return static_cast<c10::Float8_e4m3fn>(r);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
|
Loading…
x
Reference in New Issue
Block a user