[Hotfix][CI/Build][Kernel] CUDA 11.8 does not support layernorm optimizations (#3782)
This commit is contained in:
parent
bc0c0192d1
commit
59a6abf3c9
@ -100,6 +100,8 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
|
|||||||
|
|
||||||
if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8)
|
if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8)
|
||||||
list(APPEND GPU_FLAGS "-DENABLE_FP8_E5M2")
|
list(APPEND GPU_FLAGS "-DENABLE_FP8_E5M2")
|
||||||
|
endif()
|
||||||
|
if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0)
|
||||||
list(REMOVE_ITEM GPU_FLAGS
|
list(REMOVE_ITEM GPU_FLAGS
|
||||||
"-D__CUDA_NO_HALF_OPERATORS__"
|
"-D__CUDA_NO_HALF_OPERATORS__"
|
||||||
"-D__CUDA_NO_HALF_CONVERSIONS__"
|
"-D__CUDA_NO_HALF_CONVERSIONS__"
|
||||||
|
@ -59,6 +59,8 @@ __global__ void rms_norm_kernel(
|
|||||||
template<typename torch_type>
|
template<typename torch_type>
|
||||||
struct _typeConvert { static constexpr bool exists = false; };
|
struct _typeConvert { static constexpr bool exists = false; };
|
||||||
|
|
||||||
|
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
|
||||||
|
// CUDA < 12.0 runs into issues with packed type conversion
|
||||||
template<>
|
template<>
|
||||||
struct _typeConvert<c10::Half> {
|
struct _typeConvert<c10::Half> {
|
||||||
static constexpr bool exists = true;
|
static constexpr bool exists = true;
|
||||||
@ -85,8 +87,8 @@ struct _typeConvert<c10::BFloat16> {
|
|||||||
__device__ static inline hip_type convert(float x) { return __float2bfloat16(x); }
|
__device__ static inline hip_type convert(float x) { return __float2bfloat16(x); }
|
||||||
__device__ static inline packed_hip_type convert(float2 x) { return __float22bfloat162_rn(x); }
|
__device__ static inline packed_hip_type convert(float2 x) { return __float22bfloat162_rn(x); }
|
||||||
};
|
};
|
||||||
#endif
|
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||||
|
#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
|
||||||
|
|
||||||
/* Vector POD struct to generate vectorized and packed FP16/BF16 ops
|
/* Vector POD struct to generate vectorized and packed FP16/BF16 ops
|
||||||
for appropriate specializations of fused_add_rms_norm_kernel.
|
for appropriate specializations of fused_add_rms_norm_kernel.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user