#pragma once #include "cuda_primitives.h" #include #include #define MMHA_USE_FP32_ACUM_FOR_FMA #define MMHA_USE_FP32_ACUM_FOR_OUT namespace cacheflow { // A vector type to store Q, K, V elements. template struct Vec {}; template<> struct Vec { using Type = float; }; template<> struct Vec { using Type = float2; }; template<> struct Vec { using Type = float4; }; template<> struct Vec { using Type = uint16_t; }; template<> struct Vec { using Type = uint32_t; }; template<> struct Vec { using Type = uint2; }; template<> struct Vec { using Type = uint4; }; template struct FloatVec {}; template<> struct FloatVec { using Type = float; }; template<> struct FloatVec { using Type = float2; }; template<> struct FloatVec { using Type = float4; }; template<> struct FloatVec { using Type = float; }; template<> struct FloatVec { using Type = float2; }; template<> struct FloatVec { using Type = Float4_; }; template<> struct FloatVec { using Type = Float8_; }; template inline __device__ float qk_dot_(const K_vec (&q)[N], const K_vec (&k)[N]) { using K_vec_acum = typename FloatVec::Type; // Compute the parallel products for Q*K^T (treat vector lanes separately). K_vec_acum qk_vec = mul(q[0], k[0]); #pragma unroll for (int ii = 1; ii < N; ++ii) { qk_vec = fma(q[ii], k[ii], qk_vec); } // Finalize the reduction across lanes. float qk = sum(qk_vec); #pragma unroll for (int mask = THREADS_PER_KEY / 2; mask >= 1; mask /= 2) { qk += __shfl_xor_sync(uint32_t(-1), qk, mask); } return qk; } //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Qk_dot { template static inline __device__ float dot(const K_vec (&q)[N], const K_vec (&k)[N]) { return qk_dot_(q, k); } }; //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ float4 hmma_fp32(const uint2& a, uint32_t b) { float4 c; float zero = 0.f; asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 \n" " {%0, %1, %2, %3}, \n" " {%4, %5}, \n" " {%6}, \n" " {%7, %7, %7, %7}; \n" : "=f"(c.x), "=f"(c.y), "=f"(c.z), "=f"(c.w) : "r"(a.x) "r"(a.y), "r"(b), "f"(zero)); return c; } //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ float qk_hmma_dot_(const uint32_t (&q)[N], const uint32_t (&k)[N]) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 using K_vec_acum = typename FloatVec::Type; K_vec_acum qk_vec = mul(q[0], k[0]); #pragma unroll for (int ii = 1; ii < N; ++ii) { qk_vec = fma(q[ii], k[ii], qk_vec); } #ifdef MMHA_USE_FP32_ACUM_FOR_FMA uint32_t qk_vec_ = float2_to_half2(qk_vec); return hmma_fp32(make_uint2(qk_vec_, 0u), 0x3c003c00u).x; #else return hmma_fp32(make_uint2(qk_vec, 0u), 0x3c003c00u).x; #endif #else return 0.f; #endif } //////////////////////////////////////////////////////////////////////////////////////////////////// template<> struct Qk_dot { template static inline __device__ float dot(const uint32_t (&q)[N], const uint32_t (&k)[N]) { #if __CUDA_ARCH__ >= 750 && defined(MMHA_USE_HMMA_FOR_REDUCTION) return qk_hmma_dot_(q, k); #else return qk_dot_<4>(q, k); #endif // defined MMHA_USE_HMMA_FOR_REDUCTION } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ float block_sum(float* red_smem, float sum) { // Decompose the thread index into warp / lane. int warp = threadIdx.x / WARP_SIZE; int lane = threadIdx.x % WARP_SIZE; // Compute the sum per warp. #pragma unroll for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { sum += __shfl_xor_sync(uint32_t(-1), sum, mask); } // Warp leaders store the data to shared memory. if (lane == 0) { red_smem[warp] = sum; } // Make sure the data is in shared memory. __syncthreads(); // The warps compute the final sums. if (lane < WARPS_PER_BLOCK) { sum = red_smem[lane]; } // Parallel reduction inside the warp. #pragma unroll for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { sum += __shfl_xor_sync(uint32_t(-1), sum, mask); } // Broadcast to other threads. return __shfl_sync(uint32_t(-1), sum, 0); } } // namespace cacheflow #undef MMHA_USE_FP32_ACUM_FOR_FMA #undef MMHA_USE_FP32_ACUM_FOR_OUT