[CI/Build] Enforce style for C++ and CUDA code with clang-format (#4722)

This commit is contained in:
Michael Goin 2024-05-22 03:18:41 -04:00 committed by GitHub
parent 9b9a10d6cb
commit 5f6d10c14c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
64 changed files with 6398 additions and 6790 deletions

26
.clang-format Normal file
View File

@ -0,0 +1,26 @@
BasedOnStyle: Google
UseTab: Never
IndentWidth: 2
ColumnLimit: 80
# Force pointers to the type for C++.
DerivePointerAlignment: false
PointerAlignment: Left
# Reordering #include statements can (and currently will) introduce errors
SortIncludes: false
# Style choices
AlignConsecutiveAssignments: false
AlignConsecutiveDeclarations: false
IndentPPDirectives: BeforeHash
IncludeCategories:
- Regex: '^<'
Priority: 4
- Regex: '^"(llvm|llvm-c|clang|clang-c|mlir|mlir-c)/'
Priority: 3
- Regex: '^"(qoda|\.\.)/'
Priority: 2
- Regex: '.*'
Priority: 1

42
.github/workflows/clang-format.yml vendored Normal file
View File

@ -0,0 +1,42 @@
name: clang-format
on:
# Trigger the workflow on push or pull request,
# but only for the main branch
push:
branches:
- main
pull_request:
branches:
- main
jobs:
clang-format:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.11"]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install clang-format==18.1.5
- name: Running clang-format
run: |
EXCLUDES=(
'csrc/moe/topk_softmax_kernels.cu'
'csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu'
'csrc/punica/bgmv/bgmv_config.h'
'csrc/punica/bgmv/bgmv_impl.cuh'
'csrc/punica/bgmv/vec_dtypes.cuh'
'csrc/punica/punica_ops.cu'
'csrc/punica/type_convert.h'
)
find csrc/ \( -name '*.h' -o -name '*.cpp' -o -name '*.cu' -o -name '*.cuh' \) -print \
| grep -vFf <(printf "%s\n" "${EXCLUDES[@]}") \
| xargs clang-format --dry-run --Werror

View File

@ -10,11 +10,11 @@
namespace vllm { namespace vllm {
// Activation and gating kernel template. // Activation and gating kernel template.
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)> template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
__global__ void act_and_mul_kernel( __global__ void act_and_mul_kernel(
scalar_t* __restrict__ out, // [..., d] scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2, d] const scalar_t* __restrict__ input, // [..., 2, d]
const int d) { const int d) {
const int64_t token_idx = blockIdx.x; const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]); const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
@ -23,72 +23,66 @@ __global__ void act_and_mul_kernel(
} }
} }
template<typename T> template <typename T>
__device__ __forceinline__ T silu_kernel(const T& x) { __device__ __forceinline__ T silu_kernel(const T& x) {
// x * sigmoid(x) // x * sigmoid(x)
return (T) (((float) x) / (1.0f + expf((float) -x))); return (T)(((float)x) / (1.0f + expf((float)-x)));
} }
template<typename T> template <typename T>
__device__ __forceinline__ T gelu_kernel(const T& x) { __device__ __forceinline__ T gelu_kernel(const T& x) {
// Equivalent to PyTorch GELU with 'none' approximation. // Equivalent to PyTorch GELU with 'none' approximation.
// Refer to: // Refer to:
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38 // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38
const float f = (float) x; const float f = (float)x;
constexpr float ALPHA = M_SQRT1_2; constexpr float ALPHA = M_SQRT1_2;
return (T) (f * 0.5f * (1.0f + ::erf(f * ALPHA))); return (T)(f * 0.5f * (1.0f + ::erf(f * ALPHA)));
} }
template<typename T> template <typename T>
__device__ __forceinline__ T gelu_tanh_kernel(const T& x) { __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
// Equivalent to PyTorch GELU with 'tanh' approximation. // Equivalent to PyTorch GELU with 'tanh' approximation.
// Refer to: // Refer to:
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L25-L30 // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L25-L30
const float f = (float) x; const float f = (float)x;
constexpr float BETA = M_SQRT2 * M_2_SQRTPI * 0.5f; constexpr float BETA = M_SQRT2 * M_2_SQRTPI * 0.5f;
constexpr float KAPPA = 0.044715; constexpr float KAPPA = 0.044715;
float x_cube = f * f * f; float x_cube = f * f * f;
float inner = BETA * (f + KAPPA * x_cube); float inner = BETA * (f + KAPPA * x_cube);
return (T) (0.5f * f * (1.0f + ::tanhf(inner))); return (T)(0.5f * f * (1.0f + ::tanhf(inner)));
} }
} // namespace vllm } // namespace vllm
// Launch activation and gating kernel. // Launch activation and gating kernel.
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \ #define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
int d = input.size(-1) / 2; \ int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \ int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \ dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \ dim3 block(std::min(d, 1024)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \ VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), \ input.scalar_type(), "act_and_mul_kernel", [&] { \
"act_and_mul_kernel", \ vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>> \
[&] { \ <<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \ input.data_ptr<scalar_t>(), d); \
out.data_ptr<scalar_t>(), \ });
input.data_ptr<scalar_t>(), \
d); \
});
void silu_and_mul( void silu_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& out, // [..., d] torch::Tensor& input) // [..., 2 * d]
torch::Tensor& input) // [..., 2 * d]
{ {
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel); LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
} }
void gelu_and_mul( void gelu_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& out, // [..., d] torch::Tensor& input) // [..., 2 * d]
torch::Tensor& input) // [..., 2 * d]
{ {
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel); LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel);
} }
void gelu_tanh_and_mul( void gelu_tanh_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& out, // [..., d] torch::Tensor& input) // [..., 2 * d]
torch::Tensor& input) // [..., 2 * d]
{ {
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel); LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel);
} }
@ -96,11 +90,11 @@ void gelu_tanh_and_mul(
namespace vllm { namespace vllm {
// Element-wise activation kernel template. // Element-wise activation kernel template.
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)> template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
__global__ void activation_kernel( __global__ void activation_kernel(
scalar_t* __restrict__ out, // [..., d] scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., d] const scalar_t* __restrict__ input, // [..., d]
const int d) { const int d) {
const int64_t token_idx = blockIdx.x; const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]); const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]);
@ -108,54 +102,49 @@ __global__ void activation_kernel(
} }
} }
} // namespace vllm } // namespace vllm
// Launch element-wise activation kernel. // Launch element-wise activation kernel.
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \ #define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
int d = input.size(-1); \ int d = input.size(-1); \
int64_t num_tokens = input.numel() / d; \ int64_t num_tokens = input.numel() / d; \
dim3 grid(num_tokens); \ dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \ dim3 block(std::min(d, 1024)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \ VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "activation_kernel", [&] { \
input.scalar_type(), \ vllm::activation_kernel<scalar_t, KERNEL<scalar_t>> \
"activation_kernel", \ <<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
[&] { \ input.data_ptr<scalar_t>(), d); \
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \ });
out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), \
d); \
});
namespace vllm { namespace vllm {
template<typename T> template <typename T>
__device__ __forceinline__ T gelu_new_kernel(const T& x) { __device__ __forceinline__ T gelu_new_kernel(const T& x) {
const float x3 = (float) (x * x * x); const float x3 = (float)(x * x * x);
const T t = (T) tanhf((T) (0.79788456f * (float) (x + (T) (0.044715f * x3)))); const T t = (T)tanhf((T)(0.79788456f * (float)(x + (T)(0.044715f * x3))));
return ((T) 0.5) * x * (((T) 1.0) + t); return ((T)0.5) * x * (((T)1.0) + t);
} }
template<typename T> template <typename T>
__device__ __forceinline__ T gelu_fast_kernel(const T& x) { __device__ __forceinline__ T gelu_fast_kernel(const T& x) {
const float f = (float) x; const float f = (float)x;
const T t = (T) tanhf(((T) (f * 0.79788456f)) * (((T) 1.0) + (T) (0.044715f * f) * x)); const T t =
return ((T) 0.5) * x * (((T) 1.0) + t); (T)tanhf(((T)(f * 0.79788456f)) * (((T)1.0) + (T)(0.044715f * f) * x));
return ((T)0.5) * x * (((T)1.0) + t);
} }
} // namespace vllm } // namespace vllm
void gelu_new( void gelu_new(torch::Tensor& out, // [..., d]
torch::Tensor& out, // [..., d] torch::Tensor& input) // [..., d]
torch::Tensor& input) // [..., d]
{ {
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel); LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
} }
void gelu_fast( void gelu_fast(torch::Tensor& out, // [..., d]
torch::Tensor& out, // [..., d] torch::Tensor& input) // [..., d]
torch::Tensor& input) // [..., d]
{ {
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel); LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
} }

View File

@ -1,5 +1,6 @@
/* /*
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h * Adapted from
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
* Copyright (c) 2023, The vLLM team. * Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
* *
@ -22,31 +23,31 @@
namespace vllm { namespace vllm {
// A vector type to store Q, K, V elements. // A vector type to store Q, K, V elements.
template<typename T, int VEC_SIZE> template <typename T, int VEC_SIZE>
struct Vec {}; struct Vec {};
// A vector type to store FP32 accumulators. // A vector type to store FP32 accumulators.
template<typename T> template <typename T>
struct FloatVec {}; struct FloatVec {};
// Template vector operations. // Template vector operations.
template<typename Acc, typename A, typename B> template <typename Acc, typename A, typename B>
inline __device__ Acc mul(A a, B b); inline __device__ Acc mul(A a, B b);
template<typename T> template <typename T>
inline __device__ float sum(T v); inline __device__ float sum(T v);
template<typename T> template <typename T>
inline __device__ float dot(T a, T b) { inline __device__ float dot(T a, T b) {
return sum(mul<T, T, T>(a, b)); return sum(mul<T, T, T>(a, b));
} }
template<typename A, typename T> template <typename A, typename T>
inline __device__ float dot(T a, T b) { inline __device__ float dot(T a, T b) {
return sum(mul<A, T, T>(a, b)); return sum(mul<A, T, T>(a, b));
} }
template<typename T> template <typename T>
inline __device__ void zero(T& dst) { inline __device__ void zero(T& dst) {
constexpr int WORDS = sizeof(T) / 4; constexpr int WORDS = sizeof(T) / 4;
union { union {
@ -61,4 +62,4 @@ inline __device__ void zero(T& dst) {
dst = tmp.raw; dst = tmp.raw;
} }
} // namespace vllm } // namespace vllm

View File

@ -1,5 +1,6 @@
/* /*
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp * Adapted from
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* Copyright (c) 2023, The vLLM team. * Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
* *
@ -27,15 +28,15 @@
#ifdef USE_ROCM #ifdef USE_ROCM
#include <hip/hip_bf16.h> #include <hip/hip_bf16.h>
#include "../quantization/fp8/amd/quant_utils.cuh" #include "../quantization/fp8/amd/quant_utils.cuh"
typedef __hip_bfloat16 __nv_bfloat16; typedef __hip_bfloat16 __nv_bfloat16;
#else #else
#include "../quantization/fp8/nvidia/quant_utils.cuh" #include "../quantization/fp8/nvidia/quant_utils.cuh"
#endif #endif
#ifndef USE_ROCM #ifndef USE_ROCM
#define WARP_SIZE 32 #define WARP_SIZE 32
#else #else
#define WARP_SIZE warpSize #define WARP_SIZE warpSize
#endif #endif
#define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MAX(a, b) ((a) > (b) ? (a) : (b))
@ -45,7 +46,7 @@
namespace vllm { namespace vllm {
// Utility function for attention softmax. // Utility function for attention softmax.
template<int NUM_WARPS> template <int NUM_WARPS>
inline __device__ float block_sum(float* red_smem, float sum) { inline __device__ float block_sum(float* red_smem, float sum) {
// Decompose the thread index into warp / lane. // Decompose the thread index into warp / lane.
int warp = threadIdx.x / WARP_SIZE; int warp = threadIdx.x / WARP_SIZE;
@ -82,31 +83,28 @@ inline __device__ float block_sum(float* red_smem, float sum) {
// TODO(woosuk): Merge the last two dimensions of the grid. // TODO(woosuk): Merge the last two dimensions of the grid.
// Grid: (num_heads, num_seqs, max_num_partitions). // Grid: (num_heads, num_seqs, max_num_partitions).
template< template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
typename scalar_t, int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
typename cache_t, int PARTITION_SIZE = 0> // Zero means no partitioning.
int HEAD_SIZE,
int BLOCK_SIZE,
int NUM_THREADS,
vllm::Fp8KVCacheDataType KV_DTYPE,
int PARTITION_SIZE = 0> // Zero means no partitioning.
__device__ void paged_attention_kernel( __device__ void paged_attention_kernel(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ max_logits, // [num_seqs, num_heads,
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size] // max_num_partitions]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] // head_size]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const int num_kv_heads, // [num_heads] const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
const float scale, // head_size/x, block_size, x]
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
const int* __restrict__ seq_lens, // [num_seqs] // head_size, block_size]
const int max_num_blocks_per_seq, const int num_kv_heads, // [num_heads]
const float* __restrict__ alibi_slopes, // [num_heads] const float scale,
const int q_stride, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int kv_block_stride, const int* __restrict__ seq_lens, // [num_seqs]
const int kv_head_stride, const int max_num_blocks_per_seq,
const float kv_scale) { const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float kv_scale) {
const int seq_idx = blockIdx.y; const int seq_idx = blockIdx.y;
const int partition_idx = blockIdx.z; const int partition_idx = blockIdx.z;
const int max_num_partitions = gridDim.z; const int max_num_partitions = gridDim.z;
@ -118,22 +116,29 @@ __device__ void paged_attention_kernel(
} }
const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE); const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks; const int num_blocks_per_partition =
USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks;
// [start_block_idx, end_block_idx) is the range of blocks to process. // [start_block_idx, end_block_idx) is the range of blocks to process.
const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; const int start_block_idx =
const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks); USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
const int end_block_idx =
MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks);
const int num_blocks = end_block_idx - start_block_idx; const int num_blocks = end_block_idx - start_block_idx;
// [start_token_idx, end_token_idx) is the range of tokens to process. // [start_token_idx, end_token_idx) is the range of tokens to process.
const int start_token_idx = start_block_idx * BLOCK_SIZE; const int start_token_idx = start_block_idx * BLOCK_SIZE;
const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len); const int end_token_idx =
MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len);
const int num_tokens = end_token_idx - start_token_idx; const int num_tokens = end_token_idx - start_token_idx;
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS constexpr int NUM_THREAD_GROUPS =
NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE
// divides NUM_THREADS
assert(NUM_THREADS % THREAD_GROUP_SIZE == 0); assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
constexpr int NUM_TOKENS_PER_THREAD_GROUP = DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE); constexpr int NUM_TOKENS_PER_THREAD_GROUP =
DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE);
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
const int thread_idx = threadIdx.x; const int thread_idx = threadIdx.x;
const int warp_idx = thread_idx / WARP_SIZE; const int warp_idx = thread_idx / WARP_SIZE;
@ -143,13 +148,14 @@ __device__ void paged_attention_kernel(
const int num_heads = gridDim.x; const int num_heads = gridDim.x;
const int num_queries_per_kv = num_heads / num_kv_heads; const int num_queries_per_kv = num_heads / num_kv_heads;
const int kv_head_idx = head_idx / num_queries_per_kv; const int kv_head_idx = head_idx / num_queries_per_kv;
const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; const float alibi_slope =
alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
// A vector type to store a part of a key or a query. // A vector type to store a part of a key or a query.
// The vector size is configured in such a way that the threads in a thread group // The vector size is configured in such a way that the threads in a thread
// fetch or compute 16 bytes at a time. // group fetch or compute 16 bytes at a time. For example, if the size of a
// For example, if the size of a thread group is 4 and the data type is half, // thread group is 4 and the data type is half, then the vector size is 16 /
// then the vector size is 16 / (4 * sizeof(half)) == 2. // (4 * sizeof(half)) == 2.
constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1); constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type; using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type; using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
@ -163,18 +169,21 @@ __device__ void paged_attention_kernel(
// Load the query to registers. // Load the query to registers.
// Each thread in a thread group has a different part of the query. // Each thread in a thread group has a different part of the query.
// For example, if the the thread group size is 4, then the first thread in the group // For example, if the the thread group size is 4, then the first thread in
// has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ... // the group has 0, 4, 8, ... th vectors of the query, and the second thread
// th vectors of the query, and so on. // has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because
// NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous. // q is split from a qkv tensor, it may not be contiguous.
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD]; __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
#pragma unroll #pragma unroll
for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) { for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD;
i += NUM_THREAD_GROUPS) {
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
q_vecs[thread_group_offset][i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE); q_vecs[thread_group_offset][i] =
*reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
} }
__syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a
// memory wall right before we use q_vecs
// Memory planning. // Memory planning.
extern __shared__ char shared_mem[]; extern __shared__ char shared_mem[];
@ -193,44 +202,50 @@ __device__ void paged_attention_kernel(
// Each thread group in a warp fetches a key from the block, and computes // Each thread group in a warp fetches a key from the block, and computes
// dot product with the query. // dot product with the query.
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) { for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
// NOTE(woosuk): The block number is stored in int32. However, we cast it to int64 block_idx += NUM_WARPS) {
// because int32 can lead to overflow when this variable is multiplied by large numbers // NOTE(woosuk): The block number is stored in int32. However, we cast it to
// (e.g., kv_block_stride). // int64 because int32 can lead to overflow when this variable is multiplied
const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]); // by large numbers (e.g., kv_block_stride).
const int64_t physical_block_number =
static_cast<int64_t>(block_table[block_idx]);
// Load a key to registers. // Load a key to registers.
// Each thread in a thread group has a different part of the key. // Each thread in a thread group has a different part of the key.
// For example, if the the thread group size is 4, then the first thread in the group // For example, if the the thread group size is 4, then the first thread in
// has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th // the group has 0, 4, 8, ... th vectors of the key, and the second thread
// vectors of the key, and so on. // has 1, 5, 9, ... th vectors of the key, and so on.
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; const int physical_block_offset =
(thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
K_vec k_vecs[NUM_VECS_PER_THREAD]; K_vec k_vecs[NUM_VECS_PER_THREAD];
#pragma unroll #pragma unroll
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride const cache_t* k_ptr =
+ kv_head_idx * kv_head_stride k_cache + physical_block_number * kv_block_stride +
+ physical_block_offset * x; kv_head_idx * kv_head_stride + physical_block_offset * x;
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
const int offset1 = (vec_idx * VEC_SIZE) / x; const int offset1 = (vec_idx * VEC_SIZE) / x;
const int offset2 = (vec_idx * VEC_SIZE) % x; const int offset2 = (vec_idx * VEC_SIZE) % x;
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2); k_vecs[j] = *reinterpret_cast<const K_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
} else { } else {
// Vector conversion from Quant_vec to K_vec. // Vector conversion from Quant_vec to K_vec.
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>( Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2); k_ptr + offset1 * BLOCK_SIZE * x + offset2);
k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(k_vec_quant, kv_scale); k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(
k_vec_quant, kv_scale);
} }
} }
// Compute dot product. // Compute dot product.
// This includes a reduction across the threads in the same thread group. // This includes a reduction across the threads in the same thread group.
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs); float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(
q_vecs[thread_group_offset], k_vecs);
// Add the ALiBi bias if slopes are given. // Add the ALiBi bias if slopes are given.
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0; qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0;
@ -285,13 +300,12 @@ __device__ void paged_attention_kernel(
// If partitioning is enabled, store the max logit and exp_sum. // If partitioning is enabled, store the max logit and exp_sum.
if (USE_PARTITIONING && thread_idx == 0) { if (USE_PARTITIONING && thread_idx == 0) {
float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions float* max_logits_ptr = max_logits +
+ head_idx * max_num_partitions seq_idx * num_heads * max_num_partitions +
+ partition_idx; head_idx * max_num_partitions + partition_idx;
*max_logits_ptr = qk_max; *max_logits_ptr = qk_max;
float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions +
+ head_idx * max_num_partitions head_idx * max_num_partitions + partition_idx;
+ partition_idx;
*exp_sums_ptr = exp_sum; *exp_sums_ptr = exp_sum;
} }
@ -304,7 +318,8 @@ __device__ void paged_attention_kernel(
constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW; constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
constexpr int NUM_ROWS_PER_THREAD = DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER); constexpr int NUM_ROWS_PER_THREAD =
DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy. // NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
float accs[NUM_ROWS_PER_THREAD]; float accs[NUM_ROWS_PER_THREAD];
@ -315,18 +330,21 @@ __device__ void paged_attention_kernel(
scalar_t zero_value; scalar_t zero_value;
zero(zero_value); zero(zero_value);
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) { for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
// NOTE(woosuk): The block number is stored in int32. However, we cast it to int64 block_idx += NUM_WARPS) {
// because int32 can lead to overflow when this variable is multiplied by large numbers // NOTE(woosuk): The block number is stored in int32. However, we cast it to
// (e.g., kv_block_stride). // int64 because int32 can lead to overflow when this variable is multiplied
const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]); // by large numbers (e.g., kv_block_stride).
const int64_t physical_block_number =
static_cast<int64_t>(block_table[block_idx]);
const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
L_vec logits_vec; L_vec logits_vec;
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx - start_token_idx)); from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx -
start_token_idx));
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
+ kv_head_idx * kv_head_stride; kv_head_idx * kv_head_stride;
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
@ -337,14 +355,17 @@ __device__ void paged_attention_kernel(
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset); v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
} else { } else {
V_quant_vec v_quant_vec = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset); V_quant_vec v_quant_vec =
*reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
// Vector conversion from V_quant_vec to V_vec. // Vector conversion from V_quant_vec to V_vec.
v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec, kv_scale); v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec,
kv_scale);
} }
if (block_idx == num_seq_blocks - 1) { if (block_idx == num_seq_blocks - 1) {
// NOTE(woosuk): When v_vec contains the tokens that are out of the context, // NOTE(woosuk): When v_vec contains the tokens that are out of the
// we should explicitly zero out the values since they may contain NaNs. // context, we should explicitly zero out the values since they may
// See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 // contain NaNs. See
// https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec); scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
#pragma unroll #pragma unroll
for (int j = 0; j < V_VEC_SIZE; j++) { for (int j = 0; j < V_VEC_SIZE; j++) {
@ -367,8 +388,8 @@ __device__ void paged_attention_kernel(
accs[i] = acc; accs[i] = acc;
} }
// NOTE(woosuk): A barrier is required because the shared memory space for logits // NOTE(woosuk): A barrier is required because the shared memory space for
// is reused for the output. // logits is reused for the output.
__syncthreads(); __syncthreads();
// Perform reduction across warps. // Perform reduction across warps.
@ -405,9 +426,9 @@ __device__ void paged_attention_kernel(
// Write the final output. // Write the final output.
if (warp_idx == 0) { if (warp_idx == 0) {
scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE scalar_t* out_ptr =
+ head_idx * max_num_partitions * HEAD_SIZE out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
+ partition_idx * HEAD_SIZE; head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE;
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
@ -419,79 +440,75 @@ __device__ void paged_attention_kernel(
} }
// Grid: (num_heads, num_seqs, 1). // Grid: (num_heads, num_seqs, 1).
template< template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
typename scalar_t, int NUM_THREADS,
typename cache_t, vllm::Fp8KVCacheDataType KV_DTYPE>
int HEAD_SIZE,
int BLOCK_SIZE,
int NUM_THREADS,
vllm::Fp8KVCacheDataType KV_DTYPE>
__global__ void paged_attention_v1_kernel( __global__ void paged_attention_v1_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] // head_size/x, block_size, x]
const int num_kv_heads, // [num_heads] const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
const float scale, // head_size, block_size]
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int num_kv_heads, // [num_heads]
const int* __restrict__ seq_lens, // [num_seqs] const float scale,
const int max_num_blocks_per_seq, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const float* __restrict__ alibi_slopes, // [num_heads] const int* __restrict__ seq_lens, // [num_seqs]
const int q_stride, const int max_num_blocks_per_seq,
const int kv_block_stride, const float* __restrict__ alibi_slopes, // [num_heads]
const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float kv_scale) { const float kv_scale) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_DTYPE>( paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
/* exp_sums */ nullptr, /* max_logits */ nullptr, KV_DTYPE>(
out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, seq_lens, /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache,
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_scale); v_cache, num_kv_heads, scale, block_tables, seq_lens,
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride,
kv_head_stride, kv_scale);
} }
// Grid: (num_heads, num_seqs, max_num_partitions). // Grid: (num_heads, num_seqs, max_num_partitions).
template< template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
typename scalar_t, int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
typename cache_t, int PARTITION_SIZE>
int HEAD_SIZE,
int BLOCK_SIZE,
int NUM_THREADS,
vllm::Fp8KVCacheDataType KV_DTYPE,
int PARTITION_SIZE>
__global__ void paged_attention_v2_kernel( __global__ void paged_attention_v2_kernel(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ max_logits, // [num_seqs, num_heads,
scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] // max_num_partitions]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] // max_num_partitions, head_size]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const int num_kv_heads, // [num_heads] const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
const float scale, // head_size/x, block_size, x]
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
const int* __restrict__ seq_lens, // [num_seqs] // head_size, block_size]
const int max_num_blocks_per_seq, const int num_kv_heads, // [num_heads]
const float* __restrict__ alibi_slopes, // [num_heads] const float scale,
const int q_stride, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int kv_block_stride, const int* __restrict__ seq_lens, // [num_seqs]
const int kv_head_stride, const int max_num_blocks_per_seq,
const float kv_scale) { const float* __restrict__ alibi_slopes, // [num_heads]
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_DTYPE, PARTITION_SIZE>( const int q_stride, const int kv_block_stride, const int kv_head_stride,
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, const float kv_scale) {
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
q_stride, kv_block_stride, kv_head_stride, kv_scale); KV_DTYPE, PARTITION_SIZE>(
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride,
kv_block_stride, kv_head_stride, kv_scale);
} }
// Grid: (num_heads, num_seqs). // Grid: (num_heads, num_seqs).
template< template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS,
typename scalar_t, int PARTITION_SIZE>
int HEAD_SIZE,
int NUM_THREADS,
int PARTITION_SIZE>
__global__ void paged_attention_v2_reduce_kernel( __global__ void paged_attention_v2_reduce_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] const float* __restrict__ exp_sums, // [num_seqs, num_heads,
const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] // max_num_partitions]
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] const float* __restrict__ max_logits, // [num_seqs, num_heads,
const int* __restrict__ seq_lens, // [num_seqs] // max_num_partitions]
const int max_num_partitions) { const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
// max_num_partitions, head_size]
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_partitions) {
const int num_heads = gridDim.x; const int num_heads = gridDim.x;
const int head_idx = blockIdx.x; const int head_idx = blockIdx.x;
const int seq_idx = blockIdx.y; const int seq_idx = blockIdx.y;
@ -499,9 +516,11 @@ __global__ void paged_attention_v2_reduce_kernel(
const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE); const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE);
if (num_partitions == 1) { if (num_partitions == 1) {
// No need to reduce. Only copy tmp_out to out. // No need to reduce. Only copy tmp_out to out.
scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; scalar_t* out_ptr =
const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
+ head_idx * max_num_partitions * HEAD_SIZE; const scalar_t* tmp_out_ptr =
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
head_idx * max_num_partitions * HEAD_SIZE;
for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) { for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) {
out_ptr[i] = tmp_out_ptr[i]; out_ptr[i] = tmp_out_ptr[i];
} }
@ -520,8 +539,9 @@ __global__ void paged_attention_v2_reduce_kernel(
// Load max logits to shared memory. // Load max logits to shared memory.
float* shared_max_logits = reinterpret_cast<float*>(shared_mem); float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
const float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions const float* max_logits_ptr = max_logits +
+ head_idx * max_num_partitions; seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions;
float max_logit = -FLT_MAX; float max_logit = -FLT_MAX;
for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
const float l = max_logits_ptr[i]; const float l = max_logits_ptr[i];
@ -550,9 +570,11 @@ __global__ void paged_attention_v2_reduce_kernel(
max_logit = VLLM_SHFL_SYNC(max_logit, 0); max_logit = VLLM_SHFL_SYNC(max_logit, 0);
// Load rescaled exp sums to shared memory. // Load rescaled exp sums to shared memory.
float* shared_exp_sums = reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions); float* shared_exp_sums =
const float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
+ head_idx * max_num_partitions; const float* exp_sums_ptr = exp_sums +
seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions;
float global_exp_sum = 0.0f; float global_exp_sum = 0.0f;
for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
float l = shared_max_logits[i]; float l = shared_max_logits[i];
@ -565,61 +587,45 @@ __global__ void paged_attention_v2_reduce_kernel(
const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f); const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
// Aggregate tmp_out to out. // Aggregate tmp_out to out.
const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE const scalar_t* tmp_out_ptr =
+ head_idx * max_num_partitions * HEAD_SIZE; tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; head_idx * max_num_partitions * HEAD_SIZE;
scalar_t* out_ptr =
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
#pragma unroll #pragma unroll
for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) { for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
float acc = 0.0f; float acc = 0.0f;
for (int j = 0; j < num_partitions; ++j) { for (int j = 0; j < num_partitions; ++j) {
acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * inv_global_exp_sum; acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] *
inv_global_exp_sum;
} }
from_float(out_ptr[i], acc); from_float(out_ptr[i], acc);
} }
} }
} // namespace vllm } // namespace vllm
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
((void*)vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \ ((void*)vllm::paged_attention_v1_kernel< \
KV_DTYPE>), shared_mem_size); \ T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_DTYPE>), \
vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \ shared_mem_size); \
KV_DTYPE><<<grid, block, shared_mem_size, stream>>>( \ vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
out_ptr, \ NUM_THREADS, KV_DTYPE> \
query_ptr, \ <<<grid, block, shared_mem_size, stream>>>( \
key_cache_ptr, \ out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \
value_cache_ptr, \ scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
num_kv_heads, \ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
scale, \ kv_scale);
block_tables_ptr, \
seq_lens_ptr, \
max_num_blocks_per_seq, \
alibi_slopes_ptr, \
q_stride, \
kv_block_stride, \
kv_head_stride, \
kv_scale);
// TODO(woosuk): Tune NUM_THREADS. // TODO(woosuk): Tune NUM_THREADS.
template< template <typename T, typename CACHE_T, int BLOCK_SIZE,
typename T, vllm::Fp8KVCacheDataType KV_DTYPE, int NUM_THREADS = 128>
typename CACHE_T,
int BLOCK_SIZE,
vllm::Fp8KVCacheDataType KV_DTYPE,
int NUM_THREADS = 128>
void paged_attention_v1_launcher( void paged_attention_v1_launcher(
torch::Tensor& out, torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& query, torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& key_cache, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
torch::Tensor& value_cache, const c10::optional<torch::Tensor>& alibi_slopes, float kv_scale) {
int num_kv_heads,
float scale,
torch::Tensor& block_tables,
torch::Tensor& seq_lens,
int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
float kv_scale) {
int num_seqs = query.size(0); int num_seqs = query.size(0);
int num_heads = query.size(1); int num_heads = query.size(1);
int head_size = query.size(2); int head_size = query.size(2);
@ -632,9 +638,10 @@ void paged_attention_v1_launcher(
assert(head_size % thread_group_size == 0); assert(head_size % thread_group_size == 0);
// NOTE: alibi_slopes is optional. // NOTE: alibi_slopes is optional.
const float* alibi_slopes_ptr = alibi_slopes ? const float* alibi_slopes_ptr =
reinterpret_cast<const float*>(alibi_slopes.value().data_ptr()) alibi_slopes
: nullptr; ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr;
T* out_ptr = reinterpret_cast<T*>(out.data_ptr()); T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr()); T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
@ -644,7 +651,8 @@ void paged_attention_v1_launcher(
int* seq_lens_ptr = seq_lens.data_ptr<int>(); int* seq_lens_ptr = seq_lens.data_ptr<int>();
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int padded_max_seq_len = DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; int padded_max_seq_len =
DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
int logits_size = padded_max_seq_len * sizeof(float); int logits_size = padded_max_seq_len * sizeof(float);
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
@ -683,19 +691,10 @@ void paged_attention_v1_launcher(
} }
} }
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE) \ #define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE) \
paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE>( \ paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE>( \
out, \ out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
query, \ seq_lens, max_seq_len, alibi_slopes, kv_scale);
key_cache, \
value_cache, \
num_kv_heads, \
scale, \
block_tables, \
seq_lens, \
max_seq_len, \
alibi_slopes, \
kv_scale);
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes // NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256. // 1, 2, 4, 64, 128, 256.
@ -716,74 +715,45 @@ void paged_attention_v1_launcher(
} }
void paged_attention_v1( void paged_attention_v1(
torch::Tensor& out, // [num_seqs, num_heads, head_size] torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size] torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] torch::Tensor&
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
int num_kv_heads, // [num_heads] torch::Tensor&
float scale, value_cache, // [num_blocks, num_heads, head_size, block_size]
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] int num_kv_heads, // [num_heads]
torch::Tensor& seq_lens, // [num_seqs] float scale,
int block_size, torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
int max_seq_len, torch::Tensor& seq_lens, // [num_seqs]
const c10::optional<torch::Tensor>& alibi_slopes, int block_size, int max_seq_len,
const std::string& kv_cache_dtype, const c10::optional<torch::Tensor>& alibi_slopes,
float kv_scale) { const std::string& kv_cache_dtype, float kv_scale){
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, CALL_V1_LAUNCHER_BLOCK_SIZE) DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
} CALL_V1_LAUNCHER_BLOCK_SIZE)}
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
vllm::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
NUM_THREADS, KV_DTYPE, PARTITION_SIZE> \
<<<grid, block, shared_mem_size, stream>>>( \
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
kv_block_stride, kv_head_stride, kv_scale); \
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \
PARTITION_SIZE> \
<<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
max_num_partitions);
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \ template <typename T, typename CACHE_T, int BLOCK_SIZE,
vllm::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \ vllm::Fp8KVCacheDataType KV_DTYPE, int NUM_THREADS = 128,
KV_DTYPE, PARTITION_SIZE> \ int PARTITION_SIZE = 512>
<<<grid, block, shared_mem_size, stream>>>( \
exp_sums_ptr, \
max_logits_ptr, \
tmp_out_ptr, \
query_ptr, \
key_cache_ptr, \
value_cache_ptr, \
num_kv_heads, \
scale, \
block_tables_ptr, \
seq_lens_ptr, \
max_num_blocks_per_seq, \
alibi_slopes_ptr, \
q_stride, \
kv_block_stride, \
kv_head_stride, \
kv_scale); \
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, PARTITION_SIZE> \
<<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
out_ptr, \
exp_sums_ptr, \
max_logits_ptr, \
tmp_out_ptr, \
seq_lens_ptr, \
max_num_partitions);
template<
typename T,
typename CACHE_T,
int BLOCK_SIZE,
vllm::Fp8KVCacheDataType KV_DTYPE,
int NUM_THREADS = 128,
int PARTITION_SIZE = 512>
void paged_attention_v2_launcher( void paged_attention_v2_launcher(
torch::Tensor& out, torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& exp_sums, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& max_logits, torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& tmp_out, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
torch::Tensor& query, const c10::optional<torch::Tensor>& alibi_slopes, float kv_scale) {
torch::Tensor& key_cache,
torch::Tensor& value_cache,
int num_kv_heads,
float scale,
torch::Tensor& block_tables,
torch::Tensor& seq_lens,
int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
float kv_scale) {
int num_seqs = query.size(0); int num_seqs = query.size(0);
int num_heads = query.size(1); int num_heads = query.size(1);
int head_size = query.size(2); int head_size = query.size(2);
@ -796,9 +766,10 @@ void paged_attention_v2_launcher(
assert(head_size % thread_group_size == 0); assert(head_size % thread_group_size == 0);
// NOTE: alibi_slopes is optional. // NOTE: alibi_slopes is optional.
const float* alibi_slopes_ptr = alibi_slopes ? const float* alibi_slopes_ptr =
reinterpret_cast<const float*>(alibi_slopes.value().data_ptr()) alibi_slopes
: nullptr; ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr;
T* out_ptr = reinterpret_cast<T*>(out.data_ptr()); T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr()); float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
@ -853,59 +824,50 @@ void paged_attention_v2_launcher(
} }
} }
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE) \ #define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE) \
paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE>( \ paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE>( \
out, \ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
exp_sums, \ num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
max_logits, \ kv_scale);
tmp_out, \
query, \
key_cache, \
value_cache, \
num_kv_heads, \
scale, \
block_tables, \
seq_lens, \
max_seq_len, \
alibi_slopes, \
kv_scale);
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes // NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256. // 1, 2, 4, 64, 128, 256.
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \ #define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
switch (block_size) { \ switch (block_size) { \
case 8: \ case 8: \
CALL_V2_LAUNCHER(T, CACHE_T, 8, KV_DTYPE); \ CALL_V2_LAUNCHER(T, CACHE_T, 8, KV_DTYPE); \
break; \ break; \
case 16: \ case 16: \
CALL_V2_LAUNCHER(T, CACHE_T, 16, KV_DTYPE); \ CALL_V2_LAUNCHER(T, CACHE_T, 16, KV_DTYPE); \
break; \ break; \
case 32: \ case 32: \
CALL_V2_LAUNCHER(T, CACHE_T, 32, KV_DTYPE); \ CALL_V2_LAUNCHER(T, CACHE_T, 32, KV_DTYPE); \
break; \ break; \
default: \ default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \ TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \ break; \
} }
void paged_attention_v2( void paged_attention_v2(
torch::Tensor& out, // [num_seqs, num_heads, head_size] torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions] torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
torch::Tensor& tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] torch::Tensor&
torch::Tensor& query, // [num_seqs, num_heads, head_size] tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] torch::Tensor&
int num_kv_heads, // [num_heads] key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
float scale, torch::Tensor&
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] value_cache, // [num_blocks, num_heads, head_size, block_size]
torch::Tensor& seq_lens, // [num_seqs] int num_kv_heads, // [num_heads]
int block_size, float scale,
int max_seq_len, torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
const c10::optional<torch::Tensor>& alibi_slopes, torch::Tensor& seq_lens, // [num_seqs]
const std::string& kv_cache_dtype, int block_size, int max_seq_len,
float kv_scale) { const c10::optional<torch::Tensor>& alibi_slopes,
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, CALL_V2_LAUNCHER_BLOCK_SIZE) const std::string& kv_cache_dtype, float kv_scale) {
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
CALL_V2_LAUNCHER_BLOCK_SIZE)
} }
#undef WARP_SIZE #undef WARP_SIZE

View File

@ -1,5 +1,6 @@
/* /*
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp * Adapted from
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* Copyright (c) 2023, The vLLM team. * Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
* *
@ -26,7 +27,7 @@
namespace vllm { namespace vllm {
// Q*K^T operation. // Q*K^T operation.
template<int THREAD_GROUP_SIZE, typename Vec, int N> template <int THREAD_GROUP_SIZE, typename Vec, int N>
inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
using A_vec = typename FloatVec<Vec>::Type; using A_vec = typename FloatVec<Vec>::Type;
// Compute the parallel products for Q*K^T (treat vector lanes separately). // Compute the parallel products for Q*K^T (treat vector lanes separately).
@ -45,12 +46,12 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
return qk; return qk;
} }
template<typename T, int THREAD_GROUP_SIZE> template <typename T, int THREAD_GROUP_SIZE>
struct Qk_dot { struct Qk_dot {
template<typename Vec, int N> template <typename Vec, int N>
static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) { static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) {
return qk_dot_<THREAD_GROUP_SIZE>(q, k); return qk_dot_<THREAD_GROUP_SIZE>(q, k);
} }
}; };
} // namespace vllm } // namespace vllm

View File

@ -1,6 +1,8 @@
/* /*
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp * Adapted from
* and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* and
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
* Copyright (c) 2023, The vLLM team. * Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
* *
@ -28,8 +30,8 @@
#include <hip/hip_bf16.h> #include <hip/hip_bf16.h>
#include <hip/hip_fp16.h> #include <hip/hip_fp16.h>
typedef __hip_bfloat162 __nv_bfloat162; typedef __hip_bfloat162 __nv_bfloat162;
typedef __hip_bfloat16 __nv_bfloat16; typedef __hip_bfloat16 __nv_bfloat16;
#endif #endif
#include <stdint.h> #include <stdint.h>
@ -50,37 +52,37 @@ struct bf16_8_t {
}; };
// BF16 vector types for Q, K, V. // BF16 vector types for Q, K, V.
template<> template <>
struct Vec<__nv_bfloat16, 1> { struct Vec<__nv_bfloat16, 1> {
using Type = __nv_bfloat16; using Type = __nv_bfloat16;
}; };
template<> template <>
struct Vec<__nv_bfloat16, 2> { struct Vec<__nv_bfloat16, 2> {
using Type = __nv_bfloat162; using Type = __nv_bfloat162;
}; };
template<> template <>
struct Vec<__nv_bfloat16, 4> { struct Vec<__nv_bfloat16, 4> {
using Type = bf16_4_t; using Type = bf16_4_t;
}; };
template<> template <>
struct Vec<__nv_bfloat16, 8> { struct Vec<__nv_bfloat16, 8> {
using Type = bf16_8_t; using Type = bf16_8_t;
}; };
// FP32 accumulator vector types corresponding to Vec. // FP32 accumulator vector types corresponding to Vec.
template<> template <>
struct FloatVec<__nv_bfloat16> { struct FloatVec<__nv_bfloat16> {
using Type = float; using Type = float;
}; };
template<> template <>
struct FloatVec<__nv_bfloat162> { struct FloatVec<__nv_bfloat162> {
using Type = float2; using Type = float2;
}; };
template<> template <>
struct FloatVec<bf16_4_t> { struct FloatVec<bf16_4_t> {
using Type = Float4_; using Type = Float4_;
}; };
template<> template <>
struct FloatVec<bf16_8_t> { struct FloatVec<bf16_8_t> {
using Type = Float8_; using Type = Float8_;
}; };
@ -108,9 +110,9 @@ inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) {
assert(false); assert(false);
#else #else
#ifndef USE_ROCM #ifndef USE_ROCM
return a + b; return a + b;
#else #else
return __hadd(a, b); return __hadd(a, b);
#endif #endif
#endif #endif
} }
@ -161,7 +163,7 @@ inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) {
} }
// Vector multiplication. // Vector multiplication.
template<> template <>
inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) { inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false); assert(false);
@ -170,7 +172,7 @@ inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) {
#endif #endif
} }
template<> template <>
inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) { inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false); assert(false);
@ -179,12 +181,12 @@ inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
#endif #endif
} }
template<> template <>
inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) { inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) {
return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b); return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
} }
template<> template <>
inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) { inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) {
bf16_4_t c; bf16_4_t c;
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
@ -192,7 +194,7 @@ inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) {
return c; return c;
} }
template<> template <>
inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) { inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) {
__nv_bfloat162 s = bf162bf162(a); __nv_bfloat162 s = bf162bf162(a);
bf16_4_t c; bf16_4_t c;
@ -201,7 +203,7 @@ inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) {
return c; return c;
} }
template<> template <>
inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) { inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) {
bf16_8_t c; bf16_8_t c;
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
@ -211,7 +213,7 @@ inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) {
return c; return c;
} }
template<> template <>
inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) { inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) {
__nv_bfloat162 s = bf162bf162(a); __nv_bfloat162 s = bf162bf162(a);
bf16_8_t c; bf16_8_t c;
@ -222,26 +224,26 @@ inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) {
return c; return c;
} }
template<> template <>
inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) { inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) {
float fa = __bfloat162float(a); float fa = __bfloat162float(a);
float fb = __bfloat162float(b); float fb = __bfloat162float(b);
return fa * fb; return fa * fb;
} }
template<> template <>
inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) { inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
float2 fa = bf1622float2(a); float2 fa = bf1622float2(a);
float2 fb = bf1622float2(b); float2 fb = bf1622float2(b);
return mul<float2, float2, float2>(fa, fb); return mul<float2, float2, float2>(fa, fb);
} }
template<> template <>
inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) { inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) {
return mul<float2, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b); return mul<float2, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
} }
template<> template <>
inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) { inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) {
Float4_ fc; Float4_ fc;
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
@ -249,7 +251,7 @@ inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) {
return fc; return fc;
} }
template<> template <>
inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) { inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) {
__nv_bfloat162 s = bf162bf162(a); __nv_bfloat162 s = bf162bf162(a);
Float4_ fc; Float4_ fc;
@ -258,7 +260,7 @@ inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) {
return fc; return fc;
} }
template<> template <>
inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) { inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) {
Float8_ fc; Float8_ fc;
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
@ -268,7 +270,7 @@ inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) {
return fc; return fc;
} }
template<> template <>
inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) { inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) {
__nv_bfloat162 s = bf162bf162(a); __nv_bfloat162 s = bf162bf162(a);
Float8_ fc; Float8_ fc;
@ -280,7 +282,8 @@ inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) {
} }
// Vector fused multiply-add. // Vector fused multiply-add.
inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) { inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b,
__nv_bfloat162 c) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false); assert(false);
#else #else
@ -288,7 +291,8 @@ inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bf
#endif #endif
} }
inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c) { inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b,
__nv_bfloat162 c) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false); assert(false);
#else #else
@ -379,23 +383,23 @@ inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc) {
} }
// Vector sum. // Vector sum.
template<> template <>
inline __device__ float sum(__nv_bfloat16 v) { inline __device__ float sum(__nv_bfloat16 v) {
return __bfloat162float(v); return __bfloat162float(v);
} }
template<> template <>
inline __device__ float sum(__nv_bfloat162 v) { inline __device__ float sum(__nv_bfloat162 v) {
float2 vf = bf1622float2(v); float2 vf = bf1622float2(v);
return vf.x + vf.y; return vf.x + vf.y;
} }
template<> template <>
inline __device__ float sum(bf16_4_t v) { inline __device__ float sum(bf16_4_t v) {
return sum(v.x) + sum(v.y); return sum(v.x) + sum(v.y);
} }
template<> template <>
inline __device__ float sum(bf16_8_t v) { inline __device__ float sum(bf16_8_t v) {
return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w); return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w);
} }
@ -448,4 +452,4 @@ inline __device__ void zero(__nv_bfloat16& dst) {
#endif #endif
} }
} // namespace vllm } // namespace vllm

View File

@ -1,6 +1,8 @@
/* /*
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp * Adapted from
* and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* and
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
* Copyright (c) 2023, The vLLM team. * Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
* *
@ -30,37 +32,37 @@
namespace vllm { namespace vllm {
// FP16 vector types for Q, K, V. // FP16 vector types for Q, K, V.
template<> template <>
struct Vec<uint16_t, 1> { struct Vec<uint16_t, 1> {
using Type = uint16_t; using Type = uint16_t;
}; };
template<> template <>
struct Vec<uint16_t, 2> { struct Vec<uint16_t, 2> {
using Type = uint32_t; using Type = uint32_t;
}; };
template<> template <>
struct Vec<uint16_t, 4> { struct Vec<uint16_t, 4> {
using Type = uint2; using Type = uint2;
}; };
template<> template <>
struct Vec<uint16_t, 8> { struct Vec<uint16_t, 8> {
using Type = uint4; using Type = uint4;
}; };
// FP32 accumulator vector types corresponding to Vec. // FP32 accumulator vector types corresponding to Vec.
template<> template <>
struct FloatVec<uint16_t> { struct FloatVec<uint16_t> {
using Type = float; using Type = float;
}; };
template<> template <>
struct FloatVec<uint32_t> { struct FloatVec<uint32_t> {
using Type = float2; using Type = float2;
}; };
template<> template <>
struct FloatVec<uint2> { struct FloatVec<uint2> {
using Type = Float4_; using Type = Float4_;
}; };
template<> template <>
struct FloatVec<uint4> { struct FloatVec<uint4> {
using Type = Float8_; using Type = Float8_;
}; };
@ -73,8 +75,8 @@ inline __device__ uint32_t h0_h0(uint16_t a) {
return b; return b;
#else #else
union { union {
uint32_t u32; uint32_t u32;
uint16_t u16[2]; uint16_t u16[2];
} tmp; } tmp;
tmp.u16[0] = a; tmp.u16[0] = a;
tmp.u16[1] = a; tmp.u16[1] = a;
@ -130,10 +132,12 @@ inline __device__ uint32_t float2_to_half2(float2 f) {
} tmp; } tmp;
#ifndef USE_ROCM #ifndef USE_ROCM
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x)); asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n"
: "=r"(tmp.u32)
: "f"(f.y), "f"(f.x));
#else #else
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
#endif #endif
#else #else
tmp.u16[0] = float_to_half(f.x); tmp.u16[0] = float_to_half(f.x);
@ -201,7 +205,7 @@ inline __device__ Float8_ add(uint4 a, Float8_ fb) {
} }
// Vector multiplication. // Vector multiplication.
template<> template <>
inline __device__ uint16_t mul(uint16_t a, uint16_t b) { inline __device__ uint16_t mul(uint16_t a, uint16_t b) {
uint16_t c; uint16_t c;
#ifndef USE_ROCM #ifndef USE_ROCM
@ -212,7 +216,7 @@ inline __device__ uint16_t mul(uint16_t a, uint16_t b) {
return c; return c;
} }
template<> template <>
inline __device__ uint32_t mul(uint32_t a, uint32_t b) { inline __device__ uint32_t mul(uint32_t a, uint32_t b) {
uint32_t c; uint32_t c;
#ifndef USE_ROCM #ifndef USE_ROCM
@ -223,12 +227,12 @@ inline __device__ uint32_t mul(uint32_t a, uint32_t b) {
return c; return c;
} }
template<> template <>
inline __device__ uint32_t mul(uint16_t a, uint32_t b) { inline __device__ uint32_t mul(uint16_t a, uint32_t b) {
return mul<uint32_t, uint32_t, uint32_t>(h0_h0(a), b); return mul<uint32_t, uint32_t, uint32_t>(h0_h0(a), b);
} }
template<> template <>
inline __device__ uint2 mul(uint2 a, uint2 b) { inline __device__ uint2 mul(uint2 a, uint2 b) {
uint2 c; uint2 c;
c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x); c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
@ -236,7 +240,7 @@ inline __device__ uint2 mul(uint2 a, uint2 b) {
return c; return c;
} }
template<> template <>
inline __device__ uint2 mul(uint16_t a, uint2 b) { inline __device__ uint2 mul(uint16_t a, uint2 b) {
uint32_t s = h0_h0(a); uint32_t s = h0_h0(a);
uint2 c; uint2 c;
@ -245,7 +249,7 @@ inline __device__ uint2 mul(uint16_t a, uint2 b) {
return c; return c;
} }
template<> template <>
inline __device__ uint4 mul(uint4 a, uint4 b) { inline __device__ uint4 mul(uint4 a, uint4 b) {
uint4 c; uint4 c;
c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x); c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
@ -255,7 +259,7 @@ inline __device__ uint4 mul(uint4 a, uint4 b) {
return c; return c;
} }
template<> template <>
inline __device__ uint4 mul(uint16_t a, uint4 b) { inline __device__ uint4 mul(uint16_t a, uint4 b) {
uint32_t s = h0_h0(a); uint32_t s = h0_h0(a);
uint4 c; uint4 c;
@ -266,26 +270,26 @@ inline __device__ uint4 mul(uint16_t a, uint4 b) {
return c; return c;
} }
template<> template <>
inline __device__ float mul(uint16_t a, uint16_t b) { inline __device__ float mul(uint16_t a, uint16_t b) {
float fa = half_to_float(a); float fa = half_to_float(a);
float fb = half_to_float(b); float fb = half_to_float(b);
return fa * fb; return fa * fb;
} }
template<> template <>
inline __device__ float2 mul(uint32_t a, uint32_t b) { inline __device__ float2 mul(uint32_t a, uint32_t b) {
float2 fa = half2_to_float2(a); float2 fa = half2_to_float2(a);
float2 fb = half2_to_float2(b); float2 fb = half2_to_float2(b);
return mul<float2, float2, float2>(fa, fb); return mul<float2, float2, float2>(fa, fb);
} }
template<> template <>
inline __device__ float2 mul(uint16_t a, uint32_t b) { inline __device__ float2 mul(uint16_t a, uint32_t b) {
return mul<float2, uint32_t, uint32_t>(h0_h0(a), b); return mul<float2, uint32_t, uint32_t>(h0_h0(a), b);
} }
template<> template <>
inline __device__ Float4_ mul(uint2 a, uint2 b) { inline __device__ Float4_ mul(uint2 a, uint2 b) {
Float4_ fc; Float4_ fc;
fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x); fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
@ -293,7 +297,7 @@ inline __device__ Float4_ mul(uint2 a, uint2 b) {
return fc; return fc;
} }
template<> template <>
inline __device__ Float4_ mul(uint16_t a, uint2 b) { inline __device__ Float4_ mul(uint16_t a, uint2 b) {
uint32_t s = h0_h0(a); uint32_t s = h0_h0(a);
Float4_ fc; Float4_ fc;
@ -302,7 +306,7 @@ inline __device__ Float4_ mul(uint16_t a, uint2 b) {
return fc; return fc;
} }
template<> template <>
inline __device__ Float8_ mul(uint4 a, uint4 b) { inline __device__ Float8_ mul(uint4 a, uint4 b) {
Float8_ fc; Float8_ fc;
fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x); fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
@ -312,7 +316,7 @@ inline __device__ Float8_ mul(uint4 a, uint4 b) {
return fc; return fc;
} }
template<> template <>
inline __device__ Float8_ mul(uint16_t a, uint4 b) { inline __device__ Float8_ mul(uint16_t a, uint4 b) {
uint32_t s = h0_h0(a); uint32_t s = h0_h0(a);
Float8_ fc; Float8_ fc;
@ -327,9 +331,13 @@ inline __device__ Float8_ mul(uint16_t a, uint4 b) {
inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) { inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) {
uint32_t d; uint32_t d;
#ifndef USE_ROCM #ifndef USE_ROCM
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c)); asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
: "=r"(d)
: "r"(a), "r"(b), "r"(c));
#else #else
asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" : "=v"(d) : "v"(a), "v"(b), "v"(c)); asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n"
: "=v"(d)
: "v"(a), "v"(b), "v"(c));
#endif #endif
return d; return d;
} }
@ -423,24 +431,24 @@ inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) {
} }
// Vector sum. // Vector sum.
template<> template <>
inline __device__ float sum(uint16_t v) { inline __device__ float sum(uint16_t v) {
return half_to_float(v); return half_to_float(v);
} }
template<> template <>
inline __device__ float sum(uint32_t v) { inline __device__ float sum(uint32_t v) {
float2 tmp = half2_to_float2(v); float2 tmp = half2_to_float2(v);
return tmp.x + tmp.y; return tmp.x + tmp.y;
} }
template<> template <>
inline __device__ float sum(uint2 v) { inline __device__ float sum(uint2 v) {
uint32_t c = add(v.x, v.y); uint32_t c = add(v.x, v.y);
return sum(c); return sum(c);
} }
template<> template <>
inline __device__ float sum(uint4 v) { inline __device__ float sum(uint4 v) {
uint32_t c = add(v.x, v.y); uint32_t c = add(v.x, v.y);
c = add(c, v.z); c = add(c, v.z);
@ -470,13 +478,9 @@ inline __device__ void from_float(uint4& dst, Float8_ src) {
} }
// From float16 to float32. // From float16 to float32.
inline __device__ float to_float(uint16_t u) { inline __device__ float to_float(uint16_t u) { return half_to_float(u); }
return half_to_float(u);
}
inline __device__ float2 to_float(uint32_t u) { inline __device__ float2 to_float(uint32_t u) { return half2_to_float2(u); }
return half2_to_float2(u);
}
inline __device__ Float4_ to_float(uint2 u) { inline __device__ Float4_ to_float(uint2 u) {
Float4_ tmp; Float4_ tmp;
@ -495,8 +499,6 @@ inline __device__ Float8_ to_float(uint4 u) {
} }
// Zero-out a variable. // Zero-out a variable.
inline __device__ void zero(uint16_t& dst) { inline __device__ void zero(uint16_t& dst) { dst = uint16_t(0); }
dst = uint16_t(0);
}
} // namespace vllm } // namespace vllm

View File

@ -1,6 +1,8 @@
/* /*
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp * Adapted from
* and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* and
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
* Copyright (c) 2023, The vLLM team. * Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
* *
@ -38,37 +40,35 @@ struct Float8_ {
}; };
// FP32 vector types for Q, K, V. // FP32 vector types for Q, K, V.
template<> template <>
struct Vec<float, 1> { struct Vec<float, 1> {
using Type = float; using Type = float;
}; };
template<> template <>
struct Vec<float, 2> { struct Vec<float, 2> {
using Type = float2; using Type = float2;
}; };
template<> template <>
struct Vec<float, 4> { struct Vec<float, 4> {
using Type = float4; using Type = float4;
}; };
// FP32 accumulator vector types corresponding to Vec. // FP32 accumulator vector types corresponding to Vec.
template<> template <>
struct FloatVec<float> { struct FloatVec<float> {
using Type = float; using Type = float;
}; };
template<> template <>
struct FloatVec<float2> { struct FloatVec<float2> {
using Type = float2; using Type = float2;
}; };
template<> template <>
struct FloatVec<float4> { struct FloatVec<float4> {
using Type = float4; using Type = float4;
}; };
// Vector addition. // Vector addition.
inline __device__ float add(float a, float b) { inline __device__ float add(float a, float b) { return a + b; }
return a + b;
}
inline __device__ float2 add(float2 a, float2 b) { inline __device__ float2 add(float2 a, float2 b) {
float2 c; float2 c;
@ -87,12 +87,12 @@ inline __device__ float4 add(float4 a, float4 b) {
} }
// Vector multiplication. // Vector multiplication.
template<> template <>
inline __device__ float mul<float, float>(float a, float b) { inline __device__ float mul<float, float>(float a, float b) {
return a * b; return a * b;
} }
template<> template <>
inline __device__ float2 mul(float2 a, float2 b) { inline __device__ float2 mul(float2 a, float2 b) {
float2 c; float2 c;
c.x = a.x * b.x; c.x = a.x * b.x;
@ -100,7 +100,7 @@ inline __device__ float2 mul(float2 a, float2 b) {
return c; return c;
} }
template<> template <>
inline __device__ float2 mul(float a, float2 b) { inline __device__ float2 mul(float a, float2 b) {
float2 c; float2 c;
c.x = a * b.x; c.x = a * b.x;
@ -108,7 +108,7 @@ inline __device__ float2 mul(float a, float2 b) {
return c; return c;
} }
template<> template <>
inline __device__ float4 mul(float4 a, float4 b) { inline __device__ float4 mul(float4 a, float4 b) {
float4 c; float4 c;
c.x = a.x * b.x; c.x = a.x * b.x;
@ -118,7 +118,7 @@ inline __device__ float4 mul(float4 a, float4 b) {
return c; return c;
} }
template<> template <>
inline __device__ float4 mul(float a, float4 b) { inline __device__ float4 mul(float a, float4 b) {
float4 c; float4 c;
c.x = a * b.x; c.x = a * b.x;
@ -129,9 +129,7 @@ inline __device__ float4 mul(float a, float4 b) {
} }
// Vector fused multiply-add. // Vector fused multiply-add.
inline __device__ float fma(float a, float b, float c) { inline __device__ float fma(float a, float b, float c) { return a * b + c; }
return a * b + c;
}
inline __device__ float2 fma(float2 a, float2 b, float2 c) { inline __device__ float2 fma(float2 a, float2 b, float2 c) {
float2 d; float2 d;
@ -182,35 +180,33 @@ inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) {
} }
// Vector sum. // Vector sum.
template<> template <>
inline __device__ float sum(float v) { inline __device__ float sum(float v) {
return v; return v;
} }
template<> template <>
inline __device__ float sum(float2 v) { inline __device__ float sum(float2 v) {
return v.x + v.y; return v.x + v.y;
} }
template<> template <>
inline __device__ float sum(float4 v) { inline __device__ float sum(float4 v) {
return v.x + v.y + v.z + v.w; return v.x + v.y + v.z + v.w;
} }
template<> template <>
inline __device__ float sum(Float4_ v) { inline __device__ float sum(Float4_ v) {
return v.x.x + v.x.y + v.y.x + v.y.y; return v.x.x + v.x.y + v.y.x + v.y.y;
} }
template<> template <>
inline __device__ float sum(Float8_ v) { inline __device__ float sum(Float8_ v) {
return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y; return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y;
} }
// Vector dot product. // Vector dot product.
inline __device__ float dot(float a, float b) { inline __device__ float dot(float a, float b) { return a * b; }
return a * b;
}
inline __device__ float dot(float2 a, float2 b) { inline __device__ float dot(float2 a, float2 b) {
float2 c = mul<float2, float2, float2>(a, b); float2 c = mul<float2, float2, float2>(a, b);
@ -232,42 +228,24 @@ inline __device__ float dot(Float8_ a, Float8_ b) {
} }
// From float to float. // From float to float.
inline __device__ void from_float(float& dst, float src) { inline __device__ void from_float(float& dst, float src) { dst = src; }
dst = src;
}
inline __device__ void from_float(float2& dst, float2 src) { inline __device__ void from_float(float2& dst, float2 src) { dst = src; }
dst = src;
}
inline __device__ void from_float(float4& dst, float4 src) { inline __device__ void from_float(float4& dst, float4 src) { dst = src; }
dst = src;
}
// From float to float. // From float to float.
inline __device__ float to_float(float u) { inline __device__ float to_float(float u) { return u; }
return u;
}
inline __device__ float2 to_float(float2 u) { inline __device__ float2 to_float(float2 u) { return u; }
return u;
}
inline __device__ float4 to_float(float4 u) { inline __device__ float4 to_float(float4 u) { return u; }
return u;
}
inline __device__ Float4_ to_float(Float4_ u) { inline __device__ Float4_ to_float(Float4_ u) { return u; }
return u;
}
inline __device__ Float8_ to_float(Float8_ u) { inline __device__ Float8_ to_float(Float8_ u) { return u; }
return u;
}
// Zero-out a variable. // Zero-out a variable.
inline __device__ void zero(float& dst) { inline __device__ void zero(float& dst) { dst = 0.f; }
dst = 0.f;
}
} // namespace vllm } // namespace vllm

View File

@ -4,38 +4,38 @@
#include <stdint.h> #include <stdint.h>
#ifdef ENABLE_FP8 #ifdef ENABLE_FP8
#ifndef USE_ROCM #ifndef USE_ROCM
#include <cuda_fp8.h> #include <cuda_fp8.h>
#endif // USE_ROCM #endif // USE_ROCM
#endif // ENABLE_FP8 #endif // ENABLE_FP8
namespace vllm { namespace vllm {
enum class Fp8KVCacheDataType { enum class Fp8KVCacheDataType {
kAuto = 0, kAuto = 0,
kFp8E4M3 = 1, kFp8E4M3 = 1,
kFp8E5M2 = 2, kFp8E5M2 = 2,
}; };
// fp8 vector types for quantization of kv cache // fp8 vector types for quantization of kv cache
template<> template <>
struct Vec<uint8_t, 1> { struct Vec<uint8_t, 1> {
using Type = uint8_t; using Type = uint8_t;
}; };
template<> template <>
struct Vec<uint8_t, 2> { struct Vec<uint8_t, 2> {
using Type = uint16_t; using Type = uint16_t;
}; };
template<> template <>
struct Vec<uint8_t, 4> { struct Vec<uint8_t, 4> {
using Type = uint32_t; using Type = uint32_t;
}; };
template<> template <>
struct Vec<uint8_t, 8> { struct Vec<uint8_t, 8> {
using Type = uint2; using Type = uint2;
}; };
} // namespace vllm } // namespace vllm

View File

@ -5,36 +5,24 @@
#include <map> #include <map>
#include <vector> #include <vector>
void swap_blocks( void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
torch::Tensor& src, const torch::Tensor& block_mapping);
torch::Tensor& dst,
const torch::Tensor& block_mapping);
void copy_blocks( void copy_blocks(std::vector<torch::Tensor>& key_caches,
std::vector<torch::Tensor>& key_caches, std::vector<torch::Tensor>& value_caches,
std::vector<torch::Tensor>& value_caches, const torch::Tensor& block_mapping);
const torch::Tensor& block_mapping);
void reshape_and_cache( void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key, torch::Tensor& key_cache, torch::Tensor& value_cache,
torch::Tensor& value, torch::Tensor& slot_mapping,
torch::Tensor& key_cache, const std::string& kv_cache_dtype, const float kv_scale);
torch::Tensor& value_cache,
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype,
const float kv_scale);
void reshape_and_cache_flash( void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key, torch::Tensor& key_cache,
torch::Tensor& value, torch::Tensor& value_cache,
torch::Tensor& key_cache, torch::Tensor& slot_mapping,
torch::Tensor& value_cache, const std::string& kv_cache_dtype);
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype);
// Just for unittest // Just for unittest
void convert_fp8( void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
torch::Tensor& dst_cache, const float scale, const std::string& kv_cache_dtype);
torch::Tensor& src_cache,
const float scale,
const std::string& kv_cache_dtype);

View File

@ -6,9 +6,9 @@
#include "dispatch_utils.h" #include "dispatch_utils.h"
#ifdef USE_ROCM #ifdef USE_ROCM
#include "quantization/fp8/amd/quant_utils.cuh" #include "quantization/fp8/amd/quant_utils.cuh"
#else #else
#include "quantization/fp8/nvidia/quant_utils.cuh" #include "quantization/fp8/nvidia/quant_utils.cuh"
#endif #endif
#include <algorithm> #include <algorithm>
@ -18,20 +18,17 @@
#ifdef USE_ROCM #ifdef USE_ROCM
#include <hip/hip_bf16.h> #include <hip/hip_bf16.h>
typedef __hip_bfloat16 __nv_bfloat16; typedef __hip_bfloat16 __nv_bfloat16;
#endif #endif
void swap_blocks( void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
torch::Tensor& src, const torch::Tensor& block_mapping) {
torch::Tensor& dst,
const torch::Tensor& block_mapping) {
torch::Device src_device = src.device(); torch::Device src_device = src.device();
torch::Device dst_device = dst.device(); torch::Device dst_device = dst.device();
cudaMemcpyKind memcpy_type; cudaMemcpyKind memcpy_type;
if (src_device.is_cuda() && dst_device.is_cuda()) { if (src_device.is_cuda() && dst_device.is_cuda()) {
TORCH_CHECK( TORCH_CHECK(src_device.index() == dst_device.index(),
src_device.index() == dst_device.index(), "src and dst must be on the same GPU");
"src and dst must be on the same GPU");
memcpy_type = cudaMemcpyDeviceToDevice; memcpy_type = cudaMemcpyDeviceToDevice;
} else if (src_device.is_cuda() && dst_device.is_cpu()) { } else if (src_device.is_cuda() && dst_device.is_cpu()) {
memcpy_type = cudaMemcpyDeviceToHost; memcpy_type = cudaMemcpyDeviceToHost;
@ -46,11 +43,12 @@ void swap_blocks(
// synchronization. // synchronization.
TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU"); TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU");
char *src_ptr = static_cast<char*>(src.data_ptr()); char* src_ptr = static_cast<char*>(src.data_ptr());
char *dst_ptr = static_cast<char*>(dst.data_ptr()); char* dst_ptr = static_cast<char*>(dst.data_ptr());
const int64_t block_size_in_bytes = src.element_size() * src[0].numel(); const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
const at::cuda::OptionalCUDAGuard device_guard(src_device.is_cuda() ? src_device : dst_device); const at::cuda::OptionalCUDAGuard device_guard(
src_device.is_cuda() ? src_device : dst_device);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// NOTE(woosuk): This can be slow if the number of blocks is large. // NOTE(woosuk): This can be slow if the number of blocks is large.
const int64_t num_blocks = block_mapping.size(0); const int64_t num_blocks = block_mapping.size(0);
@ -59,29 +57,25 @@ void swap_blocks(
int64_t dst_block_number = block_mapping[i][1].item<int64_t>(); int64_t dst_block_number = block_mapping[i][1].item<int64_t>();
int64_t src_offset = src_block_number * block_size_in_bytes; int64_t src_offset = src_block_number * block_size_in_bytes;
int64_t dst_offset = dst_block_number * block_size_in_bytes; int64_t dst_offset = dst_block_number * block_size_in_bytes;
cudaMemcpyAsync( cudaMemcpyAsync(dst_ptr + dst_offset, src_ptr + src_offset,
dst_ptr + dst_offset, block_size_in_bytes, memcpy_type, stream);
src_ptr + src_offset,
block_size_in_bytes,
memcpy_type,
stream);
} }
} }
namespace vllm { namespace vllm {
// Grid: (num_layers, num_pairs) // Grid: (num_layers, num_pairs)
template<typename scalar_t> template <typename scalar_t>
__global__ void copy_blocks_kernel( __global__ void copy_blocks_kernel(int64_t* key_cache_ptrs,
int64_t* key_cache_ptrs, int64_t* value_cache_ptrs,
int64_t* value_cache_ptrs, const int64_t* __restrict__ block_mapping,
const int64_t* __restrict__ block_mapping, const int numel_per_block) {
const int numel_per_block) {
const int layer_idx = blockIdx.x; const int layer_idx = blockIdx.x;
const int pair_idx = blockIdx.y; const int pair_idx = blockIdx.y;
scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]); scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]);
scalar_t* value_cache = reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]); scalar_t* value_cache =
reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
int64_t src_block_number = block_mapping[2 * pair_idx]; int64_t src_block_number = block_mapping[2 * pair_idx];
int64_t dst_block_number = block_mapping[2 * pair_idx + 1]; int64_t dst_block_number = block_mapping[2 * pair_idx + 1];
@ -99,12 +93,11 @@ __global__ void copy_blocks_kernel(
} }
} }
} // namespace vllm } // namespace vllm
void copy_blocks( void copy_blocks(std::vector<torch::Tensor>& key_caches,
std::vector<torch::Tensor>& key_caches, std::vector<torch::Tensor>& value_caches,
std::vector<torch::Tensor>& value_caches, const torch::Tensor& block_mapping) {
const torch::Tensor& block_mapping) {
int num_layers = key_caches.size(); int num_layers = key_caches.size();
TORCH_CHECK(num_layers == value_caches.size()); TORCH_CHECK(num_layers == value_caches.size());
if (num_layers == 0) { if (num_layers == 0) {
@ -118,8 +111,10 @@ void copy_blocks(
int64_t key_cache_ptrs[num_layers]; int64_t key_cache_ptrs[num_layers];
int64_t value_cache_ptrs[num_layers]; int64_t value_cache_ptrs[num_layers];
for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) { for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
key_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr()); key_cache_ptrs[layer_idx] =
value_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr()); reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
value_cache_ptrs[layer_idx] =
reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
} }
// block_mapping is a 2D tensor with shape (num_pairs, 2). // block_mapping is a 2D tensor with shape (num_pairs, 2).
@ -127,10 +122,12 @@ void copy_blocks(
// Move the data structures to the GPU. // Move the data structures to the GPU.
// NOTE: This synchronizes the CPU and GPU. // NOTE: This synchronizes the CPU and GPU.
torch::Tensor key_cache_ptrs_tensor = torch::from_blob( torch::Tensor key_cache_ptrs_tensor =
key_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device); torch::from_blob(key_cache_ptrs, {num_layers}, torch::kInt64)
torch::Tensor value_cache_ptrs_tensor = torch::from_blob( .to(cache_device);
value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device); torch::Tensor value_cache_ptrs_tensor =
torch::from_blob(value_cache_ptrs, {num_layers}, torch::kInt64)
.to(cache_device);
// Launch the kernel. // Launch the kernel.
const int numel_per_block = key_caches[0][0].numel(); const int numel_per_block = key_caches[0][0].numel();
@ -139,31 +136,28 @@ void copy_blocks(
const at::cuda::OptionalCUDAGuard device_guard(cache_device); const at::cuda::OptionalCUDAGuard device_guard(cache_device);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES( VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] { key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>( vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
key_cache_ptrs_tensor.data_ptr<int64_t>(), key_cache_ptrs_tensor.data_ptr<int64_t>(),
value_cache_ptrs_tensor.data_ptr<int64_t>(), value_cache_ptrs_tensor.data_ptr<int64_t>(),
block_mapping.data_ptr<int64_t>(), block_mapping.data_ptr<int64_t>(), numel_per_block);
numel_per_block); }));
}));
} }
namespace vllm { namespace vllm {
template<typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt> template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void reshape_and_cache_kernel( __global__ void reshape_and_cache_kernel(
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x,
cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size] // block_size, x]
const int64_t* __restrict__ slot_mapping, // [num_tokens] cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size,
const int key_stride, // block_size]
const int value_stride, const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int num_heads, const int key_stride, const int value_stride, const int num_heads,
const int head_size, const int head_size, const int block_size, const int x,
const int block_size, const float kv_scale) {
const int x,
const float kv_scale) {
const int64_t token_idx = blockIdx.x; const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx]; const int64_t slot_idx = slot_mapping[token_idx];
if (slot_idx < 0) { if (slot_idx < 0) {
@ -184,40 +178,39 @@ __global__ void reshape_and_cache_kernel(
const int x_idx = head_offset / x; const int x_idx = head_offset / x;
const int x_offset = head_offset % x; const int x_offset = head_offset % x;
const int64_t tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x const int64_t tgt_key_idx =
+ head_idx * (head_size / x) * block_size * x block_idx * num_heads * (head_size / x) * block_size * x +
+ x_idx * block_size * x head_idx * (head_size / x) * block_size * x + x_idx * block_size * x +
+ block_offset * x block_offset * x + x_offset;
+ x_offset; const int64_t tgt_value_idx =
const int64_t tgt_value_idx = block_idx * num_heads * head_size * block_size block_idx * num_heads * head_size * block_size +
+ head_idx * head_size * block_size head_idx * head_size * block_size + head_offset * block_size +
+ head_offset * block_size block_offset;
+ block_offset;
scalar_t tgt_key = key[src_key_idx]; scalar_t tgt_key = key[src_key_idx];
scalar_t tgt_value = value[src_value_idx]; scalar_t tgt_value = value[src_value_idx];
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
key_cache[tgt_key_idx] = tgt_key; key_cache[tgt_key_idx] = tgt_key;
value_cache[tgt_value_idx] = tgt_value; value_cache[tgt_value_idx] = tgt_value;
} else { } else {
key_cache[tgt_key_idx] = fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, kv_scale); key_cache[tgt_key_idx] =
value_cache[tgt_value_idx] = fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, kv_scale); fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, kv_scale);
value_cache[tgt_value_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, kv_scale);
} }
} }
} }
template<typename scalar_t> template <typename scalar_t>
__global__ void reshape_and_cache_flash_kernel( __global__ void reshape_and_cache_flash_kernel(
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
scalar_t* __restrict__ k_cache, // [num_blocks, block_size, num_heads, head_size] scalar_t* __restrict__ k_cache, // [num_blocks, block_size, num_heads,
scalar_t* __restrict__ v_cache, // [num_blocks, block_size, num_heads, head_size] // head_size]
const int64_t* __restrict__ slot_mapping, // [num_tokens] scalar_t* __restrict__ v_cache, // [num_blocks, block_size, num_heads,
const int block_stride, // head_size]
const int key_stride, const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int value_stride, const int block_stride, const int key_stride, const int value_stride,
const int num_heads, const int num_heads, const int head_size, const int block_size) {
const int head_size,
const int block_size) {
const int64_t token_idx = blockIdx.x; const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx]; const int64_t slot_idx = slot_mapping[token_idx];
// NOTE: slot_idx can be -1 if the token is padded // NOTE: slot_idx can be -1 if the token is padded
@ -232,43 +225,37 @@ __global__ void reshape_and_cache_flash_kernel(
const int64_t src_value_idx = token_idx * value_stride + i; const int64_t src_value_idx = token_idx * value_stride + i;
const int head_idx = i / head_size; const int head_idx = i / head_size;
const int head_offset = i % head_size; const int head_offset = i % head_size;
const int64_t tgt_value_idx = block_idx * block_stride const int64_t tgt_value_idx = block_idx * block_stride +
+ block_offset * num_heads * head_size block_offset * num_heads * head_size +
+ head_idx * head_size head_idx * head_size + head_offset;
+ head_offset;
k_cache[tgt_value_idx] = key[src_key_idx]; k_cache[tgt_value_idx] = key[src_key_idx];
v_cache[tgt_value_idx] = value[src_value_idx]; v_cache[tgt_value_idx] = value[src_value_idx];
} }
} }
} // namespace vllm } // namespace vllm
// KV_T is the stored data type of kv-cache. // KV_T is the stored data type of kv-cache.
// CACHE_T is the data type of key and value tensors. // CACHE_T is the data type of key and value tensors.
// KV_DTYPE is the real data type of kv-cache. // KV_DTYPE is the real data type of kv-cache.
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \ #define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \
vllm::reshape_and_cache_kernel<KV_T, CACHE_T, KV_DTYPE><<<grid, block, 0, stream>>>( \ vllm::reshape_and_cache_kernel<KV_T, CACHE_T, KV_DTYPE> \
reinterpret_cast<KV_T*>(key.data_ptr()), \ <<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(value.data_ptr()), \ reinterpret_cast<KV_T*>(key.data_ptr()), \
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \ reinterpret_cast<KV_T*>(value.data_ptr()), \
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \ reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), \ reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
key_stride, \ slot_mapping.data_ptr<int64_t>(), key_stride, value_stride, \
value_stride, \ num_heads, head_size, block_size, x, kv_scale);
num_heads, \
head_size, \
block_size, \
x, \
kv_scale);
void reshape_and_cache( void reshape_and_cache(
torch::Tensor& key, // [num_tokens, num_heads, head_size] torch::Tensor& key, // [num_tokens, num_heads, head_size]
torch::Tensor& value, // [num_tokens, num_heads, head_size] torch::Tensor& value, // [num_tokens, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] torch::Tensor&
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor& slot_mapping, // [num_tokens] torch::Tensor&
const std::string& kv_cache_dtype, value_cache, // [num_blocks, num_heads, head_size, block_size]
const float kv_scale) torch::Tensor& slot_mapping, // [num_tokens]
{ const std::string& kv_cache_dtype, const float kv_scale) {
int num_tokens = key.size(0); int num_tokens = key.size(0);
int num_heads = key.size(1); int num_heads = key.size(1);
int head_size = key.size(2); int head_size = key.size(2);
@ -283,17 +270,17 @@ void reshape_and_cache(
const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype, CALL_RESHAPE_AND_CACHE) DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
CALL_RESHAPE_AND_CACHE)
} }
void reshape_and_cache_flash( void reshape_and_cache_flash(
torch::Tensor& key, // [num_tokens, num_heads, head_size] torch::Tensor& key, // [num_tokens, num_heads, head_size]
torch::Tensor& value, // [num_tokens, num_heads, head_size] torch::Tensor& value, // [num_tokens, num_heads, head_size]
torch::Tensor& k_cache, // [num_blocks, block_size, num_heads, head_size] torch::Tensor& k_cache, // [num_blocks, block_size, num_heads, head_size]
torch::Tensor& v_cache, // [num_blocks, block_size, num_heads, head_size] torch::Tensor& v_cache, // [num_blocks, block_size, num_heads, head_size]
torch::Tensor& slot_mapping, // [num_tokens] torch::Tensor& slot_mapping, // [num_tokens]
const std::string& kv_cache_dtype) const std::string& kv_cache_dtype) {
{
// FIXME: only support auto datatype, does not support fp8 // FIXME: only support auto datatype, does not support fp8
if (kv_cache_dtype != "auto") { if (kv_cache_dtype != "auto") {
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
@ -313,62 +300,47 @@ void reshape_and_cache_flash(
const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(
key.scalar_type(), key.scalar_type(), "reshape_and_cache_flash", [&] {
"reshape_and_cache_flash", vllm::reshape_and_cache_flash_kernel<scalar_t>
[&] { <<<grid, block, 0, stream>>>(
vllm::reshape_and_cache_flash_kernel<scalar_t><<<grid, block, 0, stream>>>( key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(), k_cache.data_ptr<scalar_t>(), v_cache.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(), slot_mapping.data_ptr<int64_t>(), block_stride, key_stride,
k_cache.data_ptr<scalar_t>(), value_stride, num_heads, head_size, block_size);
v_cache.data_ptr<scalar_t>(), });
slot_mapping.data_ptr<int64_t>(),
block_stride,
key_stride,
value_stride,
num_heads,
head_size,
block_size);
});
} }
namespace vllm { namespace vllm {
template<typename Tout, typename Tin, Fp8KVCacheDataType kv_dt> template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
__global__ void convert_fp8_kernel( __global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache,
const Tin* __restrict__ src_cache, Tout* __restrict__ dst_cache,
Tout* __restrict__ dst_cache, const float kv_scale,
const float kv_scale, const int64_t block_stride) {
const int64_t block_stride) {
const int64_t block_idx = blockIdx.x; const int64_t block_idx = blockIdx.x;
for (int i = threadIdx.x; i < block_stride; i += blockDim.x) { for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
int64_t idx = block_idx * block_stride + i; int64_t idx = block_idx * block_stride + i;
dst_cache[idx] = fp8::scaled_convert<Tout, Tin, kv_dt>(src_cache[idx], kv_scale); dst_cache[idx] =
fp8::scaled_convert<Tout, Tin, kv_dt>(src_cache[idx], kv_scale);
} }
} }
} // namespace vllm } // namespace vllm
#define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE) \ #define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE) \
vllm::convert_fp8_kernel<Tout, Tin, KV_DTYPE><<<grid, block, 0, stream>>>( \ vllm::convert_fp8_kernel<Tout, Tin, KV_DTYPE><<<grid, block, 0, stream>>>( \
reinterpret_cast<Tin*>(src_cache.data_ptr()), \ reinterpret_cast<Tin*>(src_cache.data_ptr()), \
reinterpret_cast<Tout*>(dst_cache.data_ptr()), \ reinterpret_cast<Tout*>(dst_cache.data_ptr()), kv_scale, block_stride);
kv_scale, \
block_stride);
// Only for testing. // Only for testing.
void convert_fp8( void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
torch::Tensor& dst_cache, const float kv_scale, const std::string& kv_cache_dtype) {
torch::Tensor& src_cache,
const float kv_scale,
const std::string& kv_cache_dtype)
{
torch::Device src_device = src_cache.device(); torch::Device src_device = src_cache.device();
torch::Device dst_device = dst_cache.device(); torch::Device dst_device = dst_cache.device();
TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU") TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU")
TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU") TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU")
TORCH_CHECK( TORCH_CHECK(src_device.index() == dst_device.index(),
src_device.index() == dst_device.index(), "src and dst must be on the same GPU");
"src and dst must be on the same GPU");
at::cuda::OptionalCUDAGuard device_guard(src_device); at::cuda::OptionalCUDAGuard device_guard(src_device);
int64_t num_blocks = src_cache.size(0); int64_t num_blocks = src_cache.size(0);
@ -398,13 +370,15 @@ void convert_fp8(
} else if (src_cache.dtype() == at::ScalarType::Half) { } else if (src_cache.dtype() == at::ScalarType::Half) {
CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kFp8E4M3); CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
} else if (src_cache.dtype() == at::ScalarType::BFloat16) { } else if (src_cache.dtype() == at::ScalarType::BFloat16) {
CALL_CONVERT_FP8(uint8_t, __nv_bfloat16, vllm::Fp8KVCacheDataType::kFp8E4M3); CALL_CONVERT_FP8(uint8_t, __nv_bfloat16,
vllm::Fp8KVCacheDataType::kFp8E4M3);
} else if (dst_cache.dtype() == at::ScalarType::Float) { } else if (dst_cache.dtype() == at::ScalarType::Float) {
CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
} else if (dst_cache.dtype() == at::ScalarType::Half) { } else if (dst_cache.dtype() == at::ScalarType::Half) {
CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) { } else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); CALL_CONVERT_FP8(__nv_bfloat16, uint8_t,
vllm::Fp8KVCacheDataType::kFp8E4M3);
} }
} else { } else {
TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype); TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype);

View File

@ -1,10 +1,10 @@
#include "cpu_types.hpp" #include "cpu_types.hpp"
namespace { namespace {
template <typename scalar_t, vec_op::FP32Vec8 (*func)(const vec_op::FP32Vec8 &), template <typename scalar_t, vec_op::FP32Vec8 (*func)(const vec_op::FP32Vec8&),
bool is_gated> bool is_gated>
void activation_kernel(int num_tokens, int d, scalar_t *__restrict__ input, void activation_kernel(int num_tokens, int d, scalar_t* __restrict__ input,
scalar_t *__restrict__ output) { scalar_t* __restrict__ output) {
using scalar_vec_t = vec_op::vec_t<scalar_t>; using scalar_vec_t = vec_op::vec_t<scalar_t>;
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
@ -34,13 +34,13 @@ void activation_kernel(int num_tokens, int d, scalar_t *__restrict__ input,
} }
} }
FORCE_INLINE vec_op::FP32Vec8 silu_act(const vec_op::FP32Vec8 &x) { FORCE_INLINE vec_op::FP32Vec8 silu_act(const vec_op::FP32Vec8& x) {
const vec_op::FP32Vec8 zeros(0.0); const vec_op::FP32Vec8 zeros(0.0);
const vec_op::FP32Vec8 ones(1.0); const vec_op::FP32Vec8 ones(1.0);
return x / (ones + (zeros - x).exp()); return x / (ones + (zeros - x).exp());
} }
FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8 &x) { FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8& x) {
const vec_op::FP32Vec8 ones(1.0); const vec_op::FP32Vec8 ones(1.0);
const vec_op::FP32Vec8 w1(0.79788456f); const vec_op::FP32Vec8 w1(0.79788456f);
const vec_op::FP32Vec8 w2(0.044715f); const vec_op::FP32Vec8 w2(0.044715f);
@ -50,7 +50,7 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8 &x) {
return w3 * x * (ones + t); return w3 * x * (ones + t);
} }
FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8 &x) { FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8& x) {
const vec_op::FP32Vec8 ones(1.0); const vec_op::FP32Vec8 ones(1.0);
const vec_op::FP32Vec8 w1(0.79788456f); const vec_op::FP32Vec8 w1(0.79788456f);
const vec_op::FP32Vec8 w2(0.044715f); const vec_op::FP32Vec8 w2(0.044715f);
@ -59,14 +59,14 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8 &x) {
return w3 * x * (ones + t); return w3 * x * (ones + t);
} }
FORCE_INLINE vec_op::FP32Vec8 gelu_act(const vec_op::FP32Vec8 &x) { FORCE_INLINE vec_op::FP32Vec8 gelu_act(const vec_op::FP32Vec8& x) {
const vec_op::FP32Vec8 ones(1.0); const vec_op::FP32Vec8 ones(1.0);
const vec_op::FP32Vec8 w1(M_SQRT1_2); const vec_op::FP32Vec8 w1(M_SQRT1_2);
const vec_op::FP32Vec8 w2(0.5); const vec_op::FP32Vec8 w2(0.5);
return x * w2 * (ones + (x * w1).er()); return x * w2 * (ones + (x * w1).er());
} }
FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8 &x) { FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8& x) {
const vec_op::FP32Vec8 ones(1.0); const vec_op::FP32Vec8 ones(1.0);
const vec_op::FP32Vec8 w1(M_SQRT2 * M_2_SQRTPI * 0.5); const vec_op::FP32Vec8 w1(M_SQRT2 * M_2_SQRTPI * 0.5);
const vec_op::FP32Vec8 w2(0.5); const vec_op::FP32Vec8 w2(0.5);
@ -75,40 +75,36 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8 &x) {
const vec_op::FP32Vec8 inner = w1 * (x + x_3 * w3); const vec_op::FP32Vec8 inner = w1 * (x + x_3 * w3);
return x * w2 * (ones + inner.tanh()); return x * w2 * (ones + inner.tanh());
} }
}; // namespace }; // namespace
void silu_and_mul(torch::Tensor &out, torch::Tensor &input) { void silu_and_mul(torch::Tensor& out, torch::Tensor& input) {
int num_tokens = input.numel() / input.size(-1); int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1) / 2; int d = input.size(-1) / 2;
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "silu_and_mul_impl", [&] {
input.scalar_type(), "silu_and_mul_impl", [&] { CPU_KERNEL_GUARD_IN(silu_and_mul_impl)
CPU_KERNEL_GUARD_IN(silu_and_mul_impl) activation_kernel<scalar_t, silu_act, true>(
activation_kernel<scalar_t, silu_act, true>(num_tokens, d, num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
input.data_ptr<scalar_t>(), CPU_KERNEL_GUARD_OUT(silu_and_mul_impl)
out.data_ptr<scalar_t>()); });
CPU_KERNEL_GUARD_OUT(silu_and_mul_impl)
});
} }
void gelu_and_mul(torch::Tensor &out, // [..., d] void gelu_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor &input) // [..., 2 * d] torch::Tensor& input) // [..., 2 * d]
{ {
int num_tokens = input.numel() / input.size(-1); int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1) / 2; int d = input.size(-1) / 2;
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_and_mul_impl", [&] {
input.scalar_type(), "gelu_and_mul_impl", [&] { CPU_KERNEL_GUARD_IN(gelu_and_mul_impl)
CPU_KERNEL_GUARD_IN(gelu_and_mul_impl) activation_kernel<scalar_t, gelu_act, true>(
activation_kernel<scalar_t, gelu_act, true>(num_tokens, d, num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
input.data_ptr<scalar_t>(), CPU_KERNEL_GUARD_OUT(gelu_and_mul_impl)
out.data_ptr<scalar_t>()); });
CPU_KERNEL_GUARD_OUT(gelu_and_mul_impl)
});
} }
void gelu_tanh_and_mul(torch::Tensor &out, // [..., d] void gelu_tanh_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor &input) // [..., 2 * d] torch::Tensor& input) // [..., 2 * d]
{ {
int num_tokens = input.numel() / input.size(-1); int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1) / 2; int d = input.size(-1) / 2;
@ -123,7 +119,7 @@ void gelu_tanh_and_mul(torch::Tensor &out, // [..., d]
}); });
} }
void gelu_new(torch::Tensor &out, torch::Tensor &input) { void gelu_new(torch::Tensor& out, torch::Tensor& input) {
int num_tokens = input.numel() / input.size(-1); int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1); int d = input.size(-1);
@ -135,7 +131,7 @@ void gelu_new(torch::Tensor &out, torch::Tensor &input) {
}); });
} }
void gelu_fast(torch::Tensor &out, torch::Tensor &input) { void gelu_fast(torch::Tensor& out, torch::Tensor& input) {
int num_tokens = input.numel() / input.size(-1); int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1); int d = input.size(-1);

View File

@ -2,7 +2,8 @@
namespace { namespace {
template <typename scalar_t> struct KernelVecType { template <typename scalar_t>
struct KernelVecType {
using q_load_vec_type = void; using q_load_vec_type = void;
using q_vec_type = void; using q_vec_type = void;
using k_load_vec_type = void; using k_load_vec_type = void;
@ -11,7 +12,8 @@ template <typename scalar_t> struct KernelVecType {
using v_load_vec_type = void; using v_load_vec_type = void;
}; };
template <> struct KernelVecType<float> { template <>
struct KernelVecType<float> {
using q_load_vec_type = vec_op::FP32Vec4; using q_load_vec_type = vec_op::FP32Vec4;
using q_vec_type = vec_op::FP32Vec16; using q_vec_type = vec_op::FP32Vec16;
using k_load_vec_type = vec_op::FP32Vec16; using k_load_vec_type = vec_op::FP32Vec16;
@ -21,7 +23,8 @@ template <> struct KernelVecType<float> {
}; };
#ifdef __AVX512BF16__ #ifdef __AVX512BF16__
template <> struct KernelVecType<c10::BFloat16> { template <>
struct KernelVecType<c10::BFloat16> {
using q_load_vec_type = vec_op::BF16Vec8; using q_load_vec_type = vec_op::BF16Vec8;
using q_vec_type = vec_op::BF16Vec32; using q_vec_type = vec_op::BF16Vec32;
using k_load_vec_type = vec_op::BF16Vec32; using k_load_vec_type = vec_op::BF16Vec32;
@ -30,7 +33,8 @@ template <> struct KernelVecType<c10::BFloat16> {
using v_load_vec_type = vec_op::BF16Vec16; using v_load_vec_type = vec_op::BF16Vec16;
}; };
#else #else
template <> struct KernelVecType<c10::BFloat16> { template <>
struct KernelVecType<c10::BFloat16> {
using q_load_vec_type = vec_op::BF16Vec8; using q_load_vec_type = vec_op::BF16Vec8;
using q_vec_type = vec_op::FP32Vec16; using q_vec_type = vec_op::FP32Vec16;
using k_load_vec_type = vec_op::BF16Vec16; using k_load_vec_type = vec_op::BF16Vec16;
@ -41,7 +45,7 @@ template <> struct KernelVecType<c10::BFloat16> {
#endif #endif
template <typename T> template <typename T>
FORCE_INLINE std::pair<T, T> reduceSoftmax(T *data, const int size, FORCE_INLINE std::pair<T, T> reduceSoftmax(T* data, const int size,
const int capacity) { const int capacity) {
T max = data[0]; T max = data[0];
for (int i = 1; i < size; ++i) { for (int i = 1; i < size; ++i) {
@ -67,10 +71,11 @@ FORCE_INLINE std::pair<T, T> reduceSoftmax(T *data, const int size,
} }
template <typename T> template <typename T>
FORCE_INLINE std::pair<T, T> FORCE_INLINE std::pair<T, T> reduceSoftmaxAlibi(T* data, const int size,
reduceSoftmaxAlibi(T *data, const int size, const int capacity, const int capacity,
const float alibi_slope, const int start_index, const float alibi_slope,
const int seq_len) { const int start_index,
const int seq_len) {
data[0] += alibi_slope * (start_index - seq_len + 1); data[0] += alibi_slope * (start_index - seq_len + 1);
T max = data[0]; T max = data[0];
for (int i = 1; i < size; ++i) { for (int i = 1; i < size; ++i) {
@ -98,7 +103,7 @@ reduceSoftmaxAlibi(T *data, const int size, const int capacity,
} }
template <typename T> template <typename T>
FORCE_INLINE void reducePartitonSoftmax(const T *max_data, T *sum_data, FORCE_INLINE void reducePartitonSoftmax(const T* max_data, T* sum_data,
const int size) { const int size) {
T max = max_data[0]; T max = max_data[0];
for (int i = 1; i < size; ++i) { for (int i = 1; i < size; ++i) {
@ -132,9 +137,9 @@ struct reduceQKBlockKernel {
static_assert(k_load_vec_type::get_elem_num() % x == 0); static_assert(k_load_vec_type::get_elem_num() % x == 0);
static_assert(q_load_vec_type::get_elem_num() * sizeof(scalar_t) == 16); static_assert(q_load_vec_type::get_elem_num() * sizeof(scalar_t) == 16);
FORCE_INLINE static void call(const scalar_t *__restrict__ q, FORCE_INLINE static void call(const scalar_t* __restrict__ q,
const scalar_t *__restrict__ k_block, const scalar_t* __restrict__ k_block,
float *__restrict__ logits, float scale, float* __restrict__ logits, float scale,
const int token_num) { const int token_num) {
const int group_num = (token_num + TOKEN_PER_GROUP - 1) / TOKEN_PER_GROUP; const int group_num = (token_num + TOKEN_PER_GROUP - 1) / TOKEN_PER_GROUP;
@ -196,8 +201,8 @@ struct reduceQKBlockKernel {
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE,
int HEAD_PARTITION_SIZE, typename acc_t> int HEAD_PARTITION_SIZE, typename acc_t>
FORCE_INLINE void reduceValueBlock(const float *prob, const scalar_t *v_block, FORCE_INLINE void reduceValueBlock(const float* prob, const scalar_t* v_block,
acc_t &&acc) { acc_t&& acc) {
using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type; using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type;
constexpr int ELEM_NUM = v_load_vec_type::get_elem_num(); constexpr int ELEM_NUM = v_load_vec_type::get_elem_num();
static_assert(BLOCK_SIZE == ELEM_NUM); static_assert(BLOCK_SIZE == ELEM_NUM);
@ -209,27 +214,27 @@ FORCE_INLINE void reduceValueBlock(const float *prob, const scalar_t *v_block,
acc[head_elem_idx] = acc[head_elem_idx] + prob_vec * fp32_v_vec; acc[head_elem_idx] = acc[head_elem_idx] + prob_vec * fp32_v_vec;
}); });
} }
}; // namespace }; // namespace
// Paged attention v1 // Paged attention v1
namespace { namespace {
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE> template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE>
struct paged_attention_v1_impl { struct paged_attention_v1_impl {
static void static void call(
call(scalar_t *__restrict__ out, // [num_seqs, num_heads, head_size] scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const scalar_t *__restrict__ q, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const scalar_t *__restrict__ k_cache, // [num_blocks, num_kv_heads, const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x] // head_size/x, block_size, x]
const scalar_t *__restrict__ v_cache, // [num_blocks, num_kv_heads, const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size] // head_size, block_size]
const int num_kv_heads, const float scale, const int num_kv_heads, const float scale,
const int const int* __restrict__ block_tables, // [num_seqs,
*__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] // max_num_blocks_per_seq]
const int *__restrict__ seq_lens, // [num_seqs] const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float *__restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const int num_seqs, const int num_heads) { const int num_seqs, const int num_heads) {
constexpr int x = 16 / sizeof(scalar_t); constexpr int x = 16 / sizeof(scalar_t);
const int num_queries_per_kv = num_heads / num_kv_heads; const int num_queries_per_kv = num_heads / num_kv_heads;
@ -243,32 +248,31 @@ struct paged_attention_v1_impl {
size_t logits_bytes = size_t logits_bytes =
parallel_work_item_num * max_seq_len_padded * sizeof(float); parallel_work_item_num * max_seq_len_padded * sizeof(float);
float *logits = (float *)std::aligned_alloc( float* logits = (float*)std::aligned_alloc(
64, logits_bytes); // Cacheline alignment for each context token. 64, logits_bytes); // Cacheline alignment for each context token.
// [parallel_work_item_num, max_seq_len_padded] // [parallel_work_item_num, max_seq_len_padded]
#pragma omp parallel for collapse(2) schedule(dynamic, 1) #pragma omp parallel for collapse(2) schedule(dynamic, 1)
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
for (int head_idx = 0; head_idx < num_heads; ++head_idx) { for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
int seq_len = seq_lens[seq_idx]; int seq_len = seq_lens[seq_idx];
const int *seq_block_table = const int* seq_block_table =
block_tables + max_num_blocks_per_seq * seq_idx; block_tables + max_num_blocks_per_seq * seq_idx;
const int block_num = (seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE; const int block_num = (seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
const int64_t kv_head_idx = head_idx / num_queries_per_kv; const int64_t kv_head_idx = head_idx / num_queries_per_kv;
const scalar_t *__restrict__ q_vec_ptr = const scalar_t* __restrict__ q_vec_ptr =
q + seq_idx * q_stride + head_idx * HEAD_SIZE; q + seq_idx * q_stride + head_idx * HEAD_SIZE;
const int last_block_token_num = const int last_block_token_num = seq_len - (block_num - 1) * BLOCK_SIZE;
seq_len - (block_num - 1) * BLOCK_SIZE; float* __restrict__ thread_block_logits =
float *__restrict__ thread_block_logits =
logits + omp_get_thread_num() * max_seq_len_padded; logits + omp_get_thread_num() * max_seq_len_padded;
// Compute logits // Compute logits
for (int block_idx = 0; block_idx < block_num; ++block_idx) { for (int block_idx = 0; block_idx < block_num; ++block_idx) {
const int64_t physical_block_idx = seq_block_table[block_idx]; const int64_t physical_block_idx = seq_block_table[block_idx];
const scalar_t *__restrict__ k_block_cache_ptr = const scalar_t* __restrict__ k_block_cache_ptr =
k_cache + physical_block_idx * kv_block_stride + k_cache + physical_block_idx * kv_block_stride +
kv_head_idx * kv_head_stride; kv_head_idx * kv_head_stride;
float *__restrict__ head_block_logits = float* __restrict__ head_block_logits =
thread_block_logits + block_idx * BLOCK_SIZE; thread_block_logits + block_idx * BLOCK_SIZE;
reduceQKBlockKernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, x>::call( reduceQKBlockKernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, x>::call(
@ -282,8 +286,7 @@ struct paged_attention_v1_impl {
block_num * BLOCK_SIZE, alibi_slopes[head_idx], 0, block_num * BLOCK_SIZE, alibi_slopes[head_idx], 0,
seq_len); seq_len);
} else { } else {
reduceSoftmax(thread_block_logits, seq_len, reduceSoftmax(thread_block_logits, seq_len, block_num * BLOCK_SIZE);
block_num * BLOCK_SIZE);
} }
// Compute value // Compute value
@ -293,14 +296,14 @@ struct paged_attention_v1_impl {
for (int head_part_idx = 0; head_part_idx < head_partition_num; for (int head_part_idx = 0; head_part_idx < head_partition_num;
++head_part_idx) { ++head_part_idx) {
vec_op::FP32Vec16 accums[head_elem_num_per_partition]; vec_op::FP32Vec16 accums[head_elem_num_per_partition];
scalar_t *__restrict__ out_ptr = scalar_t* __restrict__ out_ptr =
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE +
head_part_idx * head_elem_num_per_partition; head_part_idx * head_elem_num_per_partition;
for (int block_idx = 0; block_idx < block_num; ++block_idx) { for (int block_idx = 0; block_idx < block_num; ++block_idx) {
const int64_t physical_block_idx = seq_block_table[block_idx]; const int64_t physical_block_idx = seq_block_table[block_idx];
const float *__restrict__ prob_vec_ptr = const float* __restrict__ prob_vec_ptr =
thread_block_logits + block_idx * BLOCK_SIZE; thread_block_logits + block_idx * BLOCK_SIZE;
const scalar_t *__restrict__ v_block_cache_ptr = const scalar_t* __restrict__ v_block_cache_ptr =
v_cache + physical_block_idx * kv_block_stride + v_cache + physical_block_idx * kv_block_stride +
kv_head_idx * kv_head_stride + kv_head_idx * kv_head_stride +
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
@ -311,7 +314,7 @@ struct paged_attention_v1_impl {
if (block_idx != block_num - 1) { if (block_idx != block_num - 1) {
const int64_t next_physical_block_idx = const int64_t next_physical_block_idx =
seq_block_table[block_idx + 1]; seq_block_table[block_idx + 1];
const scalar_t *__restrict__ next_v_block_cache_ptr = const scalar_t* __restrict__ next_v_block_cache_ptr =
v_cache + next_physical_block_idx * kv_block_stride + v_cache + next_physical_block_idx * kv_block_stride +
kv_head_idx * kv_head_stride + kv_head_idx * kv_head_stride +
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
@ -340,16 +343,16 @@ struct paged_attention_v1_impl {
#define LAUNCH_V1_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \ #define LAUNCH_V1_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \
paged_attention_v1_impl<T, HEAD_SIZE, BLOCK_SIZE>::call( \ paged_attention_v1_impl<T, HEAD_SIZE, BLOCK_SIZE>::call( \
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, num_seqs, \ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, num_seqs, \
num_heads); num_heads);
template <typename T, int BLOCK_SIZE> template <typename T, int BLOCK_SIZE>
void paged_attention_v1_impl_launcher( void paged_attention_v1_impl_launcher(
torch::Tensor &out, torch::Tensor &query, torch::Tensor &key_cache, torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor &value_cache, int num_kv_heads, float scale, torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor &block_tables, torch::Tensor &seq_lens, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
int max_seq_len, const c10::optional<torch::Tensor> &alibi_slopes) { const c10::optional<torch::Tensor>& alibi_slopes) {
int num_seqs = query.size(0); int num_seqs = query.size(0);
int num_heads = query.size(1); int num_heads = query.size(1);
int head_size = query.size(2); int head_size = query.size(2);
@ -359,67 +362,66 @@ void paged_attention_v1_impl_launcher(
int kv_head_stride = key_cache.stride(1); int kv_head_stride = key_cache.stride(1);
// NOTE: alibi_slopes is optional. // NOTE: alibi_slopes is optional.
const float *alibi_slopes_ptr = const float* alibi_slopes_ptr =
alibi_slopes alibi_slopes
? reinterpret_cast<const float *>(alibi_slopes.value().data_ptr()) ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr; : nullptr;
T *out_ptr = reinterpret_cast<T *>(out.data_ptr()); T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
T *query_ptr = reinterpret_cast<T *>(query.data_ptr()); T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
T *key_cache_ptr = reinterpret_cast<T *>(key_cache.data_ptr()); T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
T *value_cache_ptr = reinterpret_cast<T *>(value_cache.data_ptr()); T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
int *block_tables_ptr = block_tables.data_ptr<int>(); int* block_tables_ptr = block_tables.data_ptr<int>();
int *seq_lens_ptr = seq_lens.data_ptr<int>(); int* seq_lens_ptr = seq_lens.data_ptr<int>();
switch (head_size) { switch (head_size) {
case 64: case 64:
LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
break; break;
case 80: case 80:
LAUNCH_V1_ATTENTION_KERNEL(T, 80, BLOCK_SIZE); LAUNCH_V1_ATTENTION_KERNEL(T, 80, BLOCK_SIZE);
break; break;
case 96: case 96:
LAUNCH_V1_ATTENTION_KERNEL(T, 96, BLOCK_SIZE); LAUNCH_V1_ATTENTION_KERNEL(T, 96, BLOCK_SIZE);
break; break;
case 112: case 112:
LAUNCH_V1_ATTENTION_KERNEL(T, 112, BLOCK_SIZE); LAUNCH_V1_ATTENTION_KERNEL(T, 112, BLOCK_SIZE);
break; break;
case 128: case 128:
LAUNCH_V1_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); LAUNCH_V1_ATTENTION_KERNEL(T, 128, BLOCK_SIZE);
break; break;
case 256: case 256:
LAUNCH_V1_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); LAUNCH_V1_ATTENTION_KERNEL(T, 256, BLOCK_SIZE);
break; break;
default: default:
TORCH_CHECK(false, "Unsupported head size: ", head_size); TORCH_CHECK(false, "Unsupported head size: ", head_size);
break; break;
} }
} }
#define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ #define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
paged_attention_v1_impl_launcher<T, BLOCK_SIZE>( \ paged_attention_v1_impl_launcher<T, BLOCK_SIZE>( \
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
seq_lens, max_seq_len, alibi_slopes); seq_lens, max_seq_len, alibi_slopes);
#define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ #define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
switch (block_size) { \ switch (block_size) { \
case 16: \ case 16: \
CALL_V1_KERNEL_LAUNCHER(T, 16); \ CALL_V1_KERNEL_LAUNCHER(T, 16); \
break; \ break; \
default: \ default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \ TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \ break; \
} }
} // namespace } // namespace
void paged_attention_v1(torch::Tensor &out, torch::Tensor &query, void paged_attention_v1(torch::Tensor& out, torch::Tensor& query,
torch::Tensor &key_cache, torch::Tensor &value_cache, torch::Tensor& key_cache, torch::Tensor& value_cache,
int num_kv_heads, float scale, int num_kv_heads, float scale,
torch::Tensor &block_tables, torch::Tensor& block_tables, torch::Tensor& seq_lens,
torch::Tensor &seq_lens, int block_size, int block_size, int max_seq_len,
int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const c10::optional<torch::Tensor> &alibi_slopes, const std::string& kv_cache_dtype, float kv_scale) {
const std::string &kv_cache_dtype, float kv_scale) {
TORCH_CHECK(kv_scale == 1.0f); TORCH_CHECK(kv_scale == 1.0f);
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl", VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl",
[&] { [&] {
@ -434,23 +436,24 @@ namespace {
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, int PARTITION_SIZE> template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, int PARTITION_SIZE>
struct paged_attention_v2_impl { struct paged_attention_v2_impl {
static void call( static void call(
scalar_t *__restrict__ out, // [num_seqs, num_heads, head_size] scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
float *__restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ exp_sums, // [num_seqs, num_heads,
float // max_num_partitions]
*__restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ max_logits, // [num_seqs, num_heads,
scalar_t *__restrict__ tmp_out, // [num_seqs, num_heads, // max_num_partitions]
// max_num_partitions, head_size] scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
const scalar_t *__restrict__ q, // [num_seqs, num_heads, head_size] // max_num_partitions, head_size]
const scalar_t *__restrict__ k_cache, // [num_blocks, num_kv_heads, const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
// head_size/x, block_size, x] const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
const scalar_t *__restrict__ v_cache, // [num_blocks, num_kv_heads, // head_size/x, block_size, x]
// head_size, block_size] const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size]
const int num_kv_heads, const float scale, const int num_kv_heads, const float scale,
const int const int* __restrict__ block_tables, // [num_seqs,
*__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] // max_num_blocks_per_seq]
const int *__restrict__ seq_lens, // [num_seqs] const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float *__restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const int num_seqs, const int num_heads, const int max_num_partitions) { const int num_seqs, const int num_heads, const int max_num_partitions) {
constexpr int x = 16 / sizeof(scalar_t); constexpr int x = 16 / sizeof(scalar_t);
@ -468,8 +471,7 @@ struct paged_attention_v2_impl {
const int seq_len = seq_lens[seq_idx]; const int seq_len = seq_lens[seq_idx];
const int start_token_idx = partition_idx * PARTITION_SIZE; const int start_token_idx = partition_idx * PARTITION_SIZE;
if (start_token_idx >= seq_len) if (start_token_idx >= seq_len) continue;
continue;
const int partition_num = const int partition_num =
(seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
@ -477,15 +479,14 @@ struct paged_attention_v2_impl {
const int token_num = const int token_num =
(std::min(seq_len, start_token_idx + PARTITION_SIZE) - (std::min(seq_len, start_token_idx + PARTITION_SIZE) -
start_token_idx); start_token_idx);
const int block_num = const int block_num = (token_num + BLOCK_SIZE - 1) / BLOCK_SIZE;
(token_num + BLOCK_SIZE - 1) / BLOCK_SIZE;
const int last_block_token_num = const int last_block_token_num =
token_num - (block_num - 1) * BLOCK_SIZE; token_num - (block_num - 1) * BLOCK_SIZE;
const int *seq_block_table = block_tables + const int* seq_block_table = block_tables +
max_num_blocks_per_seq * seq_idx + max_num_blocks_per_seq * seq_idx +
start_token_idx / BLOCK_SIZE; start_token_idx / BLOCK_SIZE;
const int64_t kv_head_idx = head_idx / num_queries_per_kv; const int64_t kv_head_idx = head_idx / num_queries_per_kv;
const scalar_t *__restrict__ q_vec_ptr = const scalar_t* __restrict__ q_vec_ptr =
q + seq_idx * q_stride + head_idx * HEAD_SIZE; q + seq_idx * q_stride + head_idx * HEAD_SIZE;
float logits[PARTITION_SIZE] __attribute__((aligned(64))) = {0}; float logits[PARTITION_SIZE] __attribute__((aligned(64))) = {0};
@ -493,10 +494,10 @@ struct paged_attention_v2_impl {
// Compute logits // Compute logits
for (int block_idx = 0; block_idx < block_num; ++block_idx) { for (int block_idx = 0; block_idx < block_num; ++block_idx) {
const int64_t physical_block_idx = seq_block_table[block_idx]; const int64_t physical_block_idx = seq_block_table[block_idx];
const scalar_t *__restrict__ k_block_cache_ptr = const scalar_t* __restrict__ k_block_cache_ptr =
k_cache + physical_block_idx * kv_block_stride + k_cache + physical_block_idx * kv_block_stride +
kv_head_idx * kv_head_stride; kv_head_idx * kv_head_stride;
float *__restrict__ head_block_logits = float* __restrict__ head_block_logits =
logits + block_idx * BLOCK_SIZE; logits + block_idx * BLOCK_SIZE;
reduceQKBlockKernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, x>::call( reduceQKBlockKernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, x>::call(
@ -510,13 +511,13 @@ struct paged_attention_v2_impl {
logits, token_num, block_num * BLOCK_SIZE, logits, token_num, block_num * BLOCK_SIZE,
alibi_slopes[head_idx], start_token_idx, seq_len); alibi_slopes[head_idx], start_token_idx, seq_len);
} else { } else {
max_and_sum = reduceSoftmax(logits, token_num, max_and_sum =
block_num * BLOCK_SIZE); reduceSoftmax(logits, token_num, block_num * BLOCK_SIZE);
} }
auto &&[max_logit, exp_sum] = max_and_sum; auto&& [max_logit, exp_sum] = max_and_sum;
scalar_t *__restrict__ output_buffer = nullptr; scalar_t* __restrict__ output_buffer = nullptr;
if (!no_reduce) { if (!no_reduce) {
auto idx = seq_idx * num_heads * max_num_partitions + auto idx = seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions + partition_idx; head_idx * max_num_partitions + partition_idx;
@ -538,13 +539,13 @@ struct paged_attention_v2_impl {
for (int head_part_idx = 0; head_part_idx < head_partition_num; for (int head_part_idx = 0; head_part_idx < head_partition_num;
++head_part_idx) { ++head_part_idx) {
vec_op::FP32Vec16 accums[head_elem_num_per_partition]; vec_op::FP32Vec16 accums[head_elem_num_per_partition];
scalar_t *__restrict__ out_ptr = scalar_t* __restrict__ out_ptr =
output_buffer + head_part_idx * head_elem_num_per_partition; output_buffer + head_part_idx * head_elem_num_per_partition;
for (int block_idx = 0; block_idx < block_num; ++block_idx) { for (int block_idx = 0; block_idx < block_num; ++block_idx) {
const int64_t physical_block_idx = seq_block_table[block_idx]; const int64_t physical_block_idx = seq_block_table[block_idx];
const float *__restrict__ prob_vec_ptr = const float* __restrict__ prob_vec_ptr =
logits + block_idx * BLOCK_SIZE; logits + block_idx * BLOCK_SIZE;
const scalar_t *__restrict__ v_block_cache_ptr = const scalar_t* __restrict__ v_block_cache_ptr =
v_cache + physical_block_idx * kv_block_stride + v_cache + physical_block_idx * kv_block_stride +
kv_head_idx * kv_head_stride + kv_head_idx * kv_head_stride +
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
@ -555,7 +556,7 @@ struct paged_attention_v2_impl {
if (block_idx != block_num - 1) { if (block_idx != block_num - 1) {
const int64_t next_physical_block_idx = const int64_t next_physical_block_idx =
seq_block_table[block_idx + 1]; seq_block_table[block_idx + 1];
const scalar_t *__restrict__ next_v_block_cache_ptr = const scalar_t* __restrict__ next_v_block_cache_ptr =
v_cache + next_physical_block_idx * kv_block_stride + v_cache + next_physical_block_idx * kv_block_stride +
kv_head_idx * kv_head_stride + kv_head_idx * kv_head_stride +
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
@ -587,8 +588,7 @@ struct paged_attention_v2_impl {
const int partition_num = const int partition_num =
(seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
if (partition_num == 1) if (partition_num == 1) continue;
continue;
reducePartitonSoftmax( reducePartitonSoftmax(
max_logits + seq_idx * num_heads * max_num_partitions + max_logits + seq_idx * num_heads * max_num_partitions +
@ -603,11 +603,11 @@ struct paged_attention_v2_impl {
using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type; using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type;
static_assert(v_load_vec_type::get_elem_num() == BLOCK_SIZE); static_assert(v_load_vec_type::get_elem_num() == BLOCK_SIZE);
constexpr int head_elem_num_per_group = constexpr int head_elem_num_per_group =
16; // Note: didn't align with the cacheline size, due to some HEAD_SIZE 16; // Note: didn't align with the cacheline size, due to some
// didn't align with 64 bytes // HEAD_SIZE didn't align with 64 bytes
static_assert(HEAD_SIZE % head_elem_num_per_group == 0); static_assert(HEAD_SIZE % head_elem_num_per_group == 0);
constexpr int head_group_num = HEAD_SIZE / head_elem_num_per_group; constexpr int head_group_num = HEAD_SIZE / head_elem_num_per_group;
const float *__restrict__ rescale_factors = exp_sums; const float* __restrict__ rescale_factors = exp_sums;
#pragma omp parallel for collapse(3) schedule(static, 1) #pragma omp parallel for collapse(3) schedule(static, 1)
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
for (int head_idx = 0; head_idx < num_heads; ++head_idx) { for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
@ -616,17 +616,16 @@ struct paged_attention_v2_impl {
const int partition_num = const int partition_num =
(seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
if (partition_num == 1) if (partition_num == 1) continue;
continue;
const float *__restrict__ seq_head_rescale_factors = const float* __restrict__ seq_head_rescale_factors =
rescale_factors + seq_idx * num_heads * max_num_partitions + rescale_factors + seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions; head_idx * max_num_partitions;
const scalar_t *__restrict__ seq_head_tmp_out = const scalar_t* __restrict__ seq_head_tmp_out =
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
head_idx * max_num_partitions * HEAD_SIZE + head_idx * max_num_partitions * HEAD_SIZE +
group_idx * head_elem_num_per_group; group_idx * head_elem_num_per_group;
scalar_t *__restrict__ seq_head_output = scalar_t* __restrict__ seq_head_output =
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE +
group_idx * head_elem_num_per_group; group_idx * head_elem_num_per_group;
@ -645,21 +644,21 @@ struct paged_attention_v2_impl {
} }
}; };
#define LAUNCH_V2_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \ #define LAUNCH_V2_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \
paged_attention_v2_impl<T, HEAD_SIZE, BLOCK_SIZE, PARTITION_SIZE>::call( \ paged_attention_v2_impl<T, HEAD_SIZE, BLOCK_SIZE, PARTITION_SIZE>::call( \
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \ out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \
key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
kv_block_stride, kv_head_stride, num_seqs, num_heads, \ kv_block_stride, kv_head_stride, num_seqs, num_heads, \
max_num_partitions); max_num_partitions);
template <typename T, int BLOCK_SIZE, int PARTITION_SIZE = 512> template <typename T, int BLOCK_SIZE, int PARTITION_SIZE = 512>
void paged_attention_v2_impl_launcher( void paged_attention_v2_impl_launcher(
torch::Tensor &out, torch::Tensor &exp_sums, torch::Tensor &max_logits, torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor &tmp_out, torch::Tensor &query, torch::Tensor &key_cache, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor &value_cache, int num_kv_heads, float scale, torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor &block_tables, torch::Tensor &seq_lens, int block_size, torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
int max_seq_len, const c10::optional<torch::Tensor> &alibi_slopes) { int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes) {
int num_seqs = query.size(0); int num_seqs = query.size(0);
int num_heads = query.size(1); int num_heads = query.size(1);
int head_size = query.size(2); int head_size = query.size(2);
@ -670,72 +669,72 @@ void paged_attention_v2_impl_launcher(
int max_num_partitions = exp_sums.size(-1); int max_num_partitions = exp_sums.size(-1);
// NOTE: alibi_slopes is optional. // NOTE: alibi_slopes is optional.
const float *alibi_slopes_ptr = const float* alibi_slopes_ptr =
alibi_slopes alibi_slopes
? reinterpret_cast<const float *>(alibi_slopes.value().data_ptr()) ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr; : nullptr;
T *out_ptr = reinterpret_cast<T *>(out.data_ptr()); T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
float *exp_sums_ptr = reinterpret_cast<float *>(exp_sums.data_ptr()); float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
float *max_logits_ptr = reinterpret_cast<float *>(max_logits.data_ptr()); float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
T *tmp_out_ptr = reinterpret_cast<T *>(tmp_out.data_ptr()); T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
T *query_ptr = reinterpret_cast<T *>(query.data_ptr()); T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
T *key_cache_ptr = reinterpret_cast<T *>(key_cache.data_ptr()); T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
T *value_cache_ptr = reinterpret_cast<T *>(value_cache.data_ptr()); T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
int *block_tables_ptr = block_tables.data_ptr<int>(); int* block_tables_ptr = block_tables.data_ptr<int>();
int *seq_lens_ptr = seq_lens.data_ptr<int>(); int* seq_lens_ptr = seq_lens.data_ptr<int>();
switch (head_size) { switch (head_size) {
case 64: case 64:
LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
break; break;
case 80: case 80:
LAUNCH_V2_ATTENTION_KERNEL(T, 80, BLOCK_SIZE); LAUNCH_V2_ATTENTION_KERNEL(T, 80, BLOCK_SIZE);
break; break;
case 96: case 96:
LAUNCH_V2_ATTENTION_KERNEL(T, 96, BLOCK_SIZE); LAUNCH_V2_ATTENTION_KERNEL(T, 96, BLOCK_SIZE);
break; break;
case 112: case 112:
LAUNCH_V2_ATTENTION_KERNEL(T, 112, BLOCK_SIZE); LAUNCH_V2_ATTENTION_KERNEL(T, 112, BLOCK_SIZE);
break; break;
case 128: case 128:
LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE);
break; break;
case 256: case 256:
LAUNCH_V2_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); LAUNCH_V2_ATTENTION_KERNEL(T, 256, BLOCK_SIZE);
break; break;
default: default:
TORCH_CHECK(false, "Unsupported head size: ", head_size); TORCH_CHECK(false, "Unsupported head size: ", head_size);
break; break;
} }
} }
#define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ #define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
paged_attention_v2_impl_launcher<T, BLOCK_SIZE>( \ paged_attention_v2_impl_launcher<T, BLOCK_SIZE>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, seq_lens, block_size, \ num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, \
max_seq_len, alibi_slopes); alibi_slopes);
#define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ #define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
switch (block_size) { \ switch (block_size) { \
case 16: \ case 16: \
CALL_V2_KERNEL_LAUNCHER(T, 16); \ CALL_V2_KERNEL_LAUNCHER(T, 16); \
break; \ break; \
default: \ default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \ TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \ break; \
} }
} // namespace } // namespace
void paged_attention_v2(torch::Tensor &out, torch::Tensor &exp_sums, void paged_attention_v2(torch::Tensor& out, torch::Tensor& exp_sums,
torch::Tensor &max_logits, torch::Tensor &tmp_out, torch::Tensor& max_logits, torch::Tensor& tmp_out,
torch::Tensor &query, torch::Tensor &key_cache, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor &value_cache, int num_kv_heads, torch::Tensor& value_cache, int num_kv_heads,
float scale, torch::Tensor &block_tables, float scale, torch::Tensor& block_tables,
torch::Tensor &seq_lens, int block_size, torch::Tensor& seq_lens, int block_size,
int max_seq_len, int max_seq_len,
const c10::optional<torch::Tensor> &alibi_slopes, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string &kv_cache_dtype, float kv_scale) { const std::string& kv_cache_dtype, float kv_scale) {
TORCH_CHECK(kv_scale == 1.0f); TORCH_CHECK(kv_scale == 1.0f);
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl", VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl",
[&] { [&] {

View File

@ -5,25 +5,26 @@
namespace { namespace {
template <typename scalar_t> template <typename scalar_t>
void copy_blocks_cpu_impl( void copy_blocks_cpu_impl(std::vector<torch::Tensor>& key_caches,
std::vector<torch::Tensor> &key_caches, std::vector<torch::Tensor>& value_caches,
std::vector<torch::Tensor> &value_caches, const torch::Tensor& mapping_pairs,
const torch::Tensor& mapping_pairs, const int element_num_per_block,
const int element_num_per_block, const int layer_num) { const int layer_num) {
const size_t pair_num = mapping_pairs.size(0); const size_t pair_num = mapping_pairs.size(0);
const size_t block_bytes = sizeof(scalar_t) * element_num_per_block; const size_t block_bytes = sizeof(scalar_t) * element_num_per_block;
#pragma omp parallel for collapse(2) #pragma omp parallel for collapse(2)
for (int layer = 0; layer < layer_num; ++layer) { for (int layer = 0; layer < layer_num; ++layer) {
for (size_t pair = 0; pair < pair_num; ++pair) { for (size_t pair = 0; pair < pair_num; ++pair) {
int64_t source_offset = element_num_per_block * mapping_pairs[pair][0].item<int64_t>(); int64_t source_offset =
element_num_per_block * mapping_pairs[pair][0].item<int64_t>();
int64_t target_offset = int64_t target_offset =
element_num_per_block * mapping_pairs[pair][1].item<int64_t>(); element_num_per_block * mapping_pairs[pair][1].item<int64_t>();
scalar_t *key_cache_ptr = key_caches[layer].data_ptr<scalar_t>(); scalar_t* key_cache_ptr = key_caches[layer].data_ptr<scalar_t>();
scalar_t *source_ptr = key_cache_ptr + source_offset; scalar_t* source_ptr = key_cache_ptr + source_offset;
scalar_t *target_ptr = key_cache_ptr + target_offset; scalar_t* target_ptr = key_cache_ptr + target_offset;
std::memcpy(target_ptr, source_ptr, block_bytes); std::memcpy(target_ptr, source_ptr, block_bytes);
scalar_t *value_cache_ptr = value_caches[layer].data_ptr<scalar_t>(); scalar_t* value_cache_ptr = value_caches[layer].data_ptr<scalar_t>();
source_ptr = value_cache_ptr + source_offset; source_ptr = value_cache_ptr + source_offset;
target_ptr = value_cache_ptr + target_offset; target_ptr = value_cache_ptr + target_offset;
std::memcpy(target_ptr, source_ptr, block_bytes); std::memcpy(target_ptr, source_ptr, block_bytes);
@ -33,9 +34,9 @@ void copy_blocks_cpu_impl(
template <typename scalar_t> template <typename scalar_t>
void reshape_and_cache_cpu_impl( void reshape_and_cache_cpu_impl(
const scalar_t *__restrict__ key, const scalar_t *__restrict__ value, const scalar_t* __restrict__ key, const scalar_t* __restrict__ value,
scalar_t *__restrict__ key_cache, scalar_t *__restrict__ value_cache, scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache,
const int64_t *__restrict__ slot_mapping, const int num_tokens, const int64_t* __restrict__ slot_mapping, const int num_tokens,
const int key_stride, const int value_stride, const int num_heads, const int key_stride, const int value_stride, const int num_heads,
const int head_size, const int block_size, const int x) { const int head_size, const int block_size, const int x) {
const int block_elem_num = num_heads * head_size * block_size; const int block_elem_num = num_heads * head_size * block_size;
@ -48,14 +49,14 @@ void reshape_and_cache_cpu_impl(
int src_key_head_idx = token_idx * key_stride + head_idx * head_size; int src_key_head_idx = token_idx * key_stride + head_idx * head_size;
int src_value_head_idx = int src_value_head_idx =
token_idx * value_stride + head_idx * head_size; token_idx * value_stride + head_idx * head_size;
const scalar_t *src_key_head_ptr = key + src_key_head_idx; const scalar_t* src_key_head_ptr = key + src_key_head_idx;
const scalar_t *src_value_head_ptr = value + src_value_head_idx; const scalar_t* src_value_head_ptr = value + src_value_head_idx;
const int64_t block_index = slot_idx / block_size; const int64_t block_index = slot_idx / block_size;
const int64_t block_offset = slot_idx % block_size; const int64_t block_offset = slot_idx % block_size;
scalar_t *target_key_head_ptr = key_cache + scalar_t* target_key_head_ptr = key_cache +
block_elem_num * block_index + block_elem_num * block_index +
head_idx * block_size * head_size; head_idx * block_size * head_size;
scalar_t *target_value_head_ptr = value_cache + scalar_t* target_value_head_ptr = value_cache +
block_elem_num * block_index + block_elem_num * block_index +
head_idx * block_size * head_size; head_idx * block_size * head_size;
@ -79,10 +80,10 @@ void reshape_and_cache_cpu_impl(
} }
} }
} }
}; // namespace }; // namespace
void copy_blocks(std::vector<torch::Tensor> &key_caches, void copy_blocks(std::vector<torch::Tensor>& key_caches,
std::vector<torch::Tensor> &value_caches, std::vector<torch::Tensor>& value_caches,
const torch::Tensor& block_mapping) { const torch::Tensor& block_mapping) {
unsigned num_layers = key_caches.size(); unsigned num_layers = key_caches.size();
TORCH_CHECK(num_layers == value_caches.size()); TORCH_CHECK(num_layers == value_caches.size());
@ -100,10 +101,10 @@ void copy_blocks(std::vector<torch::Tensor> &key_caches,
}); });
} }
void reshape_and_cache(torch::Tensor &key, torch::Tensor &value, void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
torch::Tensor &key_cache, torch::Tensor &value_cache, torch::Tensor& key_cache, torch::Tensor& value_cache,
torch::Tensor &slot_mapping, torch::Tensor& slot_mapping,
const std::string &kv_cache_dtype, float kv_scale) { const std::string& kv_cache_dtype, float kv_scale) {
TORCH_CHECK(kv_scale == 1.0f); TORCH_CHECK(kv_scale == 1.0f);
int num_tokens = key.size(0); int num_tokens = key.size(0);
@ -127,7 +128,7 @@ void reshape_and_cache(torch::Tensor &key, torch::Tensor &value,
}); });
} }
void swap_blocks(torch::Tensor &src, torch::Tensor &dst, void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
const torch::Tensor&block_mapping) { const torch::Tensor& block_mapping) {
TORCH_CHECK(false, "swap_blocks is unsupported on CPU.") TORCH_CHECK(false, "swap_blocks is unsupported on CPU.")
} }

View File

@ -2,10 +2,10 @@
namespace { namespace {
template <typename scalar_t> template <typename scalar_t>
void rms_norm_impl(scalar_t *__restrict__ out, void rms_norm_impl(scalar_t* __restrict__ out,
const scalar_t *__restrict__ input, const scalar_t* __restrict__ input,
const scalar_t *__restrict__ weight, const float epsilon, const scalar_t* __restrict__ weight, const float epsilon,
const int num_tokens, const int hidden_size) { const int num_tokens, const int hidden_size) {
using scalar_vec_t = vec_op::vec_t<scalar_t>; using scalar_vec_t = vec_op::vec_t<scalar_t>;
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0); TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0);
@ -41,11 +41,11 @@ void rms_norm_impl(scalar_t *__restrict__ out,
} }
template <typename scalar_t> template <typename scalar_t>
void fused_add_rms_norm_impl(scalar_t *__restrict__ input, void fused_add_rms_norm_impl(scalar_t* __restrict__ input,
scalar_t *__restrict__ residual, scalar_t* __restrict__ residual,
const scalar_t *__restrict__ weight, const scalar_t* __restrict__ weight,
const float epsilon, const int num_tokens, const float epsilon, const int num_tokens,
const int hidden_size) { const int hidden_size) {
using scalar_vec_t = vec_op::vec_t<scalar_t>; using scalar_vec_t = vec_op::vec_t<scalar_t>;
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0); TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0);
@ -85,24 +85,24 @@ void fused_add_rms_norm_impl(scalar_t *__restrict__ input,
} }
} }
} }
} // namespace } // namespace
void rms_norm(torch::Tensor &out, torch::Tensor &input, void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
torch::Tensor &weight, float epsilon) { float epsilon) {
int hidden_size = input.size(-1); int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size; int num_tokens = input.numel() / hidden_size;
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_impl", [&] { VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_impl", [&] {
CPU_KERNEL_GUARD_IN(rms_norm_impl) CPU_KERNEL_GUARD_IN(rms_norm_impl)
rms_norm_impl(out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), rms_norm_impl(out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(), epsilon, num_tokens, weight.data_ptr<scalar_t>(), epsilon, num_tokens,
hidden_size); hidden_size);
CPU_KERNEL_GUARD_OUT(rms_norm_impl) CPU_KERNEL_GUARD_OUT(rms_norm_impl)
}); });
} }
void fused_add_rms_norm(torch::Tensor &input, torch::Tensor &residual, void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
torch::Tensor &weight, float epsilon) { torch::Tensor& weight, float epsilon) {
int hidden_size = input.size(-1); int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size; int num_tokens = input.numel() / hidden_size;

View File

@ -4,16 +4,16 @@
namespace { namespace {
template <typename scalar_t> template <typename scalar_t>
void rotary_embedding_impl( void rotary_embedding_impl(
const int64_t const int64_t* __restrict__ positions, // [batch_size, seq_len] or
*__restrict__ positions, // [batch_size, seq_len] or [num_tokens] // [num_tokens]
scalar_t scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads,
*__restrict__ query, /// [batch_size, seq_len, num_heads, head_size] or /// head_size] or [num_tokens, num_heads,
/// [num_tokens, num_heads, head_size] /// head_size]
scalar_t scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
*__restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or // head_size] or [num_tokens, num_kv_heads,
// [num_tokens, num_kv_heads, head_size] // head_size]
const scalar_t const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
*__restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] // 2]
const int rot_dim, const int64_t query_stride, const int64_t key_stride, const int rot_dim, const int64_t query_stride, const int64_t key_stride,
const int num_heads, const int num_kv_heads, const int head_size, const int num_heads, const int num_kv_heads, const int head_size,
const int num_tokens) { const int num_tokens) {
@ -26,7 +26,7 @@ void rotary_embedding_impl(
#pragma omp parallel for #pragma omp parallel for
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
int64_t pos = positions[token_idx]; int64_t pos = positions[token_idx];
const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim; const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
for (int i = 0; i < num_heads; ++i) { for (int i = 0; i < num_heads; ++i) {
const int head_idx = i; const int head_idx = i;
@ -94,16 +94,16 @@ void rotary_embedding_impl(
template <typename scalar_t> template <typename scalar_t>
void rotary_embedding_gptj_impl( void rotary_embedding_gptj_impl(
const int64_t const int64_t* __restrict__ positions, // [batch_size, seq_len] or
*__restrict__ positions, // [batch_size, seq_len] or [num_tokens] // [num_tokens]
scalar_t scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads,
*__restrict__ query, /// [batch_size, seq_len, num_heads, head_size] or /// head_size] or [num_tokens, num_heads,
/// [num_tokens, num_heads, head_size] /// head_size]
scalar_t scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
*__restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or // head_size] or [num_tokens, num_kv_heads,
// [num_tokens, num_kv_heads, head_size] // head_size]
const scalar_t const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
*__restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] // 2]
const int rot_dim, const int64_t query_stride, const int64_t key_stride, const int rot_dim, const int64_t query_stride, const int64_t key_stride,
const int num_heads, const int num_kv_heads, const int head_size, const int num_heads, const int num_kv_heads, const int head_size,
const int num_tokens) { const int num_tokens) {
@ -113,13 +113,13 @@ void rotary_embedding_gptj_impl(
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
for (int i = 0; i < num_heads; ++i) { for (int i = 0; i < num_heads; ++i) {
int64_t pos = positions[token_idx]; int64_t pos = positions[token_idx];
const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim; const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
const scalar_t *cos_cache_ptr = cache_ptr; const scalar_t* cos_cache_ptr = cache_ptr;
const scalar_t *sin_cache_ptr = cache_ptr + embed_dim; const scalar_t* sin_cache_ptr = cache_ptr + embed_dim;
const int head_idx = i; const int head_idx = i;
const int64_t token_head = const int64_t token_head =
token_idx * query_stride + head_idx * head_size; token_idx * query_stride + head_idx * head_size;
scalar_t *head_query = token_head + query; scalar_t* head_query = token_head + query;
for (int j = 0; j < embed_dim; j += 1) { for (int j = 0; j < embed_dim; j += 1) {
const int rot_offset = j; const int rot_offset = j;
const int x_index = 2 * rot_offset; const int x_index = 2 * rot_offset;
@ -141,12 +141,12 @@ void rotary_embedding_gptj_impl(
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
for (int i = 0; i < num_kv_heads; ++i) { for (int i = 0; i < num_kv_heads; ++i) {
int64_t pos = positions[token_idx]; int64_t pos = positions[token_idx];
const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim; const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
const scalar_t *cos_cache_ptr = cache_ptr; const scalar_t* cos_cache_ptr = cache_ptr;
const scalar_t *sin_cache_ptr = cache_ptr + embed_dim; const scalar_t* sin_cache_ptr = cache_ptr + embed_dim;
const int head_idx = i; const int head_idx = i;
const int64_t token_head = token_idx * key_stride + head_idx * head_size; const int64_t token_head = token_idx * key_stride + head_idx * head_size;
scalar_t *head_key = key + token_head; scalar_t* head_key = key + token_head;
for (int j = 0; j < embed_dim; j += 1) { for (int j = 0; j < embed_dim; j += 1) {
const int rot_offset = j; const int rot_offset = j;
const int x_index = 2 * rot_offset; const int x_index = 2 * rot_offset;
@ -164,11 +164,11 @@ void rotary_embedding_gptj_impl(
} }
} }
} }
}; // namespace }; // namespace
void rotary_embedding(torch::Tensor &positions, torch::Tensor &query, void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
torch::Tensor &key, int head_size, torch::Tensor& key, int head_size,
torch::Tensor &cos_sin_cache, bool is_neox) { torch::Tensor& cos_sin_cache, bool is_neox) {
int num_tokens = query.numel() / query.size(-1); int num_tokens = query.numel() / query.size(-1);
int rot_dim = cos_sin_cache.size(1); int rot_dim = cos_sin_cache.size(1);
int num_heads = query.size(-1) / head_size; int num_heads = query.size(-1) / head_size;

View File

@ -8,66 +8,37 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
pybind11::module ops = m.def_submodule("ops", "vLLM custom operators"); pybind11::module ops = m.def_submodule("ops", "vLLM custom operators");
// Attention ops // Attention ops
ops.def( ops.def("paged_attention_v1", &paged_attention_v1,
"paged_attention_v1", "Compute the attention between an input query and the cached "
&paged_attention_v1, "keys/values using PagedAttention.");
"Compute the attention between an input query and the cached keys/values using PagedAttention."); ops.def("paged_attention_v2", &paged_attention_v2, "PagedAttention V2.");
ops.def(
"paged_attention_v2",
&paged_attention_v2,
"PagedAttention V2.");
// Activation ops // Activation ops
ops.def( ops.def("silu_and_mul", &silu_and_mul, "Activation function used in SwiGLU.");
"silu_and_mul", ops.def("gelu_and_mul", &gelu_and_mul,
&silu_and_mul, "Activation function used in GeGLU with `none` approximation.");
"Activation function used in SwiGLU."); ops.def("gelu_tanh_and_mul", &gelu_tanh_and_mul,
ops.def( "Activation function used in GeGLU with `tanh` approximation.");
"gelu_and_mul", ops.def("gelu_new", &gelu_new, "GELU implementation used in GPT-2.");
&gelu_and_mul, ops.def("gelu_fast", &gelu_fast, "Approximate GELU implementation.");
"Activation function used in GeGLU with `none` approximation.");
ops.def(
"gelu_tanh_and_mul",
&gelu_tanh_and_mul,
"Activation function used in GeGLU with `tanh` approximation.");
ops.def(
"gelu_new",
&gelu_new,
"GELU implementation used in GPT-2.");
ops.def(
"gelu_fast",
&gelu_fast,
"Approximate GELU implementation.");
// Layernorm // Layernorm
ops.def( ops.def("rms_norm", &rms_norm,
"rms_norm", "Apply Root Mean Square (RMS) Normalization to the input tensor.");
&rms_norm,
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
ops.def( ops.def("fused_add_rms_norm", &fused_add_rms_norm,
"fused_add_rms_norm", "In-place fused Add and RMS Normalization");
&fused_add_rms_norm,
"In-place fused Add and RMS Normalization");
// Rotary embedding // Rotary embedding
ops.def( ops.def("rotary_embedding", &rotary_embedding,
"rotary_embedding", "Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
&rotary_embedding,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
// Cache ops // Cache ops
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
cache_ops.def( cache_ops.def("swap_blocks", &swap_blocks,
"swap_blocks", "Swap in (out) the cache blocks from src to dst");
&swap_blocks, cache_ops.def("copy_blocks", &copy_blocks,
"Swap in (out) the cache blocks from src to dst"); "Copy the cache blocks from src to dst");
cache_ops.def( cache_ops.def("reshape_and_cache", &reshape_and_cache,
"copy_blocks", "Reshape the key and value tensors and cache them");
&copy_blocks,
"Copy the cache blocks from src to dst");
cache_ops.def(
"reshape_and_cache",
&reshape_and_cache,
"Reshape the key and value tensors and cache them");
} }

View File

@ -1,7 +1,7 @@
#pragma once #pragma once
#ifdef USE_ROCM #ifdef USE_ROCM
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#endif #endif
#ifndef USE_ROCM #ifndef USE_ROCM
@ -17,7 +17,8 @@
#endif #endif
#ifndef USE_ROCM #ifndef USE_ROCM
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask) #define VLLM_SHFL_XOR_SYNC(var, lane_mask) \
__shfl_xor_sync(uint32_t(-1), var, lane_mask)
#else #else
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask) #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask)
#endif #endif
@ -29,7 +30,8 @@
#endif #endif
#ifndef USE_ROCM #ifndef USE_ROCM
#define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down_sync(uint32_t(-1), var, lane_delta) #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) \
__shfl_down_sync(uint32_t(-1), var, lane_delta)
#else #else
#define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down(var, lane_delta) #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down(var, lane_delta)
#endif #endif
@ -41,4 +43,3 @@
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL) hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL)
#endif #endif

View File

@ -2,9 +2,6 @@
#include <torch/extension.h> #include <torch/extension.h>
int get_device_attribute( int get_device_attribute(int attribute, int device_id);
int attribute,
int device_id);
int get_max_shared_memory_per_block_device_attribute( int get_max_shared_memory_per_block_device_attribute(int device_id);
int device_id);

View File

@ -2,34 +2,28 @@
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#include <hip/hip_runtime_api.h> #include <hip/hip_runtime_api.h>
#endif #endif
int get_device_attribute( int get_device_attribute(int attribute, int device_id) {
int attribute, int device, value;
int device_id) if (device_id < 0) {
{ cudaGetDevice(&device);
int device, value; } else {
if (device_id < 0) { device = device_id;
cudaGetDevice(&device); }
} cudaDeviceGetAttribute(&value, static_cast<cudaDeviceAttr>(attribute),
else { device);
device = device_id; return value;
}
cudaDeviceGetAttribute(&value, static_cast<cudaDeviceAttr>(attribute), device);
return value;
} }
int get_max_shared_memory_per_block_device_attribute(int device_id) {
int get_max_shared_memory_per_block_device_attribute( int attribute;
int device_id) // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
{ // cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
int attribute;
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
// cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
#ifdef USE_ROCM #ifdef USE_ROCM
attribute = hipDeviceAttributeMaxSharedMemoryPerBlock; attribute = hipDeviceAttributeMaxSharedMemoryPerBlock;
#else #else
attribute = cudaDevAttrMaxSharedMemoryPerBlockOptin; attribute = cudaDevAttrMaxSharedMemoryPerBlockOptin;
#endif #endif
return get_device_attribute(attribute, device_id); return get_device_attribute(attribute, device_id);
} }

View File

@ -7,11 +7,11 @@
// fake pointer type // fake pointer type
using fptr_t = uint64_t; using fptr_t = uint64_t;
static_assert(sizeof(void *) == sizeof(fptr_t)); static_assert(sizeof(void*) == sizeof(fptr_t));
fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data, fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
const std::vector<std::string> &handles, const std::vector<std::string>& handles,
const std::vector<int64_t> &offsets, int rank, const std::vector<int64_t>& offsets, int rank,
bool full_nvlink) { bool full_nvlink) {
int world_size = offsets.size(); int world_size = offsets.size();
if (world_size > 8) if (world_size > 8)
@ -29,7 +29,7 @@ fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t)); std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t));
} }
return (fptr_t) new vllm::CustomAllreduce( return (fptr_t) new vllm::CustomAllreduce(
reinterpret_cast<vllm::Signal *>(meta.data_ptr()), rank_data.data_ptr(), reinterpret_cast<vllm::Signal*>(meta.data_ptr()), rank_data.data_ptr(),
rank_data.numel(), ipc_handles, offsets, rank, full_nvlink); rank_data.numel(), ipc_handles, offsets, rank, full_nvlink);
} }
@ -49,13 +49,13 @@ fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
* 5. A[None].expand(2, -1, -1, -1): Not OK * 5. A[None].expand(2, -1, -1, -1): Not OK
* 6. A[:, 1:, 1:]: Not OK * 6. A[:, 1:, 1:]: Not OK
*/ */
bool _is_weak_contiguous(torch::Tensor &t) { bool _is_weak_contiguous(torch::Tensor& t) {
return t.is_contiguous() || return t.is_contiguous() ||
(t.storage().nbytes() - t.storage_offset() * t.element_size() == (t.storage().nbytes() - t.storage_offset() * t.element_size() ==
t.numel() * t.element_size()); t.numel() * t.element_size());
} }
bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size, bool should_custom_ar(torch::Tensor& inp, int max_size, int world_size,
bool full_nvlink) { bool full_nvlink) {
auto inp_size = inp.numel() * inp.element_size(); auto inp_size = inp.numel() * inp.element_size();
// custom allreduce requires input byte size to be multiples of 16 // custom allreduce requires input byte size to be multiples of 16
@ -67,28 +67,27 @@ bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size,
return false; return false;
} }
void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out, void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
cudaStream_t stream) { cudaStream_t stream) {
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa); auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
TORCH_CHECK(_is_weak_contiguous(out)); TORCH_CHECK(_is_weak_contiguous(out));
switch (out.scalar_type()) { switch (out.scalar_type()) {
case at::ScalarType::Float: { case at::ScalarType::Float: {
fa->allreduce<float>(stream, reinterpret_cast<float *>(inp.data_ptr()), fa->allreduce<float>(stream, reinterpret_cast<float*>(inp.data_ptr()),
reinterpret_cast<float *>(out.data_ptr()), reinterpret_cast<float*>(out.data_ptr()),
out.numel()); out.numel());
break; break;
} }
case at::ScalarType::Half: { case at::ScalarType::Half: {
fa->allreduce<half>(stream, reinterpret_cast<half *>(inp.data_ptr()), fa->allreduce<half>(stream, reinterpret_cast<half*>(inp.data_ptr()),
reinterpret_cast<half *>(out.data_ptr()), reinterpret_cast<half*>(out.data_ptr()), out.numel());
out.numel());
break; break;
} }
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
case at::ScalarType::BFloat16: { case at::ScalarType::BFloat16: {
fa->allreduce<nv_bfloat16>( fa->allreduce<nv_bfloat16>(
stream, reinterpret_cast<nv_bfloat16 *>(inp.data_ptr()), stream, reinterpret_cast<nv_bfloat16*>(inp.data_ptr()),
reinterpret_cast<nv_bfloat16 *>(out.data_ptr()), out.numel()); reinterpret_cast<nv_bfloat16*>(out.data_ptr()), out.numel());
break; break;
} }
#endif #endif
@ -98,7 +97,7 @@ void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out,
} }
} }
void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) { void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
auto stream = c10::cuda::getCurrentCUDAStream().stream(); auto stream = c10::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
@ -106,8 +105,8 @@ void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) {
_all_reduce(_fa, inp, out, stream); _all_reduce(_fa, inp, out, stream);
} }
void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &reg_buffer, void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
torch::Tensor &out) { torch::Tensor& out) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
auto stream = c10::cuda::getCurrentCUDAStream().stream(); auto stream = c10::cuda::getCurrentCUDAStream().stream();
@ -122,27 +121,27 @@ void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &reg_buffer,
} }
void dispose(fptr_t _fa) { void dispose(fptr_t _fa) {
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa); auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
delete fa; delete fa;
} }
int meta_size() { return sizeof(vllm::Signal); } int meta_size() { return sizeof(vllm::Signal); }
void register_buffer(fptr_t _fa, torch::Tensor &t, void register_buffer(fptr_t _fa, torch::Tensor& t,
const std::vector<std::string> &handles, const std::vector<std::string>& handles,
const std::vector<int64_t> &offsets) { const std::vector<int64_t>& offsets) {
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa); auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
fa->register_buffer(handles, offsets, t.data_ptr()); fa->register_buffer(handles, offsets, t.data_ptr());
} }
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta( std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(
fptr_t _fa) { fptr_t _fa) {
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa); auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
return fa->get_graph_buffer_ipc_meta(); return fa->get_graph_buffer_ipc_meta();
} }
void register_graph_buffers(fptr_t _fa, const std::vector<std::string> &handles, void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
const std::vector<std::vector<int64_t>> &offsets) { const std::vector<std::vector<int64_t>>& offsets) {
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa); auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
fa->register_graph_buffers(handles, offsets); fa->register_graph_buffers(handles, offsets);
} }

View File

@ -31,9 +31,9 @@ struct Signal {
alignas(128) uint32_t end[kMaxBlocks][8]; alignas(128) uint32_t end[kMaxBlocks][8];
}; };
struct __align__(16) RankData { const void *__restrict__ ptrs[8]; }; struct __align__(16) RankData { const void* __restrict__ ptrs[8]; };
struct __align__(16) RankSignals { volatile Signal *signals[8]; }; struct __align__(16) RankSignals { volatile Signal* signals[8]; };
// like std::array, but aligned // like std::array, but aligned
template <typename T, int sz> template <typename T, int sz>
@ -68,11 +68,11 @@ DINLINE half downcast_s(float val) {
// scalar add functions // scalar add functions
// for some reason when compiling with Pytorch, the + operator for half and // for some reason when compiling with Pytorch, the + operator for half and
// bfloat is disabled so we call the intrinsics directly // bfloat is disabled so we call the intrinsics directly
DINLINE half &assign_add(half &a, half b) { DINLINE half& assign_add(half& a, half b) {
a = __hadd(a, b); a = __hadd(a, b);
return a; return a;
} }
DINLINE float &assign_add(float &a, float b) { return a += b; } DINLINE float& assign_add(float& a, float b) { return a += b; }
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); } DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); }
@ -80,14 +80,14 @@ template <>
DINLINE nv_bfloat16 downcast_s(float val) { DINLINE nv_bfloat16 downcast_s(float val) {
return __float2bfloat16(val); return __float2bfloat16(val);
} }
DINLINE nv_bfloat16 &assign_add(nv_bfloat16 &a, nv_bfloat16 b) { DINLINE nv_bfloat16& assign_add(nv_bfloat16& a, nv_bfloat16 b) {
a = __hadd(a, b); a = __hadd(a, b);
return a; return a;
} }
#endif #endif
template <typename T, int N> template <typename T, int N>
DINLINE array_t<T, N> &packed_assign_add(array_t<T, N> &a, array_t<T, N> b) { DINLINE array_t<T, N>& packed_assign_add(array_t<T, N>& a, array_t<T, N> b) {
#pragma unroll #pragma unroll
for (int i = 0; i < N; i++) { for (int i = 0; i < N; i++) {
assign_add(a.data[i], b.data[i]); assign_add(a.data[i], b.data[i]);
@ -128,7 +128,7 @@ DINLINE O downcast(array_t<float, O::size> val) {
// prior memory accesses. Note: volatile writes will not be reordered against // prior memory accesses. Note: volatile writes will not be reordered against
// other volatile writes. // other volatile writes.
template <int ngpus> template <int ngpus>
DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg, DINLINE void start_sync(const RankSignals& sg, volatile Signal* self_sg,
int rank) { int rank) {
if (threadIdx.x < ngpus) { if (threadIdx.x < ngpus) {
// reset flag for next time // reset flag for next time
@ -137,8 +137,7 @@ DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg,
// Latency = 1 p2p write // Latency = 1 p2p write
sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1; sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1;
// wait until we got true from all ranks // wait until we got true from all ranks
while (!self_sg->start[blockIdx.x][threadIdx.x]) while (!self_sg->start[blockIdx.x][threadIdx.x]);
;
} }
__syncthreads(); __syncthreads();
} }
@ -147,7 +146,7 @@ DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg,
// barrier in the all reduce kernel. If it's the final synchronization barrier, // barrier in the all reduce kernel. If it's the final synchronization barrier,
// we don't need to make any visibility guarantees for prior memory accesses. // we don't need to make any visibility guarantees for prior memory accesses.
template <int ngpus, bool final_sync = false> template <int ngpus, bool final_sync = false>
DINLINE void end_sync(const RankSignals &sg, volatile Signal *self_sg, DINLINE void end_sync(const RankSignals& sg, volatile Signal* self_sg,
int rank) { int rank) {
__syncthreads(); __syncthreads();
// eliminate the case that prior writes are not visible after signals become // eliminate the case that prior writes are not visible after signals become
@ -162,14 +161,13 @@ DINLINE void end_sync(const RankSignals &sg, volatile Signal *self_sg,
// Latency = 1 p2p write // Latency = 1 p2p write
sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1; sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1;
// wait until we got true from all ranks // wait until we got true from all ranks
while (!self_sg->end[blockIdx.x][threadIdx.x]) while (!self_sg->end[blockIdx.x][threadIdx.x]);
;
} }
if constexpr (!final_sync) __syncthreads(); if constexpr (!final_sync) __syncthreads();
} }
template <typename P, int ngpus, typename A> template <typename P, int ngpus, typename A>
DINLINE P packed_reduce(const P *ptrs[], int idx) { DINLINE P packed_reduce(const P* ptrs[], int idx) {
A tmp = upcast(ptrs[0][idx]); A tmp = upcast(ptrs[0][idx]);
#pragma unroll #pragma unroll
for (int i = 1; i < ngpus; i++) { for (int i = 1; i < ngpus; i++) {
@ -180,8 +178,8 @@ DINLINE P packed_reduce(const P *ptrs[], int idx) {
template <typename T, int ngpus> template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1) __global__ void __launch_bounds__(512, 1)
cross_device_reduce_1stage(RankData *_dp, RankSignals sg, cross_device_reduce_1stage(RankData* _dp, RankSignals sg,
volatile Signal *self_sg, T *__restrict__ result, volatile Signal* self_sg, T* __restrict__ result,
int rank, int size) { int rank, int size) {
using P = typename packed_t<T>::P; using P = typename packed_t<T>::P;
using A = typename packed_t<T>::A; using A = typename packed_t<T>::A;
@ -192,21 +190,20 @@ __global__ void __launch_bounds__(512, 1)
// do the actual reduction // do the actual reduction
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
idx += gridDim.x * blockDim.x) { idx += gridDim.x * blockDim.x) {
((P *)result)[idx] = ((P*)result)[idx] = packed_reduce<P, ngpus, A>((const P**)&dp.ptrs[0], idx);
packed_reduce<P, ngpus, A>((const P **)&dp.ptrs[0], idx);
} }
end_sync<ngpus, true>(sg, self_sg, rank); end_sync<ngpus, true>(sg, self_sg, rank);
} }
template <typename P> template <typename P>
DINLINE P *get_tmp_buf(volatile Signal *sg) { DINLINE P* get_tmp_buf(volatile Signal* sg) {
return (P *)(((Signal *)sg) + 1); return (P*)(((Signal*)sg) + 1);
} }
template <typename T, int ngpus> template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1) __global__ void __launch_bounds__(512, 1)
cross_device_reduce_2stage(RankData *_dp, RankSignals sg, cross_device_reduce_2stage(RankData* _dp, RankSignals sg,
volatile Signal *self_sg, T *__restrict__ result, volatile Signal* self_sg, T* __restrict__ result,
int rank, int size) { int rank, int size) {
int tid = blockIdx.x * blockDim.x + threadIdx.x; int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = gridDim.x * blockDim.x; int stride = gridDim.x * blockDim.x;
@ -216,12 +213,12 @@ __global__ void __launch_bounds__(512, 1)
int start = rank * part; int start = rank * part;
int end = rank == ngpus - 1 ? size : start + part; int end = rank == ngpus - 1 ? size : start + part;
int largest_part = part + size % ngpus; int largest_part = part + size % ngpus;
const P *ptrs[ngpus]; const P* ptrs[ngpus];
P *tmps[ngpus]; P* tmps[ngpus];
#pragma unroll #pragma unroll
for (int i = 0; i < ngpus; i++) { for (int i = 0; i < ngpus; i++) {
int target = (rank + i) % ngpus; int target = (rank + i) % ngpus;
ptrs[i] = (const P *)_dp->ptrs[target]; ptrs[i] = (const P*)_dp->ptrs[target];
tmps[i] = get_tmp_buf<P>(sg.signals[target]); tmps[i] = get_tmp_buf<P>(sg.signals[target]);
} }
auto tmp_out = tmps[0]; auto tmp_out = tmps[0];
@ -243,7 +240,7 @@ __global__ void __launch_bounds__(512, 1)
int gather_from_rank = ((rank + i) % ngpus); int gather_from_rank = ((rank + i) % ngpus);
if (gather_from_rank == ngpus - 1 || idx < part) { if (gather_from_rank == ngpus - 1 || idx < part) {
int dst_idx = gather_from_rank * part + idx; int dst_idx = gather_from_rank * part + idx;
((P *)result)[dst_idx] = tmps[i][idx]; ((P*)result)[dst_idx] = tmps[i][idx];
} }
} }
} }
@ -261,14 +258,14 @@ class CustomAllreduce {
// below are device pointers // below are device pointers
RankSignals sg_; RankSignals sg_;
std::unordered_map<void *, RankData *> buffers_; std::unordered_map<void*, RankData*> buffers_;
Signal *self_sg_; Signal* self_sg_;
// stores the registered device pointers from all ranks // stores the registered device pointers from all ranks
RankData *d_rank_data_base_, *d_rank_data_end_; RankData *d_rank_data_base_, *d_rank_data_end_;
std::vector<void *> graph_unreg_buffers_; std::vector<void*> graph_unreg_buffers_;
// a map from IPC handles to opened IPC pointers // a map from IPC handles to opened IPC pointers
std::map<IPC_KEY, char *> ipc_handles_; std::map<IPC_KEY, char*> ipc_handles_;
/** /**
* meta is a pointer to device metadata and temporary buffer for allreduce. * meta is a pointer to device metadata and temporary buffer for allreduce.
@ -279,22 +276,22 @@ class CustomAllreduce {
* note: this class does not own any device memory. Any required buffers * note: this class does not own any device memory. Any required buffers
* are passed in from the constructor * are passed in from the constructor
*/ */
CustomAllreduce(Signal *meta, void *rank_data, size_t rank_data_sz, CustomAllreduce(Signal* meta, void* rank_data, size_t rank_data_sz,
const cudaIpcMemHandle_t *handles, const cudaIpcMemHandle_t* handles,
const std::vector<int64_t> &offsets, int rank, const std::vector<int64_t>& offsets, int rank,
bool full_nvlink = true) bool full_nvlink = true)
: rank_(rank), : rank_(rank),
world_size_(offsets.size()), world_size_(offsets.size()),
full_nvlink_(full_nvlink), full_nvlink_(full_nvlink),
self_sg_(meta), self_sg_(meta),
d_rank_data_base_(reinterpret_cast<RankData *>(rank_data)), d_rank_data_base_(reinterpret_cast<RankData*>(rank_data)),
d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) { d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) {
for (int i = 0; i < world_size_; i++) { for (int i = 0; i < world_size_; i++) {
Signal *rank_sg; Signal* rank_sg;
if (i != rank_) { if (i != rank_) {
char *handle = open_ipc_handle(&handles[i]); char* handle = open_ipc_handle(&handles[i]);
handle += offsets[i]; handle += offsets[i];
rank_sg = (Signal *)handle; rank_sg = (Signal*)handle;
} else { } else {
rank_sg = self_sg_; rank_sg = self_sg_;
} }
@ -302,13 +299,13 @@ class CustomAllreduce {
} }
} }
char *open_ipc_handle(const void *ipc_handle) { char* open_ipc_handle(const void* ipc_handle) {
auto [it, new_handle] = auto [it, new_handle] =
ipc_handles_.insert({*((IPC_KEY *)ipc_handle), nullptr}); ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr});
if (new_handle) { if (new_handle) {
char *ipc_ptr; char* ipc_ptr;
CUDACHECK(cudaIpcOpenMemHandle((void **)&ipc_ptr, CUDACHECK(cudaIpcOpenMemHandle((void**)&ipc_ptr,
*((const cudaIpcMemHandle_t *)ipc_handle), *((const cudaIpcMemHandle_t*)ipc_handle),
cudaIpcMemLazyEnablePeerAccess)); cudaIpcMemLazyEnablePeerAccess));
it->second = ipc_ptr; it->second = ipc_ptr;
} }
@ -323,7 +320,7 @@ class CustomAllreduce {
std::vector<int64_t> offsets(num_buffers); std::vector<int64_t> offsets(num_buffers);
for (int i = 0; i < num_buffers; i++) { for (int i = 0; i < num_buffers; i++) {
auto ptr = graph_unreg_buffers_[i]; auto ptr = graph_unreg_buffers_[i];
void *base_ptr; void* base_ptr;
// note: must share the base address of each allocation, or we get wrong // note: must share the base address of each allocation, or we get wrong
// address // address
if (cuPointerGetAttribute(&base_ptr, if (cuPointerGetAttribute(&base_ptr,
@ -331,8 +328,8 @@ class CustomAllreduce {
(CUdeviceptr)ptr) != CUDA_SUCCESS) (CUdeviceptr)ptr) != CUDA_SUCCESS)
throw std::runtime_error("failed to get pointer attr"); throw std::runtime_error("failed to get pointer attr");
CUDACHECK(cudaIpcGetMemHandle( CUDACHECK(cudaIpcGetMemHandle(
(cudaIpcMemHandle_t *)&handles[i * handle_sz], base_ptr)); (cudaIpcMemHandle_t*)&handles[i * handle_sz], base_ptr));
offsets[i] = ((char *)ptr) - ((char *)base_ptr); offsets[i] = ((char*)ptr) - ((char*)base_ptr);
} }
return std::make_pair(handles, offsets); return std::make_pair(handles, offsets);
} }
@ -344,13 +341,13 @@ class CustomAllreduce {
std::to_string(d_rank_data_base_ + num - d_rank_data_end_)); std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
} }
void register_buffer(const std::vector<std::string> &handles, void register_buffer(const std::vector<std::string>& handles,
const std::vector<int64_t> &offsets, void *self) { const std::vector<int64_t>& offsets, void* self) {
check_rank_data_capacity(); check_rank_data_capacity();
RankData data; RankData data;
for (int i = 0; i < world_size_; i++) { for (int i = 0; i < world_size_; i++) {
if (i != rank_) { if (i != rank_) {
char *handle = open_ipc_handle(handles[i].data()); char* handle = open_ipc_handle(handles[i].data());
handle += offsets[i]; handle += offsets[i];
data.ptrs[i] = handle; data.ptrs[i] = handle;
} else { } else {
@ -371,17 +368,17 @@ class CustomAllreduce {
// got a different address. IPC handles have internal reference counting // got a different address. IPC handles have internal reference counting
// mechanism so overhead should be small. // mechanism so overhead should be small.
void register_graph_buffers( void register_graph_buffers(
const std::vector<std::string> &handles, const std::vector<std::string>& handles,
const std::vector<std::vector<int64_t>> &offsets) { const std::vector<std::vector<int64_t>>& offsets) {
auto num_buffers = graph_unreg_buffers_.size(); auto num_buffers = graph_unreg_buffers_.size();
check_rank_data_capacity(num_buffers); check_rank_data_capacity(num_buffers);
std::vector<RankData> rank_data(num_buffers); std::vector<RankData> rank_data(num_buffers);
for (int i = 0; i < num_buffers; i++) { for (int i = 0; i < num_buffers; i++) {
auto self_ptr = graph_unreg_buffers_[i]; auto self_ptr = graph_unreg_buffers_[i];
auto &rd = rank_data[i]; auto& rd = rank_data[i];
for (int j = 0; j < world_size_; j++) { for (int j = 0; j < world_size_; j++) {
if (j != rank_) { if (j != rank_) {
char *handle = char* handle =
open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]); open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]);
handle += offsets[j][i]; handle += offsets[j][i];
rd.ptrs[j] = handle; rd.ptrs[j] = handle;
@ -405,7 +402,7 @@ class CustomAllreduce {
* will cause contention on NVLink bus. * will cause contention on NVLink bus.
*/ */
template <typename T> template <typename T>
void allreduce(cudaStream_t stream, T *input, T *output, int size, void allreduce(cudaStream_t stream, T* input, T* output, int size,
int threads = 512, int block_limit = 36) { int threads = 512, int block_limit = 36) {
auto d = packed_t<T>::P::size; auto d = packed_t<T>::P::size;
if (size % d != 0) if (size % d != 0)
@ -418,7 +415,7 @@ class CustomAllreduce {
std::to_string(kMaxBlocks) + ". Got " + std::to_string(kMaxBlocks) + ". Got " +
std::to_string(block_limit)); std::to_string(block_limit));
RankData *ptrs; RankData* ptrs;
cudaStreamCaptureStatus status; cudaStreamCaptureStatus status;
CUDACHECK(cudaStreamIsCapturing(stream, &status)); CUDACHECK(cudaStreamIsCapturing(stream, &status));
if (status == cudaStreamCaptureStatusActive) { if (status == cudaStreamCaptureStatusActive) {

View File

@ -48,7 +48,7 @@ __global__ void dummy_kernel() {
} }
template <typename T> template <typename T>
__global__ void set_data(T *data, int size, int myRank) { __global__ void set_data(T* data, int size, int myRank) {
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
idx += gridDim.x * blockDim.x) { idx += gridDim.x * blockDim.x) {
data[idx] = myRank * 0.11f; data[idx] = myRank * 0.11f;
@ -56,8 +56,8 @@ __global__ void set_data(T *data, int size, int myRank) {
} }
template <typename T> template <typename T>
__global__ void convert_data(const T *data1, const T *data2, double *fdata1, __global__ void convert_data(const T* data1, const T* data2, double* fdata1,
double *fdata2, int size) { double* fdata2, int size) {
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
idx += gridDim.x * blockDim.x) { idx += gridDim.x * blockDim.x) {
fdata1[idx] = data1[idx]; fdata1[idx] = data1[idx];
@ -65,7 +65,7 @@ __global__ void convert_data(const T *data1, const T *data2, double *fdata1,
} }
} }
__global__ void init_rand(curandState_t *state, int size, int nRanks) { __global__ void init_rand(curandState_t* state, int size, int nRanks) {
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
idx += gridDim.x * blockDim.x) { idx += gridDim.x * blockDim.x) {
for (int i = 0; i < nRanks; i++) { for (int i = 0; i < nRanks; i++) {
@ -75,7 +75,7 @@ __global__ void init_rand(curandState_t *state, int size, int nRanks) {
} }
template <typename T> template <typename T>
__global__ void gen_data(curandState_t *state, T *data, double *ground_truth, __global__ void gen_data(curandState_t* state, T* data, double* ground_truth,
int myRank, int nRanks, int size) { int myRank, int nRanks, int size) {
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
idx += gridDim.x * blockDim.x) { idx += gridDim.x * blockDim.x) {
@ -91,9 +91,9 @@ __global__ void gen_data(curandState_t *state, T *data, double *ground_truth,
} }
template <typename T> template <typename T>
void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
int data_size, bool performance_test) { int data_size, bool performance_test) {
T *result; T* result;
cudaStream_t stream; cudaStream_t stream;
CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
CUDACHECK(cudaMalloc(&result, data_size * sizeof(T))); CUDACHECK(cudaMalloc(&result, data_size * sizeof(T)));
@ -101,8 +101,8 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
cudaIpcMemHandle_t self_data_handle; cudaIpcMemHandle_t self_data_handle;
cudaIpcMemHandle_t data_handles[8]; cudaIpcMemHandle_t data_handles[8];
vllm::Signal *buffer; vllm::Signal* buffer;
T *self_data_copy; T* self_data_copy;
/** /**
* Allocate IPC buffer * Allocate IPC buffer
* *
@ -125,22 +125,22 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
MPI_BYTE, data_handles, sizeof(cudaIpcMemHandle_t), MPI_BYTE, data_handles, sizeof(cudaIpcMemHandle_t),
MPI_BYTE, MPI_COMM_WORLD)); MPI_BYTE, MPI_COMM_WORLD));
void *rank_data; void* rank_data;
size_t rank_data_sz = 16 * 1024 * 1024; size_t rank_data_sz = 16 * 1024 * 1024;
CUDACHECK(cudaMalloc(&rank_data, rank_data_sz)); CUDACHECK(cudaMalloc(&rank_data, rank_data_sz));
std::vector<int64_t> offsets(nRanks, 0); std::vector<int64_t> offsets(nRanks, 0);
vllm::CustomAllreduce fa(buffer, rank_data, rank_data_sz, data_handles, vllm::CustomAllreduce fa(buffer, rank_data, rank_data_sz, data_handles,
offsets, myRank); offsets, myRank);
auto *self_data = auto* self_data =
reinterpret_cast<T *>(reinterpret_cast<char *>(buffer) + reinterpret_cast<T*>(reinterpret_cast<char*>(buffer) +
sizeof(vllm::Signal) + data_size * sizeof(T)); sizeof(vllm::Signal) + data_size * sizeof(T));
// hack buffer registration // hack buffer registration
{ {
std::vector<std::string> handles; std::vector<std::string> handles;
handles.reserve(nRanks); handles.reserve(nRanks);
for (int i = 0; i < nRanks; i++) { for (int i = 0; i < nRanks; i++) {
char *begin = (char *)&data_handles[i]; char* begin = (char*)&data_handles[i];
char *end = (char *)&data_handles[i + 1]; char* end = (char*)&data_handles[i + 1];
handles.emplace_back(begin, end); handles.emplace_back(begin, end);
} }
std::vector<int64_t> offsets(nRanks, std::vector<int64_t> offsets(nRanks,
@ -148,9 +148,9 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
fa.register_buffer(handles, offsets, self_data); fa.register_buffer(handles, offsets, self_data);
} }
double *ground_truth; double* ground_truth;
CUDACHECK(cudaMallocHost(&ground_truth, data_size * sizeof(double))); CUDACHECK(cudaMallocHost(&ground_truth, data_size * sizeof(double)));
curandState_t *states; curandState_t* states;
CUDACHECK(cudaMalloc(&states, sizeof(curandState_t) * nRanks * data_size)); CUDACHECK(cudaMalloc(&states, sizeof(curandState_t) * nRanks * data_size));
init_rand<<<108, 1024, 0, stream>>>(states, data_size, nRanks); init_rand<<<108, 1024, 0, stream>>>(states, data_size, nRanks);
gen_data<T><<<108, 1024, 0, stream>>>(states, self_data, ground_truth, myRank, gen_data<T><<<108, 1024, 0, stream>>>(states, self_data, ground_truth, myRank,
@ -287,7 +287,7 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
CUDACHECK(cudaStreamDestroy(stream)); CUDACHECK(cudaStreamDestroy(stream));
} }
int main(int argc, char **argv) { int main(int argc, char** argv) {
int nRanks, myRank; int nRanks, myRank;
MPICHECK(MPI_Init(&argc, &argv)); MPICHECK(MPI_Init(&argc, &argv));
MPICHECK(MPI_Comm_rank(MPI_COMM_WORLD, &myRank)); MPICHECK(MPI_Comm_rank(MPI_COMM_WORLD, &myRank));
@ -296,7 +296,7 @@ int main(int argc, char **argv) {
ncclUniqueId id; ncclUniqueId id;
ncclComm_t comm; ncclComm_t comm;
if (myRank == 0) ncclGetUniqueId(&id); if (myRank == 0) ncclGetUniqueId(&id);
MPICHECK(MPI_Bcast(static_cast<void *>(&id), sizeof(id), MPI_BYTE, 0, MPICHECK(MPI_Bcast(static_cast<void*>(&id), sizeof(id), MPI_BYTE, 0,
MPI_COMM_WORLD)); MPI_COMM_WORLD));
NCCLCHECK(ncclCommInitRank(&comm, nRanks, id, myRank)); NCCLCHECK(ncclCommInitRank(&comm, nRanks, id, myRank));

View File

@ -6,32 +6,30 @@
#include <torch/extension.h> #include <torch/extension.h>
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ #define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \ #define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \ #define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \ AT_DISPATCH_SWITCH(TYPE, NAME, \
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__)) VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \ #define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ #define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))

View File

@ -11,26 +11,24 @@
#include <hip/hip_bf16.h> #include <hip/hip_bf16.h>
#include <hip/hip_fp16.h> #include <hip/hip_fp16.h>
using __nv_bfloat16 = __hip_bfloat16; using __nv_bfloat16 = __hip_bfloat16;
using __nv_bfloat162 = __hip_bfloat162; using __nv_bfloat162 = __hip_bfloat162;
#endif #endif
namespace vllm { namespace vllm {
// TODO(woosuk): Further optimize this kernel. // TODO(woosuk): Further optimize this kernel.
template<typename scalar_t> template <typename scalar_t>
__global__ void rms_norm_kernel( __global__ void rms_norm_kernel(
scalar_t* __restrict__ out, // [..., hidden_size] scalar_t* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size] const scalar_t* __restrict__ input, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size] const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const float epsilon, const int num_tokens, const int hidden_size) {
const int num_tokens,
const int hidden_size) {
__shared__ float s_variance; __shared__ float s_variance;
float variance = 0.0f; float variance = 0.0f;
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
const float x = (float) input[blockIdx.x * hidden_size + idx]; const float x = (float)input[blockIdx.x * hidden_size + idx];
variance += x * x; variance += x * x;
} }
variance = blockReduceSum<float>(variance); variance = blockReduceSum<float>(variance);
@ -40,12 +38,12 @@ __global__ void rms_norm_kernel(
__syncthreads(); __syncthreads();
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float) input[blockIdx.x * hidden_size + idx]; float x = (float)input[blockIdx.x * hidden_size + idx];
out[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx]; out[blockIdx.x * hidden_size + idx] =
((scalar_t)(x * s_variance)) * weight[idx];
} }
} }
/* Converter structs for the conversion from torch types to HIP/CUDA types, /* Converter structs for the conversion from torch types to HIP/CUDA types,
and the associated type conversions within HIP/CUDA. These helpers need and the associated type conversions within HIP/CUDA. These helpers need
to be implemented for now because the relevant type conversion to be implemented for now because the relevant type conversion
@ -56,46 +54,63 @@ __global__ void rms_norm_kernel(
If false, the optimized kernel is not used for the corresponding torch type. If false, the optimized kernel is not used for the corresponding torch type.
If true, the struct should be fully defined as shown in the examples below. If true, the struct should be fully defined as shown in the examples below.
*/ */
template<typename torch_type> template <typename torch_type>
struct _typeConvert { static constexpr bool exists = false; }; struct _typeConvert {
static constexpr bool exists = false;
};
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)) #if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
// CUDA < 12.0 runs into issues with packed type conversion // CUDA < 12.0 runs into issues with packed type conversion
template<> template <>
struct _typeConvert<c10::Half> { struct _typeConvert<c10::Half> {
static constexpr bool exists = true; static constexpr bool exists = true;
using hip_type = __half; using hip_type = __half;
using packed_hip_type = __half2; using packed_hip_type = __half2;
__device__ static inline float convert(hip_type x) { return __half2float(x); } __device__ static inline float convert(hip_type x) { return __half2float(x); }
__device__ static inline float2 convert(packed_hip_type x) { return __half22float2(x); } __device__ static inline float2 convert(packed_hip_type x) {
__device__ static inline hip_type convert(float x) { return __float2half_rn(x); } return __half22float2(x);
__device__ static inline packed_hip_type convert(float2 x) { return __float22half2_rn(x); } }
__device__ static inline hip_type convert(float x) {
return __float2half_rn(x);
}
__device__ static inline packed_hip_type convert(float2 x) {
return __float22half2_rn(x);
}
}; };
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
// CUDA_ARCH < 800 does not have BF16 support // CUDA_ARCH < 800 does not have BF16 support
// TODO: Add in ROCm support once public headers handle bf16 maturely // TODO: Add in ROCm support once public headers handle bf16 maturely
template<> template <>
struct _typeConvert<c10::BFloat16> { struct _typeConvert<c10::BFloat16> {
static constexpr bool exists = true; static constexpr bool exists = true;
using hip_type = __nv_bfloat16; using hip_type = __nv_bfloat16;
using packed_hip_type = __nv_bfloat162; using packed_hip_type = __nv_bfloat162;
__device__ static inline float convert(hip_type x) { return __bfloat162float(x); } __device__ static inline float convert(hip_type x) {
__device__ static inline float2 convert(packed_hip_type x) { return __bfloat1622float2(x); } return __bfloat162float(x);
__device__ static inline hip_type convert(float x) { return __float2bfloat16(x); } }
__device__ static inline packed_hip_type convert(float2 x) { return __float22bfloat162_rn(x); } __device__ static inline float2 convert(packed_hip_type x) {
return __bfloat1622float2(x);
}
__device__ static inline hip_type convert(float x) {
return __float2bfloat16(x);
}
__device__ static inline packed_hip_type convert(float2 x) {
return __float22bfloat162_rn(x);
}
}; };
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)) #endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >=
// 12000))
/* Vector POD struct to generate vectorized and packed FP16/BF16 ops /* Vector POD struct to generate vectorized and packed FP16/BF16 ops
for appropriate specializations of fused_add_rms_norm_kernel. for appropriate specializations of fused_add_rms_norm_kernel.
Only functions that are necessary in that kernel are implemented. Only functions that are necessary in that kernel are implemented.
Alignment to 16 bytes is required to use 128-bit global memory ops. Alignment to 16 bytes is required to use 128-bit global memory ops.
*/ */
template<typename scalar_t, int width> template <typename scalar_t, int width>
struct alignas(16) _f16Vec { struct alignas(16) _f16Vec {
/* Not theoretically necessary that width is a power of 2 but should /* Not theoretically necessary that width is a power of 2 but should
almost always be the case for optimization purposes */ almost always be the case for optimization purposes */
@ -108,51 +123,49 @@ struct alignas(16) _f16Vec {
__device__ _f16Vec& operator+=(const _f16Vec<scalar_t, width>& other) { __device__ _f16Vec& operator+=(const _f16Vec<scalar_t, width>& other) {
if constexpr (width % 2 == 0) { if constexpr (width % 2 == 0) {
#pragma unroll #pragma unroll
for (int i = 0; i < width; i += 2) { for (int i = 0; i < width; i += 2) {
T2 temp{data[i], data[i+1]}; T2 temp{data[i], data[i + 1]};
temp += T2{other.data[i], other.data[i+1]}; temp += T2{other.data[i], other.data[i + 1]};
data[i] = temp.x; data[i] = temp.x;
data[i+1] = temp.y; data[i + 1] = temp.y;
} }
} else { } else {
#pragma unroll #pragma unroll
for (int i = 0; i < width; ++i) for (int i = 0; i < width; ++i) data[i] += other.data[i];
data[i] += other.data[i];
} }
return *this; return *this;
} }
__device__ _f16Vec& operator*=(const _f16Vec<scalar_t, width>& other) { __device__ _f16Vec& operator*=(const _f16Vec<scalar_t, width>& other) {
if constexpr (width % 2 == 0) { if constexpr (width % 2 == 0) {
#pragma unroll #pragma unroll
for (int i = 0; i < width; i += 2) { for (int i = 0; i < width; i += 2) {
T2 temp{data[i], data[i+1]}; T2 temp{data[i], data[i + 1]};
temp *= T2{other.data[i], other.data[i+1]}; temp *= T2{other.data[i], other.data[i + 1]};
data[i] = temp.x; data[i] = temp.x;
data[i+1] = temp.y; data[i + 1] = temp.y;
} }
} else { } else {
#pragma unroll #pragma unroll
for (int i = 0; i < width; ++i) for (int i = 0; i < width; ++i) data[i] *= other.data[i];
data[i] *= other.data[i];
} }
return *this; return *this;
} }
__device__ _f16Vec& operator*=(const float scale) { __device__ _f16Vec& operator*=(const float scale) {
if constexpr (width % 2 == 0) { if constexpr (width % 2 == 0) {
#pragma unroll #pragma unroll
for (int i = 0; i < width; i += 2) { for (int i = 0; i < width; i += 2) {
float2 temp_f = Converter::convert(T2{data[i], data[i+1]}); float2 temp_f = Converter::convert(T2{data[i], data[i + 1]});
temp_f.x *= scale; temp_f.x *= scale;
temp_f.y *= scale; temp_f.y *= scale;
T2 temp = Converter::convert(temp_f); T2 temp = Converter::convert(temp_f);
data[i] = temp.x; data[i] = temp.x;
data[i+1] = temp.y; data[i + 1] = temp.y;
} }
} else { } else {
#pragma unroll #pragma unroll
for (int i = 0; i < width; ++i) { for (int i = 0; i < width; ++i) {
float temp = Converter::convert(data[i]) * scale; float temp = Converter::convert(data[i]) * scale;
data[i] = Converter::convert(temp); data[i] = Converter::convert(temp);
@ -164,13 +177,13 @@ struct alignas(16) _f16Vec {
__device__ float sum_squares() const { __device__ float sum_squares() const {
float result = 0.0f; float result = 0.0f;
if constexpr (width % 2 == 0) { if constexpr (width % 2 == 0) {
#pragma unroll #pragma unroll
for (int i = 0; i < width; i += 2) { for (int i = 0; i < width; i += 2) {
float2 z = Converter::convert(T2{data[i], data[i+1]}); float2 z = Converter::convert(T2{data[i], data[i + 1]});
result += z.x * z.x + z.y * z.y; result += z.x * z.x + z.y * z.y;
} }
} else { } else {
#pragma unroll #pragma unroll
for (int i = 0; i < width; ++i) { for (int i = 0; i < width; ++i) {
float x = Converter::convert(data[i]); float x = Converter::convert(data[i]);
result += x * x; result += x * x;
@ -184,15 +197,13 @@ struct alignas(16) _f16Vec {
Additional optimizations we can make in this case are Additional optimizations we can make in this case are
packed and vectorized operations, which help with the packed and vectorized operations, which help with the
memory latency bottleneck. */ memory latency bottleneck. */
template<typename scalar_t, int width> template <typename scalar_t, int width>
__global__ std::enable_if_t< __global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
(width > 0) && _typeConvert<scalar_t>::exists> fused_add_rms_norm_kernel( fused_add_rms_norm_kernel(
scalar_t* __restrict__ input, // [..., hidden_size] scalar_t* __restrict__ input, // [..., hidden_size]
scalar_t* __restrict__ residual, // [..., hidden_size] scalar_t* __restrict__ residual, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size] const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const float epsilon, const int num_tokens, const int hidden_size) {
const int num_tokens,
const int hidden_size) {
// Sanity checks on our vector struct and type-punned pointer arithmetic // Sanity checks on our vector struct and type-punned pointer arithmetic
static_assert(std::is_pod_v<_f16Vec<scalar_t, width>>); static_assert(std::is_pod_v<_f16Vec<scalar_t, width>>);
static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width); static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width);
@ -203,9 +214,12 @@ __global__ std::enable_if_t<
/* These and the argument pointers are all declared `restrict` as they are /* These and the argument pointers are all declared `restrict` as they are
not aliased in practice. Argument pointers should not be dereferenced not aliased in practice. Argument pointers should not be dereferenced
in this kernel as that would be undefined behavior */ in this kernel as that would be undefined behavior */
auto* __restrict__ input_v = reinterpret_cast<_f16Vec<scalar_t, width>*>(input); auto* __restrict__ input_v =
auto* __restrict__ residual_v = reinterpret_cast<_f16Vec<scalar_t, width>*>(residual); reinterpret_cast<_f16Vec<scalar_t, width>*>(input);
auto* __restrict__ weight_v = reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight); auto* __restrict__ residual_v =
reinterpret_cast<_f16Vec<scalar_t, width>*>(residual);
auto* __restrict__ weight_v =
reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight);
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
int id = blockIdx.x * vec_hidden_size + idx; int id = blockIdx.x * vec_hidden_size + idx;
@ -218,7 +232,8 @@ __global__ std::enable_if_t<
calculation of max_block_size in fused_add_rms_norm */ calculation of max_block_size in fused_add_rms_norm */
if (num_tokens < 256) { if (num_tokens < 256) {
variance = blockReduceSum<float, 1024>(variance); variance = blockReduceSum<float, 1024>(variance);
} else variance = blockReduceSum<float, 256>(variance); } else
variance = blockReduceSum<float, 256>(variance);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon); s_variance = rsqrtf(variance / hidden_size + epsilon);
} }
@ -233,26 +248,23 @@ __global__ std::enable_if_t<
} }
} }
/* Generic fused_add_rms_norm_kernel /* Generic fused_add_rms_norm_kernel
The width field is not used here but necessary for other specializations. The width field is not used here but necessary for other specializations.
*/ */
template<typename scalar_t, int width> template <typename scalar_t, int width>
__global__ std::enable_if_t< __global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
(width == 0) || !_typeConvert<scalar_t>::exists> fused_add_rms_norm_kernel( fused_add_rms_norm_kernel(
scalar_t* __restrict__ input, // [..., hidden_size] scalar_t* __restrict__ input, // [..., hidden_size]
scalar_t* __restrict__ residual, // [..., hidden_size] scalar_t* __restrict__ residual, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size] const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const float epsilon, const int num_tokens, const int hidden_size) {
const int num_tokens,
const int hidden_size) {
__shared__ float s_variance; __shared__ float s_variance;
float variance = 0.0f; float variance = 0.0f;
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
scalar_t z = input[blockIdx.x * hidden_size + idx]; scalar_t z = input[blockIdx.x * hidden_size + idx];
z += residual[blockIdx.x * hidden_size + idx]; z += residual[blockIdx.x * hidden_size + idx];
float x = (float) z; float x = (float)z;
variance += x * x; variance += x * x;
residual[blockIdx.x * hidden_size + idx] = z; residual[blockIdx.x * hidden_size + idx] = z;
} }
@ -260,25 +272,26 @@ __global__ std::enable_if_t<
calculation of max_block_size in fused_add_rms_norm */ calculation of max_block_size in fused_add_rms_norm */
if (num_tokens < 256) { if (num_tokens < 256) {
variance = blockReduceSum<float, 1024>(variance); variance = blockReduceSum<float, 1024>(variance);
} else variance = blockReduceSum<float, 256>(variance); } else
variance = blockReduceSum<float, 256>(variance);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon); s_variance = rsqrtf(variance / hidden_size + epsilon);
} }
__syncthreads(); __syncthreads();
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float) residual[blockIdx.x * hidden_size + idx]; float x = (float)residual[blockIdx.x * hidden_size + idx];
input[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx]; input[blockIdx.x * hidden_size + idx] =
((scalar_t)(x * s_variance)) * weight[idx];
} }
} }
} // namespace vllm } // namespace vllm
void rms_norm( void rms_norm(torch::Tensor& out, // [..., hidden_size]
torch::Tensor& out, // [..., hidden_size] torch::Tensor& input, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size] torch::Tensor& weight, // [hidden_size]
torch::Tensor& weight, // [hidden_size] float epsilon) {
float epsilon) {
int hidden_size = input.size(-1); int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size; int num_tokens = input.numel() / hidden_size;
@ -286,40 +299,27 @@ void rms_norm(
dim3 block(std::min(hidden_size, 1024)); dim3 block(std::min(hidden_size, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
input.scalar_type(), vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
"rms_norm_kernel", out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
[&] { weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>( });
out.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
epsilon,
num_tokens,
hidden_size);
});
} }
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \ #define LAUNCH_FUSED_ADD_RMS_NORM(width) \
VLLM_DISPATCH_FLOATING_TYPES( \ VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), \ input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
"fused_add_rms_norm_kernel", \ vllm::fused_add_rms_norm_kernel<scalar_t, width> \
[&] { \ <<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(), \
vllm::fused_add_rms_norm_kernel \ residual.data_ptr<scalar_t>(), \
<scalar_t, width><<<grid, block, 0, stream>>>( \ weight.data_ptr<scalar_t>(), epsilon, \
input.data_ptr<scalar_t>(), \ num_tokens, hidden_size); \
residual.data_ptr<scalar_t>(), \ });
weight.data_ptr<scalar_t>(), \
epsilon, \
num_tokens, \
hidden_size); \
});
void fused_add_rms_norm( void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size] torch::Tensor& residual, // [..., hidden_size]
torch::Tensor& residual, // [..., hidden_size] torch::Tensor& weight, // [hidden_size]
torch::Tensor& weight, // [hidden_size] float epsilon) {
float epsilon) {
int hidden_size = input.size(-1); int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size; int num_tokens = input.numel() / hidden_size;
@ -342,8 +342,8 @@ void fused_add_rms_norm(
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr()); auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr()); auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr());
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr()); auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
bool ptrs_are_aligned = inp_ptr % 16 == 0 && res_ptr % 16 == 0 \ bool ptrs_are_aligned =
&& wt_ptr % 16 == 0; inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
if (ptrs_are_aligned && hidden_size % 8 == 0) { if (ptrs_are_aligned && hidden_size % 8 == 0) {
LAUNCH_FUSED_ADD_RMS_NORM(8); LAUNCH_FUSED_ADD_RMS_NORM(8);
} else { } else {

View File

@ -3,5 +3,6 @@
#include <torch/extension.h> #include <torch/extension.h>
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("topk_softmax", &topk_softmax, "Apply topk softmax to the gating outputs."); m.def("topk_softmax", &topk_softmax,
"Apply topk softmax to the gating outputs.");
} }

View File

@ -2,8 +2,6 @@
#include <torch/extension.h> #include <torch/extension.h>
void topk_softmax( void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices,
torch::Tensor& topk_weights, torch::Tensor& token_expert_indices,
torch::Tensor& topk_indices, torch::Tensor& gating_output);
torch::Tensor& token_expert_indices,
torch::Tensor& gating_output);

View File

@ -7,119 +7,128 @@
#include "cuda_compat.h" #include "cuda_compat.h"
#include "dispatch_utils.h" #include "dispatch_utils.h"
#define CEILDIV(x,y) (((x) + (y) - 1) / (y)) #define CEILDIV(x, y) (((x) + (y) - 1) / (y))
namespace vllm { namespace vllm {
namespace { namespace {
__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, int32_t col) { __device__ __forceinline__ int32_t index(int32_t total_col, int32_t row,
// don't worry about overflow because num_experts is relatively small int32_t col) {
return row * total_col + col; // don't worry about overflow because num_experts is relatively small
} return row * total_col + col;
} }
} // namespace
template <typename scalar_t> template <typename scalar_t>
__global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids, __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
int32_t *sorted_token_ids, int32_t* sorted_token_ids,
int32_t *expert_ids, int32_t* expert_ids,
int32_t *total_tokens_post_pad, int32_t* total_tokens_post_pad,
int32_t num_experts, int32_t num_experts,
int32_t block_size, int32_t block_size, size_t numel) {
size_t numel) { const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); const size_t start_idx = threadIdx.x * tokens_per_thread;
const size_t start_idx = threadIdx.x * tokens_per_thread;
extern __shared__ int32_t shared_mem[]; extern __shared__ int32_t shared_mem[];
int32_t* tokens_cnts = shared_mem; // 2d tensor with shape (num_experts + 1, num_experts) int32_t* tokens_cnts =
int32_t* cumsum = shared_mem + (num_experts + 1) * num_experts; // 1d tensor with shape (num_experts + 1) shared_mem; // 2d tensor with shape (num_experts + 1, num_experts)
int32_t* cumsum =
shared_mem + (num_experts + 1) *
num_experts; // 1d tensor with shape (num_experts + 1)
for (int i = 0; i < num_experts; ++i) { for (int i = 0; i < num_experts; ++i) {
tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0; tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
}
/**
* In the first step we compute token_cnts[thread_index + 1][expert_index],
* which counts how many tokens in the token shard of thread_index are
* assigned to expert expert_index.
*/
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])];
}
__syncthreads();
// For each expert we accumulate the token counts from the different threads.
tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
for (int i = 1; i <= blockDim.x; ++i) {
tokens_cnts[index(num_experts, i, threadIdx.x)] +=
tokens_cnts[index(num_experts, i - 1, threadIdx.x)];
}
__syncthreads();
// We accumulate the token counts of all experts in thread 0.
if (threadIdx.x == 0) {
cumsum[0] = 0;
for (int i = 1; i <= num_experts; ++i) {
cumsum[i] = cumsum[i - 1] +
CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)],
block_size) *
block_size;
} }
*total_tokens_post_pad = cumsum[num_experts];
}
/** __syncthreads();
* In the first step we compute token_cnts[thread_index + 1][expert_index],
* which counts how many tokens in the token shard of thread_index are assigned
* to expert expert_index.
*/
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])];
}
__syncthreads(); /**
* For each expert, each thread processes the tokens of the corresponding
* blocks and stores the corresponding expert_id for each block.
*/
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
i += block_size) {
expert_ids[i / block_size] = threadIdx.x;
}
// For each expert we accumulate the token counts from the different threads. /**
tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0; * Each thread processes a token shard, calculating the index of each token
for (int i = 1; i <= blockDim.x; ++i) { * after sorting by expert number. Given the example topk_ids =
tokens_cnts[index(num_experts, i, threadIdx.x)] += tokens_cnts[index(num_experts, i-1, threadIdx.x)]; * [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *,
} * *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a
* padding value(preset in python).
__syncthreads(); */
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
// We accumulate the token counts of all experts in thread 0. int32_t expert_id = topk_ids[i];
if (threadIdx.x == 0) { /** The cumsum[expert_id] stores the starting index of the tokens that the
cumsum[0] = 0; * expert with expert_id needs to process, and
for (int i = 1; i <= num_experts; ++i) { * tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens
cumsum[i] = cumsum[i-1] + CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], block_size) * block_size; * processed by the expert with expert_id within the current thread's token
} * shard.
*total_tokens_post_pad = cumsum[num_experts]; */
} int32_t rank_post_pad =
tokens_cnts[index(num_experts, threadIdx.x, expert_id)] +
__syncthreads(); cumsum[expert_id];
sorted_token_ids[rank_post_pad] = i;
/** ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)];
* For each expert, each thread processes the tokens of the corresponding blocks }
* and stores the corresponding expert_id for each block.
*/
for (int i = cumsum[threadIdx.x];i < cumsum[threadIdx.x + 1];i += block_size) {
expert_ids[i / block_size] = threadIdx.x;
}
/**
* Each thread processes a token shard, calculating the index of each token after
* sorting by expert number. Given the example topk_ids = [0,1,2,1,2,3,0,3,4] and
* block_size = 4, then the output would be [0, 6, *, *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *],
* where * represents a padding value(preset in python).
*/
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
int32_t expert_id = topk_ids[i];
/** The cumsum[expert_id] stores the starting index of the tokens that the
* expert with expert_id needs to process, and tokens_cnts[threadIdx.x][expert_id]
* stores the indices of the tokens processed by the expert with expert_id within
* the current thread's token shard.
*/
int32_t rank_post_pad = tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + cumsum[expert_id];
sorted_token_ids[rank_post_pad] = i;
++tokens_cnts[index(num_experts, threadIdx.x, expert_id)];
}
}
} }
} // namespace vllm
void moe_align_block_size( void moe_align_block_size(torch::Tensor topk_ids, int num_experts,
torch::Tensor topk_ids, int block_size, torch::Tensor sorted_token_ids,
int num_experts, torch::Tensor experts_ids,
int block_size, torch::Tensor num_tokens_post_pad) {
torch::Tensor sorted_token_ids, const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
torch::Tensor experts_ids, VLLM_DISPATCH_INTEGRAL_TYPES(
torch::Tensor num_tokens_post_pad) { topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // calc needed amount of shared mem for `tokens_cnts` and `cumsum`
VLLM_DISPATCH_INTEGRAL_TYPES( // tensors
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { const int32_t shared_mem =
// calc needed amount of shared mem for `tokens_cnts` and `cumsum` tensors ((num_experts + 1) * num_experts + (num_experts + 1)) *
const int32_t shared_mem = ((num_experts + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t); sizeof(int32_t);
// set dynamic shared mem // set dynamic shared mem
auto kernel = vllm::moe_align_block_size_kernel<scalar_t>; auto kernel = vllm::moe_align_block_size_kernel<scalar_t>;
AT_CUDA_CHECK( AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize((void *)kernel, shared_mem)); (void*)kernel, shared_mem));
kernel<<<1, num_experts, shared_mem, stream>>>( kernel<<<1, num_experts, shared_mem, stream>>>(
topk_ids.data_ptr<scalar_t>(), topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
sorted_token_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(), experts_ids.data_ptr<int32_t>(),
num_tokens_post_pad.data_ptr<int32_t>(), num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
num_experts,
block_size,
topk_ids.numel()); topk_ids.numel());
}); });
} }

View File

@ -2,224 +2,136 @@
#include <torch/extension.h> #include <torch/extension.h>
void paged_attention_v1( void paged_attention_v1(torch::Tensor& out, torch::Tensor& query,
torch::Tensor& out, torch::Tensor& key_cache, torch::Tensor& value_cache,
torch::Tensor& query, int num_kv_heads, float scale,
torch::Tensor& key_cache, torch::Tensor& block_tables, torch::Tensor& seq_lens,
torch::Tensor& value_cache, int block_size, int max_seq_len,
int num_kv_heads, const c10::optional<torch::Tensor>& alibi_slopes,
float scale, const std::string& kv_cache_dtype, float kv_scale);
torch::Tensor& block_tables,
torch::Tensor& seq_lens,
int block_size,
int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype,
float kv_scale);
void paged_attention_v2( void paged_attention_v2(torch::Tensor& out, torch::Tensor& exp_sums,
torch::Tensor& out, torch::Tensor& max_logits, torch::Tensor& tmp_out,
torch::Tensor& exp_sums, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& max_logits, torch::Tensor& value_cache, int num_kv_heads,
torch::Tensor& tmp_out, float scale, torch::Tensor& block_tables,
torch::Tensor& query, torch::Tensor& seq_lens, int block_size,
torch::Tensor& key_cache, int max_seq_len,
torch::Tensor& value_cache, const c10::optional<torch::Tensor>& alibi_slopes,
int num_kv_heads, const std::string& kv_cache_dtype, float kv_scale);
float scale,
torch::Tensor& block_tables,
torch::Tensor& seq_lens,
int block_size,
int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype,
float kv_scale);
void rms_norm( void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
torch::Tensor& out, float epsilon);
torch::Tensor& input,
torch::Tensor& weight,
float epsilon);
void fused_add_rms_norm( void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
torch::Tensor& input, torch::Tensor& weight, float epsilon);
torch::Tensor& residual,
torch::Tensor& weight,
float epsilon);
void rotary_embedding( void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
torch::Tensor& positions, torch::Tensor& key, int head_size,
torch::Tensor& query, torch::Tensor& cos_sin_cache, bool is_neox);
torch::Tensor& key,
int head_size,
torch::Tensor& cos_sin_cache,
bool is_neox);
void batched_rotary_embedding( void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
torch::Tensor& positions, torch::Tensor& key, int head_size,
torch::Tensor& query, torch::Tensor& cos_sin_cache, bool is_neox,
torch::Tensor& key, int rot_dim,
int head_size, torch::Tensor& cos_sin_cache_offsets);
torch::Tensor& cos_sin_cache,
bool is_neox,
int rot_dim,
torch::Tensor& cos_sin_cache_offsets);
void silu_and_mul( void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
torch::Tensor& out,
torch::Tensor& input);
void gelu_and_mul( void gelu_and_mul(torch::Tensor& out, torch::Tensor& input);
torch::Tensor& out,
torch::Tensor& input);
void gelu_tanh_and_mul( void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);
torch::Tensor& out,
torch::Tensor& input);
void gelu_new( void gelu_new(torch::Tensor& out, torch::Tensor& input);
torch::Tensor& out,
torch::Tensor& input);
void gelu_fast( void gelu_fast(torch::Tensor& out, torch::Tensor& input);
torch::Tensor& out,
torch::Tensor& input);
#ifndef USE_ROCM #ifndef USE_ROCM
torch::Tensor aqlm_gemm( torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
const torch::Tensor& input, const torch::Tensor& codebooks,
const torch::Tensor& codes, const torch::Tensor& scales,
const torch::Tensor& codebooks, const torch::Tensor& codebook_partition_sizes,
const torch::Tensor& scales, const std::optional<torch::Tensor>& bias);
const torch::Tensor& codebook_partition_sizes,
const std::optional<torch::Tensor>& bias
);
torch::Tensor aqlm_dequant( torch::Tensor aqlm_dequant(const torch::Tensor& codes,
const torch::Tensor& codes, const torch::Tensor& codebooks,
const torch::Tensor& codebooks, const torch::Tensor& codebook_partition_sizes);
const torch::Tensor& codebook_partition_sizes
);
torch::Tensor awq_gemm( torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
torch::Tensor _in_feats, torch::Tensor _scaling_factors, torch::Tensor _zeros,
torch::Tensor _kernel, int split_k_iters);
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
int split_k_iters);
torch::Tensor awq_dequantize( torch::Tensor awq_dequantize(torch::Tensor _kernel,
torch::Tensor _kernel, torch::Tensor _scaling_factors,
torch::Tensor _scaling_factors, torch::Tensor _zeros, int split_k_iters, int thx,
torch::Tensor _zeros, int thy);
int split_k_iters,
int thx,
int thy);
torch::Tensor marlin_gemm( torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& a, torch::Tensor& b_scales, torch::Tensor& workspace,
torch::Tensor& b_q_weight, int64_t size_m, int64_t size_n, int64_t size_k);
torch::Tensor& b_scales,
torch::Tensor& workspace,
int64_t size_m,
int64_t size_n,
int64_t size_k);
torch::Tensor gptq_marlin_24_gemm( torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor &a, torch::Tensor& b_meta,
torch::Tensor &b_q_weight, torch::Tensor& b_scales,
torch::Tensor &b_meta, torch::Tensor& workspace, int64_t num_bits,
torch::Tensor &b_scales, int64_t size_m, int64_t size_n,
torch::Tensor &workspace, int64_t size_k);
int64_t num_bits,
int64_t size_m,
int64_t size_n,
int64_t size_k);
torch::Tensor gptq_marlin_gemm( torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor &a, torch::Tensor& b_scales, torch::Tensor& g_idx,
torch::Tensor &b_q_weight, torch::Tensor& perm, torch::Tensor& workspace,
torch::Tensor &b_scales, int64_t num_bits, int64_t size_m, int64_t size_n,
torch::Tensor &g_idx, int64_t size_k, bool is_k_full);
torch::Tensor &perm,
torch::Tensor &workspace,
int64_t num_bits,
int64_t size_m,
int64_t size_n,
int64_t size_k,
bool is_k_full);
torch::Tensor gptq_marlin_repack( torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
torch::Tensor &b_q_weight, int64_t size_k, int64_t size_n,
torch::Tensor &perm, int64_t num_bits);
int64_t size_k,
int64_t size_n,
int64_t num_bits);
int cutlass_scaled_mm_dq( int cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor& out, torch::Tensor const& b, torch::Tensor const& a_scales,
torch::Tensor const &a, torch::Tensor const& b_scales);
torch::Tensor const &b,
torch::Tensor const &a_scales,
torch::Tensor const &b_scales);
#endif #endif
void squeezellm_gemm( void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor vec, torch::Tensor lookup_table);
torch::Tensor mat,
torch::Tensor mul,
torch::Tensor lookup_table);
torch::Tensor gptq_gemm( torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
torch::Tensor a, torch::Tensor b_gptq_qzeros,
torch::Tensor b_q_weight, torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
torch::Tensor b_gptq_qzeros, bool use_exllama, int bit);
torch::Tensor b_gptq_scales,
torch::Tensor b_g_idx,
bool use_exllama,
int bit);
void gptq_shuffle( void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int bit);
torch::Tensor q_weight,
torch::Tensor q_perm,
int bit);
void static_scaled_fp8_quant( void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
torch::Tensor& out, torch::Tensor& scale);
torch::Tensor& input,
torch::Tensor& scale);
void dynamic_scaled_fp8_quant( void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
torch::Tensor& out, torch::Tensor& scale);
torch::Tensor& input,
torch::Tensor& scale);
void moe_align_block_size( void moe_align_block_size(torch::Tensor topk_ids, int num_experts,
torch::Tensor topk_ids, int block_size, torch::Tensor sorted_token_ids,
int num_experts, torch::Tensor experts_ids,
int block_size, torch::Tensor num_tokens_post_pad);
torch::Tensor sorted_token_ids,
torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad);
#ifndef USE_ROCM #ifndef USE_ROCM
using fptr_t = uint64_t; using fptr_t = uint64_t;
fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data, fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
const std::vector<std::string> &handles, const std::vector<std::string>& handles,
const std::vector<int64_t> &offsets, int rank, const std::vector<int64_t>& offsets, int rank,
bool full_nvlink);
bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size,
bool full_nvlink); bool full_nvlink);
void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out); bool should_custom_ar(torch::Tensor& inp, int max_size, int world_size,
void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &reg_buffer, bool full_nvlink);
torch::Tensor &out); void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
torch::Tensor& out);
void dispose(fptr_t _fa); void dispose(fptr_t _fa);
int meta_size(); int meta_size();
void register_buffer(fptr_t _fa, torch::Tensor &t, void register_buffer(fptr_t _fa, torch::Tensor& t,
const std::vector<std::string> &handles, const std::vector<std::string>& handles,
const std::vector<int64_t> &offsets); const std::vector<int64_t>& offsets);
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa); std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(
void register_graph_buffers(fptr_t _fa, const std::vector<std::string> &handles, fptr_t _fa);
const std::vector<std::vector<int64_t>> &offsets); void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
const std::vector<std::vector<int64_t>>& offsets);
#endif #endif

View File

@ -7,14 +7,10 @@
namespace vllm { namespace vllm {
template<typename scalar_t, bool IS_NEOX> template <typename scalar_t, bool IS_NEOX>
inline __device__ void apply_token_rotary_embedding( inline __device__ void apply_token_rotary_embedding(
scalar_t* __restrict__ arr, scalar_t* __restrict__ arr, const scalar_t* __restrict__ cos_ptr,
const scalar_t* __restrict__ cos_ptr, const scalar_t* __restrict__ sin_ptr, int rot_offset, int embed_dim) {
const scalar_t* __restrict__ sin_ptr,
int rot_offset,
int embed_dim)
{
int x_index, y_index; int x_index, y_index;
scalar_t cos, sin; scalar_t cos, sin;
if (IS_NEOX) { if (IS_NEOX) {
@ -37,19 +33,17 @@ inline __device__ void apply_token_rotary_embedding(
arr[y_index] = y * cos + x * sin; arr[y_index] = y * cos + x * sin;
} }
template<typename scalar_t, bool IS_NEOX> template <typename scalar_t, bool IS_NEOX>
inline __device__ void apply_rotary_embedding( inline __device__ void apply_rotary_embedding(
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] // head_size] or [num_tokens, num_heads,
const scalar_t* cache_ptr, // head_size]
const int head_size, scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
const int num_heads, // head_size] or [num_tokens, num_kv_heads,
const int num_kv_heads, // head_size]
const int rot_dim, const scalar_t* cache_ptr, const int head_size, const int num_heads,
const int token_idx, const int num_kv_heads, const int rot_dim, const int token_idx,
const int64_t query_stride, const int64_t query_stride, const int64_t key_stride) {
const int64_t key_stride)
{
const int embed_dim = rot_dim / 2; const int embed_dim = rot_dim / 2;
const scalar_t* cos_ptr = cache_ptr; const scalar_t* cos_ptr = cache_ptr;
const scalar_t* sin_ptr = cache_ptr + embed_dim; const scalar_t* sin_ptr = cache_ptr + embed_dim;
@ -59,8 +53,8 @@ inline __device__ void apply_rotary_embedding(
const int head_idx = i / embed_dim; const int head_idx = i / embed_dim;
const int64_t token_head = token_idx * query_stride + head_idx * head_size; const int64_t token_head = token_idx * query_stride + head_idx * head_size;
const int rot_offset = i % embed_dim; const int rot_offset = i % embed_dim;
apply_token_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr, apply_token_rotary_embedding<scalar_t, IS_NEOX>(
sin_ptr, rot_offset, embed_dim); query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
} }
const int nk = num_kv_heads * embed_dim; const int nk = num_kv_heads * embed_dim;
@ -68,62 +62,74 @@ inline __device__ void apply_rotary_embedding(
const int head_idx = i / embed_dim; const int head_idx = i / embed_dim;
const int64_t token_head = token_idx * key_stride + head_idx * head_size; const int64_t token_head = token_idx * key_stride + head_idx * head_size;
const int rot_offset = i % embed_dim; const int rot_offset = i % embed_dim;
apply_token_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr, apply_token_rotary_embedding<scalar_t, IS_NEOX>(
sin_ptr, rot_offset, embed_dim); key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
} }
} }
template<typename scalar_t, bool IS_NEOX> template <typename scalar_t, bool IS_NEOX>
__global__ void rotary_embedding_kernel( __global__ void rotary_embedding_kernel(
const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] const int64_t* __restrict__ positions, // [batch_size, seq_len] or
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] // [num_tokens]
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] // head_size] or [num_tokens, num_heads,
const int rot_dim, // head_size]
const int64_t query_stride, scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
const int64_t key_stride, // head_size] or [num_tokens, num_kv_heads,
const int num_heads, // head_size]
const int num_kv_heads, const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
const int head_size) { // 2]
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
const int num_heads, const int num_kv_heads, const int head_size) {
// Each thread block is responsible for one token. // Each thread block is responsible for one token.
const int token_idx = blockIdx.x; const int token_idx = blockIdx.x;
int64_t pos = positions[token_idx]; int64_t pos = positions[token_idx];
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
apply_rotary_embedding<scalar_t, IS_NEOX>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride); apply_rotary_embedding<scalar_t, IS_NEOX>(
query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
token_idx, query_stride, key_stride);
} }
template<typename scalar_t, bool IS_NEOX> template <typename scalar_t, bool IS_NEOX>
__global__ void batched_rotary_embedding_kernel( __global__ void batched_rotary_embedding_kernel(
const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] const int64_t* __restrict__ positions, // [batch_size, seq_len] or
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] // [num_tokens]
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] // head_size] or [num_tokens, num_heads,
const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len] or [num_tokens] // head_size]
const int rot_dim, scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
const int64_t query_stride, // head_size] or [num_tokens, num_kv_heads,
const int64_t key_stride, // head_size]
const int num_heads, const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
const int num_kv_heads, // 2]
const int head_size) { const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len]
// or [num_tokens]
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
const int num_heads, const int num_kv_heads, const int head_size) {
// Each thread block is responsible for one token. // Each thread block is responsible for one token.
const int token_idx = blockIdx.x; const int token_idx = blockIdx.x;
int64_t pos = positions[token_idx]; int64_t pos = positions[token_idx];
int64_t cos_sin_cache_offset = cos_sin_cache_offsets[token_idx]; int64_t cos_sin_cache_offset = cos_sin_cache_offsets[token_idx];
const scalar_t* cache_ptr = cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim; const scalar_t* cache_ptr =
cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim;
apply_rotary_embedding<scalar_t, IS_NEOX>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride); apply_rotary_embedding<scalar_t, IS_NEOX>(
query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
token_idx, query_stride, key_stride);
} }
} // namespace vllm } // namespace vllm
void rotary_embedding( void rotary_embedding(
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size] torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size] // [num_tokens, num_heads * head_size]
int head_size, torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
torch::Tensor& cos_sin_cache, // [max_position, rot_dim] // [num_tokens, num_kv_heads * head_size]
bool is_neox) { int head_size,
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
bool is_neox) {
int64_t num_tokens = query.numel() / query.size(-1); int64_t num_tokens = query.numel() / query.size(-1);
int rot_dim = cos_sin_cache.size(1); int rot_dim = cos_sin_cache.size(1);
int num_heads = query.size(-1) / head_size; int num_heads = query.size(-1) / head_size;
@ -135,36 +141,21 @@ void rotary_embedding(
dim3 block(std::min(num_heads * rot_dim / 2, 512)); dim3 block(std::min(num_heads * rot_dim / 2, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] {
query.scalar_type(), if (is_neox) {
"rotary_embedding", vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
[&] { positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
if (is_neox) { key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(), rot_dim,
vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>( query_stride, key_stride, num_heads, num_kv_heads, head_size);
positions.data_ptr<int64_t>(), } else {
query.data_ptr<scalar_t>(), vllm::rotary_embedding_kernel<scalar_t, false>
key.data_ptr<scalar_t>(), <<<grid, block, 0, stream>>>(
cos_sin_cache.data_ptr<scalar_t>(), positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
rot_dim, key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
query_stride, rot_dim, query_stride, key_stride, num_heads, num_kv_heads,
key_stride, head_size);
num_heads, }
num_kv_heads, });
head_size);
} else {
vllm::rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(),
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos_sin_cache.data_ptr<scalar_t>(),
rot_dim,
query_stride,
key_stride,
num_heads,
num_kv_heads,
head_size);
}
});
} }
/* /*
@ -172,14 +163,15 @@ Batched version of rotary embedding, pack multiple LoRAs together
and process in batched manner. and process in batched manner.
*/ */
void batched_rotary_embedding( void batched_rotary_embedding(
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size] torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size] // [num_tokens, num_heads * head_size]
int head_size, torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
torch::Tensor& cos_sin_cache, // [max_position, rot_dim] // [num_tokens, num_kv_heads * head_size]
bool is_neox, int head_size,
int rot_dim, torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
torch::Tensor& cos_sin_cache_offsets // [num_tokens] bool is_neox, int rot_dim,
torch::Tensor& cos_sin_cache_offsets // [num_tokens]
) { ) {
int64_t num_tokens = cos_sin_cache_offsets.size(0); int64_t num_tokens = cos_sin_cache_offsets.size(0);
int num_heads = query.size(-1) / head_size; int num_heads = query.size(-1) / head_size;
@ -191,36 +183,21 @@ void batched_rotary_embedding(
dim3 block(std::min(num_heads * rot_dim / 2, 512)); dim3 block(std::min(num_heads * rot_dim / 2, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] {
query.scalar_type(), if (is_neox) {
"rotary_embedding", vllm::batched_rotary_embedding_kernel<scalar_t, true>
[&] { <<<grid, block, 0, stream>>>(
if (is_neox) { positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
vllm::batched_rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>( key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
positions.data_ptr<int64_t>(), cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
query.data_ptr<scalar_t>(), key_stride, num_heads, num_kv_heads, head_size);
key.data_ptr<scalar_t>(), } else {
cos_sin_cache.data_ptr<scalar_t>(), vllm::batched_rotary_embedding_kernel<scalar_t, false>
cos_sin_cache_offsets.data_ptr<int64_t>(), <<<grid, block, 0, stream>>>(
rot_dim, positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
query_stride, key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
key_stride, cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
num_heads, key_stride, num_heads, num_kv_heads, head_size);
num_kv_heads, }
head_size); });
} else {
vllm::batched_rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(),
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos_sin_cache.data_ptr<scalar_t>(),
cos_sin_cache_offsets.data_ptr<int64_t>(),
rot_dim,
query_stride,
key_stride,
num_heads,
num_kv_heads,
head_size);
}
});
} }

View File

@ -8,116 +8,87 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
pybind11::module ops = m.def_submodule("ops", "vLLM custom operators"); pybind11::module ops = m.def_submodule("ops", "vLLM custom operators");
// Attention ops // Attention ops
ops.def( ops.def("paged_attention_v1", &paged_attention_v1,
"paged_attention_v1", "Compute the attention between an input query and the cached "
&paged_attention_v1, "keys/values using PagedAttention.");
"Compute the attention between an input query and the cached keys/values using PagedAttention."); ops.def("paged_attention_v2", &paged_attention_v2, "PagedAttention V2.");
ops.def(
"paged_attention_v2",
&paged_attention_v2,
"PagedAttention V2.");
// Activation ops // Activation ops
ops.def( ops.def("silu_and_mul", &silu_and_mul, "Activation function used in SwiGLU.");
"silu_and_mul", ops.def("gelu_and_mul", &gelu_and_mul,
&silu_and_mul, "Activation function used in GeGLU with `none` approximation.");
"Activation function used in SwiGLU."); ops.def("gelu_tanh_and_mul", &gelu_tanh_and_mul,
ops.def( "Activation function used in GeGLU with `tanh` approximation.");
"gelu_and_mul", ops.def("gelu_new", &gelu_new, "GELU implementation used in GPT-2.");
&gelu_and_mul, ops.def("gelu_fast", &gelu_fast, "Approximate GELU implementation.");
"Activation function used in GeGLU with `none` approximation.");
ops.def(
"gelu_tanh_and_mul",
&gelu_tanh_and_mul,
"Activation function used in GeGLU with `tanh` approximation.");
ops.def(
"gelu_new",
&gelu_new,
"GELU implementation used in GPT-2.");
ops.def(
"gelu_fast",
&gelu_fast,
"Approximate GELU implementation.");
// Layernorm // Layernorm
ops.def( ops.def("rms_norm", &rms_norm,
"rms_norm", "Apply Root Mean Square (RMS) Normalization to the input tensor.");
&rms_norm,
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
ops.def( ops.def("fused_add_rms_norm", &fused_add_rms_norm,
"fused_add_rms_norm", "In-place fused Add and RMS Normalization");
&fused_add_rms_norm,
"In-place fused Add and RMS Normalization");
// Rotary embedding // Rotary embedding
ops.def( ops.def("rotary_embedding", &rotary_embedding,
"rotary_embedding", "Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
&rotary_embedding,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
ops.def( ops.def("batched_rotary_embedding", &batched_rotary_embedding,
"batched_rotary_embedding", "Apply GPT-NeoX or GPT-J style rotary embedding to query and key "
&batched_rotary_embedding, "(supports multiple loras)");
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key (supports multiple loras)");
// Quantization ops // Quantization ops
#ifndef USE_ROCM #ifndef USE_ROCM
ops.def("aqlm_gemm", &aqlm_gemm, "Quantized GEMM for AQLM"); ops.def("aqlm_gemm", &aqlm_gemm, "Quantized GEMM for AQLM");
ops.def("aqlm_dequant", &aqlm_dequant, "Decompression method for AQLM"); ops.def("aqlm_dequant", &aqlm_dequant, "Decompression method for AQLM");
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ"); ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
ops.def("marlin_gemm", &marlin_gemm, "Marlin (Dense) Optimized Quantized GEMM for GPTQ"); ops.def("marlin_gemm", &marlin_gemm,
ops.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm, "Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ"); "Marlin (Dense) Optimized Quantized GEMM for GPTQ");
ops.def("gptq_marlin_gemm", &gptq_marlin_gemm, "gptq_marlin Optimized Quantized GEMM for GPTQ"); ops.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm,
ops.def("gptq_marlin_repack", &gptq_marlin_repack, "gptq_marlin repack from GPTQ"); "Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ");
ops.def("gptq_marlin_gemm", &gptq_marlin_gemm,
"gptq_marlin Optimized Quantized GEMM for GPTQ");
ops.def("gptq_marlin_repack", &gptq_marlin_repack,
"gptq_marlin repack from GPTQ");
ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ"); ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
ops.def("cutlass_scaled_mm_dq", &cutlass_scaled_mm_dq, "CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column quantization."); ops.def("cutlass_scaled_mm_dq", &cutlass_scaled_mm_dq,
"CUTLASS w8a8 GEMM, supporting symmetric per-tensor or "
"per-row/column quantization.");
#endif #endif
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ"); ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ"); ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM"); ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
ops.def("static_scaled_fp8_quant", &static_scaled_fp8_quant, "Compute FP8 quantized tensor for given scaling factor"); ops.def("static_scaled_fp8_quant", &static_scaled_fp8_quant,
ops.def("dynamic_scaled_fp8_quant", &dynamic_scaled_fp8_quant, "Compute FP8 quantized tensor and scaling factor"); "Compute FP8 quantized tensor for given scaling factor");
ops.def( ops.def("dynamic_scaled_fp8_quant", &dynamic_scaled_fp8_quant,
"moe_align_block_size", "Compute FP8 quantized tensor and scaling factor");
&moe_align_block_size, ops.def("moe_align_block_size", &moe_align_block_size,
"Aligning the number of tokens to be processed by each expert such that it is divisible by the block size."); "Aligning the number of tokens to be processed by each expert such "
"that it is divisible by the block size.");
// Cache ops // Cache ops
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
cache_ops.def( cache_ops.def("swap_blocks", &swap_blocks,
"swap_blocks", "Swap in (out) the cache blocks from src to dst");
&swap_blocks, cache_ops.def("copy_blocks", &copy_blocks,
"Swap in (out) the cache blocks from src to dst"); "Copy the cache blocks from src to dst");
cache_ops.def( cache_ops.def("reshape_and_cache", &reshape_and_cache,
"copy_blocks", "Reshape the key and value tensors and cache them");
&copy_blocks, cache_ops.def("reshape_and_cache_flash", &reshape_and_cache_flash,
"Copy the cache blocks from src to dst"); "Reshape the key and value tensors and cache them");
cache_ops.def( cache_ops.def("convert_fp8", &convert_fp8,
"reshape_and_cache", "Convert the key and value cache to fp8 data type");
&reshape_and_cache,
"Reshape the key and value tensors and cache them");
cache_ops.def(
"reshape_and_cache_flash",
&reshape_and_cache_flash,
"Reshape the key and value tensors and cache them");
cache_ops.def(
"convert_fp8",
&convert_fp8,
"Convert the key and value cache to fp8 data type");
// Cuda utils // Cuda utils
pybind11::module cuda_utils = m.def_submodule("cuda_utils", "vLLM cuda utils"); pybind11::module cuda_utils =
cuda_utils.def( m.def_submodule("cuda_utils", "vLLM cuda utils");
"get_device_attribute", cuda_utils.def("get_device_attribute", &get_device_attribute,
&get_device_attribute, "Gets the specified device attribute.");
"Gets the specified device attribute.");
cuda_utils.def( cuda_utils.def("get_max_shared_memory_per_block_device_attribute",
"get_max_shared_memory_per_block_device_attribute", &get_max_shared_memory_per_block_device_attribute,
&get_max_shared_memory_per_block_device_attribute, "Gets the maximum shared memory per block device attribute.");
"Gets the maximum shared memory per block device attribute.");
#ifndef USE_ROCM #ifndef USE_ROCM
// Custom all-reduce kernels // Custom all-reduce kernels
@ -134,5 +105,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
custom_ar.def("register_graph_buffers", &register_graph_buffers, custom_ar.def("register_graph_buffers", &register_graph_buffers,
"register_graph_buffers"); "register_graph_buffers");
#endif #endif
} }

View File

@ -25,32 +25,28 @@
#include <iostream> #include <iostream>
#include <cstdlib> #include <cstdlib>
namespace vllm { namespace vllm {
namespace aqlm { namespace aqlm {
__global__ void Code1x16MatVec( __global__ void Code1x16MatVec(
const int4* __restrict__ A, const int4* __restrict__ A, const int4* __restrict__ B,
const int4* __restrict__ B, int4* __restrict__ C, const int4* __restrict__ codebook, const int prob_m,
int4* __restrict__ C, const int prob_k,
const int4* __restrict__ codebook, const int4 codebook_a_sizes, // cumulative sizes of A spanning each
const int prob_m, // codebook, at most 3 long.
const int prob_k, const int codebook_stride // as int4.
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long.
const int codebook_stride // as int4.
) { ) {
int a_gl_stride = prob_k / 8 / 8; int a_gl_stride = prob_k / 8 / 8;
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
bool pred = a_gl_rd < prob_m; bool pred = a_gl_rd < prob_m;
if (pred) if (pred) {
{ // advance to the correct codebook, this easy because we only multiply one
// advance to the correct codebook, this easy because we only multiply one column of the codebook. // column of the codebook.
auto codebook_size = &codebook_a_sizes.x; auto codebook_size = &codebook_a_sizes.x;
while (a_gl_rd >= *codebook_size) while (a_gl_rd >= *codebook_size) {
{ codebook += codebook_stride;
codebook += codebook_stride; ++codebook_size;
++codebook_size;
} }
} }
@ -67,8 +63,7 @@ __global__ void Code1x16MatVec(
// We pad shared memory to avoid bank conflicts during reads // We pad shared memory to avoid bank conflicts during reads
__syncthreads(); __syncthreads();
for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) { for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) {
if (b_gl_rd + i < prob_k / 8) if (b_gl_rd + i < prob_k / 8) sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i];
sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i];
} }
__syncthreads(); __syncthreads();
b_gl_rd += 32 * 8; b_gl_rd += 32 * 8;
@ -76,22 +71,19 @@ __global__ void Code1x16MatVec(
int b_sh_rd = 9 * (threadIdx.x % 32); int b_sh_rd = 9 * (threadIdx.x % 32);
if (pred && a_gl_rd < a_gl_end) { if (pred && a_gl_rd < a_gl_end) {
const uint16_t* enc = reinterpret_cast<const uint16_t*>(&A[a_gl_rd]); const uint16_t* enc = reinterpret_cast<const uint16_t*>(&A[a_gl_rd]);
#pragma unroll #pragma unroll
for (int i = 0; i < 8; i++) { for (int i = 0; i < 8; i++) {
uint32_t dec[4]; uint32_t dec[4];
// We bypass the L1 cache to avoid massive amounts of memory streaming that doesn't // We bypass the L1 cache to avoid massive amounts of memory streaming
// actually help us; this brings > 2x speedup. // that doesn't actually help us; this brings > 2x speedup.
asm volatile ( asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
"ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];" : "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3])
: "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3]) : "l"((void*)&codebook[enc[i]]));
: "l"((void*) &codebook[enc[i]])
);
half2* a = reinterpret_cast<half2*>(&dec); half2* a = reinterpret_cast<half2*>(&dec);
half2* b = reinterpret_cast<half2*>(&sh_b[b_sh_rd]); half2* b = reinterpret_cast<half2*>(&sh_b[b_sh_rd]);
half2 res2 = {}; half2 res2 = {};
#pragma unroll #pragma unroll
for (int j = 0; j < 4; j++) for (int j = 0; j < 4; j++) res2 = __hfma2(a[j], b[j], res2);
res2 = __hfma2(a[j], b[j], res2);
res += __half2float(res2.x) + __half2float(res2.y); res += __half2float(res2.x) + __half2float(res2.y);
b_sh_rd++; b_sh_rd++;
} }
@ -100,37 +92,33 @@ __global__ void Code1x16MatVec(
} }
if (pred) { if (pred) {
#pragma unroll #pragma unroll
for (int i = 16; i > 0; i /= 2) for (int i = 16; i > 0; i /= 2) res += __shfl_down_sync(0xffffffff, res, i);
res += __shfl_down_sync(0xffffffff, res, i);
if (threadIdx.x % 32 == 0) if (threadIdx.x % 32 == 0)
reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res); reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res);
} }
} }
__global__ void Code2x8MatVec( __global__ void Code2x8MatVec(
const int4* __restrict__ A, const int4* __restrict__ A, const int4* __restrict__ B,
const int4* __restrict__ B, int4* __restrict__ C, const int4* __restrict__ codebook, int prob_m,
int4* __restrict__ C, int prob_k,
const int4* __restrict__ codebook, const int4 codebook_a_sizes, // cumulative sizes of A spanning each
int prob_m, // codebook, at most 3 long.
int prob_k, const int codebook_stride // as int4.
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long.
const int codebook_stride // as int4.
) { ) {
int a_gl_stride = prob_k / 8 / 8; int a_gl_stride = prob_k / 8 / 8;
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
bool pred = a_gl_rd < prob_m; bool pred = a_gl_rd < prob_m;
if (pred) if (pred) {
{ // advance to the correct codebook, this easy because we only multiply one
// advance to the correct codebook, this easy because we only multiply one column of the codebook. // column of the codebook.
auto codebook_size = &codebook_a_sizes.x; auto codebook_size = &codebook_a_sizes.x;
while (a_gl_rd >= *codebook_size) while (a_gl_rd >= *codebook_size) {
{ codebook += codebook_stride;
codebook += codebook_stride; ++codebook_size;
++codebook_size;
} }
} }
@ -148,9 +136,8 @@ __global__ void Code2x8MatVec(
for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) { for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) {
int4 dec = codebook[i]; int4 dec = codebook[i];
#pragma unroll #pragma unroll
for (int j = 0; j < 8; j++) for (int j = 0; j < 8; j++) sh_code[8 * i + (j + lane) % 8] = dec;
sh_code[8 * i + (j + lane) % 8] = dec;
} }
__syncthreads(); __syncthreads();
@ -161,8 +148,7 @@ __global__ void Code2x8MatVec(
// We pad shared memory to avoid bank conflicts during reads // We pad shared memory to avoid bank conflicts during reads
__syncthreads(); __syncthreads();
for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) { for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) {
if (b_gl_rd + i < prob_k / 8) if (b_gl_rd + i < prob_k / 8) sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i];
sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i];
} }
__syncthreads(); __syncthreads();
b_gl_rd += 32 * 8; b_gl_rd += 32 * 8;
@ -170,13 +156,15 @@ __global__ void Code2x8MatVec(
int b_sh_rd = 9 * (threadIdx.x % 32); int b_sh_rd = 9 * (threadIdx.x % 32);
if (pred && a_gl_rd < a_gl_end) { if (pred && a_gl_rd < a_gl_end) {
const uint8_t* enc = reinterpret_cast<const uint8_t*>(&A[a_gl_rd]); const uint8_t* enc = reinterpret_cast<const uint8_t*>(&A[a_gl_rd]);
#pragma unroll #pragma unroll
for (int i = 0; i < 8; i++) { for (int i = 0; i < 8; i++) {
half2* a0 = reinterpret_cast<half2*>(&sh_code0[8 * enc[2 * i + 0] + lane]); half2* a0 =
half2* a1 = reinterpret_cast<half2*>(&sh_code1[8 * enc[2 * i + 1] + lane]); reinterpret_cast<half2*>(&sh_code0[8 * enc[2 * i + 0] + lane]);
half2* b = reinterpret_cast<half2*>(&sh_b[b_sh_rd]); half2* a1 =
reinterpret_cast<half2*>(&sh_code1[8 * enc[2 * i + 1] + lane]);
half2* b = reinterpret_cast<half2*>(&sh_b[b_sh_rd]);
half2 res2 = {}; half2 res2 = {};
#pragma unroll #pragma unroll
for (int j = 0; j < 4; j++) for (int j = 0; j < 4; j++)
res2 = __hfma2(__hadd2(a0[j], a1[j]), b[j], res2); res2 = __hfma2(__hadd2(a0[j], a1[j]), b[j], res2);
res += __half2float(res2.x) + __half2float(res2.y); res += __half2float(res2.x) + __half2float(res2.y);
@ -187,36 +175,31 @@ __global__ void Code2x8MatVec(
} }
if (pred) { if (pred) {
#pragma unroll #pragma unroll
for (int i = 16; i > 0; i /= 2) for (int i = 16; i > 0; i /= 2) res += __shfl_down_sync(0xffffffff, res, i);
res += __shfl_down_sync(0xffffffff, res, i);
if (threadIdx.x % 32 == 0) if (threadIdx.x % 32 == 0)
reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res); reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res);
} }
} }
__global__ void Code1x16Dequant( __global__ void Code1x16Dequant(
const int4* __restrict__ A, const int4* __restrict__ A, int4* __restrict__ C,
int4* __restrict__ C, const int4* __restrict__ codebook, int prob_m, int prob_k,
const int4* __restrict__ codebook, const int4 codebook_a_sizes, // cumulative sizes of A spanning each
int prob_m, // codebook, at most 3 long, sums to m.
int prob_k, const int codebook_stride // as int4
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long, sums to m.
const int codebook_stride // as int4
) { ) {
int a_gl_stride = prob_k / 8 / 8; int a_gl_stride = prob_k / 8 / 8;
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
bool pred = a_gl_rd < prob_m; bool pred = a_gl_rd < prob_m;
if (pred) if (pred) {
{ // advance to the correct codebook, this easy because we only multiply one
// advance to the correct codebook, this easy because we only multiply one column of the codebook. // column of the codebook.
auto codebook_size = &codebook_a_sizes.x; auto codebook_size = &codebook_a_sizes.x;
while (a_gl_rd >= *codebook_size) while (a_gl_rd >= *codebook_size) {
{ codebook += codebook_stride;
codebook += codebook_stride; ++codebook_size;
++codebook_size;
} }
} }
@ -231,17 +214,15 @@ __global__ void Code1x16Dequant(
while (iters--) { while (iters--) {
if (pred && a_gl_rd < a_gl_end) { if (pred && a_gl_rd < a_gl_end) {
const uint16_t* enc = reinterpret_cast<const uint16_t*>(&A[a_gl_rd]); const uint16_t* enc = reinterpret_cast<const uint16_t*>(&A[a_gl_rd]);
#pragma unroll #pragma unroll
for (int i = 0; i < 8; i++) { for (int i = 0; i < 8; i++) {
int4 chunk; int4 chunk;
auto dec = reinterpret_cast<uint32_t*>(&chunk); auto dec = reinterpret_cast<uint32_t*>(&chunk);
// We bypass the L1 cache to avoid massive amounts of memory streaming that doesn't // We bypass the L1 cache to avoid massive amounts of memory streaming
// actually help us; this brings > 2x speedup. // that doesn't actually help us; this brings > 2x speedup.
asm volatile ( asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
"ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];" : "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3])
: "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3]) : "l"((void*)&codebook[enc[i]]));
: "l"((void*) &codebook[enc[i]])
);
C[a_gl_rd * 8 + i] = chunk; C[a_gl_rd * 8 + i] = chunk;
} }
@ -250,28 +231,25 @@ __global__ void Code1x16Dequant(
} }
} }
__global__ void Code2x8Dequant( __global__ void Code2x8Dequant(
const int4* __restrict__ A, const int4* __restrict__ A, int4* __restrict__ C,
int4* __restrict__ C, const int4* __restrict__ codebook, int prob_m, int prob_k,
const int4* __restrict__ codebook, const int4
int prob_m, codebook_a_sizes, // cumulative sizes of A spanning each codebook, at
int prob_k, // most 3 long, corresponds to cols.
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long, corresponds to cols. const int codebook_stride // as int4
const int codebook_stride // as int4
) { ) {
int a_gl_stride = prob_k / 8 / 8; int a_gl_stride = prob_k / 8 / 8;
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
bool pred = a_gl_rd < prob_m; bool pred = a_gl_rd < prob_m;
if (pred) if (pred) {
{ // advance to the correct codebook, this easy because we only multiply one
// advance to the correct codebook, this easy because we only multiply one column of the codebook. // column of the codebook.
auto codebook_size = &codebook_a_sizes.x; auto codebook_size = &codebook_a_sizes.x;
while (a_gl_rd >= *codebook_size) while (a_gl_rd >= *codebook_size) {
{ codebook += codebook_stride;
codebook += codebook_stride; ++codebook_size;
++codebook_size;
} }
} }
@ -290,9 +268,8 @@ __global__ void Code2x8Dequant(
for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) { for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) {
int4 dec = codebook[i]; int4 dec = codebook[i];
#pragma unroll #pragma unroll
for (int j = 0; j < 8; j++) for (int j = 0; j < 8; j++) sh_code[8 * i + (j + lane) % 8] = dec;
sh_code[8 * i + (j + lane) % 8] = dec;
} }
__syncthreads(); __syncthreads();
@ -302,12 +279,14 @@ __global__ void Code2x8Dequant(
while (iters--) { while (iters--) {
if (pred && a_gl_rd < a_gl_end) { if (pred && a_gl_rd < a_gl_end) {
const uint8_t* enc = reinterpret_cast<const uint8_t*>(&A[a_gl_rd]); const uint8_t* enc = reinterpret_cast<const uint8_t*>(&A[a_gl_rd]);
#pragma unroll #pragma unroll
for (int i = 0; i < 8; i++) { for (int i = 0; i < 8; i++) {
int4 chunk; int4 chunk;
half2* a0 = reinterpret_cast<half2*>(&sh_code0[8 * enc[2 * i + 0] + lane]); half2* a0 =
half2* a1 = reinterpret_cast<half2*>(&sh_code1[8 * enc[2 * i + 1] + lane]); reinterpret_cast<half2*>(&sh_code0[8 * enc[2 * i + 0] + lane]);
#pragma unroll half2* a1 =
reinterpret_cast<half2*>(&sh_code1[8 * enc[2 * i + 1] + lane]);
#pragma unroll
for (int j = 0; j < 4; j++) for (int j = 0; j < 4; j++)
reinterpret_cast<half2*>(&chunk)[j] = __hadd2(a0[j], a1[j]); reinterpret_cast<half2*>(&chunk)[j] = __hadd2(a0[j], a1[j]);
C[a_gl_rd * 8 + i] = chunk; C[a_gl_rd * 8 + i] = chunk;
@ -317,22 +296,15 @@ __global__ void Code2x8Dequant(
} }
} }
inline int ceildiv(int a, int b) { inline int ceildiv(int a, int b) { return (a + b - 1) / b; }
return (a + b - 1) / b;
}
const int THREAD_M = 16; const int THREAD_M = 16;
void code1x16_matvec_cuda( void code1x16_matvec_cuda(const void* __restrict__ A,
const void* __restrict__ A, const void* __restrict__ B, void* __restrict__ C,
const void* __restrict__ B, const void* __restrict__ codebook, int prob_m,
void* __restrict__ C, int prob_k, const int4 codebook_a_sizes,
const void* __restrict__ codebook, const int codebook_stride) {
int prob_m,
int prob_k,
const int4 codebook_a_sizes,
const int codebook_stride
) {
int sms; int sms;
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0); cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
int waves = 0; int waves = 0;
@ -345,28 +317,16 @@ void code1x16_matvec_cuda(
int blocks = ceildiv(prob_m, thread_m); int blocks = ceildiv(prob_m, thread_m);
int threads = 32 * thread_m; int threads = 32 * thread_m;
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
Code1x16MatVec<<<blocks, threads, 16*32*9, stream>>>( Code1x16MatVec<<<blocks, threads, 16 * 32 * 9, stream>>>(
(const int4*) A, (const int4*)A, (const int4*)B, (int4*)C, (const int4*)codebook, prob_m,
(const int4*) B, prob_k, codebook_a_sizes, codebook_stride);
(int4*) C,
(const int4*) codebook,
prob_m,
prob_k,
codebook_a_sizes,
codebook_stride
);
} }
void code2x8_matvec_cuda( void code2x8_matvec_cuda(const void* __restrict__ A, const void* __restrict__ B,
const void* __restrict__ A, void* __restrict__ C,
const void* __restrict__ B, const void* __restrict__ codebook, int prob_m,
void* __restrict__ C, int prob_k, const int4 codebook_a_sizes,
const void* __restrict__ codebook, const int codebook_stride) {
int prob_m,
int prob_k,
const int4 codebook_a_sizes,
const int codebook_stride
) {
int sms; int sms;
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0); cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
int waves = 0; int waves = 0;
@ -379,30 +339,20 @@ void code2x8_matvec_cuda(
int blocks = ceildiv(prob_m, thread_m); int blocks = ceildiv(prob_m, thread_m);
int threads = 32 * thread_m; int threads = 32 * thread_m;
int shared = 16 * (2 * 256 * 8 + 32 * 9); int shared = 16 * (2 * 256 * 8 + 32 * 9);
cudaFuncSetAttribute( cudaFuncSetAttribute(Code2x8MatVec,
Code2x8MatVec, cudaFuncAttributeMaxDynamicSharedMemorySize, shared cudaFuncAttributeMaxDynamicSharedMemorySize, shared);
);
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
Code2x8MatVec<<<blocks, threads, shared, stream>>>( Code2x8MatVec<<<blocks, threads, shared, stream>>>(
(const int4*) A, (const int4*)A, (const int4*)B, (int4*)C, (const int4*)codebook, prob_m,
(const int4*) B, prob_k, codebook_a_sizes, codebook_stride);
(int4*) C,
(const int4*) codebook,
prob_m,
prob_k,
codebook_a_sizes,
codebook_stride
);
} }
void code1x16_dequant_cuda( void code1x16_dequant_cuda(
const void* __restrict__ A, const void* __restrict__ A, void* __restrict__ C,
void* __restrict__ C, const void* __restrict__ codebook, int prob_m, int prob_k,
const void* __restrict__ codebook, const int4 codebook_a_sizes, // cumulative sizes of A spanning each
int prob_m, // codebook, at most 3 long.
int prob_k, const int codebook_stride // as int4.
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long.
const int codebook_stride // as int4.
) { ) {
int sms; int sms;
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0); cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
@ -417,25 +367,21 @@ void code1x16_dequant_cuda(
int threads = 32 * thread_m; int threads = 32 * thread_m;
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
Code1x16Dequant<<<blocks, threads, 0, stream>>>( Code1x16Dequant<<<blocks, threads, 0, stream>>>(
(const int4*) A, (const int4*)A, (int4*)C, (const int4*)codebook, prob_m, prob_k,
(int4*) C, codebook_a_sizes, // cumulative sizes of A spanning each codebook, at
(const int4*) codebook, // most 3 long.
prob_m, codebook_stride // as int4.
prob_k,
codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long.
codebook_stride // as int4.
); );
} }
// Dequantizes the code and codebook into weights. // Dequantizes the code and codebook into weights.
void code2x8_dequant_cuda( void code2x8_dequant_cuda(
const void* __restrict__ A, const void* __restrict__ A, void* __restrict__ C,
void* __restrict__ C, const void* __restrict__ codebook, int prob_m, int prob_k,
const void* __restrict__ codebook, const int4
int prob_m, codebook_a_sizes, // cumulative sizes of A spanning each codebook, at
int prob_k, // most 3 long, corresponds to cols.
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long, corresponds to cols. const int codebook_stride // as int4
const int codebook_stride // as int4
) { ) {
int sms; int sms;
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0); cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
@ -451,74 +397,50 @@ void code2x8_dequant_cuda(
int shared = 16 * (2 * 256 * 8 + 32 * 9); int shared = 16 * (2 * 256 * 8 + 32 * 9);
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cudaFuncSetAttribute( cudaFuncSetAttribute(Code2x8Dequant,
Code2x8Dequant, cudaFuncAttributeMaxDynamicSharedMemorySize, shared cudaFuncAttributeMaxDynamicSharedMemorySize, shared);
);
Code2x8Dequant<<<blocks, threads, shared, stream>>>( Code2x8Dequant<<<blocks, threads, shared, stream>>>(
(const int4*) A, (const int4*)A, (int4*)C, (const int4*)codebook, prob_m, prob_k,
(int4*) C, codebook_a_sizes, codebook_stride);
(const int4*) codebook,
prob_m,
prob_k,
codebook_a_sizes,
codebook_stride
);
} }
int codebook_stride(const torch::Tensor& codebooks) int codebook_stride(const torch::Tensor& codebooks) {
{
return codebooks.stride(0) * codebooks.element_size() / sizeof(int4); return codebooks.stride(0) * codebooks.element_size() / sizeof(int4);
} }
void code1x16_matvec( void code1x16_matvec(
const torch::Tensor& A, const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& C,
const torch::Tensor& B, const torch::Tensor& codebook,
torch::Tensor& C, const int4 codebook_a_sizes // cumulative sizes of A spanning each
const torch::Tensor& codebook, // codebook, at most 3 long.
const int4 codebook_a_sizes // cumulative sizes of A spanning each codebook, at most 3 long.
) { ) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
int prob_m = C.size(0); int prob_m = C.size(0);
int prob_k = B.size(0); int prob_k = B.size(0);
code1x16_matvec_cuda( code1x16_matvec_cuda(A.data_ptr(), B.data_ptr(), C.data_ptr(),
A.data_ptr(), codebook.data_ptr(), prob_m, prob_k, codebook_a_sizes,
B.data_ptr(), codebook_stride(codebook));
C.data_ptr(),
codebook.data_ptr(),
prob_m,
prob_k,
codebook_a_sizes,
codebook_stride(codebook)
);
} }
torch::Tensor code1x16_matmat( torch::Tensor code1x16_matmat(const torch::Tensor& input,
const torch::Tensor& input, const torch::Tensor& codes,
const torch::Tensor& codes, const torch::Tensor& codebooks,
const torch::Tensor& codebooks, const torch::Tensor& scales,
const torch::Tensor& scales, const int4 codebook_a_sizes,
const int4 codebook_a_sizes, const std::optional<torch::Tensor>& bias) {
const std::optional<torch::Tensor>& bias) {
auto input_sizes = input.sizes(); auto input_sizes = input.sizes();
auto out_features = codes.size(0) * codebooks.size(2); auto out_features = codes.size(0) * codebooks.size(2);
auto flat_input = input.reshape({-1, input.size(-1)}); auto flat_input = input.reshape({-1, input.size(-1)});
auto flat_output = torch::empty({flat_input.size(0), out_features}, auto flat_output = torch::empty(
torch::TensorOptions() {flat_input.size(0), out_features},
.dtype(input.dtype()) torch::TensorOptions().dtype(input.dtype()).device(input.device()));
.device(input.device())
);
for (int i = 0; i < flat_input.size(0); ++i) { for (int i = 0; i < flat_input.size(0); ++i) {
auto input_vec = flat_input.index({i}); auto input_vec = flat_input.index({i});
auto output_vec = flat_output.index({i}); auto output_vec = flat_output.index({i});
code1x16_matvec( code1x16_matvec(codes.squeeze(2), input_vec, output_vec, codebooks,
codes.squeeze(2), codebook_a_sizes);
input_vec,
output_vec,
codebooks,
codebook_a_sizes
);
} }
flat_output *= scales.flatten().unsqueeze(0); flat_output *= scales.flatten().unsqueeze(0);
@ -533,55 +455,35 @@ torch::Tensor code1x16_matmat(
return output; return output;
} }
void code2x8_matvec( void code2x8_matvec(const torch::Tensor& A, const torch::Tensor& B,
const torch::Tensor& A, torch::Tensor& C, const torch::Tensor& codebook,
const torch::Tensor& B, const int4 codebook_a_sizes) {
torch::Tensor& C,
const torch::Tensor& codebook,
const int4 codebook_a_sizes
) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
int prob_m = C.size(0); int prob_m = C.size(0);
int prob_k = B.size(0); int prob_k = B.size(0);
code2x8_matvec_cuda( code2x8_matvec_cuda(A.data_ptr(), B.data_ptr(), C.data_ptr(),
A.data_ptr(), codebook.data_ptr(), prob_m, prob_k, codebook_a_sizes,
B.data_ptr(), 2 * codebook_stride(codebook));
C.data_ptr(),
codebook.data_ptr(),
prob_m,
prob_k,
codebook_a_sizes,
2 * codebook_stride(codebook)
);
} }
torch::Tensor code2x8_matmat( torch::Tensor code2x8_matmat(const torch::Tensor& input,
const torch::Tensor& input, const torch::Tensor& codes,
const torch::Tensor& codes, const torch::Tensor& codebooks,
const torch::Tensor& codebooks, const torch::Tensor& scales,
const torch::Tensor& scales, const int4 codebook_a_sizes,
const int4 codebook_a_sizes, const std::optional<torch::Tensor>& bias) {
const std::optional<torch::Tensor>& bias
) {
auto input_sizes = input.sizes(); auto input_sizes = input.sizes();
auto out_features = codes.size(0) * codebooks.size(2); auto out_features = codes.size(0) * codebooks.size(2);
auto flat_input = input.reshape({-1, input.size(-1)}); auto flat_input = input.reshape({-1, input.size(-1)});
auto flat_output = torch::empty({flat_input.size(0), out_features}, auto flat_output = torch::empty(
torch::TensorOptions() {flat_input.size(0), out_features},
.dtype(input.dtype()) torch::TensorOptions().dtype(input.dtype()).device(input.device()));
.device(input.device())
);
for (int i = 0; i < flat_input.size(0); ++i) { for (int i = 0; i < flat_input.size(0); ++i) {
auto input_vec = flat_input.index({i}); auto input_vec = flat_input.index({i});
auto output_vec = flat_output.index({i}); auto output_vec = flat_output.index({i});
code2x8_matvec( code2x8_matvec(codes.squeeze(2), input_vec, output_vec, codebooks,
codes.squeeze(2), codebook_a_sizes);
input_vec,
output_vec,
codebooks,
codebook_a_sizes
);
} }
flat_output *= scales.flatten().unsqueeze(0); flat_output *= scales.flatten().unsqueeze(0);
if (bias.has_value()) { if (bias.has_value()) {
@ -596,64 +498,56 @@ torch::Tensor code2x8_matmat(
} }
// Accumulate the partition sizes. // Accumulate the partition sizes.
int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes) int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes) {
{
int4 cumulative_sizes; int4 cumulative_sizes;
auto cumulative_size = &cumulative_sizes.x; auto cumulative_size = &cumulative_sizes.x;
int i = 0; int i = 0;
int last = 0; int last = 0;
assert(codebook_partition_sizes.size(0) <= 4); assert(codebook_partition_sizes.size(0) <= 4);
for (; i < codebook_partition_sizes.size(0); ++i, ++cumulative_size) for (; i < codebook_partition_sizes.size(0); ++i, ++cumulative_size) {
{
*cumulative_size = codebook_partition_sizes[i].item<int>() + last; *cumulative_size = codebook_partition_sizes[i].item<int>() + last;
last = *cumulative_size; last = *cumulative_size;
} }
// fill in the rest with unreachable. // fill in the rest with unreachable.
for (; i < 4; ++i, ++cumulative_size) for (; i < 4; ++i, ++cumulative_size) {
{ *cumulative_size = last * 10;
*cumulative_size = last*10;
} }
return cumulative_sizes; return cumulative_sizes;
} }
} // namespace aqlm } // namespace aqlm
} // namespace vllm } // namespace vllm
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
torch::Tensor aqlm_gemm( const torch::Tensor& codebooks,
const torch::Tensor& input, const torch::Tensor& scales,
const torch::Tensor& codes, const torch::Tensor& codebook_partition_sizes,
const torch::Tensor& codebooks, const std::optional<torch::Tensor>& bias) {
const torch::Tensor& scales, int4 cumulative_sizes =
const torch::Tensor& codebook_partition_sizes, vllm::aqlm::accumulate_sizes(codebook_partition_sizes);
const std::optional<torch::Tensor>& bias
)
{
int4 cumulative_sizes = vllm::aqlm::accumulate_sizes(codebook_partition_sizes);
int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0); int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0);
int const entries = codebooks.size(1); int const entries = codebooks.size(1);
if (nbooks == 1 && entries == (1 << 16)) if (nbooks == 1 && entries == (1 << 16)) {
{ return vllm::aqlm::code1x16_matmat(input, codes, codebooks, scales,
return vllm::aqlm::code1x16_matmat(input, codes, codebooks, scales, cumulative_sizes, bias); cumulative_sizes, bias);
} }
if (nbooks == 2 && entries == (1 << 8)) if (nbooks == 2 && entries == (1 << 8)) {
{ return vllm::aqlm::code2x8_matmat(input, codes, codebooks, scales,
return vllm::aqlm::code2x8_matmat(input, codes, codebooks, scales, cumulative_sizes, bias); cumulative_sizes, bias);
} }
TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, " entries is not currently supported.") TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries,
" entries is not currently supported.")
return {}; return {};
} }
torch::Tensor aqlm_dequant( torch::Tensor aqlm_dequant(const torch::Tensor& codes,
const torch::Tensor& codes, const torch::Tensor& codebooks,
const torch::Tensor& codebooks, const torch::Tensor& codebook_partition_sizes) {
const torch::Tensor& codebook_partition_sizes int4 cumulative_sizes =
) vllm::aqlm::accumulate_sizes(codebook_partition_sizes);
{
int4 cumulative_sizes = vllm::aqlm::accumulate_sizes(codebook_partition_sizes);
int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0); int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0);
int const entries = codebooks.size(1); int const entries = codebooks.size(1);
@ -668,45 +562,37 @@ torch::Tensor aqlm_dequant(
assert(out_features = codebook_partition_sizes.sum().item<int>()); assert(out_features = codebook_partition_sizes.sum().item<int>());
auto weights = torch::empty({out_features, in_features}, auto weights = torch::empty({out_features, in_features},
torch::TensorOptions() torch::TensorOptions()
.dtype(codebooks.dtype()) .dtype(codebooks.dtype())
.device(codebooks.device()) .device(codebooks.device()));
);
if (nbooks == 1 && entries == (1 << 16)) if (nbooks == 1 && entries == (1 << 16)) {
{ vllm::aqlm::code1x16_dequant_cuda(codes.data_ptr(), weights.data_ptr(),
vllm::aqlm::code1x16_dequant_cuda( codebooks.data_ptr(), out_features,
codes.data_ptr(), in_features, cumulative_sizes,
weights.data_ptr(), vllm::aqlm::codebook_stride(codebooks));
codebooks.data_ptr(),
out_features,
in_features,
cumulative_sizes,
vllm::aqlm::codebook_stride(codebooks));
// if you wanted to flip to scaling the weights, (though it's 30%-ish slower and not consistent with gemv implementation.) // if you wanted to flip to scaling the weights, (though it's 30%-ish slower
// weights *= scales.index({"...", 0, 0}); // and not consistent with gemv implementation.) weights *=
// scales.index({"...", 0, 0});
return weights; return weights;
} }
if (nbooks == 2 && entries == (1 << 8)) if (nbooks == 2 && entries == (1 << 8)) {
{ vllm::aqlm::code2x8_dequant_cuda(codes.data_ptr(), weights.data_ptr(),
vllm::aqlm::code2x8_dequant_cuda( codebooks.data_ptr(), out_features,
codes.data_ptr(), in_features, cumulative_sizes,
weights.data_ptr(), vllm::aqlm::codebook_stride(codebooks));
codebooks.data_ptr(),
out_features,
in_features,
cumulative_sizes,
vllm::aqlm::codebook_stride(codebooks));
// if you wanted to flip to scaling the weights, (though it's 30%-ish slower and not consistent with gemv implementation) // if you wanted to flip to scaling the weights, (though it's 30%-ish slower
// weights *= scales.index({"...", 0, 0}); // and not consistent with gemv implementation) weights *=
// scales.index({"...", 0, 0});
return weights; return weights;
} }
TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, " entries is not currently supported.") TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries,
" entries is not currently supported.")
return {}; return {};
} }

View File

@ -1,11 +1,11 @@
/* /*
Adapted from https://github.com/mit-han-lab/llm-awq Adapted from https://github.com/mit-han-lab/llm-awq
Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h Modified from NVIDIA FasterTransformer:
https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
@article{lin2023awq, @article{lin2023awq,
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration}, title={AWQ: Activation-aware Weight Quantization for LLM Compression and
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song}, Acceleration}, author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang,
journal={arXiv}, Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023}
year={2023}
} }
*/ */
@ -14,74 +14,88 @@ Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransfor
namespace vllm { namespace vllm {
namespace awq { namespace awq {
__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) {
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
assert(false); assert(false);
#else #else
uint4 result; uint4 result;
uint32_t* h = reinterpret_cast<uint32_t*>(&result); uint32_t* h = reinterpret_cast<uint32_t*>(&result);
uint32_t const i4s = reinterpret_cast<uint32_t const&>(source); uint32_t const i4s = reinterpret_cast<uint32_t const&>(source);
// First, we extract the i4s and construct an intermediate fp16 number. // First, we extract the i4s and construct an intermediate fp16 number.
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint32_t BOTTOM_MASK = 0x000f000f; static constexpr uint32_t BOTTOM_MASK = 0x000f000f;
static constexpr uint32_t TOP_MASK = 0x00f000f0; static constexpr uint32_t TOP_MASK = 0x00f000f0;
static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400; static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
// Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing // Note that the entire sequence only requires 1 shift instruction. This is
// format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions. // thanks to the register packing format and the fact that we force our
// In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and // integers to be unsigned, and account for this in the fp16 subtractions. In
// elt_67 to fp16 without having to shift them to the bottom bits before hand. // addition, I exploit the fact that sub and fma have the same throughput in
// order to convert elt_23 and elt_67 to fp16 without having to shift them to
// the bottom bits before hand.
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW
// immediately before required. // dependency if we issue immediately before required.
const uint32_t top_i4s = i4s >> 8; const uint32_t top_i4s = i4s >> 8;
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[0]) : "=r"(h[0])
: "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM),
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400 "n"(immLut));
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
: "=r"(h[1]) asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); : "=r"(h[1])
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400 : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM),
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" "n"(immLut));
: "=r"(h[2]) // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
: "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400 : "=r"(h[2])
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM),
: "=r"(h[3]) "n"(immLut));
: "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[3])
: "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM),
"n"(immLut));
// I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the // I use inline PTX below because I am not sure if the compiler will emit
// half2 ctor. In this case, I chose performance reliability over code readability. // float2half instructions if I use the half2 ctor. In this case, I chose
// performance reliability over code readability.
// This is the half2 {1032, 1032} represented as an integer. // This is the half2 {1032, 1032} represented as an integer.
// static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408; // static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
// Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7] // Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7]
static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400; static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400;
// This is the half2 {1 / 16, 1 / 16} represented as an integer. // This is the half2 {1 / 16, 1 / 16} represented as an integer.
static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00; static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
// This is the half2 {-72, -72} represented as an integer. // This is the half2 {-72, -72} represented as an integer.
// static constexpr uint32_t NEG_72 = 0xd480d480; // static constexpr uint32_t NEG_72 = 0xd480d480;
// Haotian: Let's use {-64, -64}. // Haotian: Let's use {-64, -64}.
static constexpr uint32_t NEG_64 = 0xd400d400; static constexpr uint32_t NEG_64 = 0xd400d400;
// Finally, we construct the output numbers. // Finally, we construct the output numbers.
// Convert elt_01 // Convert elt_01
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM)); asm volatile("sub.f16x2 %0, %1, %2;\n"
// Convert elt_23 : "=r"(h[0])
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_45 // Convert elt_23
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM)); asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
// Convert elt_67 : "=r"(h[1])
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
// Convert elt_45
asm volatile("sub.f16x2 %0, %1, %2;\n"
: "=r"(h[2])
: "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_67
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
: "=r"(h[3])
: "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
return result; return result;
#endif #endif
} }
} // namespace awq } // namespace awq
} // namespace vllm } // namespace vllm

View File

@ -1,14 +1,12 @@
/* /*
Adapted from https://github.com/mit-han-lab/llm-awq Adapted from https://github.com/mit-han-lab/llm-awq
@article{lin2023awq, @article{lin2023awq,
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration}, title={AWQ: Activation-aware Weight Quantization for LLM Compression and
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song}, Acceleration}, author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang,
journal={arXiv}, Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023}
year={2023}
} }
*/ */
#include <torch/extension.h> #include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
@ -20,26 +18,20 @@ namespace vllm {
namespace awq { namespace awq {
// Pack two half values. // Pack two half values.
static inline __device__ __host__ unsigned static inline __device__ __host__ unsigned __pack_half2(const half x,
__pack_half2(const half x, const half y) { const half y) {
unsigned v0 = *((unsigned short *)&x); unsigned v0 = *((unsigned short*)&x);
unsigned v1 = *((unsigned short *)&y); unsigned v1 = *((unsigned short*)&y);
return (v1 << 16) | v0; return (v1 << 16) | v0;
} }
template<int N> template <int N>
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32( __global__ void __launch_bounds__(64)
int G, gemm_forward_4bit_cuda_m16nXk32(int G, int split_k_iters,
int split_k_iters, half* __restrict__ A, int* __restrict__ B,
half* __restrict__ A, half* __restrict__ scaling_factors,
int* __restrict__ B, int* __restrict__ zeros, int M, int IC,
half* __restrict__ scaling_factors, int OC, half* __restrict__ C) {
int* __restrict__ zeros,
int M,
int IC,
int OC,
half* __restrict__ C)
{
// Only support matrix n = 64 or 128 // Only support matrix n = 64 or 128
assert(N == 64 || N == 128); assert(N == 64 || N == 128);
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
@ -70,43 +62,46 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32(
static constexpr int row_stride = 2 * 32 * 8 / N; static constexpr int row_stride = 2 * 32 * 8 / N;
bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < N; bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < N;
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id bool ld_A_flag =
(blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp +
threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id
// bool wb_C_flag = (threadIdx.x / 4) < M; // bool wb_C_flag = (threadIdx.x / 4) < M;
half* A_ptr = A half* A_ptr =
+ (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC A +
+ (((int)threadIdx.x) % (32 / 8)) * 8; (((int)blockIdx_y) / j_factors1 * 16 +
(((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) *
IC +
(((int)threadIdx.x) % (32 / 8)) * 8;
int* B_ptr = B int* B_ptr = B + ((int)threadIdx.y) * (OC / 8) * (256 / N) +
+ ((int)threadIdx.y) * (OC / 8) * (256 / N) (((int)threadIdx.x) / (N / 8)) * (OC / 8) +
+ (((int)threadIdx.x) / (N / 8)) * (OC / 8) (((int)blockIdx_y) % j_factors1) * (N / 8) +
+ (((int)blockIdx_y) % j_factors1) * (N / 8) (((int)threadIdx.x) % (N / 8)) * 1;
+ (((int)threadIdx.x) % (N / 8)) * 1; // Why * 1 in the above line?
// Why * 1 in the above line?
half* A_shared_ptr = A_shared half* A_shared_ptr = A_shared +
+ ((int)threadIdx.y) * row_stride_warp * (32 + 8) ((int)threadIdx.y) * row_stride_warp * (32 + 8) +
+ (((int)threadIdx.x) / (32 / 8)) * (32 + 8) (((int)threadIdx.x) / (32 / 8)) * (32 + 8) +
+ (((int)threadIdx.x) % (32 / 8) ) * 8; (((int)threadIdx.x) % (32 / 8)) * 8;
half* B_shared_ptr = B_shared half* B_shared_ptr = B_shared +
+ ((int)threadIdx.y) * (row_stride / 2) * (N + 8) ((int)threadIdx.y) * (row_stride / 2) * (N + 8) +
+ (((int)threadIdx.x) / (N / 8)) * (N + 8) (((int)threadIdx.x) / (N / 8)) * (N + 8) +
+ (((int)threadIdx.x) % (N / 8)) * 8; (((int)threadIdx.x) % (N / 8)) * 8;
int* zeros_ptr = zeros int* zeros_ptr = zeros + (((int)blockIdx_y) % j_factors1) * (N / 8) +
+ (((int)blockIdx_y) % j_factors1) * (N / 8) ((int)threadIdx.x) % (N / 8);
+ ((int)threadIdx.x) % (N / 8);
half* scaling_factors_ptr = scaling_factors half* scaling_factors_ptr = scaling_factors +
+ (((int)blockIdx_y) % j_factors1) * N (((int)blockIdx_y) % j_factors1) * N +
+ (((int)threadIdx.x) % (N / 8)) * 8; (((int)threadIdx.x) % (N / 8)) * 8;
half* C_ptr = C half* C_ptr =
+ static_cast<long long>(blockIdx_z) * M * OC // blockIdz.x -> split_k dim C +
+ (((int)blockIdx_y) % j_factors1) * N static_cast<long long>(blockIdx_z) * M * OC // blockIdz.x -> split_k dim
+ ((int)threadIdx.y) * (N / 2) + (((int)blockIdx_y) % j_factors1) * N + ((int)threadIdx.y) * (N / 2) +
+ (((int)threadIdx.x) % 4) * 2; (((int)threadIdx.x) % 4) * 2;
// preload s.f. and zeros // preload s.f. and zeros
int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters; int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
@ -115,57 +110,83 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32(
int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z; int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
__syncthreads(); __syncthreads();
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
if (ld_A_flag) if (ld_A_flag) {
{
*(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32)); *(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
} } else {
else
{
*(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0); *(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
} }
// for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) { // for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8)); uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded); uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC)); uint4 B_loaded_scale =
*(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
/* /*
if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){ if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 &&
printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w); threadIdx.y == 0){ printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x,
B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x,
B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
} }
*/ */
// uint4 B_loaded_scale = make_uint4(0, 0, 0, 0); // uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8); int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < N / 16; ++ax0_ax1_fused_0) { for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < N / 16; ++ax0_ax1_fused_0) {
// B: 32 x 136 (128+8) float16 // B: 32 x 136 (128+8) float16
// each warp: 32 x 4 // each warp: 32 x 4
// each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4 // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus
// *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8))); // zero -> WB UINT4
// row stride in shared memory: (NWARPS * 32 * 8 / cta_N) // *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) *
uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8)); // 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15)
// * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 *
// 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) *
// 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) *
// 8))); row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
uint32_t B_loaded =
*(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
//uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8); // uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N /
// 8)) * 8);
// uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8); // uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x
// % (cta_N / 8)) * 8);
// - zero and * scale // - zero and * scale
// TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale. // TODO (Haotian): can save 4 assembly instructions if sormulate as deq =
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); // q * scale - zero * scale.
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); asm volatile("sub.f16x2 %0, %1, %2;\n"
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); : "=r"(B_loaded_fp16.x)
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); : "=r"(B_loaded_fp16.x)
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); asm volatile("sub.f16x2 %0, %1, %2;\n"
: "=r"(B_loaded_fp16.y)
: "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
: "=r"(B_loaded_fp16.y)
: "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n"
: "=r"(B_loaded_fp16.z)
: "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
: "=r"(B_loaded_fp16.z)
: "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n"
: "=r"(B_loaded_fp16.w)
: "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
: "=r"(B_loaded_fp16.w)
: "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
/* /*
if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){ if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 ==
printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w); 0 && threadIdx.x == 17 && threadIdx.y == 0){ printf("[x] %X %X %X %X\n",
B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
} }
*/ */
// write back // write back
*(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) = B_loaded_fp16; *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) =
B_loaded_fp16;
} }
__syncthreads(); __syncthreads();
@ -173,112 +194,179 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32(
{ {
unsigned int addr; unsigned int addr;
__asm__ __volatile__( __asm__ __volatile__(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, "
: "=r"(addr) "addr; }\n"
: "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8)))) : "=r"(addr)
); : "l"((void*)((&(A_shared[(k_0_1 * 16)])) +
(((((int)threadIdx.x) & 15) * 40) +
((((int)threadIdx.x) >> 4) * 8)))));
__asm__ __volatile__( __asm__ __volatile__(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16" "ldmatrix.sync.aligned.m8n8.x4.shared.b16"
"{%0, %1, %2, %3}, [%4];\n" "{%0, %1, %2, %3}, [%4];\n"
: "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3]) : "=r"(((unsigned*)(A_shared_warp + 0))[0]),
: "r"(addr) "=r"(((unsigned*)(A_shared_warp + 0))[1]),
); "=r"(((unsigned*)(A_shared_warp + 0))[2]),
"=r"(((unsigned*)(A_shared_warp + 0))[3])
: "r"(addr));
} }
for (int ax1_0 = 0; ax1_0 < N / 32; ++ax1_0) { for (int ax1_0 = 0; ax1_0 < N / 32; ++ax1_0) {
{ {
unsigned int addr; unsigned int addr;
__asm__ __volatile__( __asm__ __volatile__(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, "
: "=r"(addr) "addr; }\n"
: "l"((void *)((&(B_shared[(((k_0_1 * (N * 16 + 128)) + (((int)threadIdx.y) * (N / 2))) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * (N + 8)) + ((((int)threadIdx.x) >> 4) * 8)))) : "=r"(addr)
); : "l"((void*)((&(B_shared[(((k_0_1 * (N * 16 + 128)) +
(((int)threadIdx.y) * (N / 2))) +
(ax1_0 * 16))])) +
(((((int)threadIdx.x) & 15) * (N + 8)) +
((((int)threadIdx.x) >> 4) * 8)))));
__asm__ __volatile__( __asm__ __volatile__(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
"{%0, %1, %2, %3}, [%4];\n" "{%0, %1, %2, %3}, [%4];\n"
: "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3]) : "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[0]),
: "r"(addr) "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[1]),
); "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[2]),
"=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[3])
: "r"(addr));
} }
} }
for (int j_0_4 = 0; j_0_4 < N / 32; ++j_0_4) { for (int j_0_4 = 0; j_0_4 < N / 32; ++j_0_4) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
{ {
__asm__ __volatile__( __asm__ __volatile__(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
"=f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
"=f"(((float*)(C_warp + (j_0_4 * 8)))[3])
: "r"(((unsigned*)(A_shared_warp + 0))[0]),
"r"(((unsigned*)(A_shared_warp + 0))[1]),
"r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[0]),
"f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
"f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
"f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
"f"(((float*)(C_warp + (j_0_4 * 8)))[3]));
} }
{ {
__asm__ __volatile__( __asm__ __volatile__(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])
: "r"(((unsigned*)(A_shared_warp + 0))[0]),
"r"(((unsigned*)(A_shared_warp + 0))[1]),
"r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]),
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]));
} }
{ {
__asm__ __volatile__( __asm__ __volatile__(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
: "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
"=f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
"=f"(((float*)(C_warp + (j_0_4 * 8)))[3])
: "r"(((unsigned*)(A_shared_warp + 0))[2]),
"r"(((unsigned*)(A_shared_warp + 0))[3]),
"r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[1]),
"f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
"f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
"f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
"f"(((float*)(C_warp + (j_0_4 * 8)))[3]));
} }
{ {
__asm__ __volatile__( __asm__ __volatile__(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
: "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])
: "r"(((unsigned*)(A_shared_warp + 0))[2]),
"r"(((unsigned*)(A_shared_warp + 0))[3]),
"r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]),
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]));
} }
#else #else
{ {
__asm__ __volatile__( __asm__ __volatile__(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, "
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) "%13};\n"
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
"=f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
"=f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
"=f"(((float*)(C_warp + (j_0_4 * 8)))[3])
: "r"(((unsigned*)(A_shared_warp + 0))[0]),
"r"(((unsigned*)(A_shared_warp + 0))[1]),
"r"(((unsigned*)(A_shared_warp + 0))[2]),
"r"(((unsigned*)(A_shared_warp + 0))[3]),
"r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[0]),
"r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[1]),
"f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
"f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
"f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
"f"(((float*)(C_warp + (j_0_4 * 8)))[3]));
} }
{ {
__asm__ __volatile__( __asm__ __volatile__(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, "
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) "%13};\n"
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])
: "r"(((unsigned*)(A_shared_warp + 0))[0]),
"r"(((unsigned*)(A_shared_warp + 0))[1]),
"r"(((unsigned*)(A_shared_warp + 0))[2]),
"r"(((unsigned*)(A_shared_warp + 0))[3]),
"r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]),
"r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]),
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]));
} }
#endif #endif
} }
} }
} }
// TODO: Shang: Hoist loop invariance. // TODO: Shang: Hoist loop invariance.
for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) { for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) {
for (int local_id = 0; local_id < 8; ++local_id) { for (int local_id = 0; local_id < 8; ++local_id) {
int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8; int row_offset = (((int)blockIdx_y) / j_factors1) * 16 +
if (row_offset < M) ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
{ if (row_offset < M) {
*(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]); *(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 +
local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
} }
} }
} }
#endif #endif
} }
__global__ void __launch_bounds__(64) dequantize_weights( __global__ void __launch_bounds__(64)
int* __restrict__ B, dequantize_weights(int* __restrict__ B, half* __restrict__ scaling_factors,
half* __restrict__ scaling_factors, int* __restrict__ zeros, half* __restrict__ C, int G) {
int* __restrict__ zeros,
half* __restrict__ C,
int G
)
{
int j_factors1 = 4; int j_factors1 = 4;
int row_stride2 = 4; int row_stride2 = 4;
int split_k_iters = 1; int split_k_iters = 1;
@ -310,14 +398,30 @@ __global__ void __launch_bounds__(64) dequantize_weights(
uint32_t B_loaded = *(uint32_t*)B_ptr2; uint32_t B_loaded = *(uint32_t*)B_ptr2;
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); asm volatile("sub.f16x2 %0, %1, %2;\n"
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); : "=r"(B_loaded_fp16.x)
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); : "=r"(B_loaded_fp16.x)
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); asm volatile("sub.f16x2 %0, %1, %2;\n"
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); : "=r"(B_loaded_fp16.y)
: "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
: "=r"(B_loaded_fp16.y)
: "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n"
: "=r"(B_loaded_fp16.z)
: "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
: "=r"(B_loaded_fp16.z)
: "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n"
: "=r"(B_loaded_fp16.w)
: "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
: "=r"(B_loaded_fp16.w)
: "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
*(uint4*)B_shared_ptr2 = B_loaded_fp16; *(uint4*)B_shared_ptr2 = B_loaded_fp16;
@ -326,58 +430,57 @@ __global__ void __launch_bounds__(64) dequantize_weights(
} }
} }
} // namespace awq } // namespace awq
} // namespace vllm } // namespace vllm
torch::Tensor awq_dequantize( torch::Tensor awq_dequantize(torch::Tensor _kernel,
torch::Tensor _kernel, torch::Tensor _scaling_factors,
torch::Tensor _scaling_factors, torch::Tensor _zeros, int split_k_iters, int thx,
torch::Tensor _zeros, int thy) {
int split_k_iters, int in_c = _kernel.size(0);
int thx, int qout_c = _kernel.size(1);
int thy) int out_c = qout_c * 8;
{ int G = in_c / _scaling_factors.size(0);
int in_c = _kernel.size(0);
int qout_c = _kernel.size(1);
int out_c = qout_c * 8;
int G = in_c / _scaling_factors.size(0);
int x_thread = thx; int x_thread = thx;
int y_thread = thy; int y_thread = thy;
int x_blocks = 1; int x_blocks = 1;
int y_blocks = 1; int y_blocks = 1;
if (thx==0) { if (thx == 0) {
x_thread = qout_c; x_thread = qout_c;
} }
if (thy==0) { if (thy == 0) {
y_thread = in_c; y_thread = in_c;
} }
if (thx==0 && thy==0) { if (thx == 0 && thy == 0) {
x_thread = 8; x_thread = 8;
y_thread = 8; y_thread = 8;
x_blocks = (int)(qout_c / 8); x_blocks = (int)(qout_c / 8);
y_blocks = (int)(in_c / 8); y_blocks = (int)(in_c / 8);
} }
const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors)); const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors));
auto options = torch::TensorOptions().dtype(_scaling_factors.dtype()).device(_scaling_factors.device()); auto options = torch::TensorOptions()
at::Tensor _de_kernel = torch::empty({in_c, out_c}, options); .dtype(_scaling_factors.dtype())
.device(_scaling_factors.device());
at::Tensor _de_kernel = torch::empty({in_c, out_c}, options);
auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>()); auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
auto de_kernel = reinterpret_cast<half*>(_de_kernel.data_ptr<at::Half>()); auto de_kernel = reinterpret_cast<half*>(_de_kernel.data_ptr<at::Half>());
auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>()); auto scaling_factors =
auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>()); reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
dim3 num_blocks(x_blocks, y_blocks); dim3 num_blocks(x_blocks, y_blocks);
dim3 threads_per_block(x_thread, y_thread); dim3 threads_per_block(x_thread, y_thread);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
vllm::awq::dequantize_weights<<<num_blocks, threads_per_block, 0, stream>>>( vllm::awq::dequantize_weights<<<num_blocks, threads_per_block, 0, stream>>>(
kernel, scaling_factors, zeros, de_kernel, G); kernel, scaling_factors, zeros, de_kernel, G);
return _de_kernel; return _de_kernel;
} }
// in_feats: M, IC [float16] // in_feats: M, IC [float16]
@ -386,61 +489,61 @@ torch::Tensor awq_dequantize(
// zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b] // zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b]
// assume that batch_size < 16 for now // assume that batch_size < 16 for now
torch::Tensor awq_gemm( torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
torch::Tensor _in_feats, torch::Tensor _scaling_factors, torch::Tensor _zeros,
torch::Tensor _kernel, int split_k_iters) {
torch::Tensor _scaling_factors, int num_in_feats = _in_feats.size(0);
torch::Tensor _zeros, int num_in_channels = _in_feats.size(1);
int split_k_iters) const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats));
{
int num_in_feats = _in_feats.size(0);
int num_in_channels = _in_feats.size(1);
const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats));
auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device()); auto options = torch::TensorOptions()
at::Tensor _out_feats = torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options); .dtype(_in_feats.dtype())
int num_out_feats = _out_feats.size(-2); .device(_in_feats.device());
int num_out_channels = _out_feats.size(-1); at::Tensor _out_feats =
torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options);
int num_out_feats = _out_feats.size(-2);
int num_out_channels = _out_feats.size(-1);
auto in_feats = reinterpret_cast<half*>(_in_feats.data_ptr<at::Half>()); auto in_feats = reinterpret_cast<half*>(_in_feats.data_ptr<at::Half>());
auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>()); auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>()); auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());
auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>()); auto scaling_factors =
auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>()); reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
int group_size = num_in_channels / _scaling_factors.size(0); auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
int group_size = num_in_channels / _scaling_factors.size(0);
if (num_out_channels % 64 != 0) if (num_out_channels % 64 != 0)
throw std::invalid_argument("OC is not multiple of cta_N = 64"); throw std::invalid_argument("OC is not multiple of cta_N = 64");
if (num_out_channels % 8 != 0) if (num_out_channels % 8 != 0)
throw std::invalid_argument("OC is not multiple of pack_num = 8"); throw std::invalid_argument("OC is not multiple of pack_num = 8");
if (group_size % 32 != 0) if (group_size % 32 != 0)
throw std::invalid_argument("Group size should be a multiple of 32"); throw std::invalid_argument("Group size should be a multiple of 32");
if (num_out_channels % group_size != 0) if (num_out_channels % group_size != 0)
throw std::invalid_argument("OC is not multiple of Group size"); throw std::invalid_argument("OC is not multiple of Group size");
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (num_out_channels % 128 == 0) if (num_out_channels % 128 == 0) {
{ int j_factors1 = num_out_channels / 128 / 1;
int j_factors1 = num_out_channels / 128 / 1; dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); // threadIdx.x: 32
// threadIdx.x: 32 // threadIdx.y: i_factors[2] * j_factors[2]
// threadIdx.y: i_factors[2] * j_factors[2] dim3 threads_per_block(32, 2);
dim3 threads_per_block(32, 2); vllm::awq::gemm_forward_4bit_cuda_m16nXk32<128>
vllm::awq::gemm_forward_4bit_cuda_m16nXk32<128><<<num_blocks, threads_per_block, 0, stream>>>( <<<num_blocks, threads_per_block, 0, stream>>>(
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros,
num_out_channels, out_feats); num_in_feats, num_in_channels, num_out_channels, out_feats);
} } else if (num_out_channels % 64 == 0) {
else if (num_out_channels % 64 == 0) int j_factors1 = num_out_channels / 64 / 1;
{ dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 *
int j_factors1 = num_out_channels / 64 / 1; split_k_iters);
dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
// threadIdx.x: 32 // threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2] // threadIdx.y: i_factors[2] * j_factors[2]
dim3 threads_per_block(32, 2); dim3 threads_per_block(32, 2);
vllm::awq::gemm_forward_4bit_cuda_m16nXk32<64><<<num_blocks, threads_per_block, 0, stream>>>( vllm::awq::gemm_forward_4bit_cuda_m16nXk32<64>
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, <<<num_blocks, threads_per_block, 0, stream>>>(
num_out_channels, out_feats); group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros,
} num_in_feats, num_in_channels, num_out_channels, out_feats);
return _out_feats.sum(0); }
return _out_feats.sum(0);
} }

View File

@ -117,10 +117,10 @@ struct cutlass_2x_gemm {
}; };
template <typename Gemm> template <typename Gemm>
void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a, void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const &b, torch::Tensor const& b,
torch::Tensor const &a_scales, torch::Tensor const& a_scales,
torch::Tensor const &b_scales) { torch::Tensor const& b_scales) {
using ElementAB = typename Gemm::ElementAB; using ElementAB = typename Gemm::ElementAB;
using ElementD = typename Gemm::ElementD; using ElementD = typename Gemm::ElementD;
@ -136,9 +136,9 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a,
using StrideC = Stride<int64_t, Int<1>, Int<0>>; using StrideC = Stride<int64_t, Int<1>, Int<0>>;
StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
auto a_ptr = static_cast<ElementAB const *>(a.data_ptr()); auto a_ptr = static_cast<ElementAB const*>(a.data_ptr());
auto b_ptr = static_cast<ElementAB const *>(b.data_ptr()); auto b_ptr = static_cast<ElementAB const*>(b.data_ptr());
auto c_ptr = static_cast<ElementD *>(out.data_ptr()); auto c_ptr = static_cast<ElementD*>(out.data_ptr());
auto a_scales_ptr = a_scales.data_ptr<float>(); auto a_scales_ptr = a_scales.data_ptr<float>();
auto b_scales_ptr = b_scales.data_ptr<float>(); auto b_scales_ptr = b_scales.data_ptr<float>();
@ -196,10 +196,10 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a,
} // namespace } // namespace
void cutlass_scaled_mm_dq_sm75(torch::Tensor &out, torch::Tensor const &a, void cutlass_scaled_mm_dq_sm75(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const &b, torch::Tensor const& b,
torch::Tensor const &a_scales, torch::Tensor const& a_scales,
torch::Tensor const &b_scales) { torch::Tensor const& b_scales) {
TORCH_CHECK(a.dtype() == torch::kInt8); TORCH_CHECK(a.dtype() == torch::kInt8);
TORCH_CHECK(b.dtype() == torch::kInt8); TORCH_CHECK(b.dtype() == torch::kInt8);
TORCH_CHECK(a_scales.dtype() == torch::kFloat32); TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
@ -223,10 +223,10 @@ void cutlass_scaled_mm_dq_sm75(torch::Tensor &out, torch::Tensor const &a,
} }
} }
void cutlass_scaled_mm_dq_sm80(torch::Tensor &out, torch::Tensor const &a, void cutlass_scaled_mm_dq_sm80(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const &b, torch::Tensor const& b,
torch::Tensor const &a_scales, torch::Tensor const& a_scales,
torch::Tensor const &b_scales) { torch::Tensor const& b_scales) {
TORCH_CHECK(a.dtype() == torch::kInt8); TORCH_CHECK(a.dtype() == torch::kInt8);
TORCH_CHECK(b.dtype() == torch::kInt8); TORCH_CHECK(b.dtype() == torch::kInt8);
TORCH_CHECK(a_scales.dtype() == torch::kFloat32); TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
@ -250,10 +250,10 @@ void cutlass_scaled_mm_dq_sm80(torch::Tensor &out, torch::Tensor const &a,
} }
} }
void cutlass_scaled_mm_dq_sm89(torch::Tensor &out, torch::Tensor const &a, void cutlass_scaled_mm_dq_sm89(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const &b, torch::Tensor const& b,
torch::Tensor const &a_scales, torch::Tensor const& a_scales,
torch::Tensor const &b_scales) { torch::Tensor const& b_scales) {
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>; using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;

View File

@ -120,10 +120,10 @@ struct cutlass_3x_gemm {
}; };
template <typename Gemm> template <typename Gemm>
void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a, void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const &b, torch::Tensor const& b,
torch::Tensor const &a_scales, torch::Tensor const& a_scales,
torch::Tensor const &b_scales) { torch::Tensor const& b_scales) {
using ElementAB = typename Gemm::ElementAB; using ElementAB = typename Gemm::ElementAB;
using ElementD = typename Gemm::ElementD; using ElementD = typename Gemm::ElementD;
@ -146,12 +146,12 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a,
using GemmKernel = typename Gemm::GemmKernel; using GemmKernel = typename Gemm::GemmKernel;
typename GemmKernel::ProblemShape prob_shape{m, n, k, 1}; typename GemmKernel::ProblemShape prob_shape{m, n, k, 1};
auto a_ptr = static_cast<ElementAB *>(a.data_ptr()); auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
auto b_ptr = static_cast<ElementAB *>(b.data_ptr()); auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr, typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr,
b_stride}; b_stride};
auto c_ptr = static_cast<ElementD *>(out.data_ptr()); auto c_ptr = static_cast<ElementD*>(out.data_ptr());
typename GemmKernel::EpilogueArguments epilogue_args{ typename GemmKernel::EpilogueArguments epilogue_args{
{}, c_ptr, c_stride, c_ptr, c_stride}; {}, c_ptr, c_stride, c_ptr, c_stride};
@ -183,10 +183,10 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a,
} }
} // namespace } // namespace
void cutlass_scaled_mm_dq_sm90(torch::Tensor &out, torch::Tensor const &a, void cutlass_scaled_mm_dq_sm90(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const &b, torch::Tensor const& b,
torch::Tensor const &a_scales, torch::Tensor const& a_scales,
torch::Tensor const &b_scales) { torch::Tensor const& b_scales) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32); TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32); TORCH_CHECK(b_scales.dtype() == torch::kFloat32);

View File

@ -2,29 +2,29 @@
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <torch/extension.h> #include <torch/extension.h>
void cutlass_scaled_mm_dq_sm75(torch::Tensor &c, torch::Tensor const &a, void cutlass_scaled_mm_dq_sm75(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const &b, torch::Tensor const& b,
torch::Tensor const &a_scales, torch::Tensor const& a_scales,
torch::Tensor const &b_scales); torch::Tensor const& b_scales);
void cutlass_scaled_mm_dq_sm80(torch::Tensor &c, torch::Tensor const &a, void cutlass_scaled_mm_dq_sm80(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const &b, torch::Tensor const& b,
torch::Tensor const &a_scales, torch::Tensor const& a_scales,
torch::Tensor const &b_scales); torch::Tensor const& b_scales);
void cutlass_scaled_mm_dq_sm89(torch::Tensor &c, torch::Tensor const &a, void cutlass_scaled_mm_dq_sm89(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const &b, torch::Tensor const& b,
torch::Tensor const &a_scales, torch::Tensor const& a_scales,
torch::Tensor const &b_scales); torch::Tensor const& b_scales);
void cutlass_scaled_mm_dq_sm90(torch::Tensor &c, torch::Tensor const &a, void cutlass_scaled_mm_dq_sm90(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const &b, torch::Tensor const& b,
torch::Tensor const &a_scales, torch::Tensor const& a_scales,
torch::Tensor const &b_scales); torch::Tensor const& b_scales);
void cutlass_scaled_mm_dq(torch::Tensor &c, torch::Tensor const &a, void cutlass_scaled_mm_dq(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const &b, torch::Tensor const &a_scales, torch::Tensor const& b, torch::Tensor const& a_scales,
torch::Tensor const &b_scales) { torch::Tensor const& b_scales) {
int32_t major_capability; int32_t major_capability;
int32_t minor_capability; int32_t minor_capability;
cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor, cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
@ -36,14 +36,15 @@ void cutlass_scaled_mm_dq(torch::Tensor &c, torch::Tensor const &a,
// Checks for conformality // Checks for conformality
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2); TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) && TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
b.size(1) == c.size(1)); b.size(1) == c.size(1));
TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0)); TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1)); TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
// Check for strides and alignment // Check for strides and alignment
TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
TORCH_CHECK(b.stride(0) == 1); // Column-major TORCH_CHECK(b.stride(0) == 1); // Column-major
TORCH_CHECK(c.stride(0) % 16 == 0 && b.stride(1) % 16 == 0); // 16 Byte Alignment TORCH_CHECK(c.stride(0) % 16 == 0 &&
b.stride(1) % 16 == 0); // 16 Byte Alignment
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
at::cuda::OptionalCUDAGuard const device_guard(device_of(a)); at::cuda::OptionalCUDAGuard const device_guard(device_of(a));

View File

@ -1,167 +1,137 @@
#pragma once #pragma once
#ifdef __HIPCC__ #ifdef __HIPCC__
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#else #else
#include <type_traits> #include <type_traits>
#include <stdint.h> #include <stdint.h>
#include <math.h> #include <math.h>
#include <iostream> #include <iostream>
#endif #endif
#include "hip_float8_impl.h" #include "hip_float8_impl.h"
struct alignas(1) hip_fp8 struct alignas(1) hip_fp8 {
{ struct from_bits_t {};
struct from_bits_t HIP_FP8_HOST_DEVICE static constexpr from_bits_t from_bits() {
{ return from_bits_t();
}; }
HIP_FP8_HOST_DEVICE static constexpr from_bits_t from_bits() { return from_bits_t(); } uint8_t data;
uint8_t data;
hip_fp8() = default; hip_fp8() = default;
HIP_FP8_HOST_DEVICE constexpr hip_fp8(const hip_fp8&) = default; HIP_FP8_HOST_DEVICE constexpr hip_fp8(const hip_fp8&) = default;
HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v) = delete; HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v) = delete;
explicit HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v, from_bits_t) explicit HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v, from_bits_t)
: data(v) : data(v) {}
{
}
#ifdef __HIP__MI300__ #ifdef __HIP__MI300__
// NOTE: ON-DEVICE... always optimal bias // NOTE: ON-DEVICE... always optimal bias
explicit HIP_FP8_DEVICE hip_fp8(float v) explicit HIP_FP8_DEVICE hip_fp8(float v)
: data(hip_fp8_impl::to_fp8_from_fp32(v)) : data(hip_fp8_impl::to_fp8_from_fp32(v)) {}
{
}
explicit HIP_FP8_DEVICE hip_fp8(_Float16 v) explicit HIP_FP8_DEVICE hip_fp8(_Float16 v)
: hip_fp8(static_cast<float>(v)) : hip_fp8(static_cast<float>(v)) {}
{
}
// Host only implementation using s/w simulation // Host only implementation using s/w simulation
explicit HIP_FP8_HOST explicit HIP_FP8_HOST
#else // __HIP__MI300__ #else // __HIP__MI300__
// both Host and DEVICE for non-MI300 using s/w simulation // both Host and DEVICE for non-MI300 using s/w simulation
explicit HIP_FP8_HOST_DEVICE explicit HIP_FP8_HOST_DEVICE
#endif // __HIP__MI300__ #endif // __HIP__MI300__
hip_fp8(float v) hip_fp8(float v) {
{ data = hip_fp8_impl::to_float8<4, 3, float, true /*negative_zero_nan*/,
data = hip_fp8_impl::to_float8<4, 3, float, true /*negative_zero_nan*/, true /*clip*/>(v); true /*clip*/>(v);
} }
explicit HIP_FP8_HOST_DEVICE hip_fp8(double v) explicit HIP_FP8_HOST_DEVICE hip_fp8(double v)
: hip_fp8(static_cast<float>(v)) : hip_fp8(static_cast<float>(v)) {}
{
}
#ifdef __HIP__MI300__ #ifdef __HIP__MI300__
// upcast using device specific intrinsic // upcast using device specific intrinsic
explicit inline HIP_FP8_DEVICE operator float() const explicit inline HIP_FP8_DEVICE operator float() const {
{ float fval;
float fval; uint32_t i32val = static_cast<uint32_t>(data);
uint32_t i32val = static_cast<uint32_t>(data);
// upcast // upcast
asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val)); asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0"
: "=v"(fval)
: "v"(i32val));
return fval; return fval;
} }
explicit inline HIP_FP8_HOST operator float() const explicit inline HIP_FP8_HOST operator float() const
#else // __HIP__MI300__ #else // __HIP__MI300__
explicit inline HIP_FP8_HOST_DEVICE operator float() const explicit inline HIP_FP8_HOST_DEVICE operator float() const
#endif // __HIP__MI300__ #endif // __HIP__MI300__
{ {
return hip_fp8_impl::from_float8<4, 3, float, true /*negative_zero_nan*/>(data); return hip_fp8_impl::from_float8<4, 3, float, true /*negative_zero_nan*/>(
} data);
}
}; };
namespace std namespace std {
{ inline hip_fp8 sin(hip_fp8 a) { return hip_fp8(sinf(float(a))); }
inline hip_fp8 sin(hip_fp8 a) inline hip_fp8 cos(hip_fp8 a) { return hip_fp8(cosf(float(a))); }
{ HIP_FP8_HOST_DEVICE constexpr hip_fp8 real(const hip_fp8& a) { return a; }
return hip_fp8(sinf(float(a))); } // namespace std
}
inline hip_fp8 cos(hip_fp8 a)
{
return hip_fp8(cosf(float(a)));
}
HIP_FP8_HOST_DEVICE constexpr hip_fp8 real(const hip_fp8& a)
{
return a;
}
} // namespace std
// Special operator overloading // Special operator overloading
inline std::ostream& operator<<(std::ostream& os, const hip_fp8& f8) inline std::ostream& operator<<(std::ostream& os, const hip_fp8& f8) {
{ return os << float(f8);
return os << float(f8);
} }
// all + operator overloading with mixed types // all + operator overloading with mixed types
// mixed types, always converts to f32, does computation in f32, and returns float // mixed types, always converts to f32, does computation in f32, and returns
inline HIP_FP8_HOST_DEVICE float operator+(const float fa, hip_fp8 b) // float
{ inline HIP_FP8_HOST_DEVICE float operator+(const float fa, hip_fp8 b) {
return (fa + float(b)); return (fa + float(b));
} }
inline HIP_FP8_HOST_DEVICE float operator+(hip_fp8 a, const float fb) inline HIP_FP8_HOST_DEVICE float operator+(hip_fp8 a, const float fb) {
{ return (float(a) + fb);
return (float(a) + fb);
} }
inline HIP_FP8_HOST_DEVICE hip_fp8 operator+(hip_fp8 a, hip_fp8 b) inline HIP_FP8_HOST_DEVICE hip_fp8 operator+(hip_fp8 a, hip_fp8 b) {
{ return hip_fp8(float(a) + float(b));
return hip_fp8(float(a) + float(b));
} }
inline HIP_FP8_HOST_DEVICE hip_fp8& operator+=(hip_fp8& a, hip_fp8 b) inline HIP_FP8_HOST_DEVICE hip_fp8& operator+=(hip_fp8& a, hip_fp8 b) {
{ return a = hip_fp8(float(a) + float(b));
return a = hip_fp8(float(a) + float(b));
} }
// overloading multiplication, always returns float, // overloading multiplication, always returns float,
inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, hip_fp8 b) inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, hip_fp8 b) {
{ return float(a) * float(b);
return float(a) * float(b);
} }
inline HIP_FP8_HOST_DEVICE float operator*(float a, hip_fp8 b) inline HIP_FP8_HOST_DEVICE float operator*(float a, hip_fp8 b) {
{ return (a * float(b));
return (a * float(b));
} }
inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, float b) inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, float b) {
{ return (float(a) * b);
return (float(a) * b);
} }
inline HIP_FP8_HOST_DEVICE float operator*(int32_t a, hip_fp8 b) inline HIP_FP8_HOST_DEVICE float operator*(int32_t a, hip_fp8 b) {
{ return ((float)a * float(b));
return ((float)a * float(b));
} }
inline HIP_FP8_HOST_DEVICE float operator*(double a, hip_fp8 b) inline HIP_FP8_HOST_DEVICE float operator*(double a, hip_fp8 b) {
{ return ((float)a * float(b));
return ((float)a * float(b));
} }
// overloading for compare // overloading for compare
inline HIP_FP8_HOST_DEVICE bool operator==(hip_fp8 a, hip_fp8 b) inline HIP_FP8_HOST_DEVICE bool operator==(hip_fp8 a, hip_fp8 b) {
{ return (a.data == b.data);
return (a.data == b.data);
} }
inline HIP_FP8_HOST_DEVICE bool operator!=(hip_fp8 a, hip_fp8 b) inline HIP_FP8_HOST_DEVICE bool operator!=(hip_fp8 a, hip_fp8 b) {
{ return (a.data != b.data);
return (a.data != b.data);
} }
inline HIP_FP8_HOST_DEVICE bool operator>=(hip_fp8 a, hip_fp8 b) inline HIP_FP8_HOST_DEVICE bool operator>=(hip_fp8 a, hip_fp8 b) {
{ return static_cast<float>(a) >= static_cast<float>(b);
return static_cast<float>(a) >= static_cast<float>(b);
} }
inline HIP_FP8_HOST_DEVICE bool operator>(hip_fp8 a, hip_fp8 b) inline HIP_FP8_HOST_DEVICE bool operator>(hip_fp8 a, hip_fp8 b) {
{ return static_cast<float>(a) > static_cast<float>(b);
return static_cast<float>(a) > static_cast<float>(b);
} }

View File

@ -1,316 +1,316 @@
#pragma once #pragma once
#if defined(__HIPCC__) && (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) #if defined(__HIPCC__) && \
#define __HIP__MI300__ (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
#define __HIP__MI300__
#endif #endif
#ifdef __HIPCC__ #ifdef __HIPCC__
#define HIP_FP8_HOST_DEVICE __host__ __device__ #define HIP_FP8_HOST_DEVICE __host__ __device__
#define HIP_FP8_HOST __host__ #define HIP_FP8_HOST __host__
#define HIP_FP8_DEVICE __device__ #define HIP_FP8_DEVICE __device__
#else #else
#define HIP_FP8_HOST_DEVICE #define HIP_FP8_HOST_DEVICE
#define HIP_FP8_HOST #define HIP_FP8_HOST
#define HIP_FP8_DEVICE #define HIP_FP8_DEVICE
#endif #endif
namespace hip_fp8_impl namespace hip_fp8_impl {
{
#ifdef __HIP__MI300__ #ifdef __HIP__MI300__
HIP_FP8_DEVICE uint8_t to_fp8_from_fp32(float v) HIP_FP8_DEVICE uint8_t to_fp8_from_fp32(float v) {
{ uint8_t i8data;
uint8_t i8data; union {
union { float fval;
float fval; uint32_t i32val;
uint32_t i32val; uint8_t i8val[4]; // NOTE: not endian independent
uint8_t i8val[4]; // NOTE: not endian independent } val;
} val;
uint32_t ival = 0; uint32_t ival = 0;
val.fval = v; val.fval = v;
if ((val.i32val & 0x7F800000) != 0x7F800000) { /// propagate NAN/INF, no clipping if ((val.i32val & 0x7F800000) !=
val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0); 0x7F800000) { /// propagate NAN/INF, no clipping
} val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0);
}
ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival,
false); // false -> WORD0 false); // false -> WORD0
val.i32val = ival; val.i32val = ival;
i8data = val.i8val[0]; i8data = val.i8val[0];
return i8data; return i8data;
} }
#endif // __HIP__MI300__ #endif // __HIP__MI300__
HIP_FP8_HOST inline int clz(uint32_t x) HIP_FP8_HOST inline int clz(uint32_t x) { return __builtin_clz(x); }
{
return __builtin_clz(x);
}
#if defined(__HIPCC__) || defined(__CUDA_ARCH__) #if defined(__HIPCC__) || defined(__CUDA_ARCH__)
HIP_FP8_DEVICE inline int clz(uint32_t x) HIP_FP8_DEVICE inline int clz(uint32_t x) { return __clz(x); }
{
return __clz(x);
}
#endif #endif
template <int we, int wm, typename T, bool negative_zero_nan, bool clip> template <int we, int wm, typename T, bool negative_zero_nan, bool clip>
HIP_FP8_HOST_DEVICE uint8_t to_float8(T _x, bool stoch = false, uint32_t rng = 0) HIP_FP8_HOST_DEVICE uint8_t to_float8(T _x, bool stoch = false,
{ uint32_t rng = 0) {
#ifdef __HIPCC__ #ifdef __HIPCC__
constexpr bool is_half = std::is_same<T, _Float16>::value; constexpr bool is_half = std::is_same<T, _Float16>::value;
#else #else
constexpr bool is_half = false; constexpr bool is_half = false;
#endif #endif
constexpr bool is_float = std::is_same<T, float>::value; constexpr bool is_float = std::is_same<T, float>::value;
static_assert(wm + we == 7, "wm+we==7"); static_assert(wm + we == 7, "wm+we==7");
static_assert(is_half || is_float, "Only half and float can be cast to f8"); static_assert(is_half || is_float, "Only half and float can be cast to f8");
const int mfmt = (sizeof(T) == 4) ? 23 : 10; const int mfmt = (sizeof(T) == 4) ? 23 : 10;
uint32_t x; uint32_t x;
if (sizeof(T) == 4) {
x = reinterpret_cast<uint32_t&>(_x);
} else {
x = reinterpret_cast<uint16_t&>(_x);
}
uint32_t head, mantissa;
int exponent, bias;
uint32_t sign;
if (sizeof(T) == 4) {
head = x & 0xFF800000;
mantissa = x & 0x7FFFFF;
exponent = (head >> 23) & 0xFF;
sign = head >> 31;
bias = 127;
} else {
head = x & 0xFC00;
mantissa = x & 0x3FF;
exponent = (head >> 10) & 0x1F;
sign = head >> 15;
bias = 15;
}
uint32_t signed_inf = (sign << 7) + (((1 << we) - 1) << wm);
// Deal with inf and NaNs
if (negative_zero_nan) {
if (sizeof(T) == 4) { if (sizeof(T) == 4) {
x = reinterpret_cast<uint32_t&>(_x); if ((x & 0x7F800000) == 0x7F800000) {
return 0x80;
}
} else { } else {
x = reinterpret_cast<uint16_t&>(_x); // if(__hisinf(x) || __hisnan(x))
if ((x & 0x7C00) == 0x7C00) {
return 0x80;
}
} }
} else {
uint32_t head, mantissa;
int exponent, bias;
uint32_t sign;
if (sizeof(T) == 4) { if (sizeof(T) == 4) {
head = x & 0xFF800000; if ((x & 0x7F800000) == 0x7F800000) {
mantissa = x & 0x7FFFFF; return signed_inf + (mantissa != 0 ? 1 : 0);
exponent = (head >> 23) & 0xFF; }
sign = head >> 31;
bias = 127;
} else { } else {
head = x & 0xFC00; if ((x & 0x7C00) == 0x7C00) {
mantissa = x & 0x3FF; return signed_inf + (mantissa != 0 ? 1 : 0);
exponent = (head >> 10) & 0x1F; }
sign = head >> 15;
bias = 15;
} }
}
if (x == 0) {
return 0;
}
uint32_t signed_inf = (sign << 7) + (((1 << we) - 1) << wm); // First need to check if it is normal or denorm as there is a difference of
// implicit 1 Then need to adjust the exponent to align with the F8 exponent,
// in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng
// to mantissa and truncate. And for RNE, no need to add rng. Then probably
// need to check whether there is carry and adjust exponent and mantissa again
// Deal with inf and NaNs // For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent
if (negative_zero_nan) { // bits
if (sizeof(T) == 4) { const int f8_bias = (1 << (we - 1)) - 1 + (negative_zero_nan ? 1 : 0);
if ((x & 0x7F800000) == 0x7F800000) { const int f8_denormal_act_exponent =
return 0x80; 1 - f8_bias; // actual exponent of f8 denormal
} // act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
} else { // f8_exponent is the converted f8 exponent with bias encoding
// if(__hisinf(x) || __hisnan(x)) // exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
if ((x & 0x7C00) == 0x7C00) { // the difference needs to be adjusted and mantissa shifted
return 0x80; int act_exponent, f8_exponent, exponent_diff;
}
}
} else {
if (sizeof(T) == 4) {
if ((x & 0x7F800000) == 0x7F800000) {
return signed_inf + (mantissa != 0 ? 1 : 0);
}
} else {
if ((x & 0x7C00) == 0x7C00) {
return signed_inf + (mantissa != 0 ? 1 : 0);
}
}
}
if (x == 0) {
return 0;
}
// First need to check if it is normal or denorm as there is a difference of if (exponent == 0) { // fp32/fp16 is in denormal.
// implicit 1 Then need to adjust the exponent to align with the F8 exponent, /* fp32 denormal is below 2^-127 so it is usually not a concern here, we
// in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng
// to mantissa and truncate. And for RNE, no need to add rng. Then probably
// need to check whether there is carry and adjust exponent and mantissa again
// For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent
// bits
const int f8_bias = (1 << (we - 1)) - 1 + (negative_zero_nan ? 1 : 0);
const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
// f8_exponent is the converted f8 exponent with bias encoding
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
// the difference needs to be adjusted and mantissa shifted
int act_exponent, f8_exponent, exponent_diff;
if (exponent == 0) { // fp32/fp16 is in denormal.
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we
mostly concern fp16 here. In this case, f8 is usually in denormal. But there mostly concern fp16 here. In this case, f8 is usually in denormal. But there
could be exceptions. fp16 denormal has exponent bias 15 while bf8 with NANOO has could be exceptions. fp16 denormal has exponent bias 15 while bf8 with NANOO has
exponent bias 16. It means that there are some numbers in fp16 denormal but they exponent bias 16. It means that there are some numbers in fp16 denormal but they
are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8
(NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */ (NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */
act_exponent = exponent - bias + 1; act_exponent = exponent - bias + 1;
exponent_diff = f8_denormal_act_exponent - act_exponent; // actual exponent is exponent-bias+1 as it is denormal exponent_diff =
} else { // fp32/fp16 is normal with implicit 1 f8_denormal_act_exponent -
act_exponent = exponent - bias; act_exponent; // actual exponent is exponent-bias+1 as it is denormal
if (act_exponent <= f8_denormal_act_exponent) { } else { // fp32/fp16 is normal with implicit 1
/* This is the case where fp32/fp16 is normal but it is in f8 denormal act_exponent = exponent - bias;
range. For example fp8 nanoo mode, denormal exponent is -7, but if the if (act_exponent <= f8_denormal_act_exponent) {
fp32/fp16 actual exponent is -7, it is actually larger due to the implicit 1, /* This is the case where fp32/fp16 is normal but it is in f8 denormal
Therefore it needs to be adjust to -6 and mantissa shift right by 1. range. For example fp8 nanoo mode, denormal exponent is -7, but if the
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */ fp32/fp16 actual exponent is -7, it is actually larger due to the implicit 1,
exponent_diff = f8_denormal_act_exponent - act_exponent; Therefore it needs to be adjust to -6 and mantissa shift right by 1.
} else { // both fp32/fp16 and f8 are in normal range So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
exponent_diff = 0; // exponent_diff=0 does not mean there is no difference exponent_diff = f8_denormal_act_exponent - act_exponent;
// for this case, } else { // both fp32/fp16 and f8 are in normal range
// act_exponent could be larger. Just that it does not need shift mantissa exponent_diff = 0; // exponent_diff=0 does not mean there is no
} // difference for this case, act_exponent could be
mantissa += (1 << mfmt); // Add the implicit 1 into mantissa // larger. Just that it does not need shift mantissa
} }
mantissa += (1 << mfmt); // Add the implicit 1 into mantissa
}
bool midpoint = (mantissa & ((1 << (mfmt - wm + exponent_diff)) - 1)) == bool midpoint = (mantissa & ((1 << (mfmt - wm + exponent_diff)) - 1)) ==
static_cast<uint32_t>(1 << (mfmt - wm + exponent_diff - 1)); static_cast<uint32_t>(1 << (mfmt - wm + exponent_diff - 1));
/* This part is a bit tricky. The judgment of whether it is a tie needs to be /* This part is a bit tricky. The judgment of whether it is a tie needs to be
done before we shift right as shift right could rip off some residual part done before we shift right as shift right could rip off some residual part
and make something not midpoint look like midpoint. For example, the fp16 and make something not midpoint look like midpoint. For example, the fp16
number 0x1002 (0 00100 0000000010), it is larger than midpoint, but after number 0x1002 (0 00100 0000000010), it is larger than midpoint, but after
shift right by 4 bits, it would look like midpoint. shift right by 4 bits, it would look like midpoint.
*/ */
if (exponent_diff > 0) { if (exponent_diff > 0) {
mantissa >>= exponent_diff; mantissa >>= exponent_diff;
} else if (exponent_diff == -1) { } else if (exponent_diff == -1) {
mantissa <<= -exponent_diff; mantissa <<= -exponent_diff;
}
bool implicit_one = mantissa & (1 << mfmt);
// if there is no implicit 1, it means the f8 is denormal and need to adjust
// to denorm exponent
f8_exponent = (act_exponent + exponent_diff) /*actual f8 exponent*/ +
f8_bias - (implicit_one ? 0 : 1);
// Now we have the exponent and mantissa adjusted
uint32_t drop_mask = (1 << (mfmt - wm)) - 1;
bool odd = mantissa & (1 << (mfmt - wm)); // if the least significant bit
// that is not truncated is 1
mantissa +=
(stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) &
drop_mask;
// Now we deal with overflow
if (f8_exponent == 0) {
if ((1 << mfmt) & mantissa) {
f8_exponent = 1; // denormal overflow to become normal, promote exponent
} }
bool implicit_one = mantissa & (1 << mfmt); } else {
// if there is no implicit 1, it means the f8 is denormal and need to adjust if ((1 << (mfmt + 1)) & mantissa) {
// to denorm exponent mantissa >>= 1;
f8_exponent = (act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1); f8_exponent++;
}
}
// Now we have the exponent and mantissa adjusted mantissa >>= (mfmt - wm);
uint32_t drop_mask = (1 << (mfmt - wm)) - 1;
bool odd = mantissa & (1 << (mfmt - wm)); // if the least significant bit that
// is not truncated is 1
mantissa += (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & drop_mask;
// Now we deal with overflow // above range: quantize to maximum possible float of the same sign
if (f8_exponent == 0) { const int max_exp = (1 << we) - (negative_zero_nan ? 1 : 2);
if ((1 << mfmt) & mantissa) { if (f8_exponent > max_exp) {
f8_exponent = 1; // denormal overflow to become normal, promote exponent if (clip) {
} mantissa = (1 << wm) - 1;
f8_exponent = max_exp;
} else { } else {
if ((1 << (mfmt + 1)) & mantissa) { return signed_inf;
mantissa >>= 1;
f8_exponent++;
}
} }
}
mantissa >>= (mfmt - wm); if (f8_exponent == 0 && mantissa == 0) {
return negative_zero_nan ? 0 : (sign << 7);
// above range: quantize to maximum possible float of the same sign }
const int max_exp = (1 << we) - (negative_zero_nan ? 1 : 2); mantissa &= (1 << wm) - 1;
if (f8_exponent > max_exp) { return (sign << 7) | (f8_exponent << wm) | mantissa;
if (clip) {
mantissa = (1 << wm) - 1;
f8_exponent = max_exp;
} else {
return signed_inf;
}
}
if (f8_exponent == 0 && mantissa == 0) {
return negative_zero_nan ? 0 : (sign << 7);
}
mantissa &= (1 << wm) - 1;
return (sign << 7) | (f8_exponent << wm) | mantissa;
} }
template <int we, int wm, typename T = float, bool negative_zero_nan = true> template <int we, int wm, typename T = float, bool negative_zero_nan = true>
inline HIP_FP8_HOST_DEVICE T from_float8(uint8_t x) inline HIP_FP8_HOST_DEVICE T from_float8(uint8_t x) {
{
#ifdef __HIPCC__ #ifdef __HIPCC__
constexpr bool is_half = std::is_same<T, _Float16>::value; constexpr bool is_half = std::is_same<T, _Float16>::value;
#else #else
constexpr bool is_half = false; constexpr bool is_half = false;
#endif #endif
constexpr bool is_float = std::is_same<T, float>::value; constexpr bool is_float = std::is_same<T, float>::value;
static_assert(is_half || is_float, "only half and float are supported"); static_assert(is_half || is_float, "only half and float are supported");
constexpr int weo = is_half ? 5 : 8; constexpr int weo = is_half ? 5 : 8;
constexpr int wmo = is_half ? 10 : (is_float ? 23 : 7); constexpr int wmo = is_half ? 10 : (is_float ? 23 : 7);
T fInf, fNegInf, fNaN, fNeg0; T fInf, fNegInf, fNaN, fNeg0;
#ifdef __HIPCC__ #ifdef __HIPCC__
if (is_half) { if (is_half) {
const uint16_t ihInf = 0x7C00; const uint16_t ihInf = 0x7C00;
const uint16_t ihNegInf = 0xFC00; const uint16_t ihNegInf = 0xFC00;
const uint16_t ihNaN = 0x7C01; const uint16_t ihNaN = 0x7C01;
const uint16_t ihNeg0 = 0x8000; const uint16_t ihNeg0 = 0x8000;
fInf = reinterpret_cast<const _Float16&>(ihInf); fInf = reinterpret_cast<const _Float16&>(ihInf);
fNegInf = reinterpret_cast<const _Float16&>(ihNegInf); fNegInf = reinterpret_cast<const _Float16&>(ihNegInf);
fNaN = reinterpret_cast<const _Float16&>(ihNaN); fNaN = reinterpret_cast<const _Float16&>(ihNaN);
fNeg0 = reinterpret_cast<const _Float16&>(ihNeg0); fNeg0 = reinterpret_cast<const _Float16&>(ihNeg0);
} else } else
#endif #endif
if (is_float) { if (is_float) {
const uint32_t ifInf = 0x7F800000; const uint32_t ifInf = 0x7F800000;
const uint32_t ifNegInf = 0xFF800000; const uint32_t ifNegInf = 0xFF800000;
const uint32_t ifNaN = 0x7F800001; const uint32_t ifNaN = 0x7F800001;
const uint32_t ifNeg0 = 0x80000000; const uint32_t ifNeg0 = 0x80000000;
fInf = reinterpret_cast<const float&>(ifInf); fInf = reinterpret_cast<const float&>(ifInf);
fNegInf = reinterpret_cast<const float&>(ifNegInf); fNegInf = reinterpret_cast<const float&>(ifNegInf);
fNaN = reinterpret_cast<const float&>(ifNaN); fNaN = reinterpret_cast<const float&>(ifNaN);
fNeg0 = reinterpret_cast<const float&>(ifNeg0); fNeg0 = reinterpret_cast<const float&>(ifNeg0);
} }
if (x == 0) { if (x == 0) {
return 0; return 0;
} }
uint32_t sign = x >> 7; uint32_t sign = x >> 7;
uint32_t mantissa = x & ((1 << wm) - 1); uint32_t mantissa = x & ((1 << wm) - 1);
int exponent = (x & 0x7F) >> wm; int exponent = (x & 0x7F) >> wm;
if (negative_zero_nan) { if (negative_zero_nan) {
if (x == 0x80) { if (x == 0x80) {
return fNaN; return fNaN;
}
} else {
if (x == 0x80) {
return fNeg0;
}
if (exponent == ((1 << we) - 1)) {
return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN;
}
} }
typename std::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type retval; } else {
if (we == 5 && is_half && !negative_zero_nan) { if (x == 0x80) {
retval = x << 8; return fNeg0;
return reinterpret_cast<const T&>(retval);
} }
if (exponent == ((1 << we) - 1)) {
const int exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0); return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN;
// subnormal input
if (exponent == 0) {
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
int sh = 1 + clz(mantissa) - (32 - wm);
mantissa <<= sh;
exponent += 1 - sh;
mantissa &= ((1 << wm) - 1);
}
exponent += exp_low_cutoff - 1;
mantissa <<= wmo - wm;
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
if (exponent <= 0) {
mantissa |= 1 << wmo;
mantissa >>= 1 - exponent;
exponent = 0;
}
if (sizeof(T) == 2) {
retval = (sign << 15) | (exponent << 10) | mantissa;
} else {
retval = (sign << 31) | (exponent << 23) | mantissa;
} }
}
typename std::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type retval;
if (we == 5 && is_half && !negative_zero_nan) {
retval = x << 8;
return reinterpret_cast<const T&>(retval); return reinterpret_cast<const T&>(retval);
}
const int exp_low_cutoff =
(1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0);
// subnormal input
if (exponent == 0) {
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
int sh = 1 + clz(mantissa) - (32 - wm);
mantissa <<= sh;
exponent += 1 - sh;
mantissa &= ((1 << wm) - 1);
}
exponent += exp_low_cutoff - 1;
mantissa <<= wmo - wm;
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
if (exponent <= 0) {
mantissa |= 1 << wmo;
mantissa >>= 1 - exponent;
exponent = 0;
}
if (sizeof(T) == 2) {
retval = (sign << 15) | (exponent << 10) | mantissa;
} else {
retval = (sign << 31) | (exponent << 23) | mantissa;
}
return reinterpret_cast<const T&>(retval);
} }
} // namespace hip_fp8_impl } // namespace hip_fp8_impl

View File

@ -9,566 +9,567 @@
#include "../../../attention/dtype_float32.cuh" #include "../../../attention/dtype_float32.cuh"
#include "../../../attention/dtype_bfloat16.cuh" #include "../../../attention/dtype_bfloat16.cuh"
namespace vllm namespace vllm {
{
#ifdef USE_ROCM #ifdef USE_ROCM
namespace fp8 { namespace fp8 {
#ifdef ENABLE_FP8 #ifdef ENABLE_FP8
template <typename Tout, typename Tin> template <typename Tout, typename Tin>
__inline__ __device__ Tout vec_conversion(const Tin& x) __inline__ __device__ Tout vec_conversion(const Tin& x) {
{ return x;
return x;
} }
template <typename Tout, typename Tin> template <typename Tout, typename Tin>
__inline__ __device__ Tout scaled_vec_conversion(const Tin& x, const float scale) __inline__ __device__ Tout scaled_vec_conversion(const Tin& x,
{ const float scale) {
return x; return x;
} }
// fp8 -> half // fp8 -> half
template <> template <>
__inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>(const uint8_t& a) __inline__ __device__ uint16_t
{ vec_conversion<uint16_t, uint8_t>(const uint8_t& a) {
hip_fp8 f8{a, hip_fp8::from_bits()}; hip_fp8 f8{a, hip_fp8::from_bits()};
__half_raw res; __half_raw res;
res.data = static_cast<float>(f8); res.data = static_cast<float>(f8);
return res.x; return res.x;
} }
// fp8x2 -> half2 // fp8x2 -> half2
template <> template <>
__inline__ __device__ uint32_t vec_conversion<uint32_t, uint16_t>(const uint16_t& a) __inline__ __device__ uint32_t
{ vec_conversion<uint32_t, uint16_t>(const uint16_t& a) {
#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) #if defined(__HIP__MI300__) && \
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
union { const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
__half2_raw h2r; union {
uint32_t ui32; __half2_raw h2r;
} tmp; uint32_t ui32;
tmp.h2r.x.data = f2[0]; } tmp;
tmp.h2r.y.data = f2[1]; tmp.h2r.x.data = f2[0];
return tmp.ui32; tmp.h2r.y.data = f2[1];
#else return tmp.ui32;
union { #else
uint16_t u16[2]; union {
uint32_t u32; uint16_t u16[2];
} tmp; uint32_t u32;
} tmp;
tmp.u16[0] = vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a)); tmp.u16[0] = vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a));
tmp.u16[1] = vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a >> 8U)); tmp.u16[1] = vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a >> 8U));
return tmp.u32; return tmp.u32;
#endif #endif
} }
// fp8x4 -> half2x2 // fp8x4 -> half2x2
template <> template <>
__inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a) __inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a) {
{ union {
union { uint2 u32x2;
uint2 u32x2; uint32_t u32[2];
uint32_t u32[2]; } tmp;
} tmp; tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a);
tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a); tmp.u32[1] = vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U));
tmp.u32[1] = vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U)); return tmp.u32x2;
return tmp.u32x2;
} }
// fp8x8 -> half2x4 // fp8x8 -> half2x4
template <> template <>
__inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a) __inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a) {
{ union {
union { uint4 u64x2;
uint4 u64x2; uint2 u64[2];
uint2 u64[2]; } tmp;
} tmp; tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x);
tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x); tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y);
tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y); return tmp.u64x2;
return tmp.u64x2;
} }
using __nv_bfloat16 = __hip_bfloat16; using __nv_bfloat16 = __hip_bfloat16;
// fp8 -> __nv_bfloat16 // fp8 -> __nv_bfloat16
template <> template <>
__inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) __inline__ __device__ __nv_bfloat16
{ vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) {
hip_fp8 f8{a, hip_fp8::from_bits()}; hip_fp8 f8{a, hip_fp8::from_bits()};
float f{f8}; float f{f8};
return __float2bfloat16(f); return __float2bfloat16(f);
} }
using __nv_bfloat162 = __hip_bfloat162; using __nv_bfloat162 = __hip_bfloat162;
// fp8x2 -> __nv_bfloat162 // fp8x2 -> __nv_bfloat162
template <> template <>
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a) __inline__ __device__ __nv_bfloat162
{ vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a) {
__nv_bfloat162 res; __nv_bfloat162 res;
res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a); res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a);
res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U)); res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U));
return res; return res;
} }
// fp8x4 -> bf16_4_t // fp8x4 -> bf16_4_t
template <> template <>
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a) __inline__ __device__ bf16_4_t
{ vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a) {
bf16_4_t res; bf16_4_t res;
res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a); res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a);
res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U)); res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U));
return res; return res;
} }
// fp8x8 -> bf16_8_t // fp8x8 -> bf16_8_t
template <> template <>
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a) __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a) {
{ bf16_4_t tmp1, tmp2;
bf16_4_t tmp1, tmp2; tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x);
tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x); tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y);
tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y); bf16_8_t res;
bf16_8_t res; res.x = tmp1.x;
res.x = tmp1.x; res.y = tmp1.y;
res.y = tmp1.y; res.z = tmp2.x;
res.z = tmp2.x; res.w = tmp2.y;
res.w = tmp2.y; return res;
return res;
} }
// fp8 -> float // fp8 -> float
template <> template <>
__inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a) __inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a) {
{ hip_fp8 fp8{a, hip_fp8::from_bits()};
hip_fp8 fp8{a, hip_fp8::from_bits()}; return static_cast<float>(fp8);
return static_cast<float>(fp8);
} }
// fp8x2 -> float2 // fp8x2 -> float2
template <> template <>
__inline__ __device__ float2 vec_conversion<float2, uint16_t>(const uint16_t& a) __inline__ __device__ float2
{ vec_conversion<float2, uint16_t>(const uint16_t& a) {
#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) #if defined(__HIP__MI300__) && \
float2 res; defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); float2 res;
res.x = f2[0]; const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
res.y = f2[1]; res.x = f2[0];
return res; res.y = f2[1];
#else return res;
float2 res; #else
res.x = vec_conversion<float, uint8_t>(static_cast<uint8_t>(a)); float2 res;
res.y = vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U)); res.x = vec_conversion<float, uint8_t>(static_cast<uint8_t>(a));
return res; res.y = vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U));
#endif return res;
#endif
} }
// fp8x4 -> float4 // fp8x4 -> float4
template <> template <>
__inline__ __device__ Float4_ vec_conversion<Float4_, uint32_t>(const uint32_t& a) __inline__ __device__ Float4_
{ vec_conversion<Float4_, uint32_t>(const uint32_t& a) {
Float4_ res; Float4_ res;
res.x = vec_conversion<float2, uint16_t>((uint16_t)a); res.x = vec_conversion<float2, uint16_t>((uint16_t)a);
res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U)); res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U));
return res; return res;
} }
// fp8x8 -> float8 // fp8x8 -> float8
template <> template <>
__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a) __inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a) {
{ Float4_ tmp1, tmp2;
Float4_ tmp1, tmp2; tmp1 = vec_conversion<Float4_, uint32_t>(a.x);
tmp1 = vec_conversion<Float4_, uint32_t>(a.x); tmp2 = vec_conversion<Float4_, uint32_t>(a.y);
tmp2 = vec_conversion<Float4_, uint32_t>(a.y); Float8_ res;
Float8_ res; res.x = tmp1.x;
res.x = tmp1.x; res.y = tmp1.y;
res.y = tmp1.y; res.z = tmp2.x;
res.z = tmp2.x; res.w = tmp2.y;
res.w = tmp2.y; return res;
return res;
} }
// half -> fp8 // half -> fp8
template <> template <>
__inline__ __device__ uint8_t vec_conversion<uint8_t, uint16_t>(const uint16_t& a) __inline__ __device__ uint8_t
{ vec_conversion<uint8_t, uint16_t>(const uint16_t& a) {
__half_raw tmp; __half_raw tmp;
tmp.x = a; tmp.x = a;
hip_fp8 f8{static_cast<float>(tmp.data)}; hip_fp8 f8{static_cast<float>(tmp.data)};
return f8.data; return f8.data;
} }
// bf16 -> fp8 // bf16 -> fp8
template <> template <>
__inline__ __device__ uint8_t vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a) __inline__ __device__ uint8_t
{ vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a) {
hip_fp8 res{__bfloat162float(a)}; hip_fp8 res{__bfloat162float(a)};
return res.data; return res.data;
} }
// float -> fp8 // float -> fp8
template <> template <>
__inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a) __inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a) {
{ hip_fp8 f8(a);
hip_fp8 f8(a); return f8.data;
return f8.data;
} }
// fp8x4 -> float4 // fp8x4 -> float4
template <> template <>
__inline__ __device__ float4 vec_conversion<float4, uint32_t>(const uint32_t& a) __inline__ __device__ float4
{ vec_conversion<float4, uint32_t>(const uint32_t& a) {
Float4_ tmp = vec_conversion<Float4_, uint32_t>(a); Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
return res; return res;
} }
// float2 -> half2 // float2 -> half2
template <> template <>
__inline__ __device__ uint32_t vec_conversion<uint32_t, float2>(const float2& a) __inline__ __device__ uint32_t
{ vec_conversion<uint32_t, float2>(const float2& a) {
union { union {
half2 float16; half2 float16;
uint32_t uint32; uint32_t uint32;
}; };
float16 = __float22half2_rn(a); float16 = __float22half2_rn(a);
return uint32; return uint32;
} }
// Float4 -> half2x2 // Float4 -> half2x2
template <> template <>
__inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a) __inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a) {
{ uint2 b;
uint2 b; float2 val;
float2 val; val.x = a.x.x;
val.x = a.x.x; val.y = a.x.y;
val.y = a.x.y; b.x = vec_conversion<uint32_t, float2>(val);
b.x = vec_conversion<uint32_t, float2>(val);
val.x = a.y.x; val.x = a.y.x;
val.y = a.y.y; val.y = a.y.y;
b.y = vec_conversion<uint32_t, float2>(val); b.y = vec_conversion<uint32_t, float2>(val);
return b; return b;
} }
// Float4 -> float4 // Float4 -> float4
template <> template <>
__inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a) __inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a) {
{ float4 b;
float4 b; b.x = a.x.x;
b.x = a.x.x; b.y = a.x.y;
b.y = a.x.y; b.z = a.y.x;
b.z = a.y.x; b.w = a.y.y;
b.w = a.y.y; return b;
return b;
} }
// Float8 -> half2x4 // Float8 -> half2x4
template <> template <>
__inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a) __inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a) {
{ uint4 b;
uint4 b; b.x = vec_conversion<uint32_t, float2>(a.x);
b.x = vec_conversion<uint32_t, float2>(a.x); b.y = vec_conversion<uint32_t, float2>(a.y);
b.y = vec_conversion<uint32_t, float2>(a.y); b.z = vec_conversion<uint32_t, float2>(a.z);
b.z = vec_conversion<uint32_t, float2>(a.z); b.w = vec_conversion<uint32_t, float2>(a.w);
b.w = vec_conversion<uint32_t, float2>(a.w); return b;
return b;
} }
// float2 -> bfloat162 // float2 -> bfloat162
template <> template <>
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(const float2& a) __inline__ __device__ __nv_bfloat162
{ vec_conversion<__nv_bfloat162, float2>(const float2& a) {
__nv_bfloat162 b = __float22bfloat162_rn(a); __nv_bfloat162 b = __float22bfloat162_rn(a);
return b; return b;
} }
// Float4 -> bfloat162x2 // Float4 -> bfloat162x2
template <> template <>
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, Float4_>(const Float4_& a) __inline__ __device__ bf16_4_t
{ vec_conversion<bf16_4_t, Float4_>(const Float4_& a) {
bf16_4_t b; bf16_4_t b;
b.x = __float22bfloat162_rn(a.x); b.x = __float22bfloat162_rn(a.x);
b.y = __float22bfloat162_rn(a.y); b.y = __float22bfloat162_rn(a.y);
return b; return b;
} }
// Float8 -> bfloat162x4 // Float8 -> bfloat162x4
template <> template <>
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(const Float8_& a) __inline__ __device__ bf16_8_t
{ vec_conversion<bf16_8_t, Float8_>(const Float8_& a) {
bf16_8_t b; bf16_8_t b;
b.x = __float22bfloat162_rn(a.x); b.x = __float22bfloat162_rn(a.x);
b.y = __float22bfloat162_rn(a.y); b.y = __float22bfloat162_rn(a.y);
b.z = __float22bfloat162_rn(a.z); b.z = __float22bfloat162_rn(a.z);
b.w = __float22bfloat162_rn(a.w); b.w = __float22bfloat162_rn(a.w);
return b; return b;
} }
/* Scaled and vectorized conversions, for data exchange between high and low
precision domains
/* Scaled and vectorized conversions, for data exchange between high and low precision domains Convention of the scale in API, e.g: FP8_data = Quantization(
High_Precision_data / scale ) s.t. Quantize(HP / scale) => FP8 Dequant(FP8) *
Convention of the scale in API, e.g: FP8_data = Quantization( High_Precision_data / scale ) scale => HP
s.t.
Quantize(HP / scale) => FP8
Dequant(FP8) * scale => HP
*/ */
// fp8 -> half // fp8 -> half
template <> template <>
__inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, const float scale) __inline__ __device__ uint16_t
{ scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, const float scale) {
hip_fp8 f8{a, hip_fp8::from_bits()}; hip_fp8 f8{a, hip_fp8::from_bits()};
__half_raw res; __half_raw res;
res.data = static_cast<float>(f8) * scale; res.data = static_cast<float>(f8) * scale;
return res.x; return res.x;
} }
// fp8x2 -> half2 // fp8x2 -> half2
template <> template <>
__inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>(const uint16_t& a, const float scale) __inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>(
{ const uint16_t& a, const float scale) {
#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) #if defined(__HIP__MI300__) && \
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
union { const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
__half2_raw h2r; union {
uint32_t ui32; __half2_raw h2r;
} tmp; uint32_t ui32;
tmp.h2r.x.data = f2[0] * scale; } tmp;
tmp.h2r.y.data = f2[1] * scale; tmp.h2r.x.data = f2[0] * scale;
return tmp.ui32; tmp.h2r.y.data = f2[1] * scale;
#else return tmp.ui32;
union { #else
uint16_t u16[2]; union {
uint32_t u32; uint16_t u16[2];
} tmp; uint32_t u32;
} tmp;
tmp.u16[0] = scaled_vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a), scale); tmp.u16[0] =
tmp.u16[1] = scaled_vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a >> 8U), scale); scaled_vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a), scale);
return tmp.u32; tmp.u16[1] = scaled_vec_conversion<uint16_t, uint8_t>(
#endif static_cast<uint8_t>(a >> 8U), scale);
return tmp.u32;
#endif
} }
// fp8x4 -> half2x2 // fp8x4 -> half2x2
template <> template <>
__inline__ __device__ uint2 scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, const float scale) __inline__ __device__ uint2
{ scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, const float scale) {
union { union {
uint2 u32x2; uint2 u32x2;
uint32_t u32[2]; uint32_t u32[2];
} tmp; } tmp;
tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale); tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale);
tmp.u32[1] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale); tmp.u32[1] =
return tmp.u32x2; scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale);
return tmp.u32x2;
} }
// fp8x8 -> half2x4 // fp8x8 -> half2x4
template <> template <>
__inline__ __device__ uint4 scaled_vec_conversion<uint4, uint2>(const uint2& a, const float scale) __inline__ __device__ uint4
{ scaled_vec_conversion<uint4, uint2>(const uint2& a, const float scale) {
union { union {
uint4 u64x2; uint4 u64x2;
uint2 u64[2]; uint2 u64[2];
} tmp; } tmp;
tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale); tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale);
tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale); tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale);
return tmp.u64x2; return tmp.u64x2;
} }
using __nv_bfloat16 = __hip_bfloat16; using __nv_bfloat16 = __hip_bfloat16;
// fp8 -> __nv_bfloat16 // fp8 -> __nv_bfloat16
template <> template <>
__inline__ __device__ __nv_bfloat16 scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, const float scale) __inline__ __device__ __nv_bfloat16
{ scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a,
hip_fp8 f8{a, hip_fp8::from_bits()}; const float scale) {
float f{f8}; hip_fp8 f8{a, hip_fp8::from_bits()};
return __float2bfloat16(f * scale); float f{f8};
return __float2bfloat16(f * scale);
} }
using __nv_bfloat162 = __hip_bfloat162; using __nv_bfloat162 = __hip_bfloat162;
// fp8x2 -> __nv_bfloat162 // fp8x2 -> __nv_bfloat162
template <> template <>
__inline__ __device__ __nv_bfloat162 scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a, const float scale) __inline__ __device__ __nv_bfloat162
{ scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a,
__nv_bfloat162 res; const float scale) {
res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale); __nv_bfloat162 res;
res.y = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale); res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale);
return res; res.y =
scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale);
return res;
} }
// fp8x4 -> bf16_4_t // fp8x4 -> bf16_4_t
template <> template <>
__inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a, const float scale) __inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(
{ const uint32_t& a, const float scale) {
bf16_4_t res; bf16_4_t res;
res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale); res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale);
res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), scale); res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U),
return res; scale);
return res;
} }
// fp8x8 -> bf16_8_t // fp8x8 -> bf16_8_t
template <> template <>
__inline__ __device__ bf16_8_t scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, const float scale) __inline__ __device__ bf16_8_t
{ scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, const float scale) {
bf16_4_t tmp1, tmp2; bf16_4_t tmp1, tmp2;
tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale); tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale);
tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale); tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale);
bf16_8_t res; bf16_8_t res;
res.x = tmp1.x; res.x = tmp1.x;
res.y = tmp1.y; res.y = tmp1.y;
res.z = tmp2.x; res.z = tmp2.x;
res.w = tmp2.y; res.w = tmp2.y;
return res; return res;
} }
// fp8 -> float // fp8 -> float
template <> template <>
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>(const uint8_t& a, const float scale) __inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
{ const uint8_t& a, const float scale) {
hip_fp8 fp8{a, hip_fp8::from_bits()}; hip_fp8 fp8{a, hip_fp8::from_bits()};
return static_cast<float>(fp8) * scale; return static_cast<float>(fp8) * scale;
} }
// fp8x2 -> float2 // fp8x2 -> float2
template <> template <>
__inline__ __device__ float2 scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, const float scale) __inline__ __device__ float2
{ scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, const float scale) {
#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) #if defined(__HIP__MI300__) && \
float2 res; defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); float2 res;
res.x = f2[0] * scale; const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
res.y = f2[1] * scale; res.x = f2[0] * scale;
return res; res.y = f2[1] * scale;
#else return res;
float2 res; #else
res.x = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a), scale); float2 res;
res.y = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U), scale); res.x = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a), scale);
return res; res.y = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U),
#endif scale);
return res;
#endif
} }
// fp8x4 -> float4 // fp8x4 -> float4
template <> template <>
__inline__ __device__ Float4_ scaled_vec_conversion<Float4_, uint32_t>(const uint32_t& a, const float scale) __inline__ __device__ Float4_
{ scaled_vec_conversion<Float4_, uint32_t>(const uint32_t& a, const float scale) {
Float4_ res; Float4_ res;
res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale); res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale);
res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale); res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale);
return res; return res;
} }
// fp8x8 -> float8 // fp8x8 -> float8
template <> template <>
__inline__ __device__ Float8_ scaled_vec_conversion<Float8_, uint2>(const uint2& a, const float scale) __inline__ __device__ Float8_
{ scaled_vec_conversion<Float8_, uint2>(const uint2& a, const float scale) {
Float4_ tmp1, tmp2; Float4_ tmp1, tmp2;
tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale); tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale);
tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale); tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale);
Float8_ res; Float8_ res;
res.x = tmp1.x; res.x = tmp1.x;
res.y = tmp1.y; res.y = tmp1.y;
res.z = tmp2.x; res.z = tmp2.x;
res.w = tmp2.y; res.w = tmp2.y;
return res; return res;
} }
/* Quantize(HP / scale) => FP8 */ /* Quantize(HP / scale) => FP8 */
// TODO(Hai): vectorized to add // TODO(Hai): vectorized to add
// half -> fp8 // half -> fp8
template <> template <>
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, const float scale) __inline__ __device__ uint8_t
{ scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, const float scale) {
__half_raw tmp; __half_raw tmp;
tmp.x = a; tmp.x = a;
hip_fp8 f8{static_cast<float>(tmp.data)/scale}; hip_fp8 f8{static_cast<float>(tmp.data) / scale};
return f8.data; return f8.data;
} }
// bf16 -> fp8 // bf16 -> fp8
template <> template <>
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a, const float scale) __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
{ const __nv_bfloat16& a, const float scale) {
hip_fp8 res{__bfloat162float(a)/scale}; hip_fp8 res{__bfloat162float(a) / scale};
return res.data; return res.data;
} }
// float -> fp8 // float -> fp8
template <> template <>
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, float>(const float& a, const float scale) __inline__ __device__ uint8_t
{ scaled_vec_conversion<uint8_t, float>(const float& a, const float scale) {
hip_fp8 f8(a/scale); hip_fp8 f8(a / scale);
return f8.data; return f8.data;
} }
// fp8x4 -> float4 // fp8x4 -> float4
template <> template <>
__inline__ __device__ float4 scaled_vec_conversion<float4, uint32_t>(const uint32_t& a, const float scale) __inline__ __device__ float4
{ scaled_vec_conversion<float4, uint32_t>(const uint32_t& a, const float scale) {
Float4_ tmp = scaled_vec_conversion<Float4_, uint32_t>(a, scale); Float4_ tmp = scaled_vec_conversion<Float4_, uint32_t>(a, scale);
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
return res; return res;
} }
#endif // ENABLE_FP8 #endif // ENABLE_FP8
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt> template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
__inline__ __device__ Tout convert(const Tin &x) { __inline__ __device__ Tout convert(const Tin& x) {
#ifdef ENABLE_FP8 #ifdef ENABLE_FP8
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
return vec_conversion<Tout, Tin>(x); return vec_conversion<Tout, Tin>(x);
} }
#endif #endif
assert(false); assert(false);
} }
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt> template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
__inline__ __device__ Tout scaled_convert(const Tin &x, const float scale) { __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
#ifdef ENABLE_FP8 #ifdef ENABLE_FP8
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
return scaled_vec_conversion<Tout, Tin>(x, scale); return scaled_vec_conversion<Tout, Tin>(x, scale);
} }
#endif #endif
assert(false); assert(false);
} }
// The following macro is used to dispatch the conversion function based on the // The following macro is used to dispatch the conversion function based on
// data type of the key and value cache. The FN is a macro that calls a function // the data type of the key and value cache. The FN is a macro that calls a
// with template<typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>. // function with template<typename scalar_t, typename cache_t,
#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \ // Fp8KVCacheDataType kv_dt>.
if (KV_DTYPE == "auto") { \ #define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
if (SRC_DTYPE == at::ScalarType::Float) { \ if (KV_DTYPE == "auto") { \
FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \
} else { \
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else { \
if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \
if (SRC_DTYPE == at::ScalarType::Float) { \ if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \ } else if (SRC_DTYPE == at::ScalarType::Half) { \
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \
} else { \ } else { \
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \ TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
} \ } \
} else { \ } else { \
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \ if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \
} \ if (SRC_DTYPE == at::ScalarType::Float) { \
} FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else { \
TORCH_CHECK(false, \
"Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else { \
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
} \
}
} // fp8 } // namespace fp8
#endif // USE_ROCM #endif // USE_ROCM
} // namespace vllm } // namespace vllm

View File

@ -10,17 +10,20 @@
namespace vllm { namespace vllm {
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) { __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
float old; float old;
old = (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) : old = (value >= 0)
__uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value))); ? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
: __uint_as_float(
atomicMin((unsigned int*)addr, __float_as_uint(value)));
return old; return old;
} }
#define FP8_E4M3_MAX std::numeric_limits<c10::Float8_e4m3fn>::max() #define FP8_E4M3_MAX std::numeric_limits<c10::Float8_e4m3fn>::max()
template<typename scalar_t> template <typename scalar_t>
__device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(const scalar_t val, const float scale) { __device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(
const scalar_t val, const float scale) {
float x = static_cast<float>(val) / scale; float x = static_cast<float>(val) / scale;
float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX)); float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
return static_cast<c10::Float8_e4m3fn>(r); return static_cast<c10::Float8_e4m3fn>(r);
@ -32,11 +35,10 @@ __device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(const scalar
// So to get the right answer, *scale needs to be initialized to // So to get the right answer, *scale needs to be initialized to
// a value <= 0.0 and we need to wait for all thread blocks to // a value <= 0.0 and we need to wait for all thread blocks to
// finish before consuming *scale. // finish before consuming *scale.
template<typename scalar_t> template <typename scalar_t>
__global__ void segmented_max_reduction( __global__ void segmented_max_reduction(float* __restrict__ scale,
float* __restrict__ scale, const scalar_t* __restrict__ input,
const scalar_t* __restrict__ input, int64_t num_elems) {
int64_t num_elems) {
__shared__ float cache[1024]; __shared__ float cache[1024];
int i = blockDim.x * blockIdx.x + threadIdx.x; int i = blockDim.x * blockIdx.x + threadIdx.x;
@ -56,7 +58,7 @@ __global__ void segmented_max_reduction(
int ib = blockDim.x / 2; int ib = blockDim.x / 2;
while (ib != 0) { while (ib != 0) {
if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) { if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) {
cache[threadIdx.x] = cache[threadIdx.x + ib]; cache[threadIdx.x] = cache[threadIdx.x + ib];
} }
__syncthreads(); __syncthreads();
ib /= 2; ib /= 2;
@ -64,16 +66,16 @@ __global__ void segmented_max_reduction(
// Finally, since cache[0] contains the maximum for this thread block, // Finally, since cache[0] contains the maximum for this thread block,
// atomically write the max to the target location // atomically write the max to the target location
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
atomicMaxFloat(scale, cache[0] / std::numeric_limits<c10::Float8_e4m3fn>::max()); atomicMaxFloat(scale,
cache[0] / std::numeric_limits<c10::Float8_e4m3fn>::max());
} }
} }
template<typename scalar_t> template <typename scalar_t>
__global__ void scaled_fp8_quant_kernel( __global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out,
c10::Float8_e4m3fn* __restrict__ out, const scalar_t* __restrict__ input,
const scalar_t* __restrict__ input, const float* __restrict__ scale,
const float* __restrict__ scale, int64_t num_elems) {
int64_t num_elems) {
int i = blockDim.x * blockIdx.x + threadIdx.x; int i = blockDim.x * blockIdx.x + threadIdx.x;
while (i < num_elems) { while (i < num_elems) {
out[i] = scaled_fp8_conversion(input[i], *scale); out[i] = scaled_fp8_conversion(input[i], *scale);
@ -81,12 +83,11 @@ __global__ void scaled_fp8_quant_kernel(
} }
} }
} // namespace vllm } // namespace vllm
void static_scaled_fp8_quant( void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
torch::Tensor& out, // [..., d] torch::Tensor& input, // [..., d]
torch::Tensor& input, // [..., d] torch::Tensor& scale) // [1]
torch::Tensor& scale) // [1]
{ {
int64_t num_tokens = input.numel() / input.size(-1); int64_t num_tokens = input.numel() / input.size(-1);
int64_t num_elems = input.numel(); int64_t num_elems = input.numel();
@ -95,21 +96,16 @@ void static_scaled_fp8_quant(
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
"scaled_fp8_quant_kernel", vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
[&] { out.data_ptr<c10::Float8_e4m3fn>(), input.data_ptr<scalar_t>(),
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>( scale.data_ptr<float>(), num_elems);
out.data_ptr<c10::Float8_e4m3fn>(),
input.data_ptr<scalar_t>(),
scale.data_ptr<float>(),
num_elems);
}); });
} }
void dynamic_scaled_fp8_quant( void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
torch::Tensor& out, // [..., d] torch::Tensor& input, // [..., d]
torch::Tensor& input, // [..., d] torch::Tensor& scale) // [1]
torch::Tensor& scale) // [1]
{ {
int64_t num_tokens = input.numel() / input.size(-1); int64_t num_tokens = input.numel() / input.size(-1);
int64_t num_elems = input.numel(); int64_t num_elems = input.numel();
@ -118,18 +114,11 @@ void dynamic_scaled_fp8_quant(
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
"scaled_fp8_quant_kernel", vllm::segmented_max_reduction<scalar_t><<<grid, block, 0, stream>>>(
[&] { scale.data_ptr<float>(), input.data_ptr<scalar_t>(), num_elems);
vllm::segmented_max_reduction<scalar_t><<<grid, block, 0, stream>>>( vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
scale.data_ptr<float>(), out.data_ptr<c10::Float8_e4m3fn>(), input.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(), scale.data_ptr<float>(), num_elems);
num_elems);
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<c10::Float8_e4m3fn>(),
input.data_ptr<scalar_t>(),
scale.data_ptr<float>(),
num_elems);
}); });
} }

View File

@ -10,9 +10,9 @@ namespace vllm {
#ifndef USE_ROCM #ifndef USE_ROCM
namespace fp8 { namespace fp8 {
#ifdef ENABLE_FP8 #ifdef ENABLE_FP8
#if 0 // Disable the following code to reduce the binary size. #if 0 // Disable the following code to reduce the binary size.
template <typename Tout, typename Tin> template <typename Tout, typename Tin>
__inline__ __device__ Tout __inline__ __device__ Tout
vec_conversion(const Tin &x, const __nv_fp8_interpretation_t fp8_type) { vec_conversion(const Tin &x, const __nv_fp8_interpretation_t fp8_type) {
@ -177,13 +177,13 @@ __inline__ __device__ uint8_t vec_conversion<uint8_t, uint16_t>(
template <> template <>
__inline__ __device__ uint8_t vec_conversion<uint8_t, __nv_bfloat16>( __inline__ __device__ uint8_t vec_conversion<uint8_t, __nv_bfloat16>(
const __nv_bfloat16 &a, const __nv_fp8_interpretation_t fp8_type) { const __nv_bfloat16 &a, const __nv_fp8_interpretation_t fp8_type) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false); assert(false);
#else #else
__nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8( __nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8(
__nv_bfloat16_raw(a), __NV_SATFINITE, fp8_type); __nv_bfloat16_raw(a), __NV_SATFINITE, fp8_type);
return (uint8_t)res; return (uint8_t)res;
#endif #endif
} }
// float -> fp8 // float -> fp8
@ -276,7 +276,7 @@ __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(
from_float(b, a); from_float(b, a);
return b; return b;
} }
#endif #endif
/* Scaled and vectorized conversions, for data exchange between high and low /* Scaled and vectorized conversions, for data exchange between high and low
precision domains Convention of the scale in API, e.g: FP8_data = precision domains Convention of the scale in API, e.g: FP8_data =
@ -286,14 +286,14 @@ __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(
template <typename Tout, typename Tin> template <typename Tout, typename Tin>
__inline__ __device__ Tout scaled_vec_conversion( __inline__ __device__ Tout scaled_vec_conversion(
const Tin &x, const float scale, const __nv_fp8_interpretation_t fp8_type) { const Tin& x, const float scale, const __nv_fp8_interpretation_t fp8_type) {
return x; return x;
} }
// fp8 -> half // fp8 -> half
template <> template <>
__inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, uint8_t>( __inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, uint8_t>(
const uint8_t &a, const float scale, const uint8_t& a, const float scale,
const __nv_fp8_interpretation_t fp8_type) { const __nv_fp8_interpretation_t fp8_type) {
__half_raw tmp = __nv_cvt_fp8_to_halfraw(a, fp8_type); __half_raw tmp = __nv_cvt_fp8_to_halfraw(a, fp8_type);
return float_to_half(half_to_float(tmp.x) * scale); return float_to_half(half_to_float(tmp.x) * scale);
@ -302,7 +302,7 @@ __inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, uint8_t>(
// fp8x2 -> half2 // fp8x2 -> half2
template <> template <>
__inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>( __inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>(
const uint16_t &a, const float scale, const uint16_t& a, const float scale,
const __nv_fp8_interpretation_t fp8_type) { const __nv_fp8_interpretation_t fp8_type) {
union { union {
uint16_t u16[2]; uint16_t u16[2];
@ -317,7 +317,7 @@ __inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>(
// fp8x4 -> half2x2 // fp8x4 -> half2x2
template <> template <>
__inline__ __device__ uint2 scaled_vec_conversion<uint2, uint32_t>( __inline__ __device__ uint2 scaled_vec_conversion<uint2, uint32_t>(
const uint32_t &a, const float scale, const uint32_t& a, const float scale,
const __nv_fp8_interpretation_t fp8_type) { const __nv_fp8_interpretation_t fp8_type) {
union { union {
uint2 u32x2; uint2 u32x2;
@ -333,7 +333,7 @@ __inline__ __device__ uint2 scaled_vec_conversion<uint2, uint32_t>(
// fp8x8 -> half2x4 // fp8x8 -> half2x4
template <> template <>
__inline__ __device__ uint4 __inline__ __device__ uint4
scaled_vec_conversion<uint4, uint2>(const uint2 &a, const float scale, scaled_vec_conversion<uint4, uint2>(const uint2& a, const float scale,
const __nv_fp8_interpretation_t fp8_type) { const __nv_fp8_interpretation_t fp8_type) {
union { union {
uint4 u64x2; uint4 u64x2;
@ -348,7 +348,7 @@ scaled_vec_conversion<uint4, uint2>(const uint2 &a, const float scale,
template <> template <>
__inline__ __device__ __nv_bfloat16 __inline__ __device__ __nv_bfloat16
scaled_vec_conversion<__nv_bfloat16, uint8_t>( scaled_vec_conversion<__nv_bfloat16, uint8_t>(
const uint8_t &a, const float scale, const uint8_t& a, const float scale,
const __nv_fp8_interpretation_t fp8_type) { const __nv_fp8_interpretation_t fp8_type) {
// Note there is no direct convert function from fp8 to bf16. // Note there is no direct convert function from fp8 to bf16.
// fp8 -> half // fp8 -> half
@ -362,7 +362,7 @@ scaled_vec_conversion<__nv_bfloat16, uint8_t>(
template <> template <>
__inline__ __device__ __nv_bfloat162 __inline__ __device__ __nv_bfloat162
scaled_vec_conversion<__nv_bfloat162, uint16_t>( scaled_vec_conversion<__nv_bfloat162, uint16_t>(
const uint16_t &a, const float scale, const uint16_t& a, const float scale,
const __nv_fp8_interpretation_t fp8_type) { const __nv_fp8_interpretation_t fp8_type) {
__nv_bfloat162 res; __nv_bfloat162 res;
res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale, res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale,
@ -375,7 +375,7 @@ scaled_vec_conversion<__nv_bfloat162, uint16_t>(
// fp8x4 -> bf16_4_t // fp8x4 -> bf16_4_t
template <> template <>
__inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>( __inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(
const uint32_t &a, const float scale, const uint32_t& a, const float scale,
const __nv_fp8_interpretation_t fp8_type) { const __nv_fp8_interpretation_t fp8_type) {
bf16_4_t res; bf16_4_t res;
res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale, res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale,
@ -388,7 +388,7 @@ __inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(
// fp8x8 -> bf16_8_t // fp8x8 -> bf16_8_t
template <> template <>
__inline__ __device__ bf16_8_t scaled_vec_conversion<bf16_8_t, uint2>( __inline__ __device__ bf16_8_t scaled_vec_conversion<bf16_8_t, uint2>(
const uint2 &a, const float scale, const uint2& a, const float scale,
const __nv_fp8_interpretation_t fp8_type) { const __nv_fp8_interpretation_t fp8_type) {
bf16_4_t tmp1, tmp2; bf16_4_t tmp1, tmp2;
tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale, fp8_type); tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale, fp8_type);
@ -404,9 +404,8 @@ __inline__ __device__ bf16_8_t scaled_vec_conversion<bf16_8_t, uint2>(
// fp8 -> float // fp8 -> float
template <> template <>
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>( __inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
const uint8_t &a, const float scale, const uint8_t& a, const float scale,
const __nv_fp8_interpretation_t fp8_type) { const __nv_fp8_interpretation_t fp8_type) {
// fp8 -> half // fp8 -> half
__half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type); __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
uint16_t tmp = res.x; uint16_t tmp = res.x;
@ -418,7 +417,7 @@ __inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
// fp8x2 -> float2 // fp8x2 -> float2
template <> template <>
__inline__ __device__ float2 scaled_vec_conversion<float2, uint16_t>( __inline__ __device__ float2 scaled_vec_conversion<float2, uint16_t>(
const uint16_t &a, const float scale, const uint16_t& a, const float scale,
const __nv_fp8_interpretation_t fp8_type) { const __nv_fp8_interpretation_t fp8_type) {
// fp8x2 -> half2 // fp8x2 -> half2
uint32_t tmp = scaled_vec_conversion<uint32_t, uint16_t>(a, scale, fp8_type); uint32_t tmp = scaled_vec_conversion<uint32_t, uint16_t>(a, scale, fp8_type);
@ -429,7 +428,7 @@ __inline__ __device__ float2 scaled_vec_conversion<float2, uint16_t>(
// fp8x4 -> float4 // fp8x4 -> float4
template <> template <>
__inline__ __device__ Float4_ scaled_vec_conversion<Float4_, uint32_t>( __inline__ __device__ Float4_ scaled_vec_conversion<Float4_, uint32_t>(
const uint32_t &a, const float scale, const uint32_t& a, const float scale,
const __nv_fp8_interpretation_t fp8_type) { const __nv_fp8_interpretation_t fp8_type) {
Float4_ res; Float4_ res;
res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale, fp8_type); res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale, fp8_type);
@ -441,7 +440,7 @@ __inline__ __device__ Float4_ scaled_vec_conversion<Float4_, uint32_t>(
// fp8x8 -> float8 // fp8x8 -> float8
template <> template <>
__inline__ __device__ Float8_ scaled_vec_conversion<Float8_, uint2>( __inline__ __device__ Float8_ scaled_vec_conversion<Float8_, uint2>(
const uint2 &a, const float scale, const uint2& a, const float scale,
const __nv_fp8_interpretation_t fp8_type) { const __nv_fp8_interpretation_t fp8_type) {
Float4_ tmp1, tmp2; Float4_ tmp1, tmp2;
tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale, fp8_type); tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale, fp8_type);
@ -457,7 +456,7 @@ __inline__ __device__ Float8_ scaled_vec_conversion<Float8_, uint2>(
// half -> fp8 // half -> fp8
template <> template <>
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, uint16_t>( __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, uint16_t>(
const uint16_t &a, const float scale, const uint16_t& a, const float scale,
const __nv_fp8_interpretation_t fp8_type) { const __nv_fp8_interpretation_t fp8_type) {
__nv_fp8_storage_t res = __nv_fp8_storage_t res =
__nv_cvt_float_to_fp8(half_to_float(a) / scale, __NV_SATFINITE, fp8_type); __nv_cvt_float_to_fp8(half_to_float(a) / scale, __NV_SATFINITE, fp8_type);
@ -467,21 +466,21 @@ __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, uint16_t>(
// bf16 -> fp8 // bf16 -> fp8
template <> template <>
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>( __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
const __nv_bfloat16 &a, const float scale, const __nv_bfloat16& a, const float scale,
const __nv_fp8_interpretation_t fp8_type) { const __nv_fp8_interpretation_t fp8_type) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false); assert(false);
#else #else
__nv_fp8_storage_t res = __nv_cvt_float_to_fp8(__bfloat162float(a) / scale, __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(__bfloat162float(a) / scale,
__NV_SATFINITE, fp8_type); __NV_SATFINITE, fp8_type);
return (uint8_t)res; return (uint8_t)res;
#endif #endif
} }
// float -> fp8 // float -> fp8
template <> template <>
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, float>( __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, float>(
const float &a, const float scale, const float& a, const float scale,
const __nv_fp8_interpretation_t fp8_type) { const __nv_fp8_interpretation_t fp8_type) {
__nv_fp8_storage_t res = __nv_fp8_storage_t res =
__nv_cvt_float_to_fp8(a / scale, __NV_SATFINITE, fp8_type); __nv_cvt_float_to_fp8(a / scale, __NV_SATFINITE, fp8_type);
@ -491,78 +490,81 @@ __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, float>(
// fp8x4 -> float4 // fp8x4 -> float4
template <> template <>
__inline__ __device__ float4 scaled_vec_conversion<float4, uint32_t>( __inline__ __device__ float4 scaled_vec_conversion<float4, uint32_t>(
const uint32_t &a, const float scale, const uint32_t& a, const float scale,
const __nv_fp8_interpretation_t fp8_type) { const __nv_fp8_interpretation_t fp8_type) {
Float4_ tmp = scaled_vec_conversion<Float4_, uint32_t>(a, scale, fp8_type); Float4_ tmp = scaled_vec_conversion<Float4_, uint32_t>(a, scale, fp8_type);
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
return res; return res;
} }
#endif // ENABLE_FP8 #endif // ENABLE_FP8
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt> template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
__inline__ __device__ Tout convert(const Tin &x) { __inline__ __device__ Tout convert(const Tin& x) {
#if 0 // Disable the following code to reduce the binary size. #if 0 // Disable the following code to reduce the binary size.
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
return vec_conversion<Tout, Tin>(x, __NV_E4M3); return vec_conversion<Tout, Tin>(x, __NV_E4M3);
} else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) { } else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) {
return vec_conversion<Tout, Tin>(x, __NV_E5M2); return vec_conversion<Tout, Tin>(x, __NV_E5M2);
} }
#endif #endif
assert(false); assert(false);
} }
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt> template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
__inline__ __device__ Tout scaled_convert(const Tin &x, const float scale) { __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
#ifdef ENABLE_FP8 #ifdef ENABLE_FP8
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
return scaled_vec_conversion<Tout, Tin>(x, scale, __NV_E4M3); return scaled_vec_conversion<Tout, Tin>(x, scale, __NV_E4M3);
} else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) { } else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) {
return scaled_vec_conversion<Tout, Tin>(x, scale, __NV_E5M2); return scaled_vec_conversion<Tout, Tin>(x, scale, __NV_E5M2);
} }
#endif #endif
assert(false); assert(false);
} }
// The following macro is used to dispatch the conversion function based on the // The following macro is used to dispatch the conversion function based on
// data type of the key and value cache. The FN is a macro that calls a function // the data type of the key and value cache. The FN is a macro that calls a
// with template<typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>. // function with template<typename scalar_t, typename cache_t,
#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \ // Fp8KVCacheDataType kv_dt>.
if (KV_DTYPE == "auto") { \ #define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
if (SRC_DTYPE == at::ScalarType::Float) { \ if (KV_DTYPE == "auto") { \
FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \
} else { \
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else { \
if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \
if (SRC_DTYPE == at::ScalarType::Float) { \ if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \ } else if (SRC_DTYPE == at::ScalarType::Half) { \
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \
} else { \
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else if (KV_DTYPE == "fp8_e5m2") { \
if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
} else { \ } else { \
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \ TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
} \ } \
} else { \ } else { \
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \ if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \
} \ if (SRC_DTYPE == at::ScalarType::Float) { \
} FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else { \
TORCH_CHECK(false, \
"Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else if (KV_DTYPE == "fp8_e5m2") { \
if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
} else { \
TORCH_CHECK(false, \
"Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else { \
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
} \
}
} // namespace fp8 } // namespace fp8
#endif // not USE_ROCM #endif // not USE_ROCM
} // namespace vllm } // namespace vllm

View File

@ -9,54 +9,54 @@ namespace vllm {
namespace gptq { namespace gptq {
// atomicAdd for half types, to support CC < 7.x // atomicAdd for half types, to support CC < 7.x
__device__ __forceinline__ void atomicAdd_half(half* address, half val) __device__ __forceinline__ void atomicAdd_half(half* address, half val) {
{ unsigned int* address_as_ui =
unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2)); (unsigned int*)((char*)address - ((size_t)address & 2));
unsigned int old = *address_as_ui; unsigned int old = *address_as_ui;
unsigned int assumed; unsigned int assumed;
do do {
{ assumed = old;
assumed = old; __half_raw hsum;
__half_raw hsum; hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); half tmpres = __hadd(hsum, val);
half tmpres = __hadd(hsum, val); hsum = __half_raw(tmpres);
hsum = __half_raw(tmpres); old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16)
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x; : (old & 0xffff0000) | hsum.x;
old = atomicCAS(address_as_ui, assumed, old); old = atomicCAS(address_as_ui, assumed, old);
} } while (assumed != old);
while (assumed != old);
} }
// atomicAdd for half2 types // atomicAdd for half2 types
__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) __device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) {
{ unsigned int* address_as_ui = (unsigned int*)address;
unsigned int* address_as_ui = (unsigned int*)address; unsigned int old = *address_as_ui;
unsigned int old = *address_as_ui; unsigned int assumed;
unsigned int assumed; do {
do assumed = old;
{ half2 old_val = *((half2*)&old);
assumed = old; half2 new_val = __hadd2(old_val, val);
half2 old_val = *((half2*)&old); old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val));
half2 new_val = __hadd2(old_val, val); } while (assumed != old);
old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val));
}
while (assumed != old);
} }
// //
#if defined(__CUDA_ARCH__) || defined(USE_ROCM) #if defined(__CUDA_ARCH__) || defined(USE_ROCM)
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM) #if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); } __device__ __forceinline__ void atomicAdd(half* address, half val) {
atomicAdd_half(address, val);
}
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM) #if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); } __device__ __forceinline__ void atomicAdd(half2* address, half2 val) {
#endif atomicAdd_half2(address, val);
}
#endif
#endif #endif
#endif #endif
} // namespace gptq } // namespace gptq

View File

@ -1,5 +1,6 @@
/* /*
Adapted from https://github.com/turboderp/exllamav2 and https://github.com/turboderp/exllama Adapted from https://github.com/turboderp/exllamav2 and
https://github.com/turboderp/exllama
*/ */
#ifndef _matrix_view_cuh #ifndef _matrix_view_cuh
@ -13,260 +14,280 @@ Adapted from https://github.com/turboderp/exllamav2 and https://github.com/turbo
namespace vllm { namespace vllm {
namespace gptq { namespace gptq {
class MatrixView_half class MatrixView_half {
{ public:
public: const half* data;
const half* data; const int height;
const int height; const int width;
const int width;
__device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width) __device__ __forceinline__ MatrixView_half(const half* data, const int height,
: data(data), height(height), width(width) const int width)
{ } : data(data), height(height), width(width) {}
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } __device__ __forceinline__ half item(int row, int column) const {
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } return data[row * width + column];
__device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); } }
__device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; } __device__ __forceinline__ half2 item_half2(int row, int column) const {
return ((half2*)data)[(row * width + column) / 2];
}
__device__ __forceinline__ half2 item_half2half2(int row, int column) const {
return __half2half2(data[row * width + column]);
}
__device__ __forceinline__ const half* item_ptr(int row, int column) const {
return &data[row * width + column];
}
__device__ __forceinline__ void item4(half (&items)[4], int row, int column) const __device__ __forceinline__ void item4(half (&items)[4], int row,
{ int column) const {
half2* ptr = (half2*) item_ptr(row, column); half2* ptr = (half2*)item_ptr(row, column);
half2 i01 = ptr[0]; half2 i01 = ptr[0];
half2 i23 = ptr[1]; half2 i23 = ptr[1];
items[0] = __low2half(i01); items[0] = __low2half(i01);
items[1] = __high2half(i01); items[1] = __high2half(i01);
items[2] = __low2half(i23); items[2] = __low2half(i23);
items[3] = __high2half(i23); items[3] = __high2half(i23);
} }
__device__ __forceinline__ void item4_f(float (&items)[4], int row, int column) const __device__ __forceinline__ void item4_f(float (&items)[4], int row,
{ int column) const {
half2* ptr = (half2*)item_ptr(row, column); half2* ptr = (half2*)item_ptr(row, column);
half2 i01 = ptr[0]; half2 i01 = ptr[0];
half2 i23 = ptr[1]; half2 i23 = ptr[1];
items[0] = __half2float(__low2half(i01)); items[0] = __half2float(__low2half(i01));
items[1] = __half2float(__high2half(i01)); items[1] = __half2float(__high2half(i01));
items[2] = __half2float(__low2half(i23)); items[2] = __half2float(__low2half(i23));
items[3] = __half2float(__high2half(i23)); items[3] = __half2float(__high2half(i23));
} }
__device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, int column) const __device__ __forceinline__ void item4_h2(half2 (&items)[4], int row,
{ int column) const {
half2* ptr = (half2*)item_ptr(row, column); half2* ptr = (half2*)item_ptr(row, column);
half2 i01 = ptr[0]; half2 i01 = ptr[0];
half2 i23 = ptr[1]; half2 i23 = ptr[1];
items[0] = __half2half2(__low2half(i01)); items[0] = __half2half2(__low2half(i01));
items[1] = __half2half2(__high2half(i01)); items[1] = __half2half2(__high2half(i01));
items[2] = __half2half2(__low2half(i23)); items[2] = __half2half2(__low2half(i23));
items[3] = __half2half2(__high2half(i23)); items[3] = __half2half2(__high2half(i23));
} }
}; };
class MatrixView_half_rw class MatrixView_half_rw {
{ public:
public: half* data;
half* data; const int height;
const int height; const int width;
const int width;
__device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width) __device__ __forceinline__ MatrixView_half_rw(half* data, const int height,
: data(data), height(height), width(width) const int width)
{ } : data(data), height(height), width(width) {}
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } __device__ __forceinline__ half item(int row, int column) const {
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } return data[row * width + column];
__device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; } }
__device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; } __device__ __forceinline__ half2 item_half2(int row, int column) const {
__device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; } return ((half2*)data)[(row * width + column) / 2];
}
__device__ __forceinline__ half* item_ptr(int row, int column) {
return &data[row * width + column];
}
__device__ __forceinline__ void set(int row, int column, half value) {
data[row * width + column] = value;
}
__device__ __forceinline__ void set_half2(int row, int column, half2 value) {
((half2*)data)[(row * width + column) / 2] = value;
}
__device__ __forceinline__ void set4(int row, int column, half v0, half v1, half v2, half v3) __device__ __forceinline__ void set4(int row, int column, half v0, half v1,
{ half v2, half v3) {
half2 v01 = __halves2half2(v0, v1); half2 v01 = __halves2half2(v0, v1);
half2 v23 = __halves2half2(v2, v3); half2 v23 = __halves2half2(v2, v3);
half2* ptr = (half2*) item_ptr(row, column); half2* ptr = (half2*)item_ptr(row, column);
ptr[0] = v01; ptr[0] = v01;
ptr[1] = v23; ptr[1] = v23;
} }
}; };
class MatrixView_q4_row class MatrixView_q4_row {
{ public:
public: const uint32_t* data;
const uint32_t* data; const int height;
const int height; const int width;
const int width;
__device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width) __device__ __forceinline__ MatrixView_q4_row(const uint32_t* data,
: data(data), height(height), width(width) const int height,
{ } const int width)
: data(data), height(height), width(width) {}
__device__ __forceinline__ int item(int row, int column) const __device__ __forceinline__ int item(int row, int column) const {
{ int shift = (column & 0x07) * 4;
int shift = (column & 0x07) * 4; return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
return (data[row * width / 8 + column / 8] >> shift) & 0x0f; }
}
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const __device__ __forceinline__ void item2(int (&items)[2], int row,
{ int column) const {
int shift = (column & 0x07) * 4; int shift = (column & 0x07) * 4;
uint32_t d = data[row * width / 8 + column / 8] >> shift; uint32_t d = data[row * width / 8 + column / 8] >> shift;
items[0] = d & 0x0f; items[0] = d & 0x0f;
items[1] = (d >> 4) & 0x0f; items[1] = (d >> 4) & 0x0f;
} }
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const __device__ __forceinline__ void item4(int (&items)[4], int row,
{ int column) const {
int shift = (column & 0x07) * 4; int shift = (column & 0x07) * 4;
uint32_t d = data[row * width / 8 + column / 8] >> shift; uint32_t d = data[row * width / 8 + column / 8] >> shift;
items[0] = d & 0x0f; items[0] = d & 0x0f;
items[1] = (d >> 4) & 0x0f; items[1] = (d >> 4) & 0x0f;
items[2] = (d >> 8) & 0x0f; items[2] = (d >> 8) & 0x0f;
items[3] = (d >> 12) & 0x0f; items[3] = (d >> 12) & 0x0f;
} }
}; };
class MatrixView_q4_column class MatrixView_q4_column {
{ public:
public: const uint32_t* data;
const uint32_t* data; const int height;
const int height; const int width;
const int width;
__device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width) __device__ __forceinline__ MatrixView_q4_column(const uint32_t* data,
: data(data), height(height), width(width) const int height,
{ } const int width)
: data(data), height(height), width(width) {}
__device__ __forceinline__ int item(int row, int column) const __device__ __forceinline__ int item(int row, int column) const {
{ int shift = (row & 0x07) * 4;
int shift = (row & 0x07) * 4; return (data[row / 8 * width + column] >> shift) & 0x0f;
return (data[row / 8 * width + column] >> shift) & 0x0f; }
}
__device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; } __device__ __forceinline__ uint32_t item_uint32_t(int row, int column) {
__device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; } return data[row / 8 * width + column];
}
__device__ __forceinline__ const uint32_t* item_uint32_ptr(int row,
int column) {
return &data[row / 8 * width + column];
}
}; };
class MatrixView_q2_row class MatrixView_q2_row {
{ public:
public: const uint32_t* data;
const uint32_t* data; const int height;
const int height; const int width;
const int width;
__device__ __forceinline__ MatrixView_q2_row(const uint32_t* data, const int height, const int width) __device__ __forceinline__ MatrixView_q2_row(const uint32_t* data,
: data(data), height(height), width(width) const int height,
{ } const int width)
: data(data), height(height), width(width) {}
__device__ __forceinline__ int item(int row, int column) const __device__ __forceinline__ int item(int row, int column) const {
{ int shift = (column & 0x0f) * 2;
int shift = (column & 0x0f) * 2; return (data[row * width / 16 + column / 16] >> shift) & 0x03;
return (data[row * width / 16 + column / 16] >> shift) & 0x03; }
}
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const __device__ __forceinline__ void item2(int (&items)[2], int row,
{ int column) const {
int shift = (column & 0x0f) * 2; int shift = (column & 0x0f) * 2;
uint32_t d = data[row * width / 16 + column / 16] >> shift; uint32_t d = data[row * width / 16 + column / 16] >> shift;
items[0] = d & 0x03; items[0] = d & 0x03;
items[1] = (d >> 2) & 0x03; items[1] = (d >> 2) & 0x03;
} }
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const __device__ __forceinline__ void item4(int (&items)[4], int row,
{ int column) const {
int shift = (column & 0x0f) * 2; int shift = (column & 0x0f) * 2;
uint32_t d = data[row * width / 16 + column / 16] >> shift; uint32_t d = data[row * width / 16 + column / 16] >> shift;
items[0] = d & 0x03; items[0] = d & 0x03;
items[1] = (d >> 2) & 0x03; items[1] = (d >> 2) & 0x03;
items[2] = (d >> 4) & 0x03; items[2] = (d >> 4) & 0x03;
items[3] = (d >> 6) & 0x03; items[3] = (d >> 6) & 0x03;
} }
}; };
class MatrixView_q3_row class MatrixView_q3_row {
{ public:
public: const uint32_t* data;
const uint32_t* data; const int height;
const int height; const int width;
const int width;
__device__ __forceinline__ MatrixView_q3_row(const uint32_t* data, const int height, const int width) __device__ __forceinline__ MatrixView_q3_row(const uint32_t* data,
: data(data), height(height), width(width) const int height,
{ } const int width)
: data(data), height(height), width(width) {}
__device__ __forceinline__ int item(int row, int column) const __device__ __forceinline__ int item(int row, int column) const {
{ int z_w = column * 3 / 32;
int z_w = column * 3 / 32; int z_mod = column & 0x1f;
int z_mod = column & 0x1f;
if (z_mod == 10) { if (z_mod == 10) {
return (data[row * width * 3 / 32 + z_w] >> 30) | ((data[row * width * 3 / 32 + (z_w + 1)] << 2) & 0x4); return (data[row * width * 3 / 32 + z_w] >> 30) |
} else if (z_mod == 21) { ((data[row * width * 3 / 32 + (z_w + 1)] << 2) & 0x4);
return (data[row * width * 3 / 32 + z_w] >> 31) | ((data[row * width * 3 / 32 + (z_w + 1)] << 1) & 0x6); } else if (z_mod == 21) {
} else if (z_mod < 10) { return (data[row * width * 3 / 32 + z_w] >> 31) |
return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3)) & 0x07; ((data[row * width * 3 / 32 + (z_w + 1)] << 1) & 0x6);
} else if (z_mod < 21) { } else if (z_mod < 10) {
return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 32)) & 0x07; return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3)) & 0x07;
} else { } else if (z_mod < 21) {
return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 64)) & 0x07; return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 32)) & 0x07;
} } else {
return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 64)) & 0x07;
} }
}
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const __device__ __forceinline__ void item4(int (&items)[4], int row,
{ int column) const {
int shift = (column & 0x1f); int shift = (column & 0x1f);
uint32_t d; uint32_t d;
if (shift <= 4) { if (shift <= 4) {
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3); d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3);
} else if (shift == 8) { } else if (shift == 8) {
d = (data[row * width / 32 * 3 + column * 3 / 32] >> 24) | ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0x0f) << 8); d = (data[row * width / 32 * 3 + column * 3 / 32] >> 24) |
} else if (shift <= 16) { ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0x0f) << 8);
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 32); } else if (shift <= 16) {
} else if (shift == 20) { d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 32);
d = (data[row * width / 32 * 3 + column * 3 / 32] >> 28) | ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0xff) << 4); } else if (shift == 20) {
} else { d = (data[row * width / 32 * 3 + column * 3 / 32] >> 28) |
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 64); ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0xff) << 4);
} } else {
items[0] = d & 0x07; d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 64);
items[1] = (d >> 3) & 0x07;
items[2] = (d >> 6) & 0x07;
items[3] = (d >> 9) & 0x07;
} }
items[0] = d & 0x07;
items[1] = (d >> 3) & 0x07;
items[2] = (d >> 6) & 0x07;
items[3] = (d >> 9) & 0x07;
}
}; };
class MatrixView_q8_row class MatrixView_q8_row {
{ public:
public: const uint32_t* data;
const uint32_t* data; const int height;
const int height; const int width;
const int width;
__device__ __forceinline__ MatrixView_q8_row(const uint32_t* data, const int height, const int width) __device__ __forceinline__ MatrixView_q8_row(const uint32_t* data,
: data(data), height(height), width(width) const int height,
{ } const int width)
: data(data), height(height), width(width) {}
__device__ __forceinline__ int item(int row, int column) const __device__ __forceinline__ int item(int row, int column) const {
{ int shift = (column & 0x03) * 8;
int shift = (column & 0x03) * 8; return (data[row * width / 4 + column / 4] >> shift) & 0xff;
return (data[row * width / 4 + column / 4] >> shift) & 0xff; }
}
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const __device__ __forceinline__ void item2(int (&items)[2], int row,
{ int column) const {
int shift = (column & 0x03) * 8; int shift = (column & 0x03) * 8;
uint32_t d = data[row * width / 4 + column / 4] >> shift; uint32_t d = data[row * width / 4 + column / 4] >> shift;
items[0] = d & 0xff; items[0] = d & 0xff;
items[1] = (d >> 8) & 0xff; items[1] = (d >> 8) & 0xff;
} }
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const __device__ __forceinline__ void item4(int (&items)[4], int row,
{ int column) const {
int shift = (column & 0x03) * 2; int shift = (column & 0x03) * 2;
uint32_t d = data[row * width / 4 + column / 4] >> shift; uint32_t d = data[row * width / 4 + column / 4] >> shift;
items[0] = d & 0xff; items[0] = d & 0xff;
items[1] = (d >> 8) & 0xff; items[1] = (d >> 8) & 0xff;
items[2] = (d >> 16) & 0xff; items[2] = (d >> 16) & 0xff;
items[3] = (d >> 24) & 0xff; items[3] = (d >> 24) & 0xff;
} }
}; };
} // namespace gptq } // namespace gptq

File diff suppressed because it is too large Load Diff

View File

@ -14,71 +14,60 @@ namespace gptq {
// //
// ffddbb99 77553311 eeccaa88 66442200 // ffddbb99 77553311 eeccaa88 66442200
__forceinline__ __device__ void shuffle_2bit_16 __forceinline__ __device__ void shuffle_2bit_16(uint32_t* q, int stride) {
( uint32_t qa = q[0];
uint32_t* q, uint32_t qb = 0;
int stride
)
{
uint32_t qa = q[0];
uint32_t qb = 0;
#pragma unroll #pragma unroll
for (int i = 0; i < 8; i++) for (int i = 0; i < 8; i++) {
{ uint32_t qa0 = qa & 0x03;
uint32_t qa0 = qa & 0x03; uint32_t qa1 = (qa & 0x0c) >> 2;
uint32_t qa1 = (qa & 0x0c) >> 2; qa >>= 4;
qa >>= 4; qb |= (qa1 << (i * 2 + 16));
qb |= (qa1 << (i * 2 + 16)); qb |= (qa0 << (i * 2));
qb |= (qa0 << (i * 2)); }
} q[0] = qb;
q[0] = qb;
} }
__forceinline__ __device__ void dequant_2bit_16 __forceinline__ __device__ void dequant_2bit_16(const uint32_t q_0,
( half2 (&dq)[8], int stride,
const uint32_t q_0, const uint32_t zero) {
half2 (&dq)[8], const uint32_t c0 = 0x64006400;
int stride, const half y4_ = __float2half_rn(1.0f / 4.0f);
const uint32_t zero const half y16_ = __float2half_rn(1.0f / 16.0f);
) const half y64_ = __float2half_rn(1.0f / 64.0f);
{ const half2 y4 = __halves2half2(y4_, y4_);
const uint32_t c0 = 0x64006400; const half2 y16 = __halves2half2(y16_, y16_);
const half y4_ = __float2half_rn(1.0f / 4.0f); const half2 y64 = __halves2half2(y64_, y64_);
const half y16_ = __float2half_rn(1.0f / 16.0f);
const half y64_ = __float2half_rn(1.0f / 64.0f);
const half2 y4 = __halves2half2(y4_, y4_);
const half2 y16 = __halves2half2(y16_, y16_);
const half2 y64 = __halves2half2(y64_, y64_);
const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero); const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero);
const half z4_ = __hsub(__int2half_rn(-256), __int2half_rn(zero)); const half z4_ = __hsub(__int2half_rn(-256), __int2half_rn(zero));
const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero)); const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero));
const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero)); const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero));
const half2 z1 = __half2half2(z1_.as_half); const half2 z1 = __half2half2(z1_.as_half);
const half2 z4 = __half2half2(z4_); const half2 z4 = __half2half2(z4_);
const half2 z16 = __half2half2(z16_); const half2 z16 = __half2half2(z16_);
const half2 z64 = __half2half2(z64_); const half2 z64 = __half2half2(z64_);
uint32_t qa = q_0; uint32_t qa = q_0;
half2_uint32 q0((qa & 0x00030003) | c0); // half2(q[ 0], q[ 1]) + 1024 half2_uint32 q0((qa & 0x00030003) | c0); // half2(q[ 0], q[ 1]) + 1024
half2_uint32 q1((qa & 0x000c000c) | c0); // half2(q[ 2], q[ 3]) * 4 + 1024 half2_uint32 q1((qa & 0x000c000c) | c0); // half2(q[ 2], q[ 3]) * 4 + 1024
half2_uint32 q2((qa & 0x00300030) | c0); // half2(q[ 4], q[ 5]) * 16 + 1024 half2_uint32 q2((qa & 0x00300030) | c0); // half2(q[ 4], q[ 5]) * 16 + 1024
half2_uint32 q3((qa & 0x00c000c0) | c0); // half2(q[ 6], q[ 7]) * 64 + 1024 half2_uint32 q3((qa & 0x00c000c0) | c0); // half2(q[ 6], q[ 7]) * 64 + 1024
qa >>= 8; qa >>= 8;
half2_uint32 q4((qa & 0x00030003) | c0); // half2(q[ 8], q[ 8]) + 1024 half2_uint32 q4((qa & 0x00030003) | c0); // half2(q[ 8], q[ 8]) + 1024
half2_uint32 q5((qa & 0x000c000c) | c0); // half2(q[10], q[11]) * 4 + 1024 half2_uint32 q5((qa & 0x000c000c) | c0); // half2(q[10], q[11]) * 4 + 1024
half2_uint32 q6((qa & 0x00300030) | c0); // half2(q[12], q[13]) * 16 + 1024 half2_uint32 q6((qa & 0x00300030) | c0); // half2(q[12], q[13]) * 16 + 1024
half2_uint32 q7((qa & 0x00c000c0) | c0); // half2(q[14], q[15]) * 64 + 1024 half2_uint32 q7((qa & 0x00c000c0) | c0); // half2(q[14], q[15]) * 64 + 1024
dq[0] = __hadd2(q0.as_half2, z1); dq[0] = __hadd2(q0.as_half2, z1);
dq[1] = __hfma2(q1.as_half2, y4, z4); dq[1] = __hfma2(q1.as_half2, y4, z4);
dq[2] = __hfma2(q2.as_half2, y16, z16); dq[2] = __hfma2(q2.as_half2, y16, z16);
dq[3] = __hfma2(q3.as_half2, y64, z64); dq[3] = __hfma2(q3.as_half2, y64, z64);
dq[4] = __hadd2(q4.as_half2, z1); dq[4] = __hadd2(q4.as_half2, z1);
dq[5] = __hfma2(q5.as_half2, y4, z4); dq[5] = __hfma2(q5.as_half2, y4, z4);
dq[6] = __hfma2(q6.as_half2, y16, z16); dq[6] = __hfma2(q6.as_half2, y16, z16);
dq[7] = __hfma2(q7.as_half2, y64, z64); dq[7] = __hfma2(q7.as_half2, y64, z64);
} }
} // namespace gptq } // namespace gptq

View File

@ -11,128 +11,136 @@ namespace gptq {
// vjjjhhhf ffdddbbb uiiiggge eecccaaa // vjjjhhhf ffdddbbb uiiiggge eecccaaa
// vtttrrrp ppnnnlll usssqqqo oommmkkk // vtttrrrp ppnnnlll usssqqqo oommmkkk
__forceinline__ __device__ void shuffle_3bit_32 __forceinline__ __device__ void shuffle_3bit_32(uint32_t* q, int stride) {
( uint32_t qa = q[0 * stride];
uint32_t* q, uint32_t qb = q[1 * stride];
int stride uint32_t qc = q[2 * stride];
)
{
uint32_t qa = q[0 * stride];
uint32_t qb = q[1 * stride];
uint32_t qc = q[2 * stride];
// qa: aa999888 77766655 54443332 22111000 // qa: aa999888 77766655 54443332 22111000
// qb: lkkkjjji iihhhggg fffeeedd dcccbbba // qb: lkkkjjji iihhhggg fffeeedd dcccbbba
// qc: vvvuuutt tsssrrrq qqpppooo nnnmmmll // qc: vvvuuutt tsssrrrq qqpppooo nnnmmmll
uint32_t qd = qc >> 26; uint32_t qd = qc >> 26;
qc <<= 4; qc <<= 4;
qc |= qb >> 28; qc |= qb >> 28;
qb <<= 2; qb <<= 2;
qb |= qa >> 30; qb |= qa >> 30;
// qa: ..999888 77766655 54443332 22111000 // qa: ..999888 77766655 54443332 22111000
// qb: ..jjjiii hhhgggff feeedddc ccbbbaaa // qb: ..jjjiii hhhgggff feeedddc ccbbbaaa
// qc: ..tttsss rrrqqqpp pooonnnm mmlllkkk // qc: ..tttsss rrrqqqpp pooonnnm mmlllkkk
// qd: vvvuuu // qd: vvvuuu
uint32_t za = 0; uint32_t za = 0;
uint32_t zb = 0; uint32_t zb = 0;
uint32_t zc = 0; uint32_t zc = 0;
for (int i = 0; i < 5; i++) { uint32_t t0 = qa & 0x07; uint32_t t1 = (qa & 0x38) >> 3; qa >>= 6; za |= (t0 << (i * 3)); za |= (t1 << (i * 3 + 16)); } for (int i = 0; i < 5; i++) {
for (int i = 0; i < 5; i++) { uint32_t t0 = qb & 0x07; uint32_t t1 = (qb & 0x38) >> 3; qb >>= 6; zb |= (t0 << (i * 3)); zb |= (t1 << (i * 3 + 16)); } uint32_t t0 = qa & 0x07;
for (int i = 0; i < 5; i++) { uint32_t t0 = qc & 0x07; uint32_t t1 = (qc & 0x38) >> 3; qc >>= 6; zc |= (t0 << (i * 3)); zc |= (t1 << (i * 3 + 16)); } uint32_t t1 = (qa & 0x38) >> 3;
qa >>= 6;
za |= (t0 << (i * 3));
za |= (t1 << (i * 3 + 16));
}
for (int i = 0; i < 5; i++) {
uint32_t t0 = qb & 0x07;
uint32_t t1 = (qb & 0x38) >> 3;
qb >>= 6;
zb |= (t0 << (i * 3));
zb |= (t1 << (i * 3 + 16));
}
for (int i = 0; i < 5; i++) {
uint32_t t0 = qc & 0x07;
uint32_t t1 = (qc & 0x38) >> 3;
qc >>= 6;
zc |= (t0 << (i * 3));
zc |= (t1 << (i * 3 + 16));
}
// za: 9997775 55333111 8886664 44222000 // za: 9997775 55333111 8886664 44222000
// zb: jjjhhhf ffdddbbb iiiggge eecccaaa // zb: jjjhhhf ffdddbbb iiiggge eecccaaa
// zc: tttrrrp ppnnnlll sssqqqo oommmkkk // zc: tttrrrp ppnnnlll sssqqqo oommmkkk
// qd: vvvuuu // qd: vvvuuu
za |= ((qd & 0x01) >> 0) << 15; za |= ((qd & 0x01) >> 0) << 15;
zb |= ((qd & 0x02) >> 1) << 15; zb |= ((qd & 0x02) >> 1) << 15;
zc |= ((qd & 0x04) >> 2) << 15; zc |= ((qd & 0x04) >> 2) << 15;
za |= ((qd & 0x08) >> 3) << 31; za |= ((qd & 0x08) >> 3) << 31;
zb |= ((qd & 0x10) >> 4) << 31; zb |= ((qd & 0x10) >> 4) << 31;
zc |= ((qd & 0x20) >> 5) << 31; zc |= ((qd & 0x20) >> 5) << 31;
// za: v9997775 55333111 u8886664 44222000 (u, v lsb) // za: v9997775 55333111 u8886664 44222000 (u, v lsb)
// zb: vjjjhhhf ffdddbbb uiiiggge eecccaaa // zb: vjjjhhhf ffdddbbb uiiiggge eecccaaa
// zc: vtttrrrp ppnnnlll usssqqqo oommmkkk // zc: vtttrrrp ppnnnlll usssqqqo oommmkkk
q[0 * stride] = za; q[0 * stride] = za;
q[1 * stride] = zb; q[1 * stride] = zb;
q[2 * stride] = zc; q[2 * stride] = zc;
} }
__forceinline__ __device__ void dequant_3bit_32 __forceinline__ __device__ void dequant_3bit_32(const uint32_t q_0,
( const uint32_t q_1,
const uint32_t q_0, const uint32_t q_2,
const uint32_t q_1, half2 (&dq)[16], int stride,
const uint32_t q_2, const uint32_t zero) {
half2 (&dq)[16], const uint32_t c0 = 0x64006400;
int stride, const half y8_ = __float2half_rn(1.0f / 8.0f);
const uint32_t zero const half y64_ = __float2half_rn(1.0f / 64.0f);
) const half2 y8 = __halves2half2(y8_, y8_);
{ const half2 y64 = __halves2half2(y64_, y64_);
const uint32_t c0 = 0x64006400; const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero);
const half y8_ = __float2half_rn(1.0f / 8.0f); const half z8_ = __hsub(__int2half_rn(-128), __int2half_rn(zero));
const half y64_ = __float2half_rn(1.0f / 64.0f); const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero));
const half2 y8 = __halves2half2(y8_, y8_); const half2 z1 = __halves2half2(z1_.as_half, z1_.as_half);
const half2 y64 = __halves2half2(y64_, y64_); const half2 z8 = __halves2half2(z8_, z8_);
const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero); const half2 z64 = __halves2half2(z64_, z64_);
const half z8_ = __hsub(__int2half_rn(-128), __int2half_rn(zero));
const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero));
const half2 z1 = __halves2half2(z1_.as_half, z1_.as_half);
const half2 z8 = __halves2half2(z8_, z8_);
const half2 z64 = __halves2half2(z64_, z64_);
uint32_t qa = q_0; uint32_t qa = q_0;
uint32_t qb = q_1; uint32_t qb = q_1;
uint32_t qc = q_2; uint32_t qc = q_2;
half2_uint32 q0((qa & 0x00070007) | c0); // half2(q[ 0], q[ 1]) + 1024 half2_uint32 q0((qa & 0x00070007) | c0); // half2(q[ 0], q[ 1]) + 1024
half2_uint32 q1((qa & 0x00380038) | c0); // half2(q[ 2], q[ 3]) * 8 + 1024 half2_uint32 q1((qa & 0x00380038) | c0); // half2(q[ 2], q[ 3]) * 8 + 1024
qa >>= 6; qa >>= 6;
half2_uint32 q2((qa & 0x00070007) | c0); // half2(q[ 4], q[ 5]) + 1024 half2_uint32 q2((qa & 0x00070007) | c0); // half2(q[ 4], q[ 5]) + 1024
half2_uint32 q3((qa & 0x00380038) | c0); // half2(q[ 6], q[ 7]) * 8 + 1024 half2_uint32 q3((qa & 0x00380038) | c0); // half2(q[ 6], q[ 7]) * 8 + 1024
half2_uint32 q4((qa & 0x01c001c0) | c0); // half2(q[ 8], q[ 9]) * 64 + 1024 half2_uint32 q4((qa & 0x01c001c0) | c0); // half2(q[ 8], q[ 9]) * 64 + 1024
qa >>= 9; qa >>= 9;
qa &= 0x00010001; qa &= 0x00010001;
half2_uint32 q5((qb & 0x00070007) | c0); // half2(q[10], q[11]) + 1024 half2_uint32 q5((qb & 0x00070007) | c0); // half2(q[10], q[11]) + 1024
half2_uint32 q6((qb & 0x00380038) | c0); // half2(q[12], q[13]) * 8 + 1024 half2_uint32 q6((qb & 0x00380038) | c0); // half2(q[12], q[13]) * 8 + 1024
qb >>= 6; qb >>= 6;
half2_uint32 q7((qb & 0x00070007) | c0); // half2(q[14], q[15]) + 1024 half2_uint32 q7((qb & 0x00070007) | c0); // half2(q[14], q[15]) + 1024
half2_uint32 q8((qb & 0x00380038) | c0); // half2(q[16], q[17]) * 8 + 1024 half2_uint32 q8((qb & 0x00380038) | c0); // half2(q[16], q[17]) * 8 + 1024
half2_uint32 q9((qb & 0x01c001c0) | c0); // half2(q[18], q[19]) * 64 + 1024 half2_uint32 q9((qb & 0x01c001c0) | c0); // half2(q[18], q[19]) * 64 + 1024
qb >>= 8; qb >>= 8;
qb &= 0x00020002; qb &= 0x00020002;
half2_uint32 q10((qc & 0x00070007) | c0); // half2(q[20], q[21]) + 1024 half2_uint32 q10((qc & 0x00070007) | c0); // half2(q[20], q[21]) + 1024
half2_uint32 q11((qc & 0x00380038) | c0); // half2(q[22], q[23]) * 8 + 1024 half2_uint32 q11((qc & 0x00380038) | c0); // half2(q[22], q[23]) * 8 + 1024
qc >>= 6; qc >>= 6;
half2_uint32 q12((qc & 0x00070007) | c0); // half2(q[24], q[25]) + 1024 half2_uint32 q12((qc & 0x00070007) | c0); // half2(q[24], q[25]) + 1024
half2_uint32 q13((qc & 0x00380038) | c0); // half2(q[26], q[27]) * 8 + 1024 half2_uint32 q13((qc & 0x00380038) | c0); // half2(q[26], q[27]) * 8 + 1024
half2_uint32 q14((qc & 0x01c001c0) | c0); // half2(q[28], q[29]) * 64 + 1024 half2_uint32 q14((qc & 0x01c001c0) | c0); // half2(q[28], q[29]) * 64 + 1024
qc >>= 7; qc >>= 7;
qc &= 0x00040004; qc &= 0x00040004;
half2_uint32 q15((qa | qb | qc) | c0); half2_uint32 q15((qa | qb | qc) | c0);
dq[ 0] = __hadd2( q0.as_half2, z1); dq[0] = __hadd2(q0.as_half2, z1);
dq[ 1] = __hfma2( q1.as_half2, y8, z8); dq[1] = __hfma2(q1.as_half2, y8, z8);
dq[ 2] = __hadd2( q2.as_half2, z1); dq[2] = __hadd2(q2.as_half2, z1);
dq[ 3] = __hfma2( q3.as_half2, y8, z8); dq[3] = __hfma2(q3.as_half2, y8, z8);
dq[ 4] = __hfma2( q4.as_half2, y64, z64); dq[4] = __hfma2(q4.as_half2, y64, z64);
dq[ 5] = __hadd2( q5.as_half2, z1); dq[5] = __hadd2(q5.as_half2, z1);
dq[ 6] = __hfma2( q6.as_half2, y8, z8); dq[6] = __hfma2(q6.as_half2, y8, z8);
dq[ 7] = __hadd2( q7.as_half2, z1); dq[7] = __hadd2(q7.as_half2, z1);
dq[ 8] = __hfma2( q8.as_half2, y8, z8); dq[8] = __hfma2(q8.as_half2, y8, z8);
dq[ 9] = __hfma2( q9.as_half2, y64, z64); dq[9] = __hfma2(q9.as_half2, y64, z64);
dq[10] = __hadd2(q10.as_half2, z1); dq[10] = __hadd2(q10.as_half2, z1);
dq[11] = __hfma2(q11.as_half2, y8, z8); dq[11] = __hfma2(q11.as_half2, y8, z8);
dq[12] = __hadd2(q12.as_half2, z1); dq[12] = __hadd2(q12.as_half2, z1);
dq[13] = __hfma2(q13.as_half2, y8, z8); dq[13] = __hfma2(q13.as_half2, y8, z8);
dq[14] = __hfma2(q14.as_half2, y64, z64); dq[14] = __hfma2(q14.as_half2, y64, z64);
dq[15] = __hadd2(q15.as_half2, z1); dq[15] = __hadd2(q15.as_half2, z1);
} }
} // namespace gptq } // namespace gptq

View File

@ -13,133 +13,112 @@ namespace gptq {
// //
// 77775555 33331111 66664444 22220000 // 77775555 33331111 66664444 22220000
__forceinline__ __device__ void shuffle_4bit_8 __forceinline__ __device__ void shuffle_4bit_8(uint32_t* q, int stride) {
( uint32_t qa = q[0];
uint32_t* q, uint32_t qb = 0;
int stride
)
{
uint32_t qa = q[0];
uint32_t qb = 0;
#pragma unroll #pragma unroll
for (int i = 0; i < 4; i++) for (int i = 0; i < 4; i++) {
{ uint32_t qa0 = qa & 0x0f;
uint32_t qa0 = qa & 0x0f; uint32_t qa1 = (qa & 0xf0) >> 4;
uint32_t qa1 = (qa & 0xf0) >> 4;
qa >>= 8;
qb |= (qa1 << (i * 4 + 16));
qb |= (qa0 << (i * 4));
}
q[0] = qb;
}
__forceinline__ __device__ void dequant_4bit_8
(
const uint32_t q_0,
half2 (&dq)[4],
int stride,
const uint32_t zero
)
{
const uint32_t c0 = 0x64006400;
const half y16_ = __float2half_rn(1.0f / 16.0f);
const half2 y16 = __halves2half2(y16_, y16_);
const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero);
const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero));
const half2 z1 = __half2half2(z1_.as_half);
const half2 z16 = __half2half2(z16_);
uint32_t qa = q_0;
half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1]) + 1024
half2_uint32 q1((qa & 0x00f000f0) | c0); // half2(q[ 2], q[ 3]) * 16 + 1024
qa >>= 8; qa >>= 8;
half2_uint32 q2((qa & 0x000f000f) | c0); // half2(q[ 4], q[ 5]) + 1024 qb |= (qa1 << (i * 4 + 16));
half2_uint32 q3((qa & 0x00f000f0) | c0); // half2(q[ 6], q[ 7]) * 16 + 1024 qb |= (qa0 << (i * 4));
}
dq[0] = __hadd2(q0.as_half2, z1); q[0] = qb;
dq[1] = __hfma2(q1.as_half2, y16, z16);
dq[2] = __hadd2(q2.as_half2, z1);
dq[3] = __hfma2(q3.as_half2, y16, z16);
} }
__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale __forceinline__ __device__ void dequant_4bit_8(const uint32_t q_0,
( half2 (&dq)[4], int stride,
const uint32_t zero, const uint32_t zero) {
const half scale, const uint32_t c0 = 0x64006400;
half2 (&z1z16)[2], const half y16_ = __float2half_rn(1.0f / 16.0f);
half2 (&y1y16)[2] const half2 y16 = __halves2half2(y16_, y16_);
) const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero);
{ const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero));
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero); const half2 z1 = __half2half2(z1_.as_half);
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero)); const half2 z16 = __half2half2(z16_);
half2 scale2 = __half2half2(scale); uint32_t qa = q_0;
half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1]) + 1024
half2_uint32 q1((qa & 0x00f000f0) | c0); // half2(q[ 2], q[ 3]) * 16 + 1024
qa >>= 8;
half2_uint32 q2((qa & 0x000f000f) | c0); // half2(q[ 4], q[ 5]) + 1024
half2_uint32 q3((qa & 0x00f000f0) | c0); // half2(q[ 6], q[ 7]) * 16 + 1024
z1z16[0] = __hmul2(scale2, __half2half2(z1.as_half)); dq[0] = __hadd2(q0.as_half2, z1);
z1z16[1] = __hmul2(scale2, __half2half2(z16)); dq[1] = __hfma2(q1.as_half2, y16, z16);
dq[2] = __hadd2(q2.as_half2, z1);
const half y1 = __float2half_rn(1.0f); dq[3] = __hfma2(q3.as_half2, y16, z16);
const half y16 = __float2half_rn(1.0f / 16.0f);
y1y16[0] = __hmul2(scale2, __half2half2(y1));
y1y16[1] = __hmul2(scale2, __half2half2(y16));
} }
__forceinline__ __device__ void dequant_4bit_8_prep_zero __forceinline__ __device__ void dequant_4bit_8_prep_zero_scale(
( const uint32_t zero, const half scale, half2 (&z1z16)[2],
const uint32_t zero, half2 (&y1y16)[2]) {
half2(&z1z16)[2], half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
half2(&y1y16)[2] half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
)
{
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
z1z16[0] = __half2half2(z1.as_half); half2 scale2 = __half2half2(scale);
z1z16[1] = __half2half2(z16);
const half y1 = __float2half_rn(1.0f); z1z16[0] = __hmul2(scale2, __half2half2(z1.as_half));
const half y16 = __float2half_rn(1.0f / 16.0f); z1z16[1] = __hmul2(scale2, __half2half2(z16));
y1y16[0] = __half2half2(y1); const half y1 = __float2half_rn(1.0f);
y1y16[1] = __half2half2(y16); const half y16 = __float2half_rn(1.0f / 16.0f);
y1y16[0] = __hmul2(scale2, __half2half2(y1));
y1y16[1] = __hmul2(scale2, __half2half2(y16));
} }
__forceinline__ __device__ void dequant_4bit_8_prep_zero(const uint32_t zero,
half2 (&z1z16)[2],
half2 (&y1y16)[2]) {
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
__forceinline__ __device__ void dequant_4bit_8_gptq z1z16[0] = __half2half2(z1.as_half);
( z1z16[1] = __half2half2(z16);
const uint32_t q_0,
half2 (&dq)[4],
half2 (&z1z16)[2],
half2 (&y1y16)[2],
int stride,
bool scaled
)
{
const uint32_t c0 = 0x64006400;
uint32_t qa = q_0; const half y1 = __float2half_rn(1.0f);
half2_uint32 q0((qa & 0x000f000f) | c0); // half2( q[0] + 1024, q[1] + 1024 ) const half y16 = __float2half_rn(1.0f / 16.0f);
half2_uint32 q1((qa & 0x00f000f0) | c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 )
qa >>= 8;
half2_uint32 q2((qa & 0x000f000f) | c0); // half2( q[4] + 1024, q[5] + 1024 )
half2_uint32 q3((qa & 0x00f000f0) | c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 )
if (scaled) y1y16[0] = __half2half2(y1);
{ y1y16[1] = __half2half2(y16);
dq[0] = __hfma2(q0.as_half2, y1y16[0], z1z16[0]); // half2( q[0] * s - z * s, q[1] * s - z * s) }
dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] * s - z * s, q[3] * s - z * s)
dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]); __forceinline__ __device__ void dequant_4bit_8_gptq(const uint32_t q_0,
dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); half2 (&dq)[4],
} half2 (&z1z16)[2],
else half2 (&y1y16)[2],
{ int stride, bool scaled) {
dq[0] = __hadd2(q0.as_half2, z1z16[0]); // half2( q[0] - z, q[1] - z ) const uint32_t c0 = 0x64006400;
dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] - z, q[3] - z )
dq[2] = __hadd2(q2.as_half2, z1z16[0]); // half2( q[4] - z, q[5] - z ) uint32_t qa = q_0;
dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); // half2( q[6] - z, q[7] - z ) half2_uint32 q0((qa & 0x000f000f) |
} c0); // half2( q[0] + 1024, q[1] + 1024 )
half2_uint32 q1((qa & 0x00f000f0) |
c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 )
qa >>= 8;
half2_uint32 q2((qa & 0x000f000f) |
c0); // half2( q[4] + 1024, q[5] + 1024 )
half2_uint32 q3((qa & 0x00f000f0) |
c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 )
if (scaled) {
dq[0] = __hfma2(q0.as_half2, y1y16[0],
z1z16[0]); // half2( q[0] * s - z * s, q[1] * s - z * s)
dq[1] = __hfma2(q1.as_half2, y1y16[1],
z1z16[1]); // half2( q[2] * s - z * s, q[3] * s - z * s)
dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]);
dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]);
} else {
dq[0] = __hadd2(q0.as_half2, z1z16[0]); // half2( q[0] - z, q[1] - z )
dq[1] = __hfma2(q1.as_half2, y1y16[1],
z1z16[1]); // half2( q[2] - z, q[3] - z )
dq[2] = __hadd2(q2.as_half2, z1z16[0]); // half2( q[4] - z, q[5] - z )
dq[3] = __hfma2(q3.as_half2, y1y16[1],
z1z16[1]); // half2( q[6] - z, q[7] - z )
}
} }
} // namespace gptq } // namespace gptq
} // namespace vllm } // namespace vllm

View File

@ -10,28 +10,18 @@ Copied from https://github.com/turboderp/exllamav2
namespace vllm { namespace vllm {
namespace gptq { namespace gptq {
__forceinline__ __device__ void shuffle_8bit_4 __forceinline__ __device__ void shuffle_8bit_4(uint32_t* q, int stride) {}
(
uint32_t* q,
int stride
)
{
}
__forceinline__ __device__ void dequant_8bit_8 __forceinline__ __device__ void dequant_8bit_8(const uint32_t q_0,
( const uint32_t q_1,
const uint32_t q_0, half2 (&dq)[4], int stride,
const uint32_t q_1, const uint32_t zero) {
half2 (&dq)[4], half dqh[8];
int stride, for (int i = 0; i < 4; i++) dqh[i] = dq_ns(exb(q_0, i * 8, 0xff), zero);
const uint32_t zero for (int i = 0; i < 4; i++) dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), zero);
)
{
half dqh[8];
for (int i = 0; i < 4; i++) dqh[i ] = dq_ns(exb(q_0, i * 8, 0xff), zero);
for (int i = 0; i < 4; i++) dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), zero);
for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]); for (int i = 0; i < 4; i++)
dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
} }
} // namespace gptq } // namespace gptq

View File

@ -8,51 +8,47 @@ Copied from https://github.com/turboderp/exllamav2
namespace vllm { namespace vllm {
namespace gptq { namespace gptq {
union half2_uint32 union half2_uint32 {
{ uint32_t as_uint32;
uint32_t as_uint32; half2 as_half2;
half2 as_half2; __device__ half2_uint32(uint32_t val) : as_uint32(val) {}
__device__ half2_uint32(uint32_t val) : as_uint32(val) {} __device__ half2_uint32(half2 val) : as_half2(val) {}
__device__ half2_uint32(half2 val) : as_half2(val) {}
}; };
union half_uint16 union half_uint16 {
{ uint16_t as_uint16;
uint16_t as_uint16; half as_half;
half as_half; __device__ half_uint16(uint16_t val) : as_uint16(val) {}
__device__ half_uint16(uint16_t val) : as_uint16(val) {} __device__ half_uint16(half val) : as_half(val) {}
__device__ half_uint16(half val) : as_half(val) {}
}; };
// Max_scale premultiplied by 1/256 // Max_scale premultiplied by 1/256
__forceinline__ __device__ half dq_scale(const int qs, const half max_scale) __forceinline__ __device__ half dq_scale(const int qs, const half max_scale) {
{ int qs_i = qs + 1;
int qs_i = qs + 1; half qs_h = __int2half_rn(qs_i * qs_i);
half qs_h = __int2half_rn(qs_i * qs_i); qs_h = __hmul(qs_h, max_scale);
qs_h = __hmul(qs_h, max_scale); return qs_h;
return qs_h;
} }
__forceinline__ __device__ half dq(const int q, const int qzero, const half scale) __forceinline__ __device__ half dq(const int q, const int qzero,
{ const half scale) {
return __hmul(__int2half_rn(q - qzero), scale); return __hmul(__int2half_rn(q - qzero), scale);
} }
__forceinline__ __device__ half dq_ns(const int q, const int qzero) __forceinline__ __device__ half dq_ns(const int q, const int qzero) {
{ // return __hsub(__int2half_rn(q), __int2half_rn(qzero));
//return __hsub(__int2half_rn(q), __int2half_rn(qzero)); return __int2half_rn(q - qzero);
return __int2half_rn(q - qzero);
} }
__forceinline__ __device__ int exb(const uint32_t q, const int shift, const int mask) __forceinline__ __device__ int exb(const uint32_t q, const int shift,
{ const int mask) {
return (int)((q >> shift) & mask); return (int)((q >> shift) & mask);
} }
__forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, const int shift, const int mask) __forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0,
{ const int shift, const int mask) {
return (int)(__funnelshift_rc(q0, q1, shift) & mask); return (int)(__funnelshift_rc(q0, q1, shift) & mask);
} }
} // namespace gptq } // namespace gptq

File diff suppressed because it is too large Load Diff

View File

@ -11,22 +11,23 @@
namespace gptq_marlin { namespace gptq_marlin {
// 8 warps are a good choice since every SM has 4 schedulers and having more than 1 warp per // 8 warps are a good choice since every SM has 4 schedulers and having more
// schedule allows some more latency hiding. At the same time, we want relatively few warps to have // than 1 warp per schedule allows some more latency hiding. At the same time,
// many registers per warp and small tiles. // we want relatively few warps to have many registers per warp and small tiles.
static constexpr int default_threads = 256; static constexpr int default_threads = 256;
static constexpr int pipe_stages = 4; // 4 pipeline stages fit into shared memory static constexpr int pipe_stages =
4; // 4 pipeline stages fit into shared memory
static constexpr int min_thread_n = 64; static constexpr int min_thread_n = 64;
static constexpr int min_thread_k = 64; static constexpr int min_thread_k = 64;
static constexpr int tile_size = 16; static constexpr int tile_size = 16;
static constexpr int max_par = 16; static constexpr int max_par = 16;
template <typename T, int n> template <typename T, int n>
struct Vec { struct Vec {
T elems[n]; T elems[n];
__device__ T& operator[](int i) { return elems[i]; } __device__ T& operator[](int i) { return elems[i]; }
}; };
@ -35,30 +36,35 @@ using I4 = Vec<int, 4>;
constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; } constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
// No support for async // No support for async
#else #else
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) { __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
bool pred = true) {
const int BYTES = 16; const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr)); uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile("{\n" asm volatile(
" .reg .pred p;\n" "{\n"
" setp.ne.b32 p, %0, 0;\n" " .reg .pred p;\n"
" @p cp.async.cg.shared.global [%1], [%2], %3;\n" " setp.ne.b32 p, %0, 0;\n"
"}\n" ::"r"((int)pred), " @p cp.async.cg.shared.global [%1], [%2], %3;\n"
"r"(smem), "l"(glob_ptr), "n"(BYTES)); "}\n" ::"r"((int)pred),
"r"(smem), "l"(glob_ptr), "n"(BYTES));
} }
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { __device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
const int BYTES = 16; const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr)); uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile("{\n" asm volatile(
" cp.async.cg.shared.global [%0], [%1], %2;\n" "{\n"
"}\n" ::"r"(smem), " cp.async.cg.shared.global [%0], [%1], %2;\n"
"l"(glob_ptr), "n"(BYTES)); "}\n" ::"r"(smem),
"l"(glob_ptr), "n"(BYTES));
} }
__device__ inline void cp_async_fence() { asm volatile("cp.async.commit_group;\n" ::); } __device__ inline void cp_async_fence() {
asm volatile("cp.async.commit_group;\n" ::);
}
template <int n> template <int n>
__device__ inline void cp_async_wait() { __device__ inline void cp_async_wait() {
@ -67,4 +73,4 @@ __device__ inline void cp_async_wait() {
#endif #endif
} // namespace gptq_marlin } // namespace gptq_marlin

View File

@ -5,58 +5,73 @@
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_bf16.h> #include <cuda_bf16.h>
namespace gptq_marlin { namespace gptq_marlin {
template <typename scalar_t> template <typename scalar_t>
class ScalarType { class ScalarType {};
};
template <> template <>
class ScalarType<half> { class ScalarType<half> {
public: public:
using scalar_t = half; using scalar_t = half;
using scalar_t2 = half2; using scalar_t2 = half2;
// Matrix fragments for tensor core instructions; their precise layout is // Matrix fragments for tensor core instructions; their precise layout is
// documented here: // documented here:
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
using FragA = Vec<half2, 4>; using FragA = Vec<half2, 4>;
using FragB = Vec<half2, 2>; using FragB = Vec<half2, 2>;
using FragC = Vec<float, 4>; using FragC = Vec<float, 4>;
using FragS = Vec<half2, 1>; using FragS = Vec<half2, 1>;
static __device__ float inline num2float(const half x) { return __half2float(x); } static __device__ float inline num2float(const half x) {
return __half2float(x);
}
static __device__ half2 inline num2num2(const half x) { return __half2half2(x); } static __device__ half2 inline num2num2(const half x) {
return __half2half2(x);
}
static __device__ half2 inline nums2num2(const half x1, const half x2) { return __halves2half2(x1, x2); } static __device__ half2 inline nums2num2(const half x1, const half x2) {
return __halves2half2(x1, x2);
}
static __host__ __device__ half inline float2num(const float x) { return __float2half(x); } static __host__ __device__ half inline float2num(const float x) {
return __float2half(x);
}
}; };
template <> template <>
class ScalarType<nv_bfloat16> { class ScalarType<nv_bfloat16> {
public: public:
using scalar_t = nv_bfloat16; using scalar_t = nv_bfloat16;
using scalar_t2 = nv_bfloat162; using scalar_t2 = nv_bfloat162;
using FragA = Vec<nv_bfloat162, 4>; using FragA = Vec<nv_bfloat162, 4>;
using FragB = Vec<nv_bfloat162, 2>; using FragB = Vec<nv_bfloat162, 2>;
using FragC = Vec<float, 4>; using FragC = Vec<float, 4>;
using FragS = Vec<nv_bfloat162, 1>; using FragS = Vec<nv_bfloat162, 1>;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
static __device__ float inline num2float(const nv_bfloat16 x) { return __bfloat162float(x); } static __device__ float inline num2float(const nv_bfloat16 x) {
return __bfloat162float(x);
}
static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) { return __bfloat162bfloat162(x); } static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) {
return __bfloat162bfloat162(x);
}
static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1, const nv_bfloat16 x2) { return __halves2bfloat162(x1, x2); } static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1,
const nv_bfloat16 x2) {
return __halves2bfloat162(x1, x2);
}
static __host__ __device__ nv_bfloat16 inline float2num(const float x) { return __float2bfloat16(x); } static __host__ __device__ nv_bfloat16 inline float2num(const float x) {
return __float2bfloat16(x);
}
#endif #endif
}; };
} } // namespace gptq_marlin
#endif #endif

View File

@ -12,14 +12,14 @@ static constexpr int tile_n_size = tile_k_size * 4;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
template <int const num_threads, int const num_bits, bool const has_perm> template <int const num_threads, int const num_bits, bool const has_perm>
__global__ void __global__ void marlin_repack_kernel(
marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, uint32_t const* __restrict__ b_q_weight_ptr,
uint32_t const *__restrict__ perm_ptr, uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,
uint32_t *__restrict__ out_ptr, int size_k, int size_n) {} int size_k, int size_n) {}
} // namespace gptq_marlin } // namespace gptq_marlin
torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm, torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k, int64_t size_n, int64_t size_k, int64_t size_n,
int64_t num_bits) { int64_t num_bits) {
TORCH_CHECK_NOT_IMPLEMENTED( TORCH_CHECK_NOT_IMPLEMENTED(
@ -30,10 +30,10 @@ torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
#else #else
template <int const num_threads, int const num_bits, bool const has_perm> template <int const num_threads, int const num_bits, bool const has_perm>
__global__ void __global__ void marlin_repack_kernel(
marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, uint32_t const* __restrict__ b_q_weight_ptr,
uint32_t const *__restrict__ perm_ptr, uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,
uint32_t *__restrict__ out_ptr, int size_k, int size_n) { int size_k, int size_n) {
constexpr int pack_factor = 32 / num_bits; constexpr int pack_factor = 32 / num_bits;
int k_tiles = size_k / tile_k_size; int k_tiles = size_k / tile_k_size;
@ -61,8 +61,8 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
constexpr int perm_size = tile_k_size / 4; constexpr int perm_size = tile_k_size / 4;
int4 *sh_perm_ptr = sh; int4* sh_perm_ptr = sh;
int4 *sh_pipe_ptr = sh_perm_ptr; int4* sh_pipe_ptr = sh_perm_ptr;
if constexpr (has_perm) { if constexpr (has_perm) {
sh_pipe_ptr += perm_size; sh_pipe_ptr += perm_size;
} }
@ -76,7 +76,7 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
auto load_perm_to_shared = [&](int k_tile_id) { auto load_perm_to_shared = [&](int k_tile_id) {
int first_k_int4 = (k_tile_id * tile_k_size) / 4; int first_k_int4 = (k_tile_id * tile_k_size) / 4;
int4 const *perm_int4_ptr = reinterpret_cast<int4 const *>(perm_ptr); int4 const* perm_int4_ptr = reinterpret_cast<int4 const*>(perm_ptr);
if (threadIdx.x < perm_size) { if (threadIdx.x < perm_size) {
sh_perm_ptr[threadIdx.x] = perm_int4_ptr[first_k_int4 + threadIdx.x]; sh_perm_ptr[threadIdx.x] = perm_int4_ptr[first_k_int4 + threadIdx.x];
@ -92,22 +92,22 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
int first_n = n_tile_id * tile_n_size; int first_n = n_tile_id * tile_n_size;
int4 *sh_ptr = sh_pipe_ptr + stage_size * pipe; int4* sh_ptr = sh_pipe_ptr + stage_size * pipe;
if constexpr (has_perm) { if constexpr (has_perm) {
if (threadIdx.x < stage_size) { if (threadIdx.x < stage_size) {
int k_id = threadIdx.x / stage_n_threads; int k_id = threadIdx.x / stage_n_threads;
int n_id = threadIdx.x % stage_n_threads; int n_id = threadIdx.x % stage_n_threads;
uint32_t const *sh_perm_int_ptr = uint32_t const* sh_perm_int_ptr =
reinterpret_cast<uint32_t const *>(sh_perm_ptr); reinterpret_cast<uint32_t const*>(sh_perm_ptr);
int src_k = sh_perm_int_ptr[k_id]; int src_k = sh_perm_int_ptr[k_id];
int src_k_packed = src_k / pack_factor; int src_k_packed = src_k / pack_factor;
cp_async4( cp_async4(
&sh_ptr[k_id * stage_n_threads + n_id], &sh_ptr[k_id * stage_n_threads + n_id],
reinterpret_cast<int4 const *>(&( reinterpret_cast<int4 const*>(&(
b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)]))); b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)])));
} }
@ -120,7 +120,7 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
int first_k_packed = first_k / pack_factor; int first_k_packed = first_k / pack_factor;
cp_async4(&sh_ptr[k_id * stage_n_threads + n_id], cp_async4(&sh_ptr[k_id * stage_n_threads + n_id],
reinterpret_cast<int4 const *>( reinterpret_cast<int4 const*>(
&(b_q_weight_ptr[(first_k_packed + k_id) * size_n + &(b_q_weight_ptr[(first_k_packed + k_id) * size_n +
first_n + (n_id * 4)]))); first_n + (n_id * 4)])));
} }
@ -151,10 +151,10 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
constexpr int sh_stride = 64; constexpr int sh_stride = 64;
constexpr uint32_t mask = (1 << num_bits) - 1; constexpr uint32_t mask = (1 << num_bits) - 1;
int4 *sh_stage_ptr = sh_pipe_ptr + stage_size * pipe; int4* sh_stage_ptr = sh_pipe_ptr + stage_size * pipe;
uint32_t *sh_stage_int_ptr = reinterpret_cast<uint32_t *>(sh_stage_ptr); uint32_t* sh_stage_int_ptr = reinterpret_cast<uint32_t*>(sh_stage_ptr);
uint32_t *sh_perm_int_ptr = reinterpret_cast<uint32_t *>(sh_perm_ptr); uint32_t* sh_perm_int_ptr = reinterpret_cast<uint32_t*>(sh_perm_ptr);
uint32_t vals[8]; uint32_t vals[8];
@ -176,17 +176,16 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
} }
} else { } else {
uint32_t b1_vals[tile_ints]; uint32_t b1_vals[tile_ints];
uint32_t b2_vals[tile_ints]; uint32_t b2_vals[tile_ints];
#pragma unroll #pragma unroll
for (int i = 0; i < tile_ints; i++) { for (int i = 0; i < tile_ints; i++) {
b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i]; b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i];
b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i]; b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i];
} }
#pragma unroll #pragma unroll
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
int cur_elem = tc_row + tc_offsets[i]; int cur_elem = tc_row + tc_offsets[i];
int cur_int = cur_elem / pack_factor; int cur_int = cur_elem / pack_factor;
@ -206,7 +205,7 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
uint32_t res = 0; uint32_t res = 0;
#pragma unroll #pragma unroll
for (int i = 0; i < 8; i++) { for (int i = 0; i < 8; i++) {
res |= vals[pack_idx[i]] << (i * 4); res |= vals[pack_idx[i]] << (i * 4);
} }
@ -218,7 +217,7 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
uint32_t res1 = 0; uint32_t res1 = 0;
uint32_t res2 = 0; uint32_t res2 = 0;
#pragma unroll #pragma unroll
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
res1 |= vals[pack_idx[i]] << (i * 8); res1 |= vals[pack_idx[i]] << (i * 8);
res2 |= vals[4 + pack_idx[i]] << (i * 8); res2 |= vals[4 + pack_idx[i]] << (i * 8);
@ -230,14 +229,14 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
}; };
auto start_pipes = [&](int k_tile_id, int n_tile_id) { auto start_pipes = [&](int k_tile_id, int n_tile_id) {
#pragma unroll #pragma unroll
for (int pipe = 0; pipe < repack_stages - 1; pipe++) { for (int pipe = 0; pipe < repack_stages - 1; pipe++) {
fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe); fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe);
} }
wait_for_stage(); wait_for_stage();
}; };
#pragma unroll #pragma unroll
for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) { for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) {
int n_tile_id = 0; int n_tile_id = 0;
@ -248,7 +247,7 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
start_pipes(k_tile_id, n_tile_id); start_pipes(k_tile_id, n_tile_id);
while (n_tile_id < n_tiles) { while (n_tile_id < n_tiles) {
#pragma unroll #pragma unroll
for (int pipe = 0; pipe < repack_stages; pipe++) { for (int pipe = 0; pipe < repack_stages; pipe++) {
fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id, fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id,
n_tile_id + pipe + repack_stages - 1); n_tile_id + pipe + repack_stages - 1);
@ -260,21 +259,21 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
} }
} }
} // namespace gptq_marlin } // namespace gptq_marlin
#define CALL_IF(NUM_BITS, HAS_PERM) \ #define CALL_IF(NUM_BITS, HAS_PERM) \
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \ else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
cudaFuncSetAttribute( \ cudaFuncSetAttribute( \
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, \ gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, \
NUM_BITS, HAS_PERM>, \ NUM_BITS, HAS_PERM>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, NUM_BITS, \ gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, NUM_BITS, \
HAS_PERM> \ HAS_PERM> \
<<<blocks, gptq_marlin::repack_threads, max_shared_mem, stream>>>( \ <<<blocks, gptq_marlin::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \ b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
} }
torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm, torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k, int64_t size_n, int64_t size_k, int64_t size_n,
int64_t num_bits) { int64_t num_bits) {
// Verify compatibility with marlin tile of 16x64 // Verify compatibility with marlin tile of 16x64
@ -318,11 +317,10 @@ torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
bool has_perm = perm.size(0) != 0; bool has_perm = perm.size(0) != 0;
// Get ptrs // Get ptrs
uint32_t const *b_q_weight_ptr = uint32_t const* b_q_weight_ptr =
reinterpret_cast<uint32_t const *>(b_q_weight.data_ptr()); reinterpret_cast<uint32_t const*>(b_q_weight.data_ptr());
uint32_t const *perm_ptr = uint32_t const* perm_ptr = reinterpret_cast<uint32_t const*>(perm.data_ptr());
reinterpret_cast<uint32_t const *>(perm.data_ptr()); uint32_t* out_ptr = reinterpret_cast<uint32_t*>(out.data_ptr());
uint32_t *out_ptr = reinterpret_cast<uint32_t *>(out.data_ptr());
// Get dev info // Get dev info
int dev = b_q_weight.get_device(); int dev = b_q_weight.get_device();

View File

@ -25,7 +25,10 @@
#include <iostream> #include <iostream>
template <typename T> inline std::string str(T x) { return std::to_string(x); } template <typename T>
inline std::string str(T x) {
return std::to_string(x);
}
namespace marlin { namespace marlin {
@ -38,9 +41,10 @@ constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; }
// corresponding index accesses must be compile-time constants, which is why we // corresponding index accesses must be compile-time constants, which is why we
// extensively use `#pragma unroll` throughout the kernel code to guarantee // extensively use `#pragma unroll` throughout the kernel code to guarantee
// this. // this.
template <typename T, int n> struct Vec { template <typename T, int n>
struct Vec {
T elems[n]; T elems[n];
__device__ T &operator[](int i) { return elems[i]; } __device__ T& operator[](int i) { return elems[i]; }
}; };
using I4 = Vec<int, 4>; using I4 = Vec<int, 4>;
@ -51,29 +55,32 @@ using I4 = Vec<int, 4>;
using FragA = Vec<half2, 4>; using FragA = Vec<half2, 4>;
using FragB = Vec<half2, 2>; using FragB = Vec<half2, 2>;
using FragC = Vec<float, 4>; using FragC = Vec<float, 4>;
using FragS = Vec<half2, 1>; // quantization scales using FragS = Vec<half2, 1>; // quantization scales
// Predicated asynchronous global->shared copy; used for inputs A where we apply // Predicated asynchronous global->shared copy; used for inputs A where we apply
// predication to handle batchsizes that are not multiples of 16. // predication to handle batchsizes that are not multiples of 16.
__device__ inline void cp_async4_pred(void *smem_ptr, const void *glob_ptr, __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
bool pred = true) { bool pred = true) {
const int BYTES = 16; const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr)); uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile("{\n" asm volatile(
" .reg .pred p;\n" "{\n"
" setp.ne.b32 p, %0, 0;\n" " .reg .pred p;\n"
" @p cp.async.cg.shared.global [%1], [%2], %3;\n" " setp.ne.b32 p, %0, 0;\n"
"}\n" ::"r"((int)pred), " @p cp.async.cg.shared.global [%1], [%2], %3;\n"
"r"(smem), "l"(glob_ptr), "n"(BYTES)); "}\n" ::"r"((int)pred),
"r"(smem), "l"(glob_ptr), "n"(BYTES));
} }
// Asynchronous global->shared copy // Asynchronous global->shared copy
__device__ inline void cp_async4(void *smem_ptr, const void *glob_ptr) { __device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
const int BYTES = 16; const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr)); uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile("{\n" asm volatile(
" cp.async.cg.shared.global [%0], [%1], %2;\n" "{\n"
"}\n" :: "r"(smem), "l"(glob_ptr), "n"(BYTES)); " cp.async.cg.shared.global [%0], [%1], %2;\n"
"}\n" ::"r"(smem),
"l"(glob_ptr), "n"(BYTES));
} }
// Async copy fence. // Async copy fence.
@ -82,28 +89,30 @@ __device__ inline void cp_async_fence() {
} }
// Wait until at most `n` async copy stages are still pending. // Wait until at most `n` async copy stages are still pending.
template <int n> __device__ inline void cp_async_wait() { template <int n>
__device__ inline void cp_async_wait() {
asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
} }
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 // m16n8k16 tensor core mma instruction with fp16 inputs and fp32
// output/accumulation. // output/accumulation.
__device__ inline void mma(const FragA &a_frag, const FragB &frag_b, __device__ inline void mma(const FragA& a_frag, const FragB& frag_b,
FragC &frag_c) { FragC& frag_c) {
const uint32_t *a = reinterpret_cast<const uint32_t *>(&a_frag); const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
const uint32_t *b = reinterpret_cast<const uint32_t *>(&frag_b); const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
float *c = reinterpret_cast<float *>(&frag_c); float* c = reinterpret_cast<float*>(&frag_c);
asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " asm volatile(
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
"r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} }
// Instruction for loading a full 16x16 matrix fragment of operand A from shared // Instruction for loading a full 16x16 matrix fragment of operand A from shared
// memory, directly in tensor core layout. // memory, directly in tensor core layout.
__device__ inline void ldsm4(FragA &frag_a, const void *smem_ptr) { __device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) {
uint32_t *a = reinterpret_cast<uint32_t *>(&frag_a); uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr)); uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
: "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
@ -113,7 +122,8 @@ __device__ inline void ldsm4(FragA &frag_a, const void *smem_ptr) {
// Lookup-table based 3-input logical operation; explicitly used for // Lookup-table based 3-input logical operation; explicitly used for
// dequantization as the compiler does not seem to automatically recognize it in // dequantization as the compiler does not seem to automatically recognize it in
// all cases. // all cases.
template <int lut> __device__ inline int lop3(int a, int b, int c) { template <int lut>
__device__ inline int lop3(int a, int b, int c) {
int res; int res;
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(res) : "=r"(res)
@ -138,24 +148,24 @@ __device__ inline FragB dequant(int q) {
const int MUL = 0x2c002c00; const int MUL = 0x2c002c00;
const int ADD = 0xd480d480; const int ADD = 0xd480d480;
FragB frag_b; FragB frag_b;
frag_b[0] = __hsub2(*reinterpret_cast<half2 *>(&lo), frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
*reinterpret_cast<const half2 *>(&SUB)); *reinterpret_cast<const half2*>(&SUB));
frag_b[1] = __hfma2(*reinterpret_cast<half2 *>(&hi), frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
*reinterpret_cast<const half2 *>(&MUL), *reinterpret_cast<const half2*>(&MUL),
*reinterpret_cast<const half2 *>(&ADD)); *reinterpret_cast<const half2*>(&ADD));
return frag_b; return frag_b;
} }
// Multiply dequantized values by the corresponding quantization scale; used // Multiply dequantized values by the corresponding quantization scale; used
// only for grouped quantization. // only for grouped quantization.
__device__ inline void scale(FragB &frag_b, FragS &frag_s, int i) { __device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) {
half2 s = __half2half2(reinterpret_cast<__half *>(&frag_s)[i]); half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]);
frag_b[0] = __hmul2(frag_b[0], s); frag_b[0] = __hmul2(frag_b[0], s);
frag_b[1] = __hmul2(frag_b[1], s); frag_b[1] = __hmul2(frag_b[1], s);
} }
// Wait until barrier reaches `count`, then lock for current threadblock. // Wait until barrier reaches `count`, then lock for current threadblock.
__device__ inline void barrier_acquire(int *lock, int count) { __device__ inline void barrier_acquire(int* lock, int count) {
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
int state = -1; int state = -1;
do do
@ -170,7 +180,7 @@ __device__ inline void barrier_acquire(int *lock, int count) {
} }
// Release barrier and increment visitation count. // Release barrier and increment visitation count.
__device__ inline void barrier_release(int *lock, bool reset = false) { __device__ inline void barrier_release(int* lock, bool reset = false) {
__syncthreads(); __syncthreads();
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
if (reset) { if (reset) {
@ -187,26 +197,27 @@ __device__ inline void barrier_release(int *lock, bool reset = false) {
} }
} }
template <const int threads, // number of threads in a threadblock template <const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the threadblock // dimension (batchsize) of the
const int thread_n_blocks, // same for n dimension (output) // threadblock
const int thread_k_blocks, // same for k dimension (reduction) const int thread_n_blocks, // same for n dimension (output)
const int stages, // number of stages for the async global->shared const int thread_k_blocks, // same for k dimension (reduction)
// fetch pipeline const int stages, // number of stages for the async global->shared
const int group_blocks = -1 // number of consecutive 16x16 blocks with // fetch pipeline
// a separate quantization scale const int group_blocks = -1 // number of consecutive 16x16 blocks
// with a separate quantization scale
> >
__global__ void __global__ void Marlin(
Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk const int4* __restrict__ A, // fp16 input matrix of shape mxk
const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
int4 *__restrict__ C, // fp16 output buffer of shape mxn int4* __restrict__ C, // fp16 output buffer of shape mxn
const int4 const int4* __restrict__ s, // fp16 quantization scales of shape
*__restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn // (k/groupsize)xn
int prob_m, // batch dimension m int prob_m, // batch dimension m
int prob_n, // output dimension n int prob_n, // output dimension n
int prob_k, // reduction dimension k int prob_k, // reduction dimension k
int *locks // extra global storage for barrier synchronization int* locks // extra global storage for barrier synchronization
) { ) {
// Each threadblock processes one "stripe" of the B matrix with (roughly) the // Each threadblock processes one "stripe" of the B matrix with (roughly) the
// same size, which might involve multiple column "slices" (of width 16 * // same size, which might involve multiple column "slices" (of width 16 *
@ -241,11 +252,11 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
int slice_row = (iters * blockIdx.x) % k_tiles; int slice_row = (iters * blockIdx.x) % k_tiles;
int slice_col_par = (iters * blockIdx.x) / k_tiles; int slice_col_par = (iters * blockIdx.x) / k_tiles;
int slice_col = slice_col_par; int slice_col = slice_col_par;
int slice_iters; // number of threadblock tiles in the current slice int slice_iters; // number of threadblock tiles in the current slice
int slice_count = int slice_count =
0; // total number of active threadblocks in the current slice 0; // total number of active threadblocks in the current slice
int slice_idx; // index of threadblock in current slice; numbered bottom to int slice_idx; // index of threadblock in current slice; numbered bottom to
// top // top
// We can easily implement parallel problem execution by just remapping // We can easily implement parallel problem execution by just remapping
// indices and advancing global pointers // indices and advancing global pointers
@ -261,27 +272,22 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
auto init_slice = [&]() { auto init_slice = [&]() {
slice_iters = slice_iters =
iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);
if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0;
slice_iters = 0; if (slice_iters == 0) return;
if (slice_iters == 0) if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row;
return;
if (slice_row + slice_iters > k_tiles)
slice_iters = k_tiles - slice_row;
slice_count = 1; slice_count = 1;
slice_idx = 0; slice_idx = 0;
int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); int col_first = iters * ceildiv(k_tiles * slice_col_par, iters);
if (col_first <= k_tiles * (slice_col_par + 1)) { if (col_first <= k_tiles * (slice_col_par + 1)) {
int col_off = col_first - k_tiles * slice_col_par; int col_off = col_first - k_tiles * slice_col_par;
slice_count = ceildiv(k_tiles - col_off, iters); slice_count = ceildiv(k_tiles - col_off, iters);
if (col_off > 0) if (col_off > 0) slice_count++;
slice_count++;
int delta_first = iters * blockIdx.x - col_first; int delta_first = iters * blockIdx.x - col_first;
if (delta_first < 0 || (col_off == 0 && delta_first == 0)) if (delta_first < 0 || (col_off == 0 && delta_first == 0))
slice_idx = slice_count - 1; slice_idx = slice_count - 1;
else { else {
slice_idx = slice_count - 1 - delta_first / iters; slice_idx = slice_count - 1 - delta_first / iters;
if (col_off > 0) if (col_off > 0) slice_idx--;
slice_idx--;
} }
} }
if (slice_col == n_tiles) { if (slice_col == n_tiles) {
@ -293,29 +299,30 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
}; };
init_slice(); init_slice();
int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory
// We typically use `constexpr` to indicate that this value is a compile-time // We typically use `constexpr` to indicate that this value is a compile-time
// constant // constant
constexpr int a_sh_stride = constexpr int a_sh_stride =
16 * thread_k_blocks / 8; // stride of an A matrix tile in shared memory 16 * thread_k_blocks / 8; // stride of an A matrix tile in shared memory
constexpr int a_gl_rd_delta_o = constexpr int a_gl_rd_delta_o =
16 * thread_k_blocks / 16 * thread_k_blocks /
8; // delta between subsequent A tiles in global memory 8; // delta between subsequent A tiles in global memory
int a_gl_rd_delta_i = int a_gl_rd_delta_i =
a_gl_stride * a_gl_stride *
(threads / a_gl_rd_delta_o); // between subsequent accesses within a tile (threads / a_gl_rd_delta_o); // between subsequent accesses within a tile
constexpr int a_sh_wr_delta = constexpr int a_sh_wr_delta =
a_sh_stride * (threads / a_gl_rd_delta_o); // between shared memory writes a_sh_stride *
(threads / a_gl_rd_delta_o); // between shared memory writes
constexpr int a_sh_rd_delta_o = constexpr int a_sh_rd_delta_o =
2 * ((threads / 32) / 2 * ((threads / 32) /
(thread_n_blocks / 4)); // between shared memory tile reads (thread_n_blocks / 4)); // between shared memory tile reads
constexpr int a_sh_rd_delta_i = constexpr int a_sh_rd_delta_i =
a_sh_stride * 16; // within a shared memory tile a_sh_stride * 16; // within a shared memory tile
constexpr int a_sh_stage = constexpr int a_sh_stage =
a_sh_stride * (16 * thread_m_blocks); // overall size of a tile a_sh_stride * (16 * thread_m_blocks); // overall size of a tile
constexpr int a_sh_wr_iters = constexpr int a_sh_wr_iters =
ceildiv(a_sh_stage, ceildiv(a_sh_stage,
a_sh_wr_delta); // number of shared write iterations for a tile a_sh_wr_delta); // number of shared write iterations for a tile
int b_gl_stride = 16 * prob_n / 32; int b_gl_stride = 16 * prob_n / 32;
constexpr int b_sh_stride = 32 * thread_n_blocks / 4; constexpr int b_sh_stride = 32 * thread_n_blocks / 4;
@ -368,7 +375,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
// needed if there are more threads than required for a certain tilesize or // needed if there are more threads than required for a certain tilesize or
// when the batchsize is not a multiple of 16. // when the batchsize is not a multiple of 16.
bool a_sh_wr_pred[a_sh_wr_iters]; bool a_sh_wr_pred[a_sh_wr_iters];
#pragma unroll #pragma unroll
for (int i = 0; i < a_sh_wr_iters; i++) for (int i = 0; i < a_sh_wr_iters; i++)
a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m;
bool s_sh_wr_pred = threadIdx.x < s_sh_stride; bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
@ -387,13 +394,13 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
// loop unrolls, all shared memory accesses are static, we simply precompute // loop unrolls, all shared memory accesses are static, we simply precompute
// both transformed reads and writes. // both transformed reads and writes.
int a_sh_wr_trans[a_sh_wr_iters]; int a_sh_wr_trans[a_sh_wr_iters];
#pragma unroll #pragma unroll
for (int i = 0; i < a_sh_wr_iters; i++) for (int i = 0; i < a_sh_wr_iters; i++)
a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr);
int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks];
#pragma unroll #pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++) { for (int i = 0; i < b_sh_wr_iters; i++) {
#pragma unroll #pragma unroll
for (int j = 0; j < thread_m_blocks; j++) for (int j = 0; j < thread_m_blocks; j++)
a_sh_rd_trans[i][j] = a_sh_rd_trans[i][j] =
transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd);
@ -403,16 +410,16 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
// runtime; we break dependencies between subsequent accesses with a tile by // runtime; we break dependencies between subsequent accesses with a tile by
// maintining multiple pointers (we have enough registers), a tiny // maintining multiple pointers (we have enough registers), a tiny
// optimization. // optimization.
const int4 *B_ptr[b_sh_wr_iters]; const int4* B_ptr[b_sh_wr_iters];
#pragma unroll #pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++) for (int i = 0; i < b_sh_wr_iters; i++)
B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;
extern __shared__ int4 sh[]; extern __shared__ int4 sh[];
// Shared memory storage for global fetch pipelines. // Shared memory storage for global fetch pipelines.
int4 *sh_a = sh; int4* sh_a = sh;
int4 *sh_b = sh_a + (stages * a_sh_stage); int4* sh_b = sh_a + (stages * a_sh_stage);
int4 *sh_s = sh_b + (stages * b_sh_stage); int4* sh_s = sh_b + (stages * b_sh_stage);
// Register storage for double buffer of shared memory reads. // Register storage for double buffer of shared memory reads.
FragA frag_a[2][thread_m_blocks]; FragA frag_a[2][thread_m_blocks];
I4 frag_b_quant[2]; I4 frag_b_quant[2];
@ -421,34 +428,33 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
// Zero accumulators. // Zero accumulators.
auto zero_accums = [&]() { auto zero_accums = [&]() {
#pragma unroll #pragma unroll
for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++)
reinterpret_cast<float *>(frag_c)[i] = 0; reinterpret_cast<float*>(frag_c)[i] = 0;
}; };
// Asynchronously fetch the next A, B and s tile from global to the next // Asynchronously fetch the next A, B and s tile from global to the next
// shared memory pipeline location. // shared memory pipeline location.
auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) {
if (pred) { if (pred) {
int4 *sh_a_stage = sh_a + a_sh_stage * pipe; int4* sh_a_stage = sh_a + a_sh_stage * pipe;
#pragma unroll #pragma unroll
for (int i = 0; i < a_sh_wr_iters; i++) { for (int i = 0; i < a_sh_wr_iters; i++) {
cp_async4_pred( cp_async4_pred(
&sh_a_stage[a_sh_wr_trans[i]], &sh_a_stage[a_sh_wr_trans[i]],
&A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off],
a_sh_wr_pred[i]); a_sh_wr_pred[i]);
} }
int4 *sh_b_stage = sh_b + b_sh_stage * pipe; int4* sh_b_stage = sh_b + b_sh_stage * pipe;
#pragma unroll #pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++) { for (int i = 0; i < b_sh_wr_iters; i++) {
cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]); cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]);
B_ptr[i] += b_gl_rd_delta_o; B_ptr[i] += b_gl_rd_delta_o;
} }
// Only fetch scales if this tile starts a new group // Only fetch scales if this tile starts a new group
if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) { if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) {
int4 *sh_s_stage = sh_s + s_sh_stage * pipe; int4* sh_s_stage = sh_s + s_sh_stage * pipe;
if (s_sh_wr_pred) if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
s_gl_rd += s_gl_rd_delta; s_gl_rd += s_gl_rd_delta;
} }
} }
@ -475,37 +481,35 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
// theoretically better attempts have lead to bad instruction ordering by // theoretically better attempts have lead to bad instruction ordering by
// the compiler and correspondingly a noticeable drop in performance. // the compiler and correspondingly a noticeable drop in performance.
if (group_blocks != -1) { if (group_blocks != -1) {
int4 *sh_s_stage = int4* sh_s_stage =
sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
(pipe / (group_blocks / thread_k_blocks))); (pipe / (group_blocks / thread_k_blocks)));
reinterpret_cast<int4 *>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
} }
int4 *sh_a_stage = sh_a + a_sh_stage * pipe; int4* sh_a_stage = sh_a + a_sh_stage * pipe;
#pragma unroll #pragma unroll
for (int i = 0; i < thread_m_blocks; i++) for (int i = 0; i < thread_m_blocks; i++)
ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
int4 *sh_b_stage = sh_b + b_sh_stage * pipe; int4* sh_b_stage = sh_b + b_sh_stage * pipe;
frag_b_quant[k % 2] = *reinterpret_cast<I4 *>( frag_b_quant[k % 2] = *reinterpret_cast<I4*>(
&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]); &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]);
}; };
// Execute the actual tensor core matmul of a sub-tile. // Execute the actual tensor core matmul of a sub-tile.
auto matmul = [&](int k) { auto matmul = [&](int k) {
// We have the m dimension as the inner loop in order to encourage overlapping // We have the m dimension as the inner loop in order to encourage overlapping
// dequantization and matmul operations. // dequantization and matmul operations.
#pragma unroll #pragma unroll
for (int j = 0; j < 4; j++) { for (int j = 0; j < 4; j++) {
int b_quant = frag_b_quant[k % 2][j]; int b_quant = frag_b_quant[k % 2][j];
int b_quant_shift = b_quant >> 8; int b_quant_shift = b_quant >> 8;
FragB frag_b0 = dequant(b_quant); FragB frag_b0 = dequant(b_quant);
// If there are no groups, we can just scale the final output once and can // If there are no groups, we can just scale the final output once and can
// avoid doing so for each weight. // avoid doing so for each weight.
if (group_blocks != -1) if (group_blocks != -1) scale(frag_b0, frag_s[k % 2][j], 0);
scale(frag_b0, frag_s[k % 2][j], 0);
FragB frag_b1 = dequant(b_quant_shift); FragB frag_b1 = dequant(b_quant_shift);
if (group_blocks != -1) if (group_blocks != -1) scale(frag_b1, frag_s[k % 2][j], 1);
scale(frag_b1, frag_s[k % 2][j], 1); #pragma unroll
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) { for (int i = 0; i < thread_m_blocks; i++) {
mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]);
mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]);
@ -530,38 +534,38 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
// unnecessary read or write iterations, e.g., for two warps we write only // unnecessary read or write iterations, e.g., for two warps we write only
// once by warp 1 and read only once by warp 0. // once by warp 1 and read only once by warp 0.
#pragma unroll #pragma unroll
for (int m_block = 0; m_block < thread_m_blocks; m_block++) { for (int m_block = 0; m_block < thread_m_blocks; m_block++) {
#pragma unroll #pragma unroll
for (int i = red_off; i > 0; i /= 2) { for (int i = red_off; i > 0; i /= 2) {
if (i <= red_idx && red_idx < 2 * i) { if (i <= red_idx && red_idx < 2 * i) {
#pragma unroll #pragma unroll
for (int j = 0; j < 4 * 2; j++) { for (int j = 0; j < 4 * 2; j++) {
int red_sh_wr = int red_sh_wr =
red_sh_delta * j + (red_sh_rd - red_sh_stride * i); red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
if (i < red_off) { if (i < red_off) {
float *c_rd = reinterpret_cast<float *>( float* c_rd =
&sh[red_sh_delta * j + red_sh_rd]); reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]);
float *c_wr = reinterpret_cast<float *>(&sh[red_sh_wr]); float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]);
#pragma unroll #pragma unroll
for (int k = 0; k < 4; k++) for (int k = 0; k < 4; k++)
reinterpret_cast<FragC *>(frag_c)[4 * 2 * m_block + j][k] += reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] +=
c_rd[k] + c_wr[k]; c_rd[k] + c_wr[k];
} }
sh[red_sh_wr] = sh[red_sh_wr] =
reinterpret_cast<int4 *>(&frag_c)[4 * 2 * m_block + j]; reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];
} }
} }
__syncthreads(); __syncthreads();
} }
if (red_idx == 0) { if (red_idx == 0) {
#pragma unroll #pragma unroll
for (int i = 0; i < 4 * 2; i++) { for (int i = 0; i < 4 * 2; i++) {
float *c_rd = float* c_rd =
reinterpret_cast<float *>(&sh[red_sh_delta * i + red_sh_rd]); reinterpret_cast<float*>(&sh[red_sh_delta * i + red_sh_rd]);
#pragma unroll #pragma unroll
for (int j = 0; j < 4; j++) for (int j = 0; j < 4; j++)
reinterpret_cast<FragC *>(frag_c)[4 * 2 * m_block + i][j] += reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] +=
c_rd[j]; c_rd[j];
} }
} }
@ -571,9 +575,9 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
}; };
// Since multiple threadblocks may process parts of the same column slice, we // Since multiple threadblocks may process parts of the same column slice, we
// finally have to globally reduce over the results. As the striped partitioning // finally have to globally reduce over the results. As the striped
// minimizes the number of such reductions and our outputs are usually rather // partitioning minimizes the number of such reductions and our outputs are
// small, we perform this reduction serially in L2 cache. // usually rather small, we perform this reduction serially in L2 cache.
auto global_reduce = [&](bool first = false, bool last = false) { auto global_reduce = [&](bool first = false, bool last = false) {
// We are very careful here to reduce directly in the output buffer to // We are very careful here to reduce directly in the output buffer to
// maximize L2 cache utilization in this step. To do this, we write out // maximize L2 cache utilization in this step. To do this, we write out
@ -592,39 +596,39 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
int row = (threadIdx.x % 32) / 4; int row = (threadIdx.x % 32) / 4;
if (!first) { if (!first) {
// Interestingly, doing direct global accesses here really seems to mess up the // Interestingly, doing direct global accesses here really seems to mess up
// compiler and lead to slowdowns, hence we also use async-copies even though // the compiler and lead to slowdowns, hence we also use async-copies even
// these fetches are not actually asynchronous. // though these fetches are not actually asynchronous.
#pragma unroll #pragma unroll
for (int i = 0; i < thread_m_blocks * 4; i++) { for (int i = 0; i < thread_m_blocks * 4; i++) {
cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], cp_async4_pred(
&C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + &sh[c_sh_wr + c_sh_wr_delta * i],
c_gl_wr_delta_i * (i % 2)], &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
i < (thread_m_blocks - 1) * 4 || c_gl_wr_delta_i * (i % 2)],
8 * (i / 2) + row < prob_m); i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m);
} }
cp_async_fence(); cp_async_fence();
cp_async_wait<0>(); cp_async_wait<0>();
} }
#pragma unroll #pragma unroll
for (int i = 0; i < thread_m_blocks * 4; i++) { for (int i = 0; i < thread_m_blocks * 4; i++) {
if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) {
if (!first) { if (!first) {
int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta];
#pragma unroll #pragma unroll
for (int j = 0; j < 2 * 4; j++) { for (int j = 0; j < 2 * 4; j++) {
reinterpret_cast<float *>( reinterpret_cast<float*>(
&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] +=
__half2float(reinterpret_cast<__half *>(&c_red)[j]); __half2float(reinterpret_cast<__half*>(&c_red)[j]);
} }
} }
if (!last) { if (!last) {
int4 c; int4 c;
#pragma unroll #pragma unroll
for (int j = 0; j < 2 * 4; j++) { for (int j = 0; j < 2 * 4; j++) {
reinterpret_cast<__half *>(&c)[j] = reinterpret_cast<__half*>(&c)[j] =
__float2half(reinterpret_cast<float *>( __float2half(reinterpret_cast<float*>(
&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]);
} }
C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] =
@ -658,17 +662,17 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
// We first reorder in shared memory to guarantee the most efficient final // We first reorder in shared memory to guarantee the most efficient final
// global write patterns // global write patterns
auto write = [&](int idx, float c0, float c1, FragS &s) { auto write = [&](int idx, float c0, float c1, FragS& s) {
half2 res = __halves2half2(__float2half(c0), __float2half(c1)); half2 res = __halves2half2(__float2half(c0), __float2half(c1));
if (group_blocks == if (group_blocks ==
-1) // for per-column quantization we finally apply the scale here -1) // for per-column quantization we finally apply the scale here
res = __hmul2(res, s[0]); res = __hmul2(res, s[0]);
((half2 *)sh)[idx] = res; ((half2*)sh)[idx] = res;
}; };
if (threadIdx.x / 32 < thread_n_blocks / 4) { if (threadIdx.x / 32 < thread_n_blocks / 4) {
#pragma unroll #pragma unroll
for (int i = 0; i < thread_m_blocks; i++) { for (int i = 0; i < thread_m_blocks; i++) {
#pragma unroll #pragma unroll
for (int j = 0; j < 4; j++) { for (int j = 0; j < 4; j++) {
int wr = c_sh_wr + 8 * j; int wr = c_sh_wr + 8 * j;
write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0],
@ -685,7 +689,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
} }
__syncthreads(); __syncthreads();
#pragma unroll #pragma unroll
for (int i = 0; for (int i = 0;
i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks));
i++) { i++) {
@ -699,9 +703,8 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
// Start global fetch and register load pipelines. // Start global fetch and register load pipelines.
auto start_pipes = [&]() { auto start_pipes = [&]() {
#pragma unroll #pragma unroll
for (int i = 0; i < stages - 1; i++) for (int i = 0; i < stages - 1; i++) fetch_to_shared(i, i, i < slice_iters);
fetch_to_shared(i, i, i < slice_iters);
zero_accums(); zero_accums();
wait_for_stage(); wait_for_stage();
fetch_to_registers(0, 0); fetch_to_registers(0, 0);
@ -711,12 +714,12 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
// Main loop. // Main loop.
while (slice_iters) { while (slice_iters) {
// We unroll over both the global fetch and the register load pipeline to ensure // We unroll over both the global fetch and the register load pipeline to
// all shared memory accesses are static. Note that both pipelines have even // ensure all shared memory accesses are static. Note that both pipelines have
// length meaning that the next iteration will always start at index 0. // even length meaning that the next iteration will always start at index 0.
#pragma unroll #pragma unroll
for (int pipe = 0; pipe < stages;) { for (int pipe = 0; pipe < stages;) {
#pragma unroll #pragma unroll
for (int k = 0; k < b_sh_wr_iters; k++) { for (int k = 0; k < b_sh_wr_iters; k++) {
fetch_to_registers(k + 1, pipe % stages); fetch_to_registers(k + 1, pipe % stages);
if (k == b_sh_wr_iters - 2) { if (k == b_sh_wr_iters - 2) {
@ -728,8 +731,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
matmul(k); matmul(k);
} }
slice_iters--; slice_iters--;
if (slice_iters == 0) if (slice_iters == 0) break;
break;
} }
a_gl_rd += a_gl_rd_delta_o * stages; a_gl_rd += a_gl_rd_delta_o * stages;
@ -742,8 +744,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
// For per-column scales, we only fetch them here in the final step before // For per-column scales, we only fetch them here in the final step before
// write-out // write-out
if (group_blocks == -1 && last) { if (group_blocks == -1 && last) {
if (s_sh_wr_pred) if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]);
cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]);
cp_async_fence(); cp_async_fence();
} }
thread_block_reduce(); thread_block_reduce();
@ -751,17 +752,17 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
cp_async_wait<0>(); cp_async_wait<0>();
__syncthreads(); __syncthreads();
if (threadIdx.x / 32 < thread_n_blocks / 4) { if (threadIdx.x / 32 < thread_n_blocks / 4) {
reinterpret_cast<int4 *>(&frag_s)[0] = sh_s[s_sh_rd + 0]; reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];
reinterpret_cast<int4 *>(&frag_s)[1] = sh_s[s_sh_rd + 4]; reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];
} }
} }
if (slice_count > 1) { // only globally reduce if there is more than one if (slice_count > 1) { // only globally reduce if there is more than one
// block in a slice // block in a slice
barrier_acquire(&locks[slice_col], slice_idx); barrier_acquire(&locks[slice_col], slice_idx);
global_reduce(slice_idx == 0, last); global_reduce(slice_idx == 0, last);
barrier_release(&locks[slice_col], last); barrier_release(&locks[slice_col], last);
} }
if (last) // only the last block in a slice actually writes the result if (last) // only the last block in a slice actually writes the result
write_result(); write_result();
slice_row = 0; slice_row = 0;
slice_col_par++; slice_col_par++;
@ -770,13 +771,12 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
if (slice_iters) { if (slice_iters) {
a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
(threadIdx.x % a_gl_rd_delta_o); (threadIdx.x % a_gl_rd_delta_o);
#pragma unroll #pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++) for (int i = 0; i < b_sh_wr_iters; i++)
B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
if (slice_col == 0) { if (slice_col == 0) {
#pragma unroll #pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++) for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
B_ptr[i] -= b_gl_stride;
} }
s_gl_rd = s_sh_stride * slice_col + threadIdx.x; s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
start_pipes(); start_pipes();
@ -787,26 +787,27 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
#else #else
template <const int threads, // number of threads in a threadblock template <const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the threadblock // dimension (batchsize) of the
const int thread_n_blocks, // same for n dimension (output) // threadblock
const int thread_k_blocks, // same for k dimension (reduction) const int thread_n_blocks, // same for n dimension (output)
const int stages, // number of stages for the async global->shared const int thread_k_blocks, // same for k dimension (reduction)
// fetch pipeline const int stages, // number of stages for the async global->shared
const int group_blocks = -1 // number of consecutive 16x16 blocks with // fetch pipeline
// a separate quantization scale const int group_blocks = -1 // number of consecutive 16x16 blocks
// with a separate quantization scale
> >
__global__ void __global__ void Marlin(
Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk const int4* __restrict__ A, // fp16 input matrix of shape mxk
const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
int4 *__restrict__ C, // fp16 output buffer of shape mxn int4* __restrict__ C, // fp16 output buffer of shape mxn
const int4 const int4* __restrict__ s, // fp16 quantization scales of shape
*__restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn // (k/groupsize)xn
int prob_m, // batch dimension m int prob_m, // batch dimension m
int prob_n, // output dimension n int prob_n, // output dimension n
int prob_k, // reduction dimension k int prob_k, // reduction dimension k
int *locks // extra global storage for barrier synchronization int* locks // extra global storage for barrier synchronization
) { ) {
// Marlin is not implemented yet for SM < 8.0 // Marlin is not implemented yet for SM < 8.0
assert(false); assert(false);
@ -819,10 +820,10 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
// than 1 warp per schedule allows some more latency hiding. At the same time, // than 1 warp per schedule allows some more latency hiding. At the same time,
// we want relatively few warps to have many registers per warp and small tiles. // we want relatively few warps to have many registers per warp and small tiles.
const int USER_THREADS = const int USER_THREADS =
256; // Note: This is only used with user-provided thread_k/n 256; // Note: This is only used with user-provided thread_k/n
const int STAGES = 4; // 4 pipeline stages fit into shared memory const int STAGES = 4; // 4 pipeline stages fit into shared memory
const int SHARED_MEM = const int SHARED_MEM =
96 * 1024; // max shared memory on compute capability 8.6 (< 8.0) 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0)
static constexpr int min_thread_n = 64; static constexpr int min_thread_n = 64;
static constexpr int min_thread_k = 64; static constexpr int min_thread_k = 64;
@ -831,7 +832,7 @@ static constexpr int tile_size = 16;
static constexpr int max_par = 16; static constexpr int max_par = 16;
static constexpr int pack_factor_4bit = static constexpr int pack_factor_4bit =
8; // We have 8 4-bit vals inside a 32 bit 8; // We have 8 4-bit vals inside a 32 bit
#define __CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ #define __CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
GROUP_BLOCKS, NUM_THREADS) \ GROUP_BLOCKS, NUM_THREADS) \
@ -858,23 +859,23 @@ thread_config_t small_batch_thread_configs[] = {
// Ordered by priority // Ordered by priority
// thread_k, thread_n, num_threads // thread_k, thread_n, num_threads
{128, 128, 256}, // Default {128, 128, 256}, // Default
{128, 64, 128}, // Reduce N 2X, same K {128, 64, 128}, // Reduce N 2X, same K
{64, 256, 256}, // Reduce K 2X, increase N 2X {64, 256, 256}, // Reduce K 2X, increase N 2X
{64, 128, 128}, // Reduce K 2X, same N {64, 128, 128}, // Reduce K 2X, same N
}; };
thread_config_t large_batch_thread_configs[] = { thread_config_t large_batch_thread_configs[] = {
// Ordered by priority // Ordered by priority
// thread_k, thread_n, num_threads // thread_k, thread_n, num_threads
{64, 256, 256}, // Default {64, 256, 256}, // Default
{128, 128, 256}, // Reduce N 2X, increase K 2X {128, 128, 256}, // Reduce N 2X, increase K 2X
{64, 128, 128}, // Reduce N 2X, same K {64, 128, 128}, // Reduce N 2X, same K
{128, 64, 128}, // Reduce N 4X, increase K 2X {128, 64, 128}, // Reduce N 4X, increase K 2X
}; };
bool is_valid_config(thread_config_t const &th_config, int prob_m, int prob_n, bool is_valid_config(thread_config_t const& th_config, int prob_m, int prob_n,
int prob_k) { int prob_k) {
// Sanity // Sanity
if (th_config.thread_k == -1 || th_config.thread_n == -1 || if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
@ -907,7 +908,6 @@ bool is_valid_config(thread_config_t const &th_config, int prob_m, int prob_n,
} }
thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) { thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) {
if (prob_m <= 16) { if (prob_m <= 16) {
for (auto th_config : small_batch_thread_configs) { for (auto th_config : small_batch_thread_configs) {
if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { if (is_valid_config(th_config, prob_m, prob_n, prob_k)) {
@ -926,20 +926,20 @@ thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) {
return thread_config_t{-1, -1, -1}; return thread_config_t{-1, -1, -1};
} }
#define CALL_IF(N_BLOCKS, K_BLOCKS, NUM_THREADS) \ #define CALL_IF(N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
__CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
__CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
__CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
__CALL_IF(2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ __CALL_IF(2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
__CALL_IF(2, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ __CALL_IF(2, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
__CALL_IF(3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ __CALL_IF(3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
__CALL_IF(3, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ __CALL_IF(3, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
__CALL_IF(4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ __CALL_IF(4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
__CALL_IF(4, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) __CALL_IF(4, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS)
void marlin_cuda(const void *A, const void *B, void *C, void *s, int prob_m, void marlin_cuda(const void* A, const void* B, void* C, void* s, int prob_m,
int prob_n, int prob_k, void *workspace, int groupsize = -1, int prob_n, int prob_k, void* workspace, int groupsize = -1,
int dev = 0, cudaStream_t stream = 0, int thread_k = -1, int dev = 0, cudaStream_t stream = 0, int thread_k = -1,
int thread_n = -1, int sms = -1, int max_par = 16) { int thread_n = -1, int sms = -1, int max_par = 16) {
int tot_m = prob_m; int tot_m = prob_m;
@ -996,12 +996,12 @@ void marlin_cuda(const void *A, const void *B, void *C, void *s, int prob_m,
" is not divisible by group_blocks = ", group_blocks); " is not divisible by group_blocks = ", group_blocks);
} }
const int4 *A_ptr = (const int4 *)A; const int4* A_ptr = (const int4*)A;
const int4 *B_ptr = (const int4 *)B; const int4* B_ptr = (const int4*)B;
int4 *C_ptr = (int4 *)C; int4* C_ptr = (int4*)C;
const int4 *s_ptr = (const int4 *)s; const int4* s_ptr = (const int4*)s;
int *locks = (int *)workspace; int* locks = (int*)workspace;
for (int i = 0; i < tot_m_blocks; i += 4) { for (int i = 0; i < tot_m_blocks; i += 4) {
int thread_m_blocks = tot_m_blocks - i; int thread_m_blocks = tot_m_blocks - i;
@ -1011,8 +1011,7 @@ void marlin_cuda(const void *A, const void *B, void *C, void *s, int prob_m,
// Note that parallel > 1 currently only works for inputs without any // Note that parallel > 1 currently only works for inputs without any
// padding // padding
par = (16 * thread_m_blocks - pad) / 64; par = (16 * thread_m_blocks - pad) / 64;
if (par > max_par) if (par > max_par) par = max_par;
par = max_par;
prob_m = 64 * par; prob_m = 64 * par;
i += 4 * (par - 1); i += 4 * (par - 1);
thread_m_blocks = 4; thread_m_blocks = 4;
@ -1041,12 +1040,11 @@ void marlin_cuda(const void *A, const void *B, void *C, void *s, int prob_m,
} }
} }
} // namespace marlin } // namespace marlin
torch::Tensor marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor &b_scales, torch::Tensor &workspace, torch::Tensor& b_scales, torch::Tensor& workspace,
int64_t size_m, int64_t size_n, int64_t size_k) { int64_t size_m, int64_t size_n, int64_t size_k) {
// Verify M // Verify M
TORCH_CHECK(size_m == a.size(0), TORCH_CHECK(size_m == a.size(0),
"Shape mismatch: a.size(0) = " + str(a.size(0)) + "Shape mismatch: a.size(0) = " + str(a.size(0)) +
@ -1074,9 +1072,9 @@ torch::Tensor marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
int actual_size_n = int actual_size_n =
(b_q_weight.size(1) / marlin::tile_size) * marlin::pack_factor_4bit; (b_q_weight.size(1) / marlin::tile_size) * marlin::pack_factor_4bit;
TORCH_CHECK(size_n == actual_size_n, TORCH_CHECK(
"size_n = " + str(size_n) + size_n == actual_size_n,
", actual_size_n = " + str(actual_size_n)); "size_n = " + str(size_n) + ", actual_size_n = " + str(actual_size_n));
// Verify A device and strides // Verify A device and strides
TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); TORCH_CHECK(a.device().is_cuda(), "A is not on GPU");

View File

@ -26,12 +26,14 @@ constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; }
// corresponding index accesses must be compile-time constants, which is why we // corresponding index accesses must be compile-time constants, which is why we
// extensively use `#pragma unroll` throughout the kernel code to guarantee // extensively use `#pragma unroll` throughout the kernel code to guarantee
// this. // this.
template <typename T, int n> struct Vec { template <typename T, int n>
struct Vec {
T elems[n]; T elems[n];
__device__ T &operator[](int i) { return elems[i]; } __device__ T& operator[](int i) { return elems[i]; }
}; };
template <int M_, int N_, int K_> struct ShapeBase { template <int M_, int N_, int K_>
struct ShapeBase {
static constexpr int M = M_, N = N_, K = K_; static constexpr int M = M_, N = N_, K = K_;
}; };
@ -44,6 +46,6 @@ using FragA = Vec<half2, 4>;
using FragB = Vec<half2, 2>; using FragB = Vec<half2, 2>;
using FragM = Vec<uint, 1>; using FragM = Vec<uint, 1>;
using FragC = Vec<float, 4>; using FragC = Vec<float, 4>;
using FragS = Vec<half2, 1>; // quantization scales using FragS = Vec<half2, 1>; // quantization scales
} // namespace marlin_24 } // namespace marlin_24

View File

@ -21,41 +21,44 @@
namespace marlin_24 { namespace marlin_24 {
// Predicated asynchronous global->shared copy; used for inputs A where we apply // Predicated asynchronous global->shared copy; used for inputs A where we apply
// predication to handle batchsizes that are not multiples of 16. // predication to handle batchsizes that are not multiples of 16.
__device__ inline void cp_async4_pred_zfill(void *smem_ptr, __device__ inline void cp_async4_pred_zfill(void* smem_ptr,
const void *glob_ptr, const void* glob_ptr,
bool pred = true, bool pred = true,
const bool zfill = false) { const bool zfill = false) {
const int BYTES = 16; const int BYTES = 16;
int src_in_bytes = (zfill ? 0 : BYTES); int src_in_bytes = (zfill ? 0 : BYTES);
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr)); uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile("{\n" asm volatile(
" .reg .pred p;\n" "{\n"
" setp.ne.b32 p, %0, 0;\n" " .reg .pred p;\n"
" @p cp.async.cg.shared.global [%1], [%2], %3;\n" " setp.ne.b32 p, %0, 0;\n"
"}\n" ::"r"((int)pred), " @p cp.async.cg.shared.global [%1], [%2], %3;\n"
"r"(smem), "l"(glob_ptr), "n"(BYTES), "r"(src_in_bytes)); "}\n" ::"r"((int)pred),
"r"(smem), "l"(glob_ptr), "n"(BYTES), "r"(src_in_bytes));
} }
__device__ inline void cp_async4_pred(void *smem_ptr, const void *glob_ptr, __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
bool pred = true) { bool pred = true) {
const int BYTES = 16; const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr)); uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile("{\n" asm volatile(
" .reg .pred p;\n" "{\n"
" setp.ne.b32 p, %0, 0;\n" " .reg .pred p;\n"
" @p cp.async.cg.shared.global [%1], [%2], %3;\n" " setp.ne.b32 p, %0, 0;\n"
"}\n" ::"r"((int)pred), " @p cp.async.cg.shared.global [%1], [%2], %3;\n"
"r"(smem), "l"(glob_ptr), "n"(BYTES)); "}\n" ::"r"((int)pred),
"r"(smem), "l"(glob_ptr), "n"(BYTES));
} }
// Asynchronous global->shared copy // Asynchronous global->shared copy
__device__ inline void cp_async4(void *smem_ptr, const void *glob_ptr) { __device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
const int BYTES = 16; const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr)); uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile("{\n" asm volatile(
" cp.async.cg.shared.global [%0], [%1], %2;\n" "{\n"
"}\n" ::"r"(smem), " cp.async.cg.shared.global [%0], [%1], %2;\n"
"l"(glob_ptr), "n"(BYTES)); "}\n" ::"r"(smem),
"l"(glob_ptr), "n"(BYTES));
} }
// Async copy fence. // Async copy fence.
@ -64,22 +67,23 @@ __device__ inline void cp_async_fence() {
} }
// Wait until at most `n` async copy stages are still pending. // Wait until at most `n` async copy stages are still pending.
template <int n> __device__ inline void cp_async_wait() { template <int n>
__device__ inline void cp_async_wait() {
asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
} }
// Instruction for loading a full 16x16 matrix fragment of operand A from shared // Instruction for loading a full 16x16 matrix fragment of operand A from shared
// memory, directly in tensor core layout. // memory, directly in tensor core layout.
__device__ inline void ldsm4(FragA &frag_a, const void *smem_ptr) { __device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) {
uint32_t *a = reinterpret_cast<uint32_t *>(&frag_a); uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr)); uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
: "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
: "r"(smem)); : "r"(smem));
} }
__device__ inline void ldsm4_m(FragM &frag_m, const void *smem_ptr) { __device__ inline void ldsm4_m(FragM& frag_m, const void* smem_ptr) {
uint32_t *a = reinterpret_cast<uint32_t *>(&frag_m); uint32_t* a = reinterpret_cast<uint32_t*>(&frag_m);
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr)); uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n" asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n"
: "=r"(a[0]), "=r"(a[1]) : "=r"(a[0]), "=r"(a[1])
@ -88,8 +92,8 @@ __device__ inline void ldsm4_m(FragM &frag_m, const void *smem_ptr) {
// Instruction for loading a full 16x16 matrix fragment of operand A from shared // Instruction for loading a full 16x16 matrix fragment of operand A from shared
// memory, directly in tensor core layout. // memory, directly in tensor core layout.
__device__ inline void ldsm4_t(FragA &frag_a, const void *smem_ptr) { __device__ inline void ldsm4_t(FragA& frag_a, const void* smem_ptr) {
uint32_t *a = reinterpret_cast<uint32_t *>(&frag_a); uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr)); uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile( asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0,%1,%2,%3}, [%4];\n" "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0,%1,%2,%3}, [%4];\n"
@ -98,7 +102,7 @@ __device__ inline void ldsm4_t(FragA &frag_a, const void *smem_ptr) {
} }
// Wait until barrier reaches `count`, then lock for current threadblock. // Wait until barrier reaches `count`, then lock for current threadblock.
__device__ inline void barrier_acquire(int *lock, int count) { __device__ inline void barrier_acquire(int* lock, int count) {
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
int state = -1; int state = -1;
do do
@ -113,7 +117,7 @@ __device__ inline void barrier_acquire(int *lock, int count) {
} }
// Release barrier and increment visitation count. // Release barrier and increment visitation count.
__device__ inline void barrier_release(int *lock, bool reset = false) { __device__ inline void barrier_release(int* lock, bool reset = false) {
__syncthreads(); __syncthreads();
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
if (reset) { if (reset) {
@ -129,4 +133,4 @@ __device__ inline void barrier_release(int *lock, bool reset = false) {
: "l"(lock), "r"(val)); : "l"(lock), "r"(val));
} }
} }
} // namespace marlin_24 } // namespace marlin_24

View File

@ -22,51 +22,56 @@ namespace marlin_24 {
// m16n8k32 sparse tensor core mma instruction with fp16 inputs and fp32 // m16n8k32 sparse tensor core mma instruction with fp16 inputs and fp32
// output/accumulation. // output/accumulation.
__device__ inline void mma_sp(const FragB &a_frag0, const FragB &a_frag1, __device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1,
const FragA &frag_b, FragC &frag_c, FragM &frag_m, const FragA& frag_b, FragC& frag_c, FragM& frag_m,
const int psel) { const int psel) {
const uint32_t *a0 = reinterpret_cast<const uint32_t *>(&a_frag0); const uint32_t* a0 = reinterpret_cast<const uint32_t*>(&a_frag0);
const uint32_t *a1 = reinterpret_cast<const uint32_t *>(&a_frag1); const uint32_t* a1 = reinterpret_cast<const uint32_t*>(&a_frag1);
const uint32_t *b = reinterpret_cast<const uint32_t *>(&frag_b); const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
const uint32_t *e = reinterpret_cast<const uint32_t *>(&frag_m); const uint32_t* e = reinterpret_cast<const uint32_t*>(&frag_m);
float *c = reinterpret_cast<float *>(&frag_c); float* c = reinterpret_cast<float*>(&frag_c);
if (psel == 0) { if (psel == 0) {
asm volatile("mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " asm volatile(
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
"{%12,%13,%14,%15}, %16, 0x0;\n" "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) "{%12,%13,%14,%15}, %16, 0x0;\n"
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]), : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
"r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]), : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]), "r"(b[2]),
"f"(c[2]), "f"(c[3]), "r"(e[0])); "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]),
asm volatile("mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " "r"(e[0]));
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " asm volatile(
"{%12,%13,%14,%15}, %16, 0x0;\n" "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
: "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7]) "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]), "{%12,%13,%14,%15}, %16, 0x0;\n"
"r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]), : "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7])
"f"(c[6]), "f"(c[7]), "r"(e[0])); : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]), "r"(b[3]),
"r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]), "f"(c[6]), "f"(c[7]),
"r"(e[0]));
} else { } else {
asm volatile("mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " asm volatile(
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
"{%12,%13,%14,%15}, %16, 0x1;\n" "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) "{%12,%13,%14,%15}, %16, 0x1;\n"
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]), : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
"r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]), : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]), "r"(b[2]),
"f"(c[2]), "f"(c[3]), "r"(e[0])); "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]),
asm volatile("mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " "r"(e[0]));
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " asm volatile(
"{%12,%13,%14,%15}, %16, 0x1;\n" "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
: "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7]) "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]), "{%12,%13,%14,%15}, %16, 0x1;\n"
"r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]), : "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7])
"f"(c[6]), "f"(c[7]), "r"(e[0])); : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]), "r"(b[3]),
"r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]), "f"(c[6]), "f"(c[7]),
"r"(e[0]));
} }
} }
// Lookup-table based 3-input logical operation; explicitly used for // Lookup-table based 3-input logical operation; explicitly used for
// dequantization as the compiler does not seem to automatically recognize it in // dequantization as the compiler does not seem to automatically recognize it in
// all cases. // all cases.
template <int lut> __device__ inline int lop3(int a, int b, int c) { template <int lut>
__device__ inline int lop3(int a, int b, int c) {
int res; int res;
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(res) : "=r"(res)
@ -120,11 +125,11 @@ __device__ inline FragB dequant_4bit(int q) {
const int ADD = 0xd480d480; const int ADD = 0xd480d480;
FragB frag_b; FragB frag_b;
frag_b[0] = __hsub2(*reinterpret_cast<half2 *>(&lo), frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
*reinterpret_cast<const half2 *>(&SUB)); *reinterpret_cast<const half2*>(&SUB));
frag_b[1] = __hfma2(*reinterpret_cast<half2 *>(&hi), frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
*reinterpret_cast<const half2 *>(&MUL), *reinterpret_cast<const half2*>(&MUL),
*reinterpret_cast<const half2 *>(&ADD)); *reinterpret_cast<const half2*>(&ADD));
return frag_b; return frag_b;
} }
@ -143,24 +148,24 @@ __device__ inline FragB dequant_8bit(int q) {
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
FragB frag_b; FragB frag_b;
frag_b[0] = __hsub2(*reinterpret_cast<half2 *>(&lo), frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
*reinterpret_cast<const half2 *>(&I8s_TO_F16s_MAGIC_NUM)); *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
frag_b[1] = __hsub2(*reinterpret_cast<half2 *>(&hi), frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
*reinterpret_cast<const half2 *>(&I8s_TO_F16s_MAGIC_NUM)); *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
return frag_b; return frag_b;
} }
// Multiply dequantized values by the corresponding quantization scale; used // Multiply dequantized values by the corresponding quantization scale; used
// only for grouped quantization. // only for grouped quantization.
__device__ inline void scale(FragB &frag_b, FragS &frag_s, int i) { __device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) {
half2 s = __half2half2(reinterpret_cast<__half *>(&frag_s)[i]); half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]);
frag_b[0] = __hmul2(frag_b[0], s); frag_b[0] = __hmul2(frag_b[0], s);
frag_b[1] = __hmul2(frag_b[1], s); frag_b[1] = __hmul2(frag_b[1], s);
} }
__device__ inline void scale_floats(float *c0, float *c1, float *c2, float *c3, __device__ inline void scale_floats(float* c0, float* c1, float* c2, float* c3,
FragS &s0, float *c4, float *c5, float *c6, FragS& s0, float* c4, float* c5, float* c6,
float *c7, FragS &s1) { float* c7, FragS& s1) {
*c0 = __fmul_rn(*c0, __half2float(s0[0].x)); *c0 = __fmul_rn(*c0, __half2float(s0[0].x));
*c1 = __fmul_rn(*c1, __half2float(s0[0].y)); *c1 = __fmul_rn(*c1, __half2float(s0[0].y));
*c2 = __fmul_rn(*c2, __half2float(s0[1].x)); *c2 = __fmul_rn(*c2, __half2float(s0[1].x));
@ -172,4 +177,4 @@ __device__ inline void scale_floats(float *c0, float *c1, float *c2, float *c3,
*c7 = __fmul_rn(*c7, __half2float(s1[1].y)); *c7 = __fmul_rn(*c7, __half2float(s1[1].y));
} }
} // namespace marlin_24 } // namespace marlin_24

View File

@ -32,12 +32,15 @@
#else #else
#include "common/mem.h" #include "common/mem.h"
#include "common/mma.h" #include "common/mma.h"
#endif #endif
template <typename T> inline std::string str(T x) { return std::to_string(x); } template <typename T>
inline std::string str(T x) {
return std::to_string(x);
}
namespace marlin_24 { namespace marlin_24 {
@ -45,7 +48,7 @@ namespace marlin_24 {
// than 1 warp per schedule allows some more latency hiding. At the same time, // than 1 warp per schedule allows some more latency hiding. At the same time,
// we want relatively few warps to have many registers per warp and small tiles. // we want relatively few warps to have many registers per warp and small tiles.
static constexpr int THREADS = 256; static constexpr int THREADS = 256;
static constexpr int STAGES = 4; // 4 pipeline stages fit into shared memory static constexpr int STAGES = 4; // 4 pipeline stages fit into shared memory
static constexpr int min_thread_n = 128; static constexpr int min_thread_n = 128;
@ -54,35 +57,36 @@ static constexpr int max_par = 16;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
template <const int num_bits, // weight bits template <const int num_bits, // weight bits
const int threads, // number of threads in a threadblock const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the threadblock // dimension (batchsize) of the
const int thread_n_blocks, // same for n dimension (output) // threadblock
const int thread_k_blocks, // same for k dimension (reduction) const int thread_n_blocks, // same for n dimension (output)
const int stages, // number of stages for the async global->shared const int thread_k_blocks, // same for k dimension (reduction)
// fetch pipeline const int stages, // number of stages for the async global->shared
const int group_blocks = -1 // number of consecutive 16x16 blocks with // fetch pipeline
// a separate quantization scale const int group_blocks = -1 // number of consecutive 16x16 blocks
// with a separate quantization scale
> >
__global__ void Marlin_24( __global__ void Marlin_24(
const int4 *__restrict__ A, // fp16 input matrix of shape mxk const int4* __restrict__ A, // fp16 input matrix of shape mxk
const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
const int4 const int4* __restrict__ meta, // 2bit metadata information about 2:4
*__restrict__ meta, // 2bit metadata information about 2:4 format on B // format on B
int4 *__restrict__ C, // fp16 output buffer of shape mxn int4* __restrict__ C, // fp16 output buffer of shape mxn
const int4 const int4* __restrict__ s, // fp16 quantization scales of shape
*__restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn // (k/groupsize)xn
int prob_m, // batch dimension m int prob_m, // batch dimension m
int prob_n, // output dimension n int prob_n, // output dimension n
int prob_k, // reduction dimension k int prob_k, // reduction dimension k
int *locks // extra global storage for barrier synchronization int* locks // extra global storage for barrier synchronization
) {} ) {}
torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor &b_meta, torch::Tensor& b_meta,
torch::Tensor &b_scales, torch::Tensor& b_scales,
torch::Tensor &workspace, int64_t num_bits, torch::Tensor& workspace, int64_t num_bits,
int64_t size_m, int64_t size_n, int64_t size_m, int64_t size_n,
int64_t size_k) { int64_t size_k) {
TORCH_CHECK_NOT_IMPLEMENTED( TORCH_CHECK_NOT_IMPLEMENTED(
@ -92,29 +96,30 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
#else #else
template <const int num_bits, // weight bits template <const int num_bits, // weight bits
const int threads, // number of threads in a threadblock const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the threadblock // dimension (batchsize) of the
const int thread_n_blocks, // same for n dimension (output) // threadblock
const int thread_k_blocks, // same for k dimension (reduction) const int thread_n_blocks, // same for n dimension (output)
const int stages, // number of stages for the async global->shared const int thread_k_blocks, // same for k dimension (reduction)
// fetch pipeline const int stages, // number of stages for the async global->shared
const int group_blocks = -1 // number of consecutive 16x16 blocks with // fetch pipeline
// a separate quantization scale const int group_blocks = -1 // number of consecutive 16x16 blocks
// with a separate quantization scale
> >
__global__ void Marlin_24( __global__ void Marlin_24(
const int4 *__restrict__ A, // fp16 input matrix of shape mxk const int4* __restrict__ A, // fp16 input matrix of shape mxk
const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
const int4 const int4* __restrict__ meta, // 2bit metadata information about 2:4
*__restrict__ meta, // 2bit metadata information about 2:4 format on B // format on B
int4 *__restrict__ C, // fp16 output buffer of shape mxn int4* __restrict__ C, // fp16 output buffer of shape mxn
const int4 const int4* __restrict__ s, // fp16 quantization scales of shape
*__restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn // (k/groupsize)xn
int prob_m, // batch dimension m int prob_m, // batch dimension m
int prob_n, // output dimension n int prob_n, // output dimension n
int prob_k, // reduction dimension k int prob_k, // reduction dimension k
int *locks // extra global storage for barrier synchronization int* locks // extra global storage for barrier synchronization
) { ) {
// Each threadblock processes one "stripe" of the B matrix with (roughly) the // Each threadblock processes one "stripe" of the B matrix with (roughly) the
// same size, which might involve multiple column "slices" (of width 16 * // same size, which might involve multiple column "slices" (of width 16 *
@ -174,27 +179,22 @@ __global__ void Marlin_24(
auto init_slice = [&]() { auto init_slice = [&]() {
slice_iters = slice_iters =
iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);
if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0;
slice_iters = 0; if (slice_iters == 0) return;
if (slice_iters == 0) if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row;
return;
if (slice_row + slice_iters > k_tiles)
slice_iters = k_tiles - slice_row;
slice_count = 1; slice_count = 1;
slice_idx = 0; slice_idx = 0;
int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); int col_first = iters * ceildiv(k_tiles * slice_col_par, iters);
if (col_first <= k_tiles * (slice_col_par + 1)) { if (col_first <= k_tiles * (slice_col_par + 1)) {
int col_off = col_first - k_tiles * slice_col_par; int col_off = col_first - k_tiles * slice_col_par;
slice_count = ceildiv(k_tiles - col_off, iters); slice_count = ceildiv(k_tiles - col_off, iters);
if (col_off > 0) if (col_off > 0) slice_count++;
slice_count++;
int delta_first = iters * blockIdx.x - col_first; int delta_first = iters * blockIdx.x - col_first;
if (delta_first < 0 || (col_off == 0 && delta_first == 0)) if (delta_first < 0 || (col_off == 0 && delta_first == 0))
slice_idx = slice_count - 1; slice_idx = slice_count - 1;
else { else {
slice_idx = slice_count - 1 - delta_first / iters; slice_idx = slice_count - 1 - delta_first / iters;
if (col_off > 0) if (col_off > 0) slice_idx--;
slice_idx--;
} }
} }
if (slice_col == n_tiles) { if (slice_col == n_tiles) {
@ -207,7 +207,7 @@ __global__ void Marlin_24(
init_slice(); init_slice();
// RLC: 8 is vec_size -> 128-bit instructions, 8 fp16 elements // RLC: 8 is vec_size -> 128-bit instructions, 8 fp16 elements
int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory
// stride of an A matrix tile in shared memory // stride of an A matrix tile in shared memory
constexpr int a_sh_stride = 32 * thread_k_blocks / 8; constexpr int a_sh_stride = 32 * thread_k_blocks / 8;
@ -239,9 +239,9 @@ __global__ void Marlin_24(
constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; constexpr int b_sh_stage = b_sh_stride * thread_k_blocks;
constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
int m_gl_stride = 2 * prob_n / 8; // (16*2*4 / 8) = 16 int m_gl_stride = 2 * prob_n / 8; // (16*2*4 / 8) = 16
constexpr int m_sh_stride = constexpr int m_sh_stride =
(16 * thread_n_blocks) / 4; // #warps n-dim * threads/warp (16 * thread_n_blocks) / 4; // #warps n-dim * threads/warp
int m_gl_rd_delta_o = m_gl_stride * thread_k_blocks; int m_gl_rd_delta_o = m_gl_stride * thread_k_blocks;
int m_gl_rd_delta_i = m_gl_stride * (threads / m_sh_stride); int m_gl_rd_delta_i = m_gl_stride * (threads / m_sh_stride);
constexpr int m_sh_wr_delta = threads / 2; constexpr int m_sh_wr_delta = threads / 2;
@ -305,7 +305,7 @@ __global__ void Marlin_24(
// needed if there are more threads than required for a certain tilesize or // needed if there are more threads than required for a certain tilesize or
// when the batchsize is not a multiple of 16. // when the batchsize is not a multiple of 16.
bool a_sh_wr_pred[a_sh_wr_iters]; bool a_sh_wr_pred[a_sh_wr_iters];
#pragma unroll #pragma unroll
for (int i = 0; i < a_sh_wr_iters; i++) { for (int i = 0; i < a_sh_wr_iters; i++) {
a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m;
} }
@ -325,13 +325,13 @@ __global__ void Marlin_24(
// loop unrolls, all shared memory accesses are static, we simply precompute // loop unrolls, all shared memory accesses are static, we simply precompute
// both transformed reads and writes. // both transformed reads and writes.
int a_sh_wr_trans[a_sh_wr_iters]; int a_sh_wr_trans[a_sh_wr_iters];
#pragma unroll #pragma unroll
for (int i = 0; i < a_sh_wr_iters; i++) for (int i = 0; i < a_sh_wr_iters; i++)
a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr);
int a_sh_rd_trans[2][b_sh_wr_iters][thread_m_blocks]; int a_sh_rd_trans[2][b_sh_wr_iters][thread_m_blocks];
#pragma unroll #pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++) { for (int i = 0; i < b_sh_wr_iters; i++) {
#pragma unroll #pragma unroll
for (int j = 0; j < thread_m_blocks; j++) { for (int j = 0; j < thread_m_blocks; j++) {
a_sh_rd_trans[0][i][j] = a_sh_rd_trans[0][i][j] =
transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd);
@ -344,23 +344,23 @@ __global__ void Marlin_24(
// runtime; we break dependencies between subsequent accesses with a tile by // runtime; we break dependencies between subsequent accesses with a tile by
// maintining multiple pointers (we have enough registers), a tiny // maintining multiple pointers (we have enough registers), a tiny
// optimization. // optimization.
const int4 *B_ptr[b_sh_wr_iters]; const int4* B_ptr[b_sh_wr_iters];
#pragma unroll #pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++) for (int i = 0; i < b_sh_wr_iters; i++)
B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;
bool m_sh_wr_pred = threadIdx.x < m_sh_wr_delta; bool m_sh_wr_pred = threadIdx.x < m_sh_wr_delta;
const int4 *meta_ptr[m_sh_iters]; const int4* meta_ptr[m_sh_iters];
#pragma unroll #pragma unroll
for (int i = 0; i < m_sh_iters; i++) for (int i = 0; i < m_sh_iters; i++)
meta_ptr[i] = meta + m_gl_rd_delta_i * i + m_gl_rd; meta_ptr[i] = meta + m_gl_rd_delta_i * i + m_gl_rd;
extern __shared__ int4 sh[]; extern __shared__ int4 sh[];
// Shared memory storage for global fetch pipelines. // Shared memory storage for global fetch pipelines.
int4 *sh_a = sh; int4* sh_a = sh;
int4 *sh_b = sh_a + (stages * a_sh_stage); int4* sh_b = sh_a + (stages * a_sh_stage);
int4 *sh_s = sh_b + (stages * b_sh_stage); int4* sh_s = sh_b + (stages * b_sh_stage);
int4 *sh_m = sh_s + (stages * s_sh_stage); int4* sh_m = sh_s + (stages * s_sh_stage);
// Register storage for double buffer of shared memory reads. // Register storage for double buffer of shared memory reads.
FragA frag_a[2][thread_m_blocks][2]; FragA frag_a[2][thread_m_blocks][2];
I4 frag_b_quant[2][b_thread_vecs]; I4 frag_b_quant[2][b_thread_vecs];
@ -370,46 +370,43 @@ __global__ void Marlin_24(
// Zero accumulators. // Zero accumulators.
auto zero_accums = [&]() { auto zero_accums = [&]() {
#pragma unroll #pragma unroll
for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++)
reinterpret_cast<float *>(frag_c)[i] = 0; reinterpret_cast<float*>(frag_c)[i] = 0;
}; };
// Asynchronously fetch the next A, B and s tile from global to the next // Asynchronously fetch the next A, B and s tile from global to the next
// shared memory pipeline location. // shared memory pipeline location.
auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) {
if (pred) { if (pred) {
int4 *sh_a_stage = sh_a + a_sh_stage * pipe; int4* sh_a_stage = sh_a + a_sh_stage * pipe;
#pragma unroll #pragma unroll
for (int i = 0; i < a_sh_wr_iters; i++) { for (int i = 0; i < a_sh_wr_iters; i++) {
cp_async4_pred( cp_async4_pred(
&sh_a_stage[a_sh_wr_trans[i]], &sh_a_stage[a_sh_wr_trans[i]],
&A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off],
a_sh_wr_pred[i]); a_sh_wr_pred[i]);
} }
int4 *sh_b_stage = sh_b + b_sh_stage * pipe; int4* sh_b_stage = sh_b + b_sh_stage * pipe;
#pragma unroll #pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++) { for (int i = 0; i < b_sh_wr_iters; i++) {
#pragma unroll #pragma unroll
for (int j = 0; j < b_thread_vecs; j++) { for (int j = 0; j < b_thread_vecs; j++) {
cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j);
B_ptr[i] + j);
} }
B_ptr[i] += b_gl_rd_delta_o; B_ptr[i] += b_gl_rd_delta_o;
} }
int4 *sh_meta_stage = sh_m + m_sh_stage * pipe; int4* sh_meta_stage = sh_m + m_sh_stage * pipe;
#pragma unroll #pragma unroll
for (int i = 0; i < m_sh_iters; i++) { for (int i = 0; i < m_sh_iters; i++) {
if (m_sh_wr_pred) if (m_sh_wr_pred)
cp_async4(&sh_meta_stage[m_sh_wr_delta * i + m_sh_wr], cp_async4(&sh_meta_stage[m_sh_wr_delta * i + m_sh_wr], meta_ptr[i]);
meta_ptr[i]);
meta_ptr[i] += m_gl_rd_delta_o; meta_ptr[i] += m_gl_rd_delta_o;
} }
// Only fetch scales if this tile starts a new group // Only fetch scales if this tile starts a new group
if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) { if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) {
int4 *sh_s_stage = sh_s + s_sh_stage * pipe; int4* sh_s_stage = sh_s + s_sh_stage * pipe;
if (s_sh_wr_pred) if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
s_gl_rd += s_gl_rd_delta; s_gl_rd += s_gl_rd_delta;
} }
} }
@ -436,13 +433,13 @@ __global__ void Marlin_24(
// theoretically better attempts have lead to bad instruction ordering by // theoretically better attempts have lead to bad instruction ordering by
// the compiler and correspondingly a noticeable drop in performance. // the compiler and correspondingly a noticeable drop in performance.
if (group_blocks != -1) { if (group_blocks != -1) {
int4 *sh_s_stage = int4* sh_s_stage =
sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
(pipe / (group_blocks / thread_k_blocks))); (pipe / (group_blocks / thread_k_blocks)));
reinterpret_cast<int4 *>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
} }
int4 *sh_a_stage = sh_a + a_sh_stage * pipe; int4* sh_a_stage = sh_a + a_sh_stage * pipe;
#pragma unroll #pragma unroll
for (int i = 0; i < thread_m_blocks; i++) { for (int i = 0; i < thread_m_blocks; i++) {
ldsm4(frag_a[k % 2][i][0], ldsm4(frag_a[k % 2][i][0],
&sh_a_stage[a_sh_rd_trans[0][k % b_sh_wr_iters][i]]); &sh_a_stage[a_sh_rd_trans[0][k % b_sh_wr_iters][i]]);
@ -450,24 +447,24 @@ __global__ void Marlin_24(
&sh_a_stage[a_sh_rd_trans[1][k % b_sh_wr_iters][i]]); &sh_a_stage[a_sh_rd_trans[1][k % b_sh_wr_iters][i]]);
} }
int4 *sh_b_stage = sh_b + b_sh_stage * pipe; int4* sh_b_stage = sh_b + b_sh_stage * pipe;
#pragma unroll #pragma unroll
for (int i = 0; i < b_thread_vecs; i++) { for (int i = 0; i < b_thread_vecs; i++) {
frag_b_quant[k % 2][i] = *reinterpret_cast<I4 *>( frag_b_quant[k % 2][i] = *reinterpret_cast<I4*>(
&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]);
} }
// Load meta with ldsm4 // Load meta with ldsm4
int4 *sh_m_stage = sh_m + m_sh_stage * pipe; int4* sh_m_stage = sh_m + m_sh_stage * pipe;
ldsm4_m(frag_m[k % 2][0], ldsm4_m(frag_m[k % 2][0],
&sh_m_stage[m_sh_rd_delta * (k % m_sh_iters) + m_sh_rd]); &sh_m_stage[m_sh_rd_delta * (k % m_sh_iters) + m_sh_rd]);
}; };
// Execute the actual tensor core matmul of a sub-tile. // Execute the actual tensor core matmul of a sub-tile.
auto matmul = [&](int k) { auto matmul = [&](int k) {
// We have the m dimension as the inner loop in order to encourage overlapping // We have the m dimension as the inner loop in order to encourage overlapping
// dequantization and matmul operations. // dequantization and matmul operations.
#pragma unroll #pragma unroll
for (int j = 0; j < 4; j++) { for (int j = 0; j < 4; j++) {
FragB frag_b0; FragB frag_b0;
FragB frag_b1; FragB frag_b1;
@ -480,7 +477,7 @@ __global__ void Marlin_24(
frag_b1 = dequant_4bit(b_quant_shift); frag_b1 = dequant_4bit(b_quant_shift);
} else { } else {
int *frag_b_quant_ptr = reinterpret_cast<int *>(frag_b_quant[k % 2]); int* frag_b_quant_ptr = reinterpret_cast<int*>(frag_b_quant[k % 2]);
int b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; int b_quant_0 = frag_b_quant_ptr[j * 2 + 0];
int b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; int b_quant_1 = frag_b_quant_ptr[j * 2 + 1];
@ -497,7 +494,7 @@ __global__ void Marlin_24(
scale(frag_b1, frag_s[k % 2][j], 1); scale(frag_b1, frag_s[k % 2][j], 1);
} }
#pragma unroll #pragma unroll
for (int i = 0; i < thread_m_blocks; i++) { for (int i = 0; i < thread_m_blocks; i++) {
mma_sp(frag_b0, frag_b1, frag_a[k % 2][i][0], frag_c[i][j][0], mma_sp(frag_b0, frag_b1, frag_a[k % 2][i][0], frag_c[i][j][0],
frag_m[k % 2][j / 2], j % 2); frag_m[k % 2][j / 2], j % 2);
@ -518,41 +515,41 @@ __global__ void Marlin_24(
int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) +
(threadIdx.x % b_sh_stride_threads); (threadIdx.x % b_sh_stride_threads);
// Parallel logarithmic shared memory reduction. We make sure to avoid any // Parallel logarithmic shared memory reduction. We make sure to avoid any
// unnecessary read or write iterations, e.g., for two warps we write only once // unnecessary read or write iterations, e.g., for two warps we write only
// by warp 1 and read only once by warp 0. // once by warp 1 and read only once by warp 0.
#pragma unroll #pragma unroll
for (int m_block = 0; m_block < thread_m_blocks; m_block++) { for (int m_block = 0; m_block < thread_m_blocks; m_block++) {
#pragma unroll #pragma unroll
for (int i = red_off; i > 0; i /= 2) { for (int i = red_off; i > 0; i /= 2) {
if (i <= red_idx && red_idx < 2 * i) { if (i <= red_idx && red_idx < 2 * i) {
#pragma unroll #pragma unroll
for (int j = 0; j < 4 * 2; j++) { for (int j = 0; j < 4 * 2; j++) {
int red_sh_wr = int red_sh_wr =
red_sh_delta * j + (red_sh_rd - red_sh_stride * i); red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
if (i < red_off) { if (i < red_off) {
float *c_rd = reinterpret_cast<float *>( float* c_rd =
&sh[red_sh_delta * j + red_sh_rd]); reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]);
float *c_wr = reinterpret_cast<float *>(&sh[red_sh_wr]); float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]);
#pragma unroll #pragma unroll
for (int k = 0; k < 4; k++) for (int k = 0; k < 4; k++)
reinterpret_cast<FragC *>(frag_c)[4 * 2 * m_block + j][k] += reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] +=
c_rd[k] + c_wr[k]; c_rd[k] + c_wr[k];
} }
sh[red_sh_wr] = sh[red_sh_wr] =
reinterpret_cast<int4 *>(&frag_c)[4 * 2 * m_block + j]; reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];
} }
} }
__syncthreads(); __syncthreads();
} }
if (red_idx == 0) { if (red_idx == 0) {
#pragma unroll #pragma unroll
for (int i = 0; i < 4 * 2; i++) { for (int i = 0; i < 4 * 2; i++) {
float *c_rd = float* c_rd =
reinterpret_cast<float *>(&sh[red_sh_delta * i + red_sh_rd]); reinterpret_cast<float*>(&sh[red_sh_delta * i + red_sh_rd]);
#pragma unroll #pragma unroll
for (int j = 0; j < 4; j++) for (int j = 0; j < 4; j++)
reinterpret_cast<FragC *>(frag_c)[4 * 2 * m_block + i][j] += reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] +=
c_rd[j]; c_rd[j];
} }
} }
@ -562,9 +559,9 @@ __global__ void Marlin_24(
}; };
// Since multiple threadblocks may process parts of the same column slice, we // Since multiple threadblocks may process parts of the same column slice, we
// finally have to globally reduce over the results. As the striped partitioning // finally have to globally reduce over the results. As the striped
// minimizes the number of such reductions and our outputs are usually rather // partitioning minimizes the number of such reductions and our outputs are
// small, we perform this reduction serially in L2 cache. // usually rather small, we perform this reduction serially in L2 cache.
auto global_reduce = [&](bool first = false, bool last = false) { auto global_reduce = [&](bool first = false, bool last = false) {
// We are very careful here to reduce directly in the output buffer to // We are very careful here to reduce directly in the output buffer to
// maximize L2 cache utilization in this step. To do this, we write out // maximize L2 cache utilization in this step. To do this, we write out
@ -574,7 +571,7 @@ __global__ void Marlin_24(
int c_gl_stride = prob_n / 8; int c_gl_stride = prob_n / 8;
int c_gl_wr_delta_o = 2 * 4 * c_gl_stride; int c_gl_wr_delta_o = 2 * 4 * c_gl_stride;
int c_gl_wr_delta_i = int c_gl_wr_delta_i =
c_gl_stride; // 8 threads (e.g., 0,4,8,12,16,20,24,28) c_gl_stride; // 8 threads (e.g., 0,4,8,12,16,20,24,28)
int c_gl_wr = 2 * c_gl_stride * (threadIdx.x % 4) + int c_gl_wr = 2 * c_gl_stride * (threadIdx.x % 4) +
8 * (threadIdx.x / 32) + (threadIdx.x % 32) / 4; 8 * (threadIdx.x / 32) + (threadIdx.x % 32) / 4;
c_gl_wr += (2 * thread_n_blocks) * slice_col; c_gl_wr += (2 * thread_n_blocks) * slice_col;
@ -584,10 +581,10 @@ __global__ void Marlin_24(
int col = 2 * ((threadIdx.x % 32) % 4); int col = 2 * ((threadIdx.x % 32) % 4);
if (!first) { if (!first) {
// Interestingly, doing direct global accesses here really seems to mess up the // Interestingly, doing direct global accesses here really seems to mess up
// compiler and lead to slowdowns, hence we also use async-copies even though // the compiler and lead to slowdowns, hence we also use async-copies even
// these fetches are not actually asynchronous. // though these fetches are not actually asynchronous.
#pragma unroll #pragma unroll
for (int i = 0; i < thread_m_blocks * 4; i++) { for (int i = 0; i < thread_m_blocks * 4; i++) {
cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i],
&C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
@ -599,32 +596,32 @@ __global__ void Marlin_24(
cp_async_wait<0>(); cp_async_wait<0>();
} }
#pragma unroll #pragma unroll
for (int i = 0; i < thread_m_blocks * 4; i++) { for (int i = 0; i < thread_m_blocks * 4; i++) {
if (i < (thread_m_blocks - 1) * 4 || if (i < (thread_m_blocks - 1) * 4 ||
8 * (i / 2) + col + (i % 2) < prob_m) { 8 * (i / 2) + col + (i % 2) < prob_m) {
if (!first) { if (!first) {
int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta];
#pragma unroll #pragma unroll
for (int j2 = 0; j2 < 2; j2++) { for (int j2 = 0; j2 < 2; j2++) {
#pragma unroll #pragma unroll
for (int j1 = 0; j1 < 4; j1++) { for (int j1 = 0; j1 < 4; j1++) {
reinterpret_cast<float *>( reinterpret_cast<float*>(
&frag_c)[4 * 2 * 4 * (i / 4) + 8 * j1 + 2 * j2 + &frag_c)[4 * 2 * 4 * (i / 4) + 8 * j1 + 2 * j2 +
4 * ((i % 4) / 2) + i % 2] += 4 * ((i % 4) / 2) + i % 2] +=
__half2float( __half2float(
reinterpret_cast<__half *>(&c_red)[(j2 * 4 + j1)]); reinterpret_cast<__half*>(&c_red)[(j2 * 4 + j1)]);
} }
} }
} }
if (!last) { if (!last) {
int4 c; int4 c;
#pragma unroll #pragma unroll
for (int j2 = 0; j2 < 2; j2++) { for (int j2 = 0; j2 < 2; j2++) {
#pragma unroll #pragma unroll
for (int j1 = 0; j1 < 4; j1++) { for (int j1 = 0; j1 < 4; j1++) {
reinterpret_cast<__half *>(&c)[(j2 * 4 + j1)] = reinterpret_cast<__half*>(&c)[(j2 * 4 + j1)] =
__float2half(reinterpret_cast<float *>( __float2half(reinterpret_cast<float*>(
&frag_c)[4 * 2 * 4 * (i / 4) + 8 * j1 + 2 * j2 + &frag_c)[4 * 2 * 4 * (i / 4) + 8 * j1 + 2 * j2 +
4 * ((i % 4) / 2) + i % 2]); 4 * ((i % 4) / 2) + i % 2]);
} }
@ -643,9 +640,9 @@ __global__ void Marlin_24(
auto write_result = [&]() { auto write_result = [&]() {
int c_gl_stride = prob_n / 8; int c_gl_stride = prob_n / 8;
constexpr int c_sh_stride = 2 * thread_n_blocks; // RLC: constexpr int c_sh_stride = 2 * thread_n_blocks; // RLC:
constexpr int c_sh_stride_2 = 2 * c_sh_stride + 2; // RLC: constexpr int c_sh_stride_2 = 2 * c_sh_stride + 2; // RLC:
constexpr int c_sh_stride_3 = 2 * (2 * thread_n_blocks) + 2; // RLC: constexpr int c_sh_stride_3 = 2 * (2 * thread_n_blocks) + 2; // RLC:
int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));
@ -654,22 +651,22 @@ __global__ void Marlin_24(
c_gl_wr += (2 * thread_n_blocks) * slice_col; c_gl_wr += (2 * thread_n_blocks) * slice_col;
int c_sh_wr = c_sh_stride_2 * ((threadIdx.x % 32) % 4) + int c_sh_wr = c_sh_stride_2 * ((threadIdx.x % 32) % 4) +
((threadIdx.x % 32) / 4); // RLC: ((threadIdx.x % 32) / 4); // RLC:
c_sh_wr += 8 * (threadIdx.x / 32); // 128/4(half4) c_sh_wr += 8 * (threadIdx.x / 32); // 128/4(half4)
constexpr int c_sh_rd_delta = constexpr int c_sh_rd_delta =
c_sh_stride_3 * (threads / (2 * 2 * thread_n_blocks)); // RLC: c_sh_stride_3 * (threads / (2 * 2 * thread_n_blocks)); // RLC:
int c_sh_rd = c_sh_stride_3 * (threadIdx.x / (2 * 2 * thread_n_blocks)) + int c_sh_rd = c_sh_stride_3 * (threadIdx.x / (2 * 2 * thread_n_blocks)) +
(threadIdx.x % (2 * 2 * thread_n_blocks)); (threadIdx.x % (2 * 2 * thread_n_blocks));
int c_gl_wr_end = c_gl_stride * prob_m; int c_gl_wr_end = c_gl_stride * prob_m;
auto write = [&](int idx, float c0, float c1, float c2, float c3, FragS &s0, auto write = [&](int idx, float c0, float c1, float c2, float c3, FragS& s0,
float c4, float c5, float c6, float c7, FragS &s1) { float c4, float c5, float c6, float c7, FragS& s1) {
uint2 res[2]; uint2 res[2];
res[0] = to_half4(c0, c1, c2, c3); res[0] = to_half4(c0, c1, c2, c3);
res[1] = to_half4(c4, c5, c6, c7); res[1] = to_half4(c4, c5, c6, c7);
half2 *tmp = (half2 *)&res; half2* tmp = (half2*)&res;
// for per-column quantization we finally apply the scale here // for per-column quantization we finally apply the scale here
if constexpr (group_blocks == -1 && num_bits == 4) { if constexpr (group_blocks == -1 && num_bits == 4) {
tmp[0] = __hmul2(tmp[0], s0[0]); tmp[0] = __hmul2(tmp[0], s0[0]);
@ -677,12 +674,12 @@ __global__ void Marlin_24(
tmp[2] = __hmul2(tmp[2], s1[0]); tmp[2] = __hmul2(tmp[2], s1[0]);
tmp[3] = __hmul2(tmp[3], s1[1]); tmp[3] = __hmul2(tmp[3], s1[1]);
} }
((int4 *)sh)[idx] = *((int4 *)&res[0]); ((int4*)sh)[idx] = *((int4*)&res[0]);
}; };
// RLC: only warp 0 and 1 baseline example // RLC: only warp 0 and 1 baseline example
if (threadIdx.x / 32 < thread_n_blocks / 4) { if (threadIdx.x / 32 < thread_n_blocks / 4) {
#pragma unroll #pragma unroll
for (int i = 0; i < thread_m_blocks; i++) { for (int i = 0; i < thread_m_blocks; i++) {
int wr = c_sh_wr; int wr = c_sh_wr;
write(wr, frag_c[i][0][0][0], frag_c[i][1][0][0], frag_c[i][2][0][0], write(wr, frag_c[i][0][0][0], frag_c[i][1][0][0], frag_c[i][2][0][0],
@ -707,7 +704,7 @@ __global__ void Marlin_24(
} }
__syncthreads(); __syncthreads();
#pragma unroll #pragma unroll
for (int i = 0; for (int i = 0;
i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks));
i++) { i++) {
@ -721,9 +718,8 @@ __global__ void Marlin_24(
// Start global fetch and register load pipelines. // Start global fetch and register load pipelines.
auto start_pipes = [&]() { auto start_pipes = [&]() {
#pragma unroll #pragma unroll
for (int i = 0; i < stages - 1; i++) for (int i = 0; i < stages - 1; i++) fetch_to_shared(i, i, i < slice_iters);
fetch_to_shared(i, i, i < slice_iters);
zero_accums(); zero_accums();
wait_for_stage(); wait_for_stage();
fetch_to_registers(0, 0); fetch_to_registers(0, 0);
@ -733,10 +729,10 @@ __global__ void Marlin_24(
// Main loop. // Main loop.
while (slice_iters) { while (slice_iters) {
// We unroll over both the global fetch and the register load pipeline to ensure // We unroll over both the global fetch and the register load pipeline to
// all shared memory accesses are static. Note that both pipelines have even // ensure all shared memory accesses are static. Note that both pipelines have
// length meaning that the next iteration will always start at index 0. // even length meaning that the next iteration will always start at index 0.
#pragma unroll #pragma unroll
for (int pipe = 0; pipe < stages;) { for (int pipe = 0; pipe < stages;) {
fetch_to_shared((pipe + stages - 1) % stages, pipe, fetch_to_shared((pipe + stages - 1) % stages, pipe,
slice_iters >= stages); slice_iters >= stages);
@ -747,8 +743,7 @@ __global__ void Marlin_24(
pipe++; pipe++;
slice_iters--; slice_iters--;
if (slice_iters == 0) if (slice_iters == 0) break;
break;
} }
a_gl_rd += a_gl_rd_delta_o * stages; a_gl_rd += a_gl_rd_delta_o * stages;
@ -762,13 +757,11 @@ __global__ void Marlin_24(
// write-out // write-out
if constexpr (group_blocks == -1) { if constexpr (group_blocks == -1) {
if constexpr (num_bits == 8) { if constexpr (num_bits == 8) {
if (s_sh_wr_pred) if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]);
cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]);
cp_async_fence(); cp_async_fence();
} else { } else {
if (last) { if (last) {
if (s_sh_wr_pred) if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]);
cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]);
cp_async_fence(); cp_async_fence();
} }
} }
@ -780,14 +773,14 @@ __global__ void Marlin_24(
cp_async_wait<0>(); cp_async_wait<0>();
__syncthreads(); __syncthreads();
if (threadIdx.x / 32 < thread_n_blocks / 4) { if (threadIdx.x / 32 < thread_n_blocks / 4) {
*(float4 *)(frag_s) = *(float4 *)(&sh_s[s_sh_rd]); *(float4*)(frag_s) = *(float4*)(&sh_s[s_sh_rd]);
} }
} else { } else {
if (last) { if (last) {
cp_async_wait<0>(); cp_async_wait<0>();
__syncthreads(); __syncthreads();
if (threadIdx.x / 32 < thread_n_blocks / 4) { if (threadIdx.x / 32 < thread_n_blocks / 4) {
*(float4 *)(frag_s) = *(float4 *)(&sh_s[s_sh_rd]); *(float4*)(frag_s) = *(float4*)(&sh_s[s_sh_rd]);
} }
} }
} }
@ -798,7 +791,7 @@ __global__ void Marlin_24(
// overflow in fp16) // overflow in fp16)
if constexpr (group_blocks == -1 && num_bits == 8) { if constexpr (group_blocks == -1 && num_bits == 8) {
if (threadIdx.x / 32 < thread_n_blocks / 4) { if (threadIdx.x / 32 < thread_n_blocks / 4) {
#pragma unroll #pragma unroll
for (int i = 0; i < thread_m_blocks; i++) { for (int i = 0; i < thread_m_blocks; i++) {
scale_floats(&frag_c[i][0][0][0], &frag_c[i][1][0][0], scale_floats(&frag_c[i][0][0][0], &frag_c[i][1][0][0],
&frag_c[i][2][0][0], &frag_c[i][3][0][0], frag_s[0][0], &frag_c[i][2][0][0], &frag_c[i][3][0][0], frag_s[0][0],
@ -827,13 +820,13 @@ __global__ void Marlin_24(
} }
} }
if (slice_count > 1) { // only globally reduce if there is more than one if (slice_count > 1) { // only globally reduce if there is more than one
// block in a slice // block in a slice
barrier_acquire(&locks[slice_col], slice_idx); barrier_acquire(&locks[slice_col], slice_idx);
global_reduce(slice_idx == 0, last); global_reduce(slice_idx == 0, last);
barrier_release(&locks[slice_col], last); barrier_release(&locks[slice_col], last);
} }
if (last) // only the last block in a slice actually writes the result if (last) // only the last block in a slice actually writes the result
write_result(); write_result();
slice_row = 0; slice_row = 0;
@ -843,19 +836,17 @@ __global__ void Marlin_24(
if (slice_iters) { if (slice_iters) {
a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
(threadIdx.x % a_gl_rd_delta_o); (threadIdx.x % a_gl_rd_delta_o);
#pragma unroll #pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++) for (int i = 0; i < b_sh_wr_iters; i++)
B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
#pragma unroll #pragma unroll
for (int i = 0; i < m_sh_iters; i++) for (int i = 0; i < m_sh_iters; i++)
meta_ptr[i] += (m_sh_stride)-m_gl_rd_delta_o * k_tiles; meta_ptr[i] += (m_sh_stride)-m_gl_rd_delta_o * k_tiles;
if (slice_col == 0) { if (slice_col == 0) {
#pragma unroll #pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++) for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
B_ptr[i] -= b_gl_stride; #pragma unroll
#pragma unroll for (int i = 0; i < m_sh_iters; i++) meta_ptr[i] -= m_gl_stride;
for (int i = 0; i < m_sh_iters; i++)
meta_ptr[i] -= m_gl_stride;
} }
s_gl_rd = s_sh_stride * slice_col + threadIdx.x; s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
start_pipes(); start_pipes();
@ -866,26 +857,26 @@ __global__ void Marlin_24(
#endif #endif
#define CALL_IF_2_4(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ #define CALL_IF_2_4(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
THREAD_K_BLOCKS, GROUP_BLOCKS) \ THREAD_K_BLOCKS, GROUP_BLOCKS) \
else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \ else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \
thread_n_blocks == THREAD_N_BLOCKS && \ thread_n_blocks == THREAD_N_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \ thread_k_blocks == THREAD_K_BLOCKS && \
group_blocks == GROUP_BLOCKS) { \ group_blocks == GROUP_BLOCKS) { \
cudaFuncSetAttribute( \ cudaFuncSetAttribute( \
Marlin_24<NUM_BITS, THREADS, THREAD_N_BLOCKS, THREAD_M_BLOCKS, \ Marlin_24<NUM_BITS, THREADS, THREAD_N_BLOCKS, THREAD_M_BLOCKS, \
THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS>, \ THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
Marlin_24<NUM_BITS, THREADS, THREAD_N_BLOCKS, THREAD_M_BLOCKS, \ Marlin_24<NUM_BITS, THREADS, THREAD_N_BLOCKS, THREAD_M_BLOCKS, \
THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS> \ THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS> \
<<<blocks, THREADS, max_shared_mem, stream>>>(A_ptr, B_ptr, meta_ptr, \ <<<blocks, THREADS, max_shared_mem, stream>>>(A_ptr, B_ptr, meta_ptr, \
C_ptr, s_ptr, prob_n, \ C_ptr, s_ptr, prob_n, \
prob_m, prob_k, locks); \ prob_m, prob_k, locks); \
} }
void marlin_cuda_2_4(const void *A, const void *B, const void *meta, void *C, void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C,
void *s, int prob_m, int prob_n, int prob_k, void* s, int prob_m, int prob_n, int prob_k,
void *workspace, int num_bits, int groupsize = -1, void* workspace, int num_bits, int groupsize = -1,
int dev = 0, cudaStream_t stream = 0, int thread_k = -1, int dev = 0, cudaStream_t stream = 0, int thread_k = -1,
int thread_m = -1, int sms = -1, int max_par = 16) { int thread_m = -1, int sms = -1, int max_par = 16) {
int tot_n = prob_n; int tot_n = prob_n;
@ -904,8 +895,8 @@ void marlin_cuda_2_4(const void *A, const void *B, const void *meta, void *C,
if (thread_k == -1 || thread_m == -1) { if (thread_k == -1 || thread_m == -1) {
if (prob_n <= 16) { if (prob_n <= 16) {
// For small batchizes, better partitioningif is slightly more important than // For small batchizes, better partitioningif is slightly more important
// better compute utilization // than better compute utilization
thread_k = 128; thread_k = 128;
thread_m = 128; thread_m = 128;
} else { } else {
@ -914,7 +905,7 @@ void marlin_cuda_2_4(const void *A, const void *B, const void *meta, void *C,
} }
} }
int thread_k_blocks = thread_k / 32; // 2:4 version with m16n8k32 instruction int thread_k_blocks = thread_k / 32; // 2:4 version with m16n8k32 instruction
int thread_m_blocks = thread_m / 16; int thread_m_blocks = thread_m / 16;
int group_blocks = (groupsize == -1) ? -1 : groupsize / 16; int group_blocks = (groupsize == -1) ? -1 : groupsize / 16;
int blocks = sms; int blocks = sms;
@ -931,13 +922,13 @@ void marlin_cuda_2_4(const void *A, const void *B, const void *meta, void *C,
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
", ", prob_n, ", ", prob_k, "]"); ", ", prob_n, ", ", prob_k, "]");
const int4 *A_ptr = (const int4 *)A; const int4* A_ptr = (const int4*)A;
const int4 *B_ptr = (const int4 *)B; const int4* B_ptr = (const int4*)B;
const int4 *meta_ptr = (const int4 *)meta; const int4* meta_ptr = (const int4*)meta;
int4 *C_ptr = (int4 *)C; int4* C_ptr = (int4*)C;
const int4 *s_ptr = (const int4 *)s; const int4* s_ptr = (const int4*)s;
int *locks = (int *)workspace; int* locks = (int*)workspace;
for (int i = 0; i < tot_n_blocks; i += 4) { for (int i = 0; i < tot_n_blocks; i += 4) {
int thread_n_blocks = tot_n_blocks - i; int thread_n_blocks = tot_n_blocks - i;
prob_n = tot_n - 16 * i; prob_n = tot_n - 16 * i;
@ -946,8 +937,7 @@ void marlin_cuda_2_4(const void *A, const void *B, const void *meta, void *C,
// Note that parallel > 1 currently only works for inputs without any // Note that parallel > 1 currently only works for inputs without any
// padding // padding
par = (16 * thread_n_blocks - pad) / 64; par = (16 * thread_n_blocks - pad) / 64;
if (par > max_par) if (par > max_par) par = max_par;
par = max_par;
prob_n = 64 * par; prob_n = 64 * par;
i += 4 * (par - 1); i += 4 * (par - 1);
thread_n_blocks = 4; thread_n_blocks = 4;
@ -959,13 +949,13 @@ void marlin_cuda_2_4(const void *A, const void *B, const void *meta, void *C,
// the false is start of the CALL_IF macros // the false is start of the CALL_IF macros
if (false) { if (false) {
} // BMxBNxBK, group } // BMxBNxBK, group
// 4-bit // 4-bit
CALL_IF_2_4(4, 8, 1, 4, -1) // e.g., 16x128x128 CALL_IF_2_4(4, 8, 1, 4, -1) // e.g., 16x128x128
CALL_IF_2_4(4, 8, 1, 4, 4) // e.g., 16x128x128, 64 CALL_IF_2_4(4, 8, 1, 4, 4) // e.g., 16x128x128, 64
CALL_IF_2_4(4, 16, 1, 2, -1) // e.g., 16x256x64 CALL_IF_2_4(4, 16, 1, 2, -1) // e.g., 16x256x64
CALL_IF_2_4(4, 16, 1, 2, 4) // e.g., 16x256x64, 64 CALL_IF_2_4(4, 16, 1, 2, 4) // e.g., 16x256x64, 64
CALL_IF_2_4(4, 16, 2, 2, -1) // e.g.. 32x256x64 CALL_IF_2_4(4, 16, 2, 2, -1) // e.g.. 32x256x64
CALL_IF_2_4(4, 16, 2, 2, 4) CALL_IF_2_4(4, 16, 2, 2, 4)
CALL_IF_2_4(4, 16, 3, 2, -1) CALL_IF_2_4(4, 16, 3, 2, -1)
CALL_IF_2_4(4, 16, 3, 2, 4) CALL_IF_2_4(4, 16, 3, 2, 4)
@ -973,11 +963,11 @@ void marlin_cuda_2_4(const void *A, const void *B, const void *meta, void *C,
CALL_IF_2_4(4, 16, 4, 2, 4) CALL_IF_2_4(4, 16, 4, 2, 4)
// 8-bit // 8-bit
CALL_IF_2_4(8, 8, 1, 4, -1) // e.g., 16x128x128 CALL_IF_2_4(8, 8, 1, 4, -1) // e.g., 16x128x128
CALL_IF_2_4(8, 8, 1, 4, 4) // e.g., 16x128x128, 64 CALL_IF_2_4(8, 8, 1, 4, 4) // e.g., 16x128x128, 64
CALL_IF_2_4(8, 16, 1, 2, -1) // e.g., 16x256x64 CALL_IF_2_4(8, 16, 1, 2, -1) // e.g., 16x256x64
CALL_IF_2_4(8, 16, 1, 2, 4) // e.g., 16x256x64, 64 CALL_IF_2_4(8, 16, 1, 2, 4) // e.g., 16x256x64, 64
CALL_IF_2_4(8, 16, 2, 2, -1) // e.g.. 32x256x64 CALL_IF_2_4(8, 16, 2, 2, -1) // e.g.. 32x256x64
CALL_IF_2_4(8, 16, 2, 2, 4) CALL_IF_2_4(8, 16, 2, 2, 4)
CALL_IF_2_4(8, 16, 3, 2, -1) CALL_IF_2_4(8, 16, 3, 2, -1)
CALL_IF_2_4(8, 16, 3, 2, 4) CALL_IF_2_4(8, 16, 3, 2, 4)
@ -997,12 +987,12 @@ void marlin_cuda_2_4(const void *A, const void *B, const void *meta, void *C,
} }
} }
} // namespace marlin_24 } // namespace marlin_24
torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor &b_meta, torch::Tensor& b_meta,
torch::Tensor &b_scales, torch::Tensor& b_scales,
torch::Tensor &workspace, int64_t num_bits, torch::Tensor& workspace, int64_t num_bits,
int64_t size_m, int64_t size_n, int64_t size_m, int64_t size_n,
int64_t size_k) { int64_t size_k) {
// Verify num_bits // Verify num_bits
@ -1037,9 +1027,9 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
" is not divisible by tile_size = " + str(marlin_24::tile_size)); " is not divisible by tile_size = " + str(marlin_24::tile_size));
int actual_size_n = (b_q_weight.size(1) / marlin_24::tile_size) * pack_factor; int actual_size_n = (b_q_weight.size(1) / marlin_24::tile_size) * pack_factor;
TORCH_CHECK(size_n == actual_size_n, TORCH_CHECK(
"size_n = " + str(size_n) + size_n == actual_size_n,
", actual_size_n = " + str(actual_size_n)); "size_n = " + str(size_n) + ", actual_size_n = " + str(actual_size_n));
// Verify meta // Verify meta
TORCH_CHECK(b_meta.size(0) == size_k / 8 / 2 / 2, TORCH_CHECK(b_meta.size(0) == size_k / 8 / 2 / 2,
@ -1081,7 +1071,7 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
", is not divisible by b_scales.size(0) = " + ", is not divisible by b_scales.size(0) = " +
str(b_scales.size(0))); str(b_scales.size(0)));
groupsize = size_k / b_scales.size(0); groupsize = size_k / b_scales.size(0);
groupsize /= 2; // Because of 24 groupsize /= 2; // Because of 24
} }
// Verify groupsize // Verify groupsize

View File

@ -22,27 +22,23 @@ __device__ inline unsigned int as_unsigned(int i) {
// 4-bit matvec kernel (LUT-based) // 4-bit matvec kernel (LUT-based)
__global__ void NUQ4MatMulKernel( __global__ void NUQ4MatMulKernel(
#ifndef USE_ROCM #ifndef USE_ROCM
const half2* __restrict__ vec, const half2* __restrict__ vec,
#else #else
const __half2* __restrict__ vec, const __half2* __restrict__ vec,
#endif #endif
const int* __restrict__ mat, const int* __restrict__ mat,
#ifndef USE_ROCM #ifndef USE_ROCM
half2* __restrict__ mul, half2* __restrict__ mul,
#else #else
float2* __restrict__ mul, float2* __restrict__ mul,
#endif #endif
const __half* __restrict__ lookup_table, const __half* __restrict__ lookup_table, int height, int width, int batch,
int height, int vec_height) {
int width,
int batch,
int vec_height
) {
const int blockwidth2 = BLOCKWIDTH / 2; const int blockwidth2 = BLOCKWIDTH / 2;
int row = BLOCKHEIGHT4 * blockIdx.x; int row = BLOCKHEIGHT4 * blockIdx.x;
int col = BLOCKWIDTH * blockIdx.y + threadIdx.x; int col = BLOCKWIDTH * blockIdx.y + threadIdx.x;
#ifndef USE_ROCM #ifndef USE_ROCM
__shared__ half2 blockvec[blockwidth2]; __shared__ half2 blockvec[blockwidth2];
@ -73,14 +69,16 @@ __global__ void NUQ4MatMulKernel(
unsigned int tmp1; unsigned int tmp1;
unsigned int lut_index1, lut_index2; unsigned int lut_index1, lut_index2;
for (int b = 0; b < batch; ++b){ for (int b = 0; b < batch; ++b) {
i = width * row + col; i = width * row + col;
res = __int2half_rd(0); res = __int2half_rd(0);
k = 0; k = 0;
__syncthreads(); __syncthreads();
if (threadIdx.x < blockwidth2) if (threadIdx.x < blockwidth2)
blockvec[threadIdx.x] = vec[b * vec_height / 2 + (row / BLOCKHEIGHT4) * blockwidth2 + threadIdx.x]; blockvec[threadIdx.x] =
vec[b * vec_height / 2 + (row / BLOCKHEIGHT4) * blockwidth2 +
threadIdx.x];
__syncthreads(); __syncthreads();
while (k < blockwidth2) { while (k < blockwidth2) {
@ -143,7 +141,8 @@ __global__ void NUQ4MatMulKernel(
#ifndef USE_ROCM #ifndef USE_ROCM
res = __hadd(__hadd(res2.x, res2.y), res); res = __hadd(__hadd(res2.x, res2.y), res);
#else #else
res = __hadd(__hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)), res); res = __hadd(__hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)),
res);
#endif #endif
i += width; i += width;
@ -179,46 +178,38 @@ __global__ void NUQ4MatMulKernel(
} }
} }
} // namespace squeezellm } // namespace squeezellm
} // namespace vllm } // namespace vllm
// 4-bit matvec kernel (LUT-based) // 4-bit matvec kernel (LUT-based)
void squeezellm_gemm( void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor vec, torch::Tensor lookup_table) {
torch::Tensor mat,
torch::Tensor mul,
torch::Tensor lookup_table
) {
int height = mat.size(0); int height = mat.size(0);
int width = mat.size(1); int width = mat.size(1);
int batch = vec.size(0); int batch = vec.size(0);
int vec_height = vec.size(1); int vec_height = vec.size(1);
dim3 blocks( dim3 blocks((height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
(height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4, (width + BLOCKWIDTH - 1) / BLOCKWIDTH);
(width + BLOCKWIDTH - 1) / BLOCKWIDTH
);
dim3 threads(BLOCKWIDTH); dim3 threads(BLOCKWIDTH);
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
vllm::squeezellm::NUQ4MatMulKernel<<<blocks, threads, 0, stream>>>( vllm::squeezellm::NUQ4MatMulKernel<<<blocks, threads, 0, stream>>>(
#ifndef USE_ROCM #ifndef USE_ROCM
(half2*) vec.data<at::Half>(), (half2*)vec.data<at::Half>(),
#else #else
(__half2*) vec.data_ptr<at::Half>(), (__half2*)vec.data_ptr<at::Half>(),
#endif #endif
mat.data_ptr<int>(), mat.data_ptr<int>(),
#ifndef USE_ROCM #ifndef USE_ROCM
(half2*) mul.data<at::Half>(), (half2*)mul.data<at::Half>(), (__half*)lookup_table.data<at::Half>(),
(__half*) lookup_table.data<at::Half>(),
#else #else
(float2*) mul.data_ptr<float>(), (float2*)mul.data_ptr<float>(),
(__half*) lookup_table.data_ptr<at::Half>(), (__half*)lookup_table.data_ptr<at::Half>(),
#endif #endif
height, width, batch, vec_height height, width, batch, vec_height);
);
} }
#undef BLOCKWIDTH #undef BLOCKWIDTH

View File

@ -1,5 +1,6 @@
/* /*
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/reduce_kernel_utils.cuh * Adapted from
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/reduce_kernel_utils.cuh
* Copyright (c) 2023, The vLLM team. * Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
* *
@ -20,12 +21,12 @@
#include "cuda_compat.h" #include "cuda_compat.h"
namespace vllm { namespace vllm {
template<typename T, int numLanes = WARP_SIZE> template <typename T, int numLanes = WARP_SIZE>
__inline__ __device__ T warpReduceSum(T val) { __inline__ __device__ T warpReduceSum(T val) {
static_assert(numLanes > 0 && (numLanes & (numLanes - 1)) == 0, static_assert(numLanes > 0 && (numLanes & (numLanes - 1)) == 0,
"numLanes is not a positive power of 2!"); "numLanes is not a positive power of 2!");
static_assert(numLanes <= WARP_SIZE); static_assert(numLanes <= WARP_SIZE);
#pragma unroll #pragma unroll
for (int mask = numLanes >> 1; mask > 0; mask >>= 1) for (int mask = numLanes >> 1; mask > 0; mask >>= 1)
val += VLLM_SHFL_XOR_SYNC(val, mask); val += VLLM_SHFL_XOR_SYNC(val, mask);
return val; return val;
@ -38,22 +39,23 @@ static constexpr int _nextPow2(unsigned int num) {
} }
/* Calculate the sum of all elements in a block */ /* Calculate the sum of all elements in a block */
template<typename T, int maxBlockSize = 1024> template <typename T, int maxBlockSize = 1024>
__inline__ __device__ T blockReduceSum(T val) { __inline__ __device__ T blockReduceSum(T val) {
static_assert(maxBlockSize <= 1024); static_assert(maxBlockSize <= 1024);
if constexpr (maxBlockSize > WARP_SIZE) { if constexpr (maxBlockSize > WARP_SIZE) {
val = warpReduceSum<T>(val); val = warpReduceSum<T>(val);
// Calculates max number of lanes that need to participate in the last warpReduce // Calculates max number of lanes that need to participate in the last
// warpReduce
constexpr int maxActiveLanes = (maxBlockSize + WARP_SIZE - 1) / WARP_SIZE; constexpr int maxActiveLanes = (maxBlockSize + WARP_SIZE - 1) / WARP_SIZE;
static __shared__ T shared[maxActiveLanes]; static __shared__ T shared[maxActiveLanes];
int lane = threadIdx.x % WARP_SIZE; int lane = threadIdx.x % WARP_SIZE;
int wid = threadIdx.x / WARP_SIZE; int wid = threadIdx.x / WARP_SIZE;
if (lane == 0) if (lane == 0) shared[wid] = val;
shared[wid] = val;
__syncthreads(); __syncthreads();
val = (threadIdx.x < blockDim.x / float(WARP_SIZE)) ? shared[lane] : (T)(0.0f); val = (threadIdx.x < blockDim.x / float(WARP_SIZE)) ? shared[lane]
: (T)(0.0f);
val = warpReduceSum<T, _nextPow2(maxActiveLanes)>(val); val = warpReduceSum<T, _nextPow2(maxActiveLanes)>(val);
} else { } else {
// A single warpReduce is equal to blockReduce // A single warpReduce is equal to blockReduce
@ -62,4 +64,4 @@ __inline__ __device__ T blockReduceSum(T val) {
return val; return val;
} }
} // namespace vllm } // namespace vllm

View File

@ -26,6 +26,7 @@ RUFF_VERSION=$(ruff --version | awk '{print $2}')
MYPY_VERSION=$(mypy --version | awk '{print $2}') MYPY_VERSION=$(mypy --version | awk '{print $2}')
CODESPELL_VERSION=$(codespell --version) CODESPELL_VERSION=$(codespell --version)
ISORT_VERSION=$(isort --vn) ISORT_VERSION=$(isort --vn)
CLANGFORMAT_VERSION=$(clang-format --version | awk '{print $3}')
# # params: tool name, tool version, required version # # params: tool name, tool version, required version
tool_version_check() { tool_version_check() {
@ -40,6 +41,7 @@ tool_version_check "ruff" $RUFF_VERSION "$(grep "ruff==" requirements-dev.txt |
tool_version_check "mypy" "$MYPY_VERSION" "$(grep mypy requirements-dev.txt | cut -d'=' -f3)" tool_version_check "mypy" "$MYPY_VERSION" "$(grep mypy requirements-dev.txt | cut -d'=' -f3)"
tool_version_check "isort" "$ISORT_VERSION" "$(grep isort requirements-dev.txt | cut -d'=' -f3)" tool_version_check "isort" "$ISORT_VERSION" "$(grep isort requirements-dev.txt | cut -d'=' -f3)"
tool_version_check "codespell" "$CODESPELL_VERSION" "$(grep codespell requirements-dev.txt | cut -d'=' -f3)" tool_version_check "codespell" "$CODESPELL_VERSION" "$(grep codespell requirements-dev.txt | cut -d'=' -f3)"
tool_version_check "clang-format" "$CLANGFORMAT_VERSION" "$(grep clang-format requirements-dev.txt | cut -d'=' -f3)"
YAPF_FLAGS=( YAPF_FLAGS=(
'--recursive' '--recursive'
@ -179,7 +181,6 @@ lint_changed() {
} }
# Run Ruff # Run Ruff
echo 'vLLM ruff:'
### This flag lints individual files. --files *must* be the first command line ### This flag lints individual files. --files *must* be the first command line
### arg to use this option. ### arg to use this option.
if [[ "$1" == '--files' ]]; then if [[ "$1" == '--files' ]]; then
@ -192,6 +193,7 @@ else
# Format only the files that changed in last commit. # Format only the files that changed in last commit.
lint_changed lint_changed
fi fi
echo 'vLLM ruff: Done'
# check spelling of specified files # check spelling of specified files
isort_check() { isort_check() {
@ -233,6 +235,59 @@ else
fi fi
echo 'vLLM isort: Done' echo 'vLLM isort: Done'
# Clang-format section
# Exclude some files for formatting because they are vendored
# NOTE: Keep up to date with .github/workflows/clang-format.yml
CLANG_FORMAT_EXCLUDES=(
'csrc/moe/topk_softmax_kernels.cu'
'csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu'
'csrc/punica/bgmv/bgmv_config.h'
'csrc/punica/bgmv/bgmv_impl.cuh'
'csrc/punica/bgmv/vec_dtypes.cuh'
'csrc/punica/punica_ops.cu'
'csrc/punica/type_convert.h'
)
# Format specified files with clang-format
clang_format() {
clang-format -i "$@"
}
# Format files that differ from main branch with clang-format.
clang_format_changed() {
# The `if` guard ensures that the list of filenames is not empty, which
# could cause clang-format to receive 0 positional arguments, making it hang
# waiting for STDIN.
#
# `diff-filter=ACM` and $MERGEBASE is to ensure we only format files that
# exist on both branches.
MERGEBASE="$(git merge-base origin/main HEAD)"
# Get the list of changed files, excluding the specified ones
changed_files=$(git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.h' '*.cpp' '*.cu' '*.cuh' | grep -vFf <(printf "%s\n" "${CLANG_FORMAT_EXCLUDES[@]}"))
if [ -n "$changed_files" ]; then
echo "$changed_files" | xargs -P 5 clang-format -i
fi
}
# Format all files with clang-format
clang_format_all() {
find csrc/ \( -name '*.h' -o -name '*.cpp' -o -name '*.cu' -o -name '*.cuh' \) -print \
| grep -vFf <(printf "%s\n" "${CLANG_FORMAT_EXCLUDES[@]}") \
| xargs clang-format -i
}
# Run clang-format
if [[ "$1" == '--files' ]]; then
clang_format "${@:2}"
elif [[ "$1" == '--all' ]]; then
clang_format_all
else
clang_format_changed
fi
echo 'vLLM clang-format: Done'
if ! git diff --quiet &>/dev/null; then if ! git diff --quiet &>/dev/null; then
echo 'Reformatted files. Please review and stage the changes.' echo 'Reformatted files. Please review and stage the changes.'
echo 'Changes not staged for commit:' echo 'Changes not staged for commit:'

View File

@ -5,6 +5,7 @@ tomli==2.0.1
ruff==0.1.5 ruff==0.1.5
codespell==2.2.6 codespell==2.2.6
isort==5.13.2 isort==5.13.2
clang-format==18.1.5
# type checking # type checking
mypy==1.9.0 mypy==1.9.0