[Kernel] Layernorm performance optimization (#3662)
This commit is contained in:
parent
51c31bc10c
commit
b6d103542c
@ -100,6 +100,11 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
|
|||||||
|
|
||||||
if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8)
|
if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8)
|
||||||
list(APPEND GPU_FLAGS "-DENABLE_FP8_E5M2")
|
list(APPEND GPU_FLAGS "-DENABLE_FP8_E5M2")
|
||||||
|
list(REMOVE_ITEM GPU_FLAGS
|
||||||
|
"-D__CUDA_NO_HALF_OPERATORS__"
|
||||||
|
"-D__CUDA_NO_HALF_CONVERSIONS__"
|
||||||
|
"-D__CUDA_NO_BFLOAT16_CONVERSIONS__"
|
||||||
|
"-D__CUDA_NO_HALF2_OPERATORS__")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
elseif(${GPU_LANG} STREQUAL "HIP")
|
elseif(${GPU_LANG} STREQUAL "HIP")
|
||||||
|
@ -4,6 +4,16 @@
|
|||||||
|
|
||||||
#include "dispatch_utils.h"
|
#include "dispatch_utils.h"
|
||||||
#include "reduction_utils.cuh"
|
#include "reduction_utils.cuh"
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
#include <cuda_bf16.h>
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
#else
|
||||||
|
#include <hip/hip_bf16.h>
|
||||||
|
#include <hip/hip_fp16.h>
|
||||||
|
|
||||||
|
using __nv_bfloat16 = __hip_bfloat16;
|
||||||
|
using __nv_bfloat162 = __hip_bfloat162;
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
@ -35,9 +45,199 @@ __global__ void rms_norm_kernel(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Further optimize this kernel.
|
|
||||||
template<typename scalar_t>
|
/* Converter structs for the conversion from torch types to HIP/CUDA types,
|
||||||
__global__ void fused_add_rms_norm_kernel(
|
and the associated type conversions within HIP/CUDA. These helpers need
|
||||||
|
to be implemented for now because the relevant type conversion
|
||||||
|
operators/constructors are not consistently implemented by HIP/CUDA, so
|
||||||
|
a generic conversion via type casts cannot be implemented.
|
||||||
|
|
||||||
|
Each struct should have the member static constexpr bool `exists`:
|
||||||
|
If false, the optimized kernel is not used for the corresponding torch type.
|
||||||
|
If true, the struct should be fully defined as shown in the examples below.
|
||||||
|
*/
|
||||||
|
template<typename torch_type>
|
||||||
|
struct _typeConvert { static constexpr bool exists = false; };
|
||||||
|
|
||||||
|
template<>
|
||||||
|
struct _typeConvert<c10::Half> {
|
||||||
|
static constexpr bool exists = true;
|
||||||
|
using hip_type = __half;
|
||||||
|
using packed_hip_type = __half2;
|
||||||
|
|
||||||
|
__device__ static inline float convert(hip_type x) { return __half2float(x); }
|
||||||
|
__device__ static inline float2 convert(packed_hip_type x) { return __half22float2(x); }
|
||||||
|
__device__ static inline hip_type convert(float x) { return __float2half_rn(x); }
|
||||||
|
__device__ static inline packed_hip_type convert(float2 x) { return __float22half2_rn(x); }
|
||||||
|
};
|
||||||
|
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||||
|
// CUDA_ARCH < 800 does not have BF16 support
|
||||||
|
// TODO: Add in ROCm support once public headers handle bf16 maturely
|
||||||
|
template<>
|
||||||
|
struct _typeConvert<c10::BFloat16> {
|
||||||
|
static constexpr bool exists = true;
|
||||||
|
using hip_type = __nv_bfloat16;
|
||||||
|
using packed_hip_type = __nv_bfloat162;
|
||||||
|
|
||||||
|
__device__ static inline float convert(hip_type x) { return __bfloat162float(x); }
|
||||||
|
__device__ static inline float2 convert(packed_hip_type x) { return __bfloat1622float2(x); }
|
||||||
|
__device__ static inline hip_type convert(float x) { return __float2bfloat16(x); }
|
||||||
|
__device__ static inline packed_hip_type convert(float2 x) { return __float22bfloat162_rn(x); }
|
||||||
|
};
|
||||||
|
#endif
|
||||||
|
|
||||||
|
|
||||||
|
/* Vector POD struct to generate vectorized and packed FP16/BF16 ops
|
||||||
|
for appropriate specializations of fused_add_rms_norm_kernel.
|
||||||
|
Only functions that are necessary in that kernel are implemented.
|
||||||
|
Alignment to 16 bytes is required to use 128-bit global memory ops.
|
||||||
|
*/
|
||||||
|
template<typename scalar_t, int width>
|
||||||
|
struct alignas(16) _f16Vec {
|
||||||
|
/* Not theoretically necessary that width is a power of 2 but should
|
||||||
|
almost always be the case for optimization purposes */
|
||||||
|
static_assert(width > 0 && (width & (width - 1)) == 0,
|
||||||
|
"Width is not a positive power of 2!");
|
||||||
|
using Converter = _typeConvert<scalar_t>;
|
||||||
|
using T1 = typename Converter::hip_type;
|
||||||
|
using T2 = typename Converter::packed_hip_type;
|
||||||
|
T1 data[width];
|
||||||
|
|
||||||
|
__device__ _f16Vec& operator+=(const _f16Vec<scalar_t, width>& other) {
|
||||||
|
if constexpr (width % 2 == 0) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < width; i += 2) {
|
||||||
|
T2 temp{data[i], data[i+1]};
|
||||||
|
temp += T2{other.data[i], other.data[i+1]};
|
||||||
|
data[i] = temp.x;
|
||||||
|
data[i+1] = temp.y;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < width; ++i)
|
||||||
|
data[i] += other.data[i];
|
||||||
|
}
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ _f16Vec& operator*=(const _f16Vec<scalar_t, width>& other) {
|
||||||
|
if constexpr (width % 2 == 0) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < width; i += 2) {
|
||||||
|
T2 temp{data[i], data[i+1]};
|
||||||
|
temp *= T2{other.data[i], other.data[i+1]};
|
||||||
|
data[i] = temp.x;
|
||||||
|
data[i+1] = temp.y;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < width; ++i)
|
||||||
|
data[i] *= other.data[i];
|
||||||
|
}
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ _f16Vec& operator*=(const float scale) {
|
||||||
|
if constexpr (width % 2 == 0) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < width; i += 2) {
|
||||||
|
float2 temp_f = Converter::convert(T2{data[i], data[i+1]});
|
||||||
|
temp_f.x *= scale;
|
||||||
|
temp_f.y *= scale;
|
||||||
|
T2 temp = Converter::convert(temp_f);
|
||||||
|
data[i] = temp.x;
|
||||||
|
data[i+1] = temp.y;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < width; ++i) {
|
||||||
|
float temp = Converter::convert(data[i]) * scale;
|
||||||
|
data[i] = Converter::convert(temp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ float sum_squares() const {
|
||||||
|
float result = 0.0f;
|
||||||
|
if constexpr (width % 2 == 0) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < width; i += 2) {
|
||||||
|
float2 z = Converter::convert(T2{data[i], data[i+1]});
|
||||||
|
result += z.x * z.x + z.y * z.y;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < width; ++i) {
|
||||||
|
float x = Converter::convert(data[i]);
|
||||||
|
result += x * x;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/* Function specialization in the case of FP16/BF16 tensors.
|
||||||
|
Additional optimizations we can make in this case are
|
||||||
|
packed and vectorized operations, which help with the
|
||||||
|
memory latency bottleneck. */
|
||||||
|
template<typename scalar_t, int width>
|
||||||
|
__global__ std::enable_if_t<
|
||||||
|
(width > 0) && _typeConvert<scalar_t>::exists> fused_add_rms_norm_kernel(
|
||||||
|
scalar_t* __restrict__ input, // [..., hidden_size]
|
||||||
|
scalar_t* __restrict__ residual, // [..., hidden_size]
|
||||||
|
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||||
|
const float epsilon,
|
||||||
|
const int num_tokens,
|
||||||
|
const int hidden_size) {
|
||||||
|
// Sanity checks on our vector struct and type-punned pointer arithmetic
|
||||||
|
static_assert(std::is_pod_v<_f16Vec<scalar_t, width>>);
|
||||||
|
static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width);
|
||||||
|
|
||||||
|
const int vec_hidden_size = hidden_size / width;
|
||||||
|
__shared__ float s_variance;
|
||||||
|
float variance = 0.0f;
|
||||||
|
/* These and the argument pointers are all declared `restrict` as they are
|
||||||
|
not aliased in practice. Argument pointers should not be dereferenced
|
||||||
|
in this kernel as that would be undefined behavior */
|
||||||
|
auto* __restrict__ input_v = reinterpret_cast<_f16Vec<scalar_t, width>*>(input);
|
||||||
|
auto* __restrict__ residual_v = reinterpret_cast<_f16Vec<scalar_t, width>*>(residual);
|
||||||
|
auto* __restrict__ weight_v = reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight);
|
||||||
|
|
||||||
|
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
|
||||||
|
int id = blockIdx.x * vec_hidden_size + idx;
|
||||||
|
_f16Vec<scalar_t, width> temp = input_v[id];
|
||||||
|
temp += residual_v[id];
|
||||||
|
variance += temp.sum_squares();
|
||||||
|
residual_v[id] = temp;
|
||||||
|
}
|
||||||
|
/* Keep the following if-else block in sync with the
|
||||||
|
calculation of max_block_size in fused_add_rms_norm */
|
||||||
|
if (num_tokens < 256) {
|
||||||
|
variance = blockReduceSum<float, 1024>(variance);
|
||||||
|
} else variance = blockReduceSum<float, 256>(variance);
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
|
||||||
|
int id = blockIdx.x * vec_hidden_size + idx;
|
||||||
|
_f16Vec<scalar_t, width> temp = residual_v[id];
|
||||||
|
temp *= s_variance;
|
||||||
|
temp *= weight_v[idx];
|
||||||
|
input_v[id] = temp;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/* Generic fused_add_rms_norm_kernel
|
||||||
|
The width field is not used here but necessary for other specializations.
|
||||||
|
*/
|
||||||
|
template<typename scalar_t, int width>
|
||||||
|
__global__ std::enable_if_t<
|
||||||
|
(width == 0) || !_typeConvert<scalar_t>::exists> fused_add_rms_norm_kernel(
|
||||||
scalar_t* __restrict__ input, // [..., hidden_size]
|
scalar_t* __restrict__ input, // [..., hidden_size]
|
||||||
scalar_t* __restrict__ residual, // [..., hidden_size]
|
scalar_t* __restrict__ residual, // [..., hidden_size]
|
||||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||||
@ -48,12 +248,17 @@ __global__ void fused_add_rms_norm_kernel(
|
|||||||
float variance = 0.0f;
|
float variance = 0.0f;
|
||||||
|
|
||||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||||
float x = (float) input[blockIdx.x * hidden_size + idx];
|
scalar_t z = input[blockIdx.x * hidden_size + idx];
|
||||||
x += (float) residual[blockIdx.x * hidden_size + idx];
|
z += residual[blockIdx.x * hidden_size + idx];
|
||||||
|
float x = (float) z;
|
||||||
variance += x * x;
|
variance += x * x;
|
||||||
residual[blockIdx.x * hidden_size + idx] = (scalar_t) x;
|
residual[blockIdx.x * hidden_size + idx] = z;
|
||||||
}
|
}
|
||||||
variance = blockReduceSum<float>(variance);
|
/* Keep the following if-else block in sync with the
|
||||||
|
calculation of max_block_size in fused_add_rms_norm */
|
||||||
|
if (num_tokens < 256) {
|
||||||
|
variance = blockReduceSum<float, 1024>(variance);
|
||||||
|
} else variance = blockReduceSum<float, 256>(variance);
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
||||||
}
|
}
|
||||||
@ -93,6 +298,21 @@ void rms_norm(
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
|
||||||
|
VLLM_DISPATCH_FLOATING_TYPES( \
|
||||||
|
input.scalar_type(), \
|
||||||
|
"fused_add_rms_norm_kernel", \
|
||||||
|
[&] { \
|
||||||
|
vllm::fused_add_rms_norm_kernel \
|
||||||
|
<scalar_t, width><<<grid, block, 0, stream>>>( \
|
||||||
|
input.data_ptr<scalar_t>(), \
|
||||||
|
residual.data_ptr<scalar_t>(), \
|
||||||
|
weight.data_ptr<scalar_t>(), \
|
||||||
|
epsilon, \
|
||||||
|
num_tokens, \
|
||||||
|
hidden_size); \
|
||||||
|
});
|
||||||
|
|
||||||
void fused_add_rms_norm(
|
void fused_add_rms_norm(
|
||||||
torch::Tensor& input, // [..., hidden_size]
|
torch::Tensor& input, // [..., hidden_size]
|
||||||
torch::Tensor& residual, // [..., hidden_size]
|
torch::Tensor& residual, // [..., hidden_size]
|
||||||
@ -102,19 +322,29 @@ void fused_add_rms_norm(
|
|||||||
int num_tokens = input.numel() / hidden_size;
|
int num_tokens = input.numel() / hidden_size;
|
||||||
|
|
||||||
dim3 grid(num_tokens);
|
dim3 grid(num_tokens);
|
||||||
dim3 block(std::min(hidden_size, 1024));
|
/* This kernel is memory-latency bound in many scenarios.
|
||||||
|
When num_tokens is large, a smaller block size allows
|
||||||
|
for increased block occupancy on CUs and better latency
|
||||||
|
hiding on global mem ops. */
|
||||||
|
const int max_block_size = (num_tokens < 256) ? 1024 : 256;
|
||||||
|
dim3 block(std::min(hidden_size, max_block_size));
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(
|
/*If the tensor types are FP16/BF16, try to use the optimized kernel
|
||||||
input.scalar_type(),
|
with packed + vectorized ops.
|
||||||
"fused_add_rms_norm_kernel",
|
Max optimization is achieved with a width-8 vector of FP16/BF16s
|
||||||
[&] {
|
since we can load at most 128 bits at once in a global memory op.
|
||||||
vllm::fused_add_rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
However, this requires each tensor's data to be aligned to 16
|
||||||
input.data_ptr<scalar_t>(),
|
bytes.
|
||||||
residual.data_ptr<scalar_t>(),
|
*/
|
||||||
weight.data_ptr<scalar_t>(),
|
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
|
||||||
epsilon,
|
auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr());
|
||||||
num_tokens,
|
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
|
||||||
hidden_size);
|
bool ptrs_are_aligned = inp_ptr % 16 == 0 && res_ptr % 16 == 0 \
|
||||||
});
|
&& wt_ptr % 16 == 0;
|
||||||
|
if (ptrs_are_aligned && hidden_size % 8 == 0) {
|
||||||
|
LAUNCH_FUSED_ADD_RMS_NORM(8);
|
||||||
|
} else {
|
||||||
|
LAUNCH_FUSED_ADD_RMS_NORM(0);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -20,43 +20,45 @@
|
|||||||
#include "cuda_compat.h"
|
#include "cuda_compat.h"
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
template<typename T, int numLanes = WARP_SIZE>
|
||||||
template<typename T>
|
|
||||||
__inline__ __device__ T warpReduceSum(T val) {
|
__inline__ __device__ T warpReduceSum(T val) {
|
||||||
#pragma unroll
|
static_assert(numLanes > 0 && (numLanes & (numLanes - 1)) == 0,
|
||||||
for (int mask = WARP_SIZE/2; mask > 0; mask >>= 1)
|
"numLanes is not a positive power of 2!");
|
||||||
|
static_assert(numLanes <= WARP_SIZE);
|
||||||
|
#pragma unroll
|
||||||
|
for (int mask = numLanes >> 1; mask > 0; mask >>= 1)
|
||||||
val += VLLM_SHFL_XOR_SYNC(val, mask);
|
val += VLLM_SHFL_XOR_SYNC(val, mask);
|
||||||
return val;
|
return val;
|
||||||
}
|
}
|
||||||
|
|
||||||
__inline__ __device__ constexpr int _calculateLaneMask(int warp_size) {
|
// Helper function to return the next largest power of 2
|
||||||
return warp_size - 1;
|
static constexpr int _nextPow2(unsigned int num) {
|
||||||
}
|
if (num <= 1) return num;
|
||||||
|
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
|
||||||
__inline__ __device__ constexpr int _calculateWidShift(int warp_size) {
|
|
||||||
return 5 + (warp_size >> 6);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Calculate the sum of all elements in a block */
|
/* Calculate the sum of all elements in a block */
|
||||||
template<typename T>
|
template<typename T, int maxBlockSize = 1024>
|
||||||
__inline__ __device__ T blockReduceSum(T val) {
|
__inline__ __device__ T blockReduceSum(T val) {
|
||||||
static __shared__ T shared[WARP_SIZE];
|
static_assert(maxBlockSize <= 1024);
|
||||||
constexpr auto LANE_MASK = _calculateLaneMask(WARP_SIZE);
|
if constexpr (maxBlockSize > WARP_SIZE) {
|
||||||
constexpr auto WID_SHIFT = _calculateWidShift(WARP_SIZE);
|
|
||||||
int lane = threadIdx.x & LANE_MASK;
|
|
||||||
int wid = threadIdx.x >> WID_SHIFT;
|
|
||||||
|
|
||||||
val = warpReduceSum<T>(val);
|
val = warpReduceSum<T>(val);
|
||||||
|
// Calculates max number of lanes that need to participate in the last warpReduce
|
||||||
|
constexpr int maxActiveLanes = (maxBlockSize + WARP_SIZE - 1) / WARP_SIZE;
|
||||||
|
static __shared__ T shared[maxActiveLanes];
|
||||||
|
int lane = threadIdx.x % WARP_SIZE;
|
||||||
|
int wid = threadIdx.x / WARP_SIZE;
|
||||||
if (lane == 0)
|
if (lane == 0)
|
||||||
shared[wid] = val;
|
shared[wid] = val;
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
|
val = (threadIdx.x < blockDim.x / float(WARP_SIZE)) ? shared[lane] : (T)(0.0f);
|
||||||
// blockDim.x is not divided by 32
|
val = warpReduceSum<T, _nextPow2(maxActiveLanes)>(val);
|
||||||
val = (threadIdx.x < (blockDim.x / (WARP_SIZE * 1.0f))) ? shared[lane] : (T)(0.0f);
|
} else {
|
||||||
val = warpReduceSum<T>(val);
|
// A single warpReduce is equal to blockReduce
|
||||||
|
val = warpReduceSum<T, _nextPow2(maxBlockSize)>(val);
|
||||||
|
}
|
||||||
return val;
|
return val;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5,7 +5,8 @@ from vllm.model_executor.layers.layernorm import RMSNorm
|
|||||||
|
|
||||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||||
NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing
|
NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing
|
||||||
HIDDEN_SIZES = [768, 5120, 8192] # Arbitrary values for testing
|
HIDDEN_SIZES = [768, 769, 770, 771, 5120, 5124, 5125, 5126, 8192,
|
||||||
|
8199] # Arbitrary values for testing
|
||||||
ADD_RESIDUAL = [False, True]
|
ADD_RESIDUAL = [False, True]
|
||||||
SEEDS = [0]
|
SEEDS = [0]
|
||||||
CUDA_DEVICES = [
|
CUDA_DEVICES = [
|
||||||
|
Loading…
x
Reference in New Issue
Block a user