165 lines
5.1 KiB
Plaintext
165 lines
5.1 KiB
Plaintext
![]() |
#pragma once
|
||
|
|
||
|
#include <torch/all.h>
|
||
|
|
||
|
#ifndef USE_ROCM
|
||
|
#include <cuda_bf16.h>
|
||
|
#include <cuda_fp16.h>
|
||
|
#else
|
||
|
#include <hip/hip_bf16.h>
|
||
|
#include <hip/hip_fp16.h>
|
||
|
|
||
|
using __nv_bfloat16 = __hip_bfloat16;
|
||
|
using __nv_bfloat162 = __hip_bfloat162;
|
||
|
#endif
|
||
|
|
||
|
namespace vllm {
|
||
|
/* Converter structs for the conversion from torch types to HIP/CUDA types,
|
||
|
and the associated type conversions within HIP/CUDA. These helpers need
|
||
|
to be implemented for now because the relevant type conversion
|
||
|
operators/constructors are not consistently implemented by HIP/CUDA, so
|
||
|
a generic conversion via type casts cannot be implemented.
|
||
|
|
||
|
Each struct should have the member static constexpr bool `exists`:
|
||
|
If false, the optimized kernel is not used for the corresponding torch type.
|
||
|
If true, the struct should be fully defined as shown in the examples below.
|
||
|
*/
|
||
|
template <typename torch_type>
|
||
|
struct _typeConvert {
|
||
|
static constexpr bool exists = false;
|
||
|
};
|
||
|
|
||
|
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
|
||
|
// CUDA < 12.0 runs into issues with packed type conversion
|
||
|
template <>
|
||
|
struct _typeConvert<c10::Half> {
|
||
|
static constexpr bool exists = true;
|
||
|
using hip_type = __half;
|
||
|
using packed_hip_type = __half2;
|
||
|
|
||
|
__device__ static inline float convert(hip_type x) { return __half2float(x); }
|
||
|
__device__ static inline float2 convert(packed_hip_type x) {
|
||
|
return __half22float2(x);
|
||
|
}
|
||
|
__device__ static inline hip_type convert(float x) {
|
||
|
return __float2half_rn(x);
|
||
|
}
|
||
|
__device__ static inline packed_hip_type convert(float2 x) {
|
||
|
return __float22half2_rn(x);
|
||
|
}
|
||
|
};
|
||
|
|
||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||
|
// CUDA_ARCH < 800 does not have BF16 support
|
||
|
// TODO: Add in ROCm support once public headers handle bf16 maturely
|
||
|
template <>
|
||
|
struct _typeConvert<c10::BFloat16> {
|
||
|
static constexpr bool exists = true;
|
||
|
using hip_type = __nv_bfloat16;
|
||
|
using packed_hip_type = __nv_bfloat162;
|
||
|
|
||
|
__device__ static inline float convert(hip_type x) {
|
||
|
return __bfloat162float(x);
|
||
|
}
|
||
|
__device__ static inline float2 convert(packed_hip_type x) {
|
||
|
return __bfloat1622float2(x);
|
||
|
}
|
||
|
__device__ static inline hip_type convert(float x) {
|
||
|
return __float2bfloat16(x);
|
||
|
}
|
||
|
__device__ static inline packed_hip_type convert(float2 x) {
|
||
|
return __float22bfloat162_rn(x);
|
||
|
}
|
||
|
};
|
||
|
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||
|
#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >=
|
||
|
// 12000))
|
||
|
|
||
|
/* Vector POD struct to generate vectorized and packed FP16/BF16 ops
|
||
|
for appropriate specializations of fused_add_rms_norm_kernel.
|
||
|
Only functions that are necessary in that kernel are implemented.
|
||
|
Alignment to 16 bytes is required to use 128-bit global memory ops.
|
||
|
*/
|
||
|
template <typename scalar_t, int width>
|
||
|
struct alignas(16) _f16Vec {
|
||
|
/* Not theoretically necessary that width is a power of 2 but should
|
||
|
almost always be the case for optimization purposes */
|
||
|
static_assert(width > 0 && (width & (width - 1)) == 0,
|
||
|
"Width is not a positive power of 2!");
|
||
|
using Converter = _typeConvert<scalar_t>;
|
||
|
using T1 = typename Converter::hip_type;
|
||
|
using T2 = typename Converter::packed_hip_type;
|
||
|
T1 data[width];
|
||
|
|
||
|
__device__ _f16Vec& operator+=(const _f16Vec<scalar_t, width>& other) {
|
||
|
if constexpr (width % 2 == 0) {
|
||
|
#pragma unroll
|
||
|
for (int i = 0; i < width; i += 2) {
|
||
|
T2 temp{data[i], data[i + 1]};
|
||
|
temp += T2{other.data[i], other.data[i + 1]};
|
||
|
data[i] = temp.x;
|
||
|
data[i + 1] = temp.y;
|
||
|
}
|
||
|
} else {
|
||
|
#pragma unroll
|
||
|
for (int i = 0; i < width; ++i) data[i] += other.data[i];
|
||
|
}
|
||
|
return *this;
|
||
|
}
|
||
|
|
||
|
__device__ _f16Vec& operator*=(const _f16Vec<scalar_t, width>& other) {
|
||
|
if constexpr (width % 2 == 0) {
|
||
|
#pragma unroll
|
||
|
for (int i = 0; i < width; i += 2) {
|
||
|
T2 temp{data[i], data[i + 1]};
|
||
|
temp *= T2{other.data[i], other.data[i + 1]};
|
||
|
data[i] = temp.x;
|
||
|
data[i + 1] = temp.y;
|
||
|
}
|
||
|
} else {
|
||
|
#pragma unroll
|
||
|
for (int i = 0; i < width; ++i) data[i] *= other.data[i];
|
||
|
}
|
||
|
return *this;
|
||
|
}
|
||
|
|
||
|
__device__ _f16Vec& operator*=(const float scale) {
|
||
|
if constexpr (width % 2 == 0) {
|
||
|
#pragma unroll
|
||
|
for (int i = 0; i < width; i += 2) {
|
||
|
float2 temp_f = Converter::convert(T2{data[i], data[i + 1]});
|
||
|
temp_f.x *= scale;
|
||
|
temp_f.y *= scale;
|
||
|
T2 temp = Converter::convert(temp_f);
|
||
|
data[i] = temp.x;
|
||
|
data[i + 1] = temp.y;
|
||
|
}
|
||
|
} else {
|
||
|
#pragma unroll
|
||
|
for (int i = 0; i < width; ++i) {
|
||
|
float temp = Converter::convert(data[i]) * scale;
|
||
|
data[i] = Converter::convert(temp);
|
||
|
}
|
||
|
}
|
||
|
return *this;
|
||
|
}
|
||
|
|
||
|
__device__ float sum_squares() const {
|
||
|
float result = 0.0f;
|
||
|
if constexpr (width % 2 == 0) {
|
||
|
#pragma unroll
|
||
|
for (int i = 0; i < width; i += 2) {
|
||
|
float2 z = Converter::convert(T2{data[i], data[i + 1]});
|
||
|
result += z.x * z.x + z.y * z.y;
|
||
|
}
|
||
|
} else {
|
||
|
#pragma unroll
|
||
|
for (int i = 0; i < width; ++i) {
|
||
|
float x = Converter::convert(data[i]);
|
||
|
result += x * x;
|
||
|
}
|
||
|
}
|
||
|
return result;
|
||
|
}
|
||
|
};
|
||
|
} // namespace vllm
|