[Kernel] Layernorm performance optimization (#3662)

This commit is contained in:
mawong-amd 2024-03-30 14:26:38 -07:00 committed by GitHub
parent 51c31bc10c
commit b6d103542c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 285 additions and 47 deletions

View File

@ -100,6 +100,11 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8) if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8)
list(APPEND GPU_FLAGS "-DENABLE_FP8_E5M2") list(APPEND GPU_FLAGS "-DENABLE_FP8_E5M2")
list(REMOVE_ITEM GPU_FLAGS
"-D__CUDA_NO_HALF_OPERATORS__"
"-D__CUDA_NO_HALF_CONVERSIONS__"
"-D__CUDA_NO_BFLOAT16_CONVERSIONS__"
"-D__CUDA_NO_HALF2_OPERATORS__")
endif() endif()
elseif(${GPU_LANG} STREQUAL "HIP") elseif(${GPU_LANG} STREQUAL "HIP")

View File

@ -4,6 +4,16 @@
#include "dispatch_utils.h" #include "dispatch_utils.h"
#include "reduction_utils.cuh" #include "reduction_utils.cuh"
#ifndef USE_ROCM
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#else
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
using __nv_bfloat16 = __hip_bfloat16;
using __nv_bfloat162 = __hip_bfloat162;
#endif
namespace vllm { namespace vllm {
@ -35,9 +45,199 @@ __global__ void rms_norm_kernel(
} }
} }
// TODO: Further optimize this kernel.
template<typename scalar_t> /* Converter structs for the conversion from torch types to HIP/CUDA types,
__global__ void fused_add_rms_norm_kernel( and the associated type conversions within HIP/CUDA. These helpers need
to be implemented for now because the relevant type conversion
operators/constructors are not consistently implemented by HIP/CUDA, so
a generic conversion via type casts cannot be implemented.
Each struct should have the member static constexpr bool `exists`:
If false, the optimized kernel is not used for the corresponding torch type.
If true, the struct should be fully defined as shown in the examples below.
*/
template<typename torch_type>
struct _typeConvert { static constexpr bool exists = false; };
template<>
struct _typeConvert<c10::Half> {
static constexpr bool exists = true;
using hip_type = __half;
using packed_hip_type = __half2;
__device__ static inline float convert(hip_type x) { return __half2float(x); }
__device__ static inline float2 convert(packed_hip_type x) { return __half22float2(x); }
__device__ static inline hip_type convert(float x) { return __float2half_rn(x); }
__device__ static inline packed_hip_type convert(float2 x) { return __float22half2_rn(x); }
};
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
// CUDA_ARCH < 800 does not have BF16 support
// TODO: Add in ROCm support once public headers handle bf16 maturely
template<>
struct _typeConvert<c10::BFloat16> {
static constexpr bool exists = true;
using hip_type = __nv_bfloat16;
using packed_hip_type = __nv_bfloat162;
__device__ static inline float convert(hip_type x) { return __bfloat162float(x); }
__device__ static inline float2 convert(packed_hip_type x) { return __bfloat1622float2(x); }
__device__ static inline hip_type convert(float x) { return __float2bfloat16(x); }
__device__ static inline packed_hip_type convert(float2 x) { return __float22bfloat162_rn(x); }
};
#endif
/* Vector POD struct to generate vectorized and packed FP16/BF16 ops
for appropriate specializations of fused_add_rms_norm_kernel.
Only functions that are necessary in that kernel are implemented.
Alignment to 16 bytes is required to use 128-bit global memory ops.
*/
template<typename scalar_t, int width>
struct alignas(16) _f16Vec {
/* Not theoretically necessary that width is a power of 2 but should
almost always be the case for optimization purposes */
static_assert(width > 0 && (width & (width - 1)) == 0,
"Width is not a positive power of 2!");
using Converter = _typeConvert<scalar_t>;
using T1 = typename Converter::hip_type;
using T2 = typename Converter::packed_hip_type;
T1 data[width];
__device__ _f16Vec& operator+=(const _f16Vec<scalar_t, width>& other) {
if constexpr (width % 2 == 0) {
#pragma unroll
for (int i = 0; i < width; i += 2) {
T2 temp{data[i], data[i+1]};
temp += T2{other.data[i], other.data[i+1]};
data[i] = temp.x;
data[i+1] = temp.y;
}
} else {
#pragma unroll
for (int i = 0; i < width; ++i)
data[i] += other.data[i];
}
return *this;
}
__device__ _f16Vec& operator*=(const _f16Vec<scalar_t, width>& other) {
if constexpr (width % 2 == 0) {
#pragma unroll
for (int i = 0; i < width; i += 2) {
T2 temp{data[i], data[i+1]};
temp *= T2{other.data[i], other.data[i+1]};
data[i] = temp.x;
data[i+1] = temp.y;
}
} else {
#pragma unroll
for (int i = 0; i < width; ++i)
data[i] *= other.data[i];
}
return *this;
}
__device__ _f16Vec& operator*=(const float scale) {
if constexpr (width % 2 == 0) {
#pragma unroll
for (int i = 0; i < width; i += 2) {
float2 temp_f = Converter::convert(T2{data[i], data[i+1]});
temp_f.x *= scale;
temp_f.y *= scale;
T2 temp = Converter::convert(temp_f);
data[i] = temp.x;
data[i+1] = temp.y;
}
} else {
#pragma unroll
for (int i = 0; i < width; ++i) {
float temp = Converter::convert(data[i]) * scale;
data[i] = Converter::convert(temp);
}
}
return *this;
}
__device__ float sum_squares() const {
float result = 0.0f;
if constexpr (width % 2 == 0) {
#pragma unroll
for (int i = 0; i < width; i += 2) {
float2 z = Converter::convert(T2{data[i], data[i+1]});
result += z.x * z.x + z.y * z.y;
}
} else {
#pragma unroll
for (int i = 0; i < width; ++i) {
float x = Converter::convert(data[i]);
result += x * x;
}
}
return result;
}
};
/* Function specialization in the case of FP16/BF16 tensors.
Additional optimizations we can make in this case are
packed and vectorized operations, which help with the
memory latency bottleneck. */
template<typename scalar_t, int width>
__global__ std::enable_if_t<
(width > 0) && _typeConvert<scalar_t>::exists> fused_add_rms_norm_kernel(
scalar_t* __restrict__ input, // [..., hidden_size]
scalar_t* __restrict__ residual, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon,
const int num_tokens,
const int hidden_size) {
// Sanity checks on our vector struct and type-punned pointer arithmetic
static_assert(std::is_pod_v<_f16Vec<scalar_t, width>>);
static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width);
const int vec_hidden_size = hidden_size / width;
__shared__ float s_variance;
float variance = 0.0f;
/* These and the argument pointers are all declared `restrict` as they are
not aliased in practice. Argument pointers should not be dereferenced
in this kernel as that would be undefined behavior */
auto* __restrict__ input_v = reinterpret_cast<_f16Vec<scalar_t, width>*>(input);
auto* __restrict__ residual_v = reinterpret_cast<_f16Vec<scalar_t, width>*>(residual);
auto* __restrict__ weight_v = reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight);
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
int id = blockIdx.x * vec_hidden_size + idx;
_f16Vec<scalar_t, width> temp = input_v[id];
temp += residual_v[id];
variance += temp.sum_squares();
residual_v[id] = temp;
}
/* Keep the following if-else block in sync with the
calculation of max_block_size in fused_add_rms_norm */
if (num_tokens < 256) {
variance = blockReduceSum<float, 1024>(variance);
} else variance = blockReduceSum<float, 256>(variance);
if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
__syncthreads();
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
int id = blockIdx.x * vec_hidden_size + idx;
_f16Vec<scalar_t, width> temp = residual_v[id];
temp *= s_variance;
temp *= weight_v[idx];
input_v[id] = temp;
}
}
/* Generic fused_add_rms_norm_kernel
The width field is not used here but necessary for other specializations.
*/
template<typename scalar_t, int width>
__global__ std::enable_if_t<
(width == 0) || !_typeConvert<scalar_t>::exists> fused_add_rms_norm_kernel(
scalar_t* __restrict__ input, // [..., hidden_size] scalar_t* __restrict__ input, // [..., hidden_size]
scalar_t* __restrict__ residual, // [..., hidden_size] scalar_t* __restrict__ residual, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size] const scalar_t* __restrict__ weight, // [hidden_size]
@ -48,12 +248,17 @@ __global__ void fused_add_rms_norm_kernel(
float variance = 0.0f; float variance = 0.0f;
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float) input[blockIdx.x * hidden_size + idx]; scalar_t z = input[blockIdx.x * hidden_size + idx];
x += (float) residual[blockIdx.x * hidden_size + idx]; z += residual[blockIdx.x * hidden_size + idx];
float x = (float) z;
variance += x * x; variance += x * x;
residual[blockIdx.x * hidden_size + idx] = (scalar_t) x; residual[blockIdx.x * hidden_size + idx] = z;
} }
variance = blockReduceSum<float>(variance); /* Keep the following if-else block in sync with the
calculation of max_block_size in fused_add_rms_norm */
if (num_tokens < 256) {
variance = blockReduceSum<float, 1024>(variance);
} else variance = blockReduceSum<float, 256>(variance);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon); s_variance = rsqrtf(variance / hidden_size + epsilon);
} }
@ -93,6 +298,21 @@ void rms_norm(
}); });
} }
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), \
"fused_add_rms_norm_kernel", \
[&] { \
vllm::fused_add_rms_norm_kernel \
<scalar_t, width><<<grid, block, 0, stream>>>( \
input.data_ptr<scalar_t>(), \
residual.data_ptr<scalar_t>(), \
weight.data_ptr<scalar_t>(), \
epsilon, \
num_tokens, \
hidden_size); \
});
void fused_add_rms_norm( void fused_add_rms_norm(
torch::Tensor& input, // [..., hidden_size] torch::Tensor& input, // [..., hidden_size]
torch::Tensor& residual, // [..., hidden_size] torch::Tensor& residual, // [..., hidden_size]
@ -102,19 +322,29 @@ void fused_add_rms_norm(
int num_tokens = input.numel() / hidden_size; int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens); dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024)); /* This kernel is memory-latency bound in many scenarios.
When num_tokens is large, a smaller block size allows
for increased block occupancy on CUs and better latency
hiding on global mem ops. */
const int max_block_size = (num_tokens < 256) ? 1024 : 256;
dim3 block(std::min(hidden_size, max_block_size));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES( /*If the tensor types are FP16/BF16, try to use the optimized kernel
input.scalar_type(), with packed + vectorized ops.
"fused_add_rms_norm_kernel", Max optimization is achieved with a width-8 vector of FP16/BF16s
[&] { since we can load at most 128 bits at once in a global memory op.
vllm::fused_add_rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>( However, this requires each tensor's data to be aligned to 16
input.data_ptr<scalar_t>(), bytes.
residual.data_ptr<scalar_t>(), */
weight.data_ptr<scalar_t>(), auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
epsilon, auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr());
num_tokens, auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
hidden_size); bool ptrs_are_aligned = inp_ptr % 16 == 0 && res_ptr % 16 == 0 \
}); && wt_ptr % 16 == 0;
if (ptrs_are_aligned && hidden_size % 8 == 0) {
LAUNCH_FUSED_ADD_RMS_NORM(8);
} else {
LAUNCH_FUSED_ADD_RMS_NORM(0);
}
} }

View File

@ -20,43 +20,45 @@
#include "cuda_compat.h" #include "cuda_compat.h"
namespace vllm { namespace vllm {
template<typename T, int numLanes = WARP_SIZE>
template<typename T>
__inline__ __device__ T warpReduceSum(T val) { __inline__ __device__ T warpReduceSum(T val) {
static_assert(numLanes > 0 && (numLanes & (numLanes - 1)) == 0,
"numLanes is not a positive power of 2!");
static_assert(numLanes <= WARP_SIZE);
#pragma unroll #pragma unroll
for (int mask = WARP_SIZE/2; mask > 0; mask >>= 1) for (int mask = numLanes >> 1; mask > 0; mask >>= 1)
val += VLLM_SHFL_XOR_SYNC(val, mask); val += VLLM_SHFL_XOR_SYNC(val, mask);
return val; return val;
} }
__inline__ __device__ constexpr int _calculateLaneMask(int warp_size) { // Helper function to return the next largest power of 2
return warp_size - 1; static constexpr int _nextPow2(unsigned int num) {
} if (num <= 1) return num;
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
__inline__ __device__ constexpr int _calculateWidShift(int warp_size) {
return 5 + (warp_size >> 6);
} }
/* Calculate the sum of all elements in a block */ /* Calculate the sum of all elements in a block */
template<typename T> template<typename T, int maxBlockSize = 1024>
__inline__ __device__ T blockReduceSum(T val) { __inline__ __device__ T blockReduceSum(T val) {
static __shared__ T shared[WARP_SIZE]; static_assert(maxBlockSize <= 1024);
constexpr auto LANE_MASK = _calculateLaneMask(WARP_SIZE); if constexpr (maxBlockSize > WARP_SIZE) {
constexpr auto WID_SHIFT = _calculateWidShift(WARP_SIZE);
int lane = threadIdx.x & LANE_MASK;
int wid = threadIdx.x >> WID_SHIFT;
val = warpReduceSum<T>(val); val = warpReduceSum<T>(val);
// Calculates max number of lanes that need to participate in the last warpReduce
constexpr int maxActiveLanes = (maxBlockSize + WARP_SIZE - 1) / WARP_SIZE;
static __shared__ T shared[maxActiveLanes];
int lane = threadIdx.x % WARP_SIZE;
int wid = threadIdx.x / WARP_SIZE;
if (lane == 0) if (lane == 0)
shared[wid] = val; shared[wid] = val;
__syncthreads(); __syncthreads();
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent val = (threadIdx.x < blockDim.x / float(WARP_SIZE)) ? shared[lane] : (T)(0.0f);
// blockDim.x is not divided by 32 val = warpReduceSum<T, _nextPow2(maxActiveLanes)>(val);
val = (threadIdx.x < (blockDim.x / (WARP_SIZE * 1.0f))) ? shared[lane] : (T)(0.0f); } else {
val = warpReduceSum<T>(val); // A single warpReduce is equal to blockReduce
val = warpReduceSum<T, _nextPow2(maxBlockSize)>(val);
}
return val; return val;
} }

View File

@ -5,7 +5,8 @@ from vllm.model_executor.layers.layernorm import RMSNorm
DTYPES = [torch.half, torch.bfloat16, torch.float] DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing
HIDDEN_SIZES = [768, 5120, 8192] # Arbitrary values for testing HIDDEN_SIZES = [768, 769, 770, 771, 5120, 5124, 5125, 5126, 8192,
8199] # Arbitrary values for testing
ADD_RESIDUAL = [False, True] ADD_RESIDUAL = [False, True]
SEEDS = [0] SEEDS = [0]
CUDA_DEVICES = [ CUDA_DEVICES = [