[CI/Build] Enforce style for C++ and CUDA code with clang-format (#4722)

This commit is contained in:
Michael Goin 2024-05-22 03:18:41 -04:00 committed by GitHub
parent 9b9a10d6cb
commit 5f6d10c14c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
64 changed files with 6398 additions and 6790 deletions

26
.clang-format Normal file
View File

@ -0,0 +1,26 @@
BasedOnStyle: Google
UseTab: Never
IndentWidth: 2
ColumnLimit: 80
# Force pointers to the type for C++.
DerivePointerAlignment: false
PointerAlignment: Left
# Reordering #include statements can (and currently will) introduce errors
SortIncludes: false
# Style choices
AlignConsecutiveAssignments: false
AlignConsecutiveDeclarations: false
IndentPPDirectives: BeforeHash
IncludeCategories:
- Regex: '^<'
Priority: 4
- Regex: '^"(llvm|llvm-c|clang|clang-c|mlir|mlir-c)/'
Priority: 3
- Regex: '^"(qoda|\.\.)/'
Priority: 2
- Regex: '.*'
Priority: 1

42
.github/workflows/clang-format.yml vendored Normal file
View File

@ -0,0 +1,42 @@
name: clang-format
on:
# Trigger the workflow on push or pull request,
# but only for the main branch
push:
branches:
- main
pull_request:
branches:
- main
jobs:
clang-format:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.11"]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install clang-format==18.1.5
- name: Running clang-format
run: |
EXCLUDES=(
'csrc/moe/topk_softmax_kernels.cu'
'csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu'
'csrc/punica/bgmv/bgmv_config.h'
'csrc/punica/bgmv/bgmv_impl.cuh'
'csrc/punica/bgmv/vec_dtypes.cuh'
'csrc/punica/punica_ops.cu'
'csrc/punica/type_convert.h'
)
find csrc/ \( -name '*.h' -o -name '*.cpp' -o -name '*.cu' -o -name '*.cuh' \) -print \
| grep -vFf <(printf "%s\n" "${EXCLUDES[@]}") \
| xargs clang-format --dry-run --Werror

View File

@ -63,31 +63,25 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \ VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), \ input.scalar_type(), "act_and_mul_kernel", [&] { \
"act_and_mul_kernel", \ vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>> \
[&] { \ <<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \ input.data_ptr<scalar_t>(), d); \
out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), \
d); \
}); });
void silu_and_mul( void silu_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d] torch::Tensor& input) // [..., 2 * d]
{ {
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel); LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
} }
void gelu_and_mul( void gelu_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d] torch::Tensor& input) // [..., 2 * d]
{ {
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel); LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel);
} }
void gelu_tanh_and_mul( void gelu_tanh_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d] torch::Tensor& input) // [..., 2 * d]
{ {
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel); LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel);
@ -118,14 +112,10 @@ __global__ void activation_kernel(
dim3 block(std::min(d, 1024)); \ dim3 block(std::min(d, 1024)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \ VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "activation_kernel", [&] { \
input.scalar_type(), \ vllm::activation_kernel<scalar_t, KERNEL<scalar_t>> \
"activation_kernel", \ <<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
[&] { \ input.data_ptr<scalar_t>(), d); \
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), \
d); \
}); });
namespace vllm { namespace vllm {
@ -140,21 +130,20 @@ __device__ __forceinline__ T gelu_new_kernel(const T& x) {
template <typename T> template <typename T>
__device__ __forceinline__ T gelu_fast_kernel(const T& x) { __device__ __forceinline__ T gelu_fast_kernel(const T& x) {
const float f = (float)x; const float f = (float)x;
const T t = (T) tanhf(((T) (f * 0.79788456f)) * (((T) 1.0) + (T) (0.044715f * f) * x)); const T t =
(T)tanhf(((T)(f * 0.79788456f)) * (((T)1.0) + (T)(0.044715f * f) * x));
return ((T)0.5) * x * (((T)1.0) + t); return ((T)0.5) * x * (((T)1.0) + t);
} }
} // namespace vllm } // namespace vllm
void gelu_new( void gelu_new(torch::Tensor& out, // [..., d]
torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., d] torch::Tensor& input) // [..., d]
{ {
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel); LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
} }
void gelu_fast( void gelu_fast(torch::Tensor& out, // [..., d]
torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., d] torch::Tensor& input) // [..., d]
{ {
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel); LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);

View File

@ -1,5 +1,6 @@
/* /*
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h * Adapted from
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
* Copyright (c) 2023, The vLLM team. * Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
* *

View File

@ -1,5 +1,6 @@
/* /*
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp * Adapted from
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* Copyright (c) 2023, The vLLM team. * Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
* *
@ -82,30 +83,27 @@ inline __device__ float block_sum(float* red_smem, float sum) {
// TODO(woosuk): Merge the last two dimensions of the grid. // TODO(woosuk): Merge the last two dimensions of the grid.
// Grid: (num_heads, num_seqs, max_num_partitions). // Grid: (num_heads, num_seqs, max_num_partitions).
template< template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
typename scalar_t, int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
typename cache_t,
int HEAD_SIZE,
int BLOCK_SIZE,
int NUM_THREADS,
vllm::Fp8KVCacheDataType KV_DTYPE,
int PARTITION_SIZE = 0> // Zero means no partitioning. int PARTITION_SIZE = 0> // Zero means no partitioning.
__device__ void paged_attention_kernel( __device__ void paged_attention_kernel(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ max_logits, // [num_seqs, num_heads,
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size] // max_num_partitions]
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
// head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] // head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size]
const int num_kv_heads, // [num_heads] const int num_kv_heads, // [num_heads]
const float scale, const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ seq_lens, // [num_seqs] const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const int kv_block_stride,
const int kv_head_stride,
const float kv_scale) { const float kv_scale) {
const int seq_idx = blockIdx.y; const int seq_idx = blockIdx.y;
const int partition_idx = blockIdx.z; const int partition_idx = blockIdx.z;
@ -118,22 +116,29 @@ __device__ void paged_attention_kernel(
} }
const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE); const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks; const int num_blocks_per_partition =
USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks;
// [start_block_idx, end_block_idx) is the range of blocks to process. // [start_block_idx, end_block_idx) is the range of blocks to process.
const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; const int start_block_idx =
const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks); USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
const int end_block_idx =
MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks);
const int num_blocks = end_block_idx - start_block_idx; const int num_blocks = end_block_idx - start_block_idx;
// [start_token_idx, end_token_idx) is the range of tokens to process. // [start_token_idx, end_token_idx) is the range of tokens to process.
const int start_token_idx = start_block_idx * BLOCK_SIZE; const int start_token_idx = start_block_idx * BLOCK_SIZE;
const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len); const int end_token_idx =
MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len);
const int num_tokens = end_token_idx - start_token_idx; const int num_tokens = end_token_idx - start_token_idx;
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS constexpr int NUM_THREAD_GROUPS =
NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE
// divides NUM_THREADS
assert(NUM_THREADS % THREAD_GROUP_SIZE == 0); assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
constexpr int NUM_TOKENS_PER_THREAD_GROUP = DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE); constexpr int NUM_TOKENS_PER_THREAD_GROUP =
DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE);
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
const int thread_idx = threadIdx.x; const int thread_idx = threadIdx.x;
const int warp_idx = thread_idx / WARP_SIZE; const int warp_idx = thread_idx / WARP_SIZE;
@ -143,13 +148,14 @@ __device__ void paged_attention_kernel(
const int num_heads = gridDim.x; const int num_heads = gridDim.x;
const int num_queries_per_kv = num_heads / num_kv_heads; const int num_queries_per_kv = num_heads / num_kv_heads;
const int kv_head_idx = head_idx / num_queries_per_kv; const int kv_head_idx = head_idx / num_queries_per_kv;
const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; const float alibi_slope =
alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
// A vector type to store a part of a key or a query. // A vector type to store a part of a key or a query.
// The vector size is configured in such a way that the threads in a thread group // The vector size is configured in such a way that the threads in a thread
// fetch or compute 16 bytes at a time. // group fetch or compute 16 bytes at a time. For example, if the size of a
// For example, if the size of a thread group is 4 and the data type is half, // thread group is 4 and the data type is half, then the vector size is 16 /
// then the vector size is 16 / (4 * sizeof(half)) == 2. // (4 * sizeof(half)) == 2.
constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1); constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type; using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type; using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
@ -163,18 +169,21 @@ __device__ void paged_attention_kernel(
// Load the query to registers. // Load the query to registers.
// Each thread in a thread group has a different part of the query. // Each thread in a thread group has a different part of the query.
// For example, if the the thread group size is 4, then the first thread in the group // For example, if the the thread group size is 4, then the first thread in
// has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ... // the group has 0, 4, 8, ... th vectors of the query, and the second thread
// th vectors of the query, and so on. // has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because
// NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous. // q is split from a qkv tensor, it may not be contiguous.
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD]; __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
#pragma unroll #pragma unroll
for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) { for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD;
i += NUM_THREAD_GROUPS) {
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
q_vecs[thread_group_offset][i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE); q_vecs[thread_group_offset][i] =
*reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
} }
__syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a
// memory wall right before we use q_vecs
// Memory planning. // Memory planning.
extern __shared__ char shared_mem[]; extern __shared__ char shared_mem[];
@ -193,44 +202,50 @@ __device__ void paged_attention_kernel(
// Each thread group in a warp fetches a key from the block, and computes // Each thread group in a warp fetches a key from the block, and computes
// dot product with the query. // dot product with the query.
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) { for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
// NOTE(woosuk): The block number is stored in int32. However, we cast it to int64 block_idx += NUM_WARPS) {
// because int32 can lead to overflow when this variable is multiplied by large numbers // NOTE(woosuk): The block number is stored in int32. However, we cast it to
// (e.g., kv_block_stride). // int64 because int32 can lead to overflow when this variable is multiplied
const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]); // by large numbers (e.g., kv_block_stride).
const int64_t physical_block_number =
static_cast<int64_t>(block_table[block_idx]);
// Load a key to registers. // Load a key to registers.
// Each thread in a thread group has a different part of the key. // Each thread in a thread group has a different part of the key.
// For example, if the the thread group size is 4, then the first thread in the group // For example, if the the thread group size is 4, then the first thread in
// has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th // the group has 0, 4, 8, ... th vectors of the key, and the second thread
// vectors of the key, and so on. // has 1, 5, 9, ... th vectors of the key, and so on.
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; const int physical_block_offset =
(thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
K_vec k_vecs[NUM_VECS_PER_THREAD]; K_vec k_vecs[NUM_VECS_PER_THREAD];
#pragma unroll #pragma unroll
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride const cache_t* k_ptr =
+ kv_head_idx * kv_head_stride k_cache + physical_block_number * kv_block_stride +
+ physical_block_offset * x; kv_head_idx * kv_head_stride + physical_block_offset * x;
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
const int offset1 = (vec_idx * VEC_SIZE) / x; const int offset1 = (vec_idx * VEC_SIZE) / x;
const int offset2 = (vec_idx * VEC_SIZE) % x; const int offset2 = (vec_idx * VEC_SIZE) % x;
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2); k_vecs[j] = *reinterpret_cast<const K_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
} else { } else {
// Vector conversion from Quant_vec to K_vec. // Vector conversion from Quant_vec to K_vec.
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>( Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2); k_ptr + offset1 * BLOCK_SIZE * x + offset2);
k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(k_vec_quant, kv_scale); k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(
k_vec_quant, kv_scale);
} }
} }
// Compute dot product. // Compute dot product.
// This includes a reduction across the threads in the same thread group. // This includes a reduction across the threads in the same thread group.
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs); float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(
q_vecs[thread_group_offset], k_vecs);
// Add the ALiBi bias if slopes are given. // Add the ALiBi bias if slopes are given.
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0; qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0;
@ -285,13 +300,12 @@ __device__ void paged_attention_kernel(
// If partitioning is enabled, store the max logit and exp_sum. // If partitioning is enabled, store the max logit and exp_sum.
if (USE_PARTITIONING && thread_idx == 0) { if (USE_PARTITIONING && thread_idx == 0) {
float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions float* max_logits_ptr = max_logits +
+ head_idx * max_num_partitions seq_idx * num_heads * max_num_partitions +
+ partition_idx; head_idx * max_num_partitions + partition_idx;
*max_logits_ptr = qk_max; *max_logits_ptr = qk_max;
float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions +
+ head_idx * max_num_partitions head_idx * max_num_partitions + partition_idx;
+ partition_idx;
*exp_sums_ptr = exp_sum; *exp_sums_ptr = exp_sum;
} }
@ -304,7 +318,8 @@ __device__ void paged_attention_kernel(
constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW; constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
constexpr int NUM_ROWS_PER_THREAD = DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER); constexpr int NUM_ROWS_PER_THREAD =
DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy. // NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
float accs[NUM_ROWS_PER_THREAD]; float accs[NUM_ROWS_PER_THREAD];
@ -315,18 +330,21 @@ __device__ void paged_attention_kernel(
scalar_t zero_value; scalar_t zero_value;
zero(zero_value); zero(zero_value);
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) { for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
// NOTE(woosuk): The block number is stored in int32. However, we cast it to int64 block_idx += NUM_WARPS) {
// because int32 can lead to overflow when this variable is multiplied by large numbers // NOTE(woosuk): The block number is stored in int32. However, we cast it to
// (e.g., kv_block_stride). // int64 because int32 can lead to overflow when this variable is multiplied
const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]); // by large numbers (e.g., kv_block_stride).
const int64_t physical_block_number =
static_cast<int64_t>(block_table[block_idx]);
const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
L_vec logits_vec; L_vec logits_vec;
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx - start_token_idx)); from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx -
start_token_idx));
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
+ kv_head_idx * kv_head_stride; kv_head_idx * kv_head_stride;
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
@ -337,14 +355,17 @@ __device__ void paged_attention_kernel(
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset); v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
} else { } else {
V_quant_vec v_quant_vec = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset); V_quant_vec v_quant_vec =
*reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
// Vector conversion from V_quant_vec to V_vec. // Vector conversion from V_quant_vec to V_vec.
v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec, kv_scale); v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec,
kv_scale);
} }
if (block_idx == num_seq_blocks - 1) { if (block_idx == num_seq_blocks - 1) {
// NOTE(woosuk): When v_vec contains the tokens that are out of the context, // NOTE(woosuk): When v_vec contains the tokens that are out of the
// we should explicitly zero out the values since they may contain NaNs. // context, we should explicitly zero out the values since they may
// See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 // contain NaNs. See
// https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec); scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
#pragma unroll #pragma unroll
for (int j = 0; j < V_VEC_SIZE; j++) { for (int j = 0; j < V_VEC_SIZE; j++) {
@ -367,8 +388,8 @@ __device__ void paged_attention_kernel(
accs[i] = acc; accs[i] = acc;
} }
// NOTE(woosuk): A barrier is required because the shared memory space for logits // NOTE(woosuk): A barrier is required because the shared memory space for
// is reused for the output. // logits is reused for the output.
__syncthreads(); __syncthreads();
// Perform reduction across warps. // Perform reduction across warps.
@ -405,9 +426,9 @@ __device__ void paged_attention_kernel(
// Write the final output. // Write the final output.
if (warp_idx == 0) { if (warp_idx == 0) {
scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE scalar_t* out_ptr =
+ head_idx * max_num_partitions * HEAD_SIZE out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
+ partition_idx * HEAD_SIZE; head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE;
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
@ -419,77 +440,73 @@ __device__ void paged_attention_kernel(
} }
// Grid: (num_heads, num_seqs, 1). // Grid: (num_heads, num_seqs, 1).
template< template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
typename scalar_t,
typename cache_t,
int HEAD_SIZE,
int BLOCK_SIZE,
int NUM_THREADS, int NUM_THREADS,
vllm::Fp8KVCacheDataType KV_DTYPE> vllm::Fp8KVCacheDataType KV_DTYPE>
__global__ void paged_attention_v1_kernel( __global__ void paged_attention_v1_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] // head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size]
const int num_kv_heads, // [num_heads] const int num_kv_heads, // [num_heads]
const float scale, const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ seq_lens, // [num_seqs] const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const int kv_block_stride,
const int kv_head_stride,
const float kv_scale) { const float kv_scale) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_DTYPE>( paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
/* exp_sums */ nullptr, /* max_logits */ nullptr, KV_DTYPE>(
out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, seq_lens, /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache,
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_scale); v_cache, num_kv_heads, scale, block_tables, seq_lens,
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride,
kv_head_stride, kv_scale);
} }
// Grid: (num_heads, num_seqs, max_num_partitions). // Grid: (num_heads, num_seqs, max_num_partitions).
template< template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
typename scalar_t, int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
typename cache_t,
int HEAD_SIZE,
int BLOCK_SIZE,
int NUM_THREADS,
vllm::Fp8KVCacheDataType KV_DTYPE,
int PARTITION_SIZE> int PARTITION_SIZE>
__global__ void paged_attention_v2_kernel( __global__ void paged_attention_v2_kernel(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ max_logits, // [num_seqs, num_heads,
scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] // max_num_partitions]
scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
// max_num_partitions, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] // head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size]
const int num_kv_heads, // [num_heads] const int num_kv_heads, // [num_heads]
const float scale, const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ seq_lens, // [num_seqs] const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const int kv_block_stride,
const int kv_head_stride,
const float kv_scale) { const float kv_scale) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_DTYPE, PARTITION_SIZE>( paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
KV_DTYPE, PARTITION_SIZE>(
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride,
q_stride, kv_block_stride, kv_head_stride, kv_scale); kv_block_stride, kv_head_stride, kv_scale);
} }
// Grid: (num_heads, num_seqs). // Grid: (num_heads, num_seqs).
template< template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS,
typename scalar_t,
int HEAD_SIZE,
int NUM_THREADS,
int PARTITION_SIZE> int PARTITION_SIZE>
__global__ void paged_attention_v2_reduce_kernel( __global__ void paged_attention_v2_reduce_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] const float* __restrict__ exp_sums, // [num_seqs, num_heads,
const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] // max_num_partitions]
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] const float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions]
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
// max_num_partitions, head_size]
const int* __restrict__ seq_lens, // [num_seqs] const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_partitions) { const int max_num_partitions) {
const int num_heads = gridDim.x; const int num_heads = gridDim.x;
@ -499,9 +516,11 @@ __global__ void paged_attention_v2_reduce_kernel(
const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE); const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE);
if (num_partitions == 1) { if (num_partitions == 1) {
// No need to reduce. Only copy tmp_out to out. // No need to reduce. Only copy tmp_out to out.
scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; scalar_t* out_ptr =
const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
+ head_idx * max_num_partitions * HEAD_SIZE; const scalar_t* tmp_out_ptr =
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
head_idx * max_num_partitions * HEAD_SIZE;
for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) { for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) {
out_ptr[i] = tmp_out_ptr[i]; out_ptr[i] = tmp_out_ptr[i];
} }
@ -520,8 +539,9 @@ __global__ void paged_attention_v2_reduce_kernel(
// Load max logits to shared memory. // Load max logits to shared memory.
float* shared_max_logits = reinterpret_cast<float*>(shared_mem); float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
const float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions const float* max_logits_ptr = max_logits +
+ head_idx * max_num_partitions; seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions;
float max_logit = -FLT_MAX; float max_logit = -FLT_MAX;
for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
const float l = max_logits_ptr[i]; const float l = max_logits_ptr[i];
@ -550,9 +570,11 @@ __global__ void paged_attention_v2_reduce_kernel(
max_logit = VLLM_SHFL_SYNC(max_logit, 0); max_logit = VLLM_SHFL_SYNC(max_logit, 0);
// Load rescaled exp sums to shared memory. // Load rescaled exp sums to shared memory.
float* shared_exp_sums = reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions); float* shared_exp_sums =
const float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
+ head_idx * max_num_partitions; const float* exp_sums_ptr = exp_sums +
seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions;
float global_exp_sum = 0.0f; float global_exp_sum = 0.0f;
for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
float l = shared_max_logits[i]; float l = shared_max_logits[i];
@ -565,14 +587,17 @@ __global__ void paged_attention_v2_reduce_kernel(
const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f); const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
// Aggregate tmp_out to out. // Aggregate tmp_out to out.
const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE const scalar_t* tmp_out_ptr =
+ head_idx * max_num_partitions * HEAD_SIZE; tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; head_idx * max_num_partitions * HEAD_SIZE;
scalar_t* out_ptr =
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
#pragma unroll #pragma unroll
for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) { for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
float acc = 0.0f; float acc = 0.0f;
for (int j = 0; j < num_partitions; ++j) { for (int j = 0; j < num_partitions; ++j) {
acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * inv_global_exp_sum; acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] *
inv_global_exp_sum;
} }
from_float(out_ptr[i], acc); from_float(out_ptr[i], acc);
} }
@ -582,44 +607,25 @@ __global__ void paged_attention_v2_reduce_kernel(
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
((void*)vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \ ((void*)vllm::paged_attention_v1_kernel< \
KV_DTYPE>), shared_mem_size); \ T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_DTYPE>), \
vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \ shared_mem_size); \
KV_DTYPE><<<grid, block, shared_mem_size, stream>>>( \ vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
out_ptr, \ NUM_THREADS, KV_DTYPE> \
query_ptr, \ <<<grid, block, shared_mem_size, stream>>>( \
key_cache_ptr, \ out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \
value_cache_ptr, \ scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
num_kv_heads, \ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
scale, \
block_tables_ptr, \
seq_lens_ptr, \
max_num_blocks_per_seq, \
alibi_slopes_ptr, \
q_stride, \
kv_block_stride, \
kv_head_stride, \
kv_scale); kv_scale);
// TODO(woosuk): Tune NUM_THREADS. // TODO(woosuk): Tune NUM_THREADS.
template< template <typename T, typename CACHE_T, int BLOCK_SIZE,
typename T, vllm::Fp8KVCacheDataType KV_DTYPE, int NUM_THREADS = 128>
typename CACHE_T,
int BLOCK_SIZE,
vllm::Fp8KVCacheDataType KV_DTYPE,
int NUM_THREADS = 128>
void paged_attention_v1_launcher( void paged_attention_v1_launcher(
torch::Tensor& out, torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& query, torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& key_cache, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
torch::Tensor& value_cache, const c10::optional<torch::Tensor>& alibi_slopes, float kv_scale) {
int num_kv_heads,
float scale,
torch::Tensor& block_tables,
torch::Tensor& seq_lens,
int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
float kv_scale) {
int num_seqs = query.size(0); int num_seqs = query.size(0);
int num_heads = query.size(1); int num_heads = query.size(1);
int head_size = query.size(2); int head_size = query.size(2);
@ -632,8 +638,9 @@ void paged_attention_v1_launcher(
assert(head_size % thread_group_size == 0); assert(head_size % thread_group_size == 0);
// NOTE: alibi_slopes is optional. // NOTE: alibi_slopes is optional.
const float* alibi_slopes_ptr = alibi_slopes ? const float* alibi_slopes_ptr =
reinterpret_cast<const float*>(alibi_slopes.value().data_ptr()) alibi_slopes
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr; : nullptr;
T* out_ptr = reinterpret_cast<T*>(out.data_ptr()); T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
@ -644,7 +651,8 @@ void paged_attention_v1_launcher(
int* seq_lens_ptr = seq_lens.data_ptr<int>(); int* seq_lens_ptr = seq_lens.data_ptr<int>();
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int padded_max_seq_len = DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; int padded_max_seq_len =
DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
int logits_size = padded_max_seq_len * sizeof(float); int logits_size = padded_max_seq_len * sizeof(float);
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
@ -685,17 +693,8 @@ void paged_attention_v1_launcher(
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE) \ #define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE) \
paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE>( \ paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE>( \
out, \ out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
query, \ seq_lens, max_seq_len, alibi_slopes, kv_scale);
key_cache, \
value_cache, \
num_kv_heads, \
scale, \
block_tables, \
seq_lens, \
max_seq_len, \
alibi_slopes, \
kv_scale);
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes // NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256. // 1, 2, 4, 64, 128, 256.
@ -718,72 +717,43 @@ void paged_attention_v1_launcher(
void paged_attention_v1( void paged_attention_v1(
torch::Tensor& out, // [num_seqs, num_heads, head_size] torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size] torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] torch::Tensor&
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor&
value_cache, // [num_blocks, num_heads, head_size, block_size]
int num_kv_heads, // [num_heads] int num_kv_heads, // [num_heads]
float scale, float scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& seq_lens, // [num_seqs] torch::Tensor& seq_lens, // [num_seqs]
int block_size, int block_size, int max_seq_len,
int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, const std::string& kv_cache_dtype, float kv_scale){
float kv_scale) {
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, CALL_V1_LAUNCHER_BLOCK_SIZE)
}
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
CALL_V1_LAUNCHER_BLOCK_SIZE)}
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \ #define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
vllm::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \ vllm::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
KV_DTYPE, PARTITION_SIZE> \ NUM_THREADS, KV_DTYPE, PARTITION_SIZE> \
<<<grid, block, shared_mem_size, stream>>>( \ <<<grid, block, shared_mem_size, stream>>>( \
exp_sums_ptr, \ exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
max_logits_ptr, \ value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
tmp_out_ptr, \ seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
query_ptr, \ kv_block_stride, kv_head_stride, kv_scale); \
key_cache_ptr, \ vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \
value_cache_ptr, \ PARTITION_SIZE> \
num_kv_heads, \
scale, \
block_tables_ptr, \
seq_lens_ptr, \
max_num_blocks_per_seq, \
alibi_slopes_ptr, \
q_stride, \
kv_block_stride, \
kv_head_stride, \
kv_scale); \
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, PARTITION_SIZE> \
<<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \ <<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
out_ptr, \ out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
exp_sums_ptr, \
max_logits_ptr, \
tmp_out_ptr, \
seq_lens_ptr, \
max_num_partitions); max_num_partitions);
template< template <typename T, typename CACHE_T, int BLOCK_SIZE,
typename T, vllm::Fp8KVCacheDataType KV_DTYPE, int NUM_THREADS = 128,
typename CACHE_T,
int BLOCK_SIZE,
vllm::Fp8KVCacheDataType KV_DTYPE,
int NUM_THREADS = 128,
int PARTITION_SIZE = 512> int PARTITION_SIZE = 512>
void paged_attention_v2_launcher( void paged_attention_v2_launcher(
torch::Tensor& out, torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& exp_sums, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& max_logits, torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& tmp_out, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
torch::Tensor& query, const c10::optional<torch::Tensor>& alibi_slopes, float kv_scale) {
torch::Tensor& key_cache,
torch::Tensor& value_cache,
int num_kv_heads,
float scale,
torch::Tensor& block_tables,
torch::Tensor& seq_lens,
int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
float kv_scale) {
int num_seqs = query.size(0); int num_seqs = query.size(0);
int num_heads = query.size(1); int num_heads = query.size(1);
int head_size = query.size(2); int head_size = query.size(2);
@ -796,8 +766,9 @@ void paged_attention_v2_launcher(
assert(head_size % thread_group_size == 0); assert(head_size % thread_group_size == 0);
// NOTE: alibi_slopes is optional. // NOTE: alibi_slopes is optional.
const float* alibi_slopes_ptr = alibi_slopes ? const float* alibi_slopes_ptr =
reinterpret_cast<const float*>(alibi_slopes.value().data_ptr()) alibi_slopes
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr; : nullptr;
T* out_ptr = reinterpret_cast<T*>(out.data_ptr()); T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
@ -855,19 +826,8 @@ void paged_attention_v2_launcher(
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE) \ #define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE) \
paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE>( \ paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE>( \
out, \ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
exp_sums, \ num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
max_logits, \
tmp_out, \
query, \
key_cache, \
value_cache, \
num_kv_heads, \
scale, \
block_tables, \
seq_lens, \
max_seq_len, \
alibi_slopes, \
kv_scale); kv_scale);
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes // NOTE(woosuk): To reduce the compilation time, we omitted block sizes
@ -892,20 +852,22 @@ void paged_attention_v2(
torch::Tensor& out, // [num_seqs, num_heads, head_size] torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions] torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
torch::Tensor& tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] torch::Tensor&
tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size] torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] torch::Tensor&
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor&
value_cache, // [num_blocks, num_heads, head_size, block_size]
int num_kv_heads, // [num_heads] int num_kv_heads, // [num_heads]
float scale, float scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& seq_lens, // [num_seqs] torch::Tensor& seq_lens, // [num_seqs]
int block_size, int block_size, int max_seq_len,
int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, const std::string& kv_cache_dtype, float kv_scale) {
float kv_scale) { DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, CALL_V2_LAUNCHER_BLOCK_SIZE) CALL_V2_LAUNCHER_BLOCK_SIZE)
} }
#undef WARP_SIZE #undef WARP_SIZE

View File

@ -1,5 +1,6 @@
/* /*
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp * Adapted from
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* Copyright (c) 2023, The vLLM team. * Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
* *

View File

@ -1,6 +1,8 @@
/* /*
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp * Adapted from
* and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* and
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
* Copyright (c) 2023, The vLLM team. * Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
* *
@ -280,7 +282,8 @@ inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) {
} }
// Vector fused multiply-add. // Vector fused multiply-add.
inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) { inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b,
__nv_bfloat162 c) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false); assert(false);
#else #else
@ -288,7 +291,8 @@ inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bf
#endif #endif
} }
inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c) { inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b,
__nv_bfloat162 c) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false); assert(false);
#else #else

View File

@ -1,6 +1,8 @@
/* /*
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp * Adapted from
* and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* and
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
* Copyright (c) 2023, The vLLM team. * Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
* *
@ -130,7 +132,9 @@ inline __device__ uint32_t float2_to_half2(float2 f) {
} tmp; } tmp;
#ifndef USE_ROCM #ifndef USE_ROCM
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x)); asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n"
: "=r"(tmp.u32)
: "f"(f.y), "f"(f.x));
#else #else
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
@ -327,9 +331,13 @@ inline __device__ Float8_ mul(uint16_t a, uint4 b) {
inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) { inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) {
uint32_t d; uint32_t d;
#ifndef USE_ROCM #ifndef USE_ROCM
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c)); asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
: "=r"(d)
: "r"(a), "r"(b), "r"(c));
#else #else
asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" : "=v"(d) : "v"(a), "v"(b), "v"(c)); asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n"
: "=v"(d)
: "v"(a), "v"(b), "v"(c));
#endif #endif
return d; return d;
} }
@ -470,13 +478,9 @@ inline __device__ void from_float(uint4& dst, Float8_ src) {
} }
// From float16 to float32. // From float16 to float32.
inline __device__ float to_float(uint16_t u) { inline __device__ float to_float(uint16_t u) { return half_to_float(u); }
return half_to_float(u);
}
inline __device__ float2 to_float(uint32_t u) { inline __device__ float2 to_float(uint32_t u) { return half2_to_float2(u); }
return half2_to_float2(u);
}
inline __device__ Float4_ to_float(uint2 u) { inline __device__ Float4_ to_float(uint2 u) {
Float4_ tmp; Float4_ tmp;
@ -495,8 +499,6 @@ inline __device__ Float8_ to_float(uint4 u) {
} }
// Zero-out a variable. // Zero-out a variable.
inline __device__ void zero(uint16_t& dst) { inline __device__ void zero(uint16_t& dst) { dst = uint16_t(0); }
dst = uint16_t(0);
}
} // namespace vllm } // namespace vllm

View File

@ -1,6 +1,8 @@
/* /*
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp * Adapted from
* and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* and
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
* Copyright (c) 2023, The vLLM team. * Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
* *
@ -66,9 +68,7 @@ struct FloatVec<float4> {
}; };
// Vector addition. // Vector addition.
inline __device__ float add(float a, float b) { inline __device__ float add(float a, float b) { return a + b; }
return a + b;
}
inline __device__ float2 add(float2 a, float2 b) { inline __device__ float2 add(float2 a, float2 b) {
float2 c; float2 c;
@ -129,9 +129,7 @@ inline __device__ float4 mul(float a, float4 b) {
} }
// Vector fused multiply-add. // Vector fused multiply-add.
inline __device__ float fma(float a, float b, float c) { inline __device__ float fma(float a, float b, float c) { return a * b + c; }
return a * b + c;
}
inline __device__ float2 fma(float2 a, float2 b, float2 c) { inline __device__ float2 fma(float2 a, float2 b, float2 c) {
float2 d; float2 d;
@ -208,9 +206,7 @@ inline __device__ float sum(Float8_ v) {
} }
// Vector dot product. // Vector dot product.
inline __device__ float dot(float a, float b) { inline __device__ float dot(float a, float b) { return a * b; }
return a * b;
}
inline __device__ float dot(float2 a, float2 b) { inline __device__ float dot(float2 a, float2 b) {
float2 c = mul<float2, float2, float2>(a, b); float2 c = mul<float2, float2, float2>(a, b);
@ -232,42 +228,24 @@ inline __device__ float dot(Float8_ a, Float8_ b) {
} }
// From float to float. // From float to float.
inline __device__ void from_float(float& dst, float src) { inline __device__ void from_float(float& dst, float src) { dst = src; }
dst = src;
}
inline __device__ void from_float(float2& dst, float2 src) { inline __device__ void from_float(float2& dst, float2 src) { dst = src; }
dst = src;
}
inline __device__ void from_float(float4& dst, float4 src) { inline __device__ void from_float(float4& dst, float4 src) { dst = src; }
dst = src;
}
// From float to float. // From float to float.
inline __device__ float to_float(float u) { inline __device__ float to_float(float u) { return u; }
return u;
}
inline __device__ float2 to_float(float2 u) { inline __device__ float2 to_float(float2 u) { return u; }
return u;
}
inline __device__ float4 to_float(float4 u) { inline __device__ float4 to_float(float4 u) { return u; }
return u;
}
inline __device__ Float4_ to_float(Float4_ u) { inline __device__ Float4_ to_float(Float4_ u) { return u; }
return u;
}
inline __device__ Float8_ to_float(Float8_ u) { inline __device__ Float8_ to_float(Float8_ u) { return u; }
return u;
}
// Zero-out a variable. // Zero-out a variable.
inline __device__ void zero(float& dst) { inline __device__ void zero(float& dst) { dst = 0.f; }
dst = 0.f;
}
} // namespace vllm } // namespace vllm

View File

@ -5,36 +5,24 @@
#include <map> #include <map>
#include <vector> #include <vector>
void swap_blocks( void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
torch::Tensor& src,
torch::Tensor& dst,
const torch::Tensor& block_mapping); const torch::Tensor& block_mapping);
void copy_blocks( void copy_blocks(std::vector<torch::Tensor>& key_caches,
std::vector<torch::Tensor>& key_caches,
std::vector<torch::Tensor>& value_caches, std::vector<torch::Tensor>& value_caches,
const torch::Tensor& block_mapping); const torch::Tensor& block_mapping);
void reshape_and_cache( void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key, torch::Tensor& key_cache, torch::Tensor& value_cache,
torch::Tensor& value,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
torch::Tensor& slot_mapping, torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype, const std::string& kv_cache_dtype, const float kv_scale);
const float kv_scale);
void reshape_and_cache_flash( void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key,
torch::Tensor& value,
torch::Tensor& key_cache, torch::Tensor& key_cache,
torch::Tensor& value_cache, torch::Tensor& value_cache,
torch::Tensor& slot_mapping, torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype); const std::string& kv_cache_dtype);
// Just for unittest // Just for unittest
void convert_fp8( void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
torch::Tensor& dst_cache, const float scale, const std::string& kv_cache_dtype);
torch::Tensor& src_cache,
const float scale,
const std::string& kv_cache_dtype);

View File

@ -21,16 +21,13 @@
typedef __hip_bfloat16 __nv_bfloat16; typedef __hip_bfloat16 __nv_bfloat16;
#endif #endif
void swap_blocks( void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
torch::Tensor& src,
torch::Tensor& dst,
const torch::Tensor& block_mapping) { const torch::Tensor& block_mapping) {
torch::Device src_device = src.device(); torch::Device src_device = src.device();
torch::Device dst_device = dst.device(); torch::Device dst_device = dst.device();
cudaMemcpyKind memcpy_type; cudaMemcpyKind memcpy_type;
if (src_device.is_cuda() && dst_device.is_cuda()) { if (src_device.is_cuda() && dst_device.is_cuda()) {
TORCH_CHECK( TORCH_CHECK(src_device.index() == dst_device.index(),
src_device.index() == dst_device.index(),
"src and dst must be on the same GPU"); "src and dst must be on the same GPU");
memcpy_type = cudaMemcpyDeviceToDevice; memcpy_type = cudaMemcpyDeviceToDevice;
} else if (src_device.is_cuda() && dst_device.is_cpu()) { } else if (src_device.is_cuda() && dst_device.is_cpu()) {
@ -50,7 +47,8 @@ void swap_blocks(
char* dst_ptr = static_cast<char*>(dst.data_ptr()); char* dst_ptr = static_cast<char*>(dst.data_ptr());
const int64_t block_size_in_bytes = src.element_size() * src[0].numel(); const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
const at::cuda::OptionalCUDAGuard device_guard(src_device.is_cuda() ? src_device : dst_device); const at::cuda::OptionalCUDAGuard device_guard(
src_device.is_cuda() ? src_device : dst_device);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// NOTE(woosuk): This can be slow if the number of blocks is large. // NOTE(woosuk): This can be slow if the number of blocks is large.
const int64_t num_blocks = block_mapping.size(0); const int64_t num_blocks = block_mapping.size(0);
@ -59,12 +57,8 @@ void swap_blocks(
int64_t dst_block_number = block_mapping[i][1].item<int64_t>(); int64_t dst_block_number = block_mapping[i][1].item<int64_t>();
int64_t src_offset = src_block_number * block_size_in_bytes; int64_t src_offset = src_block_number * block_size_in_bytes;
int64_t dst_offset = dst_block_number * block_size_in_bytes; int64_t dst_offset = dst_block_number * block_size_in_bytes;
cudaMemcpyAsync( cudaMemcpyAsync(dst_ptr + dst_offset, src_ptr + src_offset,
dst_ptr + dst_offset, block_size_in_bytes, memcpy_type, stream);
src_ptr + src_offset,
block_size_in_bytes,
memcpy_type,
stream);
} }
} }
@ -72,8 +66,7 @@ namespace vllm {
// Grid: (num_layers, num_pairs) // Grid: (num_layers, num_pairs)
template <typename scalar_t> template <typename scalar_t>
__global__ void copy_blocks_kernel( __global__ void copy_blocks_kernel(int64_t* key_cache_ptrs,
int64_t* key_cache_ptrs,
int64_t* value_cache_ptrs, int64_t* value_cache_ptrs,
const int64_t* __restrict__ block_mapping, const int64_t* __restrict__ block_mapping,
const int numel_per_block) { const int numel_per_block) {
@ -81,7 +74,8 @@ __global__ void copy_blocks_kernel(
const int pair_idx = blockIdx.y; const int pair_idx = blockIdx.y;
scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]); scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]);
scalar_t* value_cache = reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]); scalar_t* value_cache =
reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
int64_t src_block_number = block_mapping[2 * pair_idx]; int64_t src_block_number = block_mapping[2 * pair_idx];
int64_t dst_block_number = block_mapping[2 * pair_idx + 1]; int64_t dst_block_number = block_mapping[2 * pair_idx + 1];
@ -101,8 +95,7 @@ __global__ void copy_blocks_kernel(
} // namespace vllm } // namespace vllm
void copy_blocks( void copy_blocks(std::vector<torch::Tensor>& key_caches,
std::vector<torch::Tensor>& key_caches,
std::vector<torch::Tensor>& value_caches, std::vector<torch::Tensor>& value_caches,
const torch::Tensor& block_mapping) { const torch::Tensor& block_mapping) {
int num_layers = key_caches.size(); int num_layers = key_caches.size();
@ -118,8 +111,10 @@ void copy_blocks(
int64_t key_cache_ptrs[num_layers]; int64_t key_cache_ptrs[num_layers];
int64_t value_cache_ptrs[num_layers]; int64_t value_cache_ptrs[num_layers];
for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) { for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
key_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr()); key_cache_ptrs[layer_idx] =
value_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr()); reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
value_cache_ptrs[layer_idx] =
reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
} }
// block_mapping is a 2D tensor with shape (num_pairs, 2). // block_mapping is a 2D tensor with shape (num_pairs, 2).
@ -127,10 +122,12 @@ void copy_blocks(
// Move the data structures to the GPU. // Move the data structures to the GPU.
// NOTE: This synchronizes the CPU and GPU. // NOTE: This synchronizes the CPU and GPU.
torch::Tensor key_cache_ptrs_tensor = torch::from_blob( torch::Tensor key_cache_ptrs_tensor =
key_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device); torch::from_blob(key_cache_ptrs, {num_layers}, torch::kInt64)
torch::Tensor value_cache_ptrs_tensor = torch::from_blob( .to(cache_device);
value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device); torch::Tensor value_cache_ptrs_tensor =
torch::from_blob(value_cache_ptrs, {num_layers}, torch::kInt64)
.to(cache_device);
// Launch the kernel. // Launch the kernel.
const int numel_per_block = key_caches[0][0].numel(); const int numel_per_block = key_caches[0][0].numel();
@ -143,8 +140,7 @@ void copy_blocks(
vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>( vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
key_cache_ptrs_tensor.data_ptr<int64_t>(), key_cache_ptrs_tensor.data_ptr<int64_t>(),
value_cache_ptrs_tensor.data_ptr<int64_t>(), value_cache_ptrs_tensor.data_ptr<int64_t>(),
block_mapping.data_ptr<int64_t>(), block_mapping.data_ptr<int64_t>(), numel_per_block);
numel_per_block);
})); }));
} }
@ -154,15 +150,13 @@ template<typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void reshape_and_cache_kernel( __global__ void reshape_and_cache_kernel(
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x,
cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size] // block_size, x]
cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size,
// block_size]
const int64_t* __restrict__ slot_mapping, // [num_tokens] const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int key_stride, const int key_stride, const int value_stride, const int num_heads,
const int value_stride, const int head_size, const int block_size, const int x,
const int num_heads,
const int head_size,
const int block_size,
const int x,
const float kv_scale) { const float kv_scale) {
const int64_t token_idx = blockIdx.x; const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx]; const int64_t slot_idx = slot_mapping[token_idx];
@ -184,23 +178,24 @@ __global__ void reshape_and_cache_kernel(
const int x_idx = head_offset / x; const int x_idx = head_offset / x;
const int x_offset = head_offset % x; const int x_offset = head_offset % x;
const int64_t tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x const int64_t tgt_key_idx =
+ head_idx * (head_size / x) * block_size * x block_idx * num_heads * (head_size / x) * block_size * x +
+ x_idx * block_size * x head_idx * (head_size / x) * block_size * x + x_idx * block_size * x +
+ block_offset * x block_offset * x + x_offset;
+ x_offset; const int64_t tgt_value_idx =
const int64_t tgt_value_idx = block_idx * num_heads * head_size * block_size block_idx * num_heads * head_size * block_size +
+ head_idx * head_size * block_size head_idx * head_size * block_size + head_offset * block_size +
+ head_offset * block_size block_offset;
+ block_offset;
scalar_t tgt_key = key[src_key_idx]; scalar_t tgt_key = key[src_key_idx];
scalar_t tgt_value = value[src_value_idx]; scalar_t tgt_value = value[src_value_idx];
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
key_cache[tgt_key_idx] = tgt_key; key_cache[tgt_key_idx] = tgt_key;
value_cache[tgt_value_idx] = tgt_value; value_cache[tgt_value_idx] = tgt_value;
} else { } else {
key_cache[tgt_key_idx] = fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, kv_scale); key_cache[tgt_key_idx] =
value_cache[tgt_value_idx] = fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, kv_scale); fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, kv_scale);
value_cache[tgt_value_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, kv_scale);
} }
} }
} }
@ -209,15 +204,13 @@ template<typename scalar_t>
__global__ void reshape_and_cache_flash_kernel( __global__ void reshape_and_cache_flash_kernel(
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
scalar_t* __restrict__ k_cache, // [num_blocks, block_size, num_heads, head_size] scalar_t* __restrict__ k_cache, // [num_blocks, block_size, num_heads,
scalar_t* __restrict__ v_cache, // [num_blocks, block_size, num_heads, head_size] // head_size]
scalar_t* __restrict__ v_cache, // [num_blocks, block_size, num_heads,
// head_size]
const int64_t* __restrict__ slot_mapping, // [num_tokens] const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int block_stride, const int block_stride, const int key_stride, const int value_stride,
const int key_stride, const int num_heads, const int head_size, const int block_size) {
const int value_stride,
const int num_heads,
const int head_size,
const int block_size) {
const int64_t token_idx = blockIdx.x; const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx]; const int64_t slot_idx = slot_mapping[token_idx];
// NOTE: slot_idx can be -1 if the token is padded // NOTE: slot_idx can be -1 if the token is padded
@ -232,10 +225,9 @@ __global__ void reshape_and_cache_flash_kernel(
const int64_t src_value_idx = token_idx * value_stride + i; const int64_t src_value_idx = token_idx * value_stride + i;
const int head_idx = i / head_size; const int head_idx = i / head_size;
const int head_offset = i % head_size; const int head_offset = i % head_size;
const int64_t tgt_value_idx = block_idx * block_stride const int64_t tgt_value_idx = block_idx * block_stride +
+ block_offset * num_heads * head_size block_offset * num_heads * head_size +
+ head_idx * head_size head_idx * head_size + head_offset;
+ head_offset;
k_cache[tgt_value_idx] = key[src_key_idx]; k_cache[tgt_value_idx] = key[src_key_idx];
v_cache[tgt_value_idx] = value[src_value_idx]; v_cache[tgt_value_idx] = value[src_value_idx];
} }
@ -246,29 +238,24 @@ __global__ void reshape_and_cache_flash_kernel(
// CACHE_T is the data type of key and value tensors. // CACHE_T is the data type of key and value tensors.
// KV_DTYPE is the real data type of kv-cache. // KV_DTYPE is the real data type of kv-cache.
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \ #define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \
vllm::reshape_and_cache_kernel<KV_T, CACHE_T, KV_DTYPE><<<grid, block, 0, stream>>>( \ vllm::reshape_and_cache_kernel<KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(key.data_ptr()), \ reinterpret_cast<KV_T*>(key.data_ptr()), \
reinterpret_cast<KV_T*>(value.data_ptr()), \ reinterpret_cast<KV_T*>(value.data_ptr()), \
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \ reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \ reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), \ slot_mapping.data_ptr<int64_t>(), key_stride, value_stride, \
key_stride, \ num_heads, head_size, block_size, x, kv_scale);
value_stride, \
num_heads, \
head_size, \
block_size, \
x, \
kv_scale);
void reshape_and_cache( void reshape_and_cache(
torch::Tensor& key, // [num_tokens, num_heads, head_size] torch::Tensor& key, // [num_tokens, num_heads, head_size]
torch::Tensor& value, // [num_tokens, num_heads, head_size] torch::Tensor& value, // [num_tokens, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] torch::Tensor&
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor&
value_cache, // [num_blocks, num_heads, head_size, block_size]
torch::Tensor& slot_mapping, // [num_tokens] torch::Tensor& slot_mapping, // [num_tokens]
const std::string& kv_cache_dtype, const std::string& kv_cache_dtype, const float kv_scale) {
const float kv_scale)
{
int num_tokens = key.size(0); int num_tokens = key.size(0);
int num_heads = key.size(1); int num_heads = key.size(1);
int head_size = key.size(2); int head_size = key.size(2);
@ -283,7 +270,8 @@ void reshape_and_cache(
const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype, CALL_RESHAPE_AND_CACHE) DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
CALL_RESHAPE_AND_CACHE)
} }
void reshape_and_cache_flash( void reshape_and_cache_flash(
@ -292,8 +280,7 @@ void reshape_and_cache_flash(
torch::Tensor& k_cache, // [num_blocks, block_size, num_heads, head_size] torch::Tensor& k_cache, // [num_blocks, block_size, num_heads, head_size]
torch::Tensor& v_cache, // [num_blocks, block_size, num_heads, head_size] torch::Tensor& v_cache, // [num_blocks, block_size, num_heads, head_size]
torch::Tensor& slot_mapping, // [num_tokens] torch::Tensor& slot_mapping, // [num_tokens]
const std::string& kv_cache_dtype) const std::string& kv_cache_dtype) {
{
// FIXME: only support auto datatype, does not support fp8 // FIXME: only support auto datatype, does not support fp8
if (kv_cache_dtype != "auto") { if (kv_cache_dtype != "auto") {
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
@ -313,36 +300,28 @@ void reshape_and_cache_flash(
const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(
key.scalar_type(), key.scalar_type(), "reshape_and_cache_flash", [&] {
"reshape_and_cache_flash", vllm::reshape_and_cache_flash_kernel<scalar_t>
[&] { <<<grid, block, 0, stream>>>(
vllm::reshape_and_cache_flash_kernel<scalar_t><<<grid, block, 0, stream>>>( key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(), k_cache.data_ptr<scalar_t>(), v_cache.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(), slot_mapping.data_ptr<int64_t>(), block_stride, key_stride,
k_cache.data_ptr<scalar_t>(), value_stride, num_heads, head_size, block_size);
v_cache.data_ptr<scalar_t>(),
slot_mapping.data_ptr<int64_t>(),
block_stride,
key_stride,
value_stride,
num_heads,
head_size,
block_size);
}); });
} }
namespace vllm { namespace vllm {
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt> template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
__global__ void convert_fp8_kernel( __global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache,
const Tin* __restrict__ src_cache,
Tout* __restrict__ dst_cache, Tout* __restrict__ dst_cache,
const float kv_scale, const float kv_scale,
const int64_t block_stride) { const int64_t block_stride) {
const int64_t block_idx = blockIdx.x; const int64_t block_idx = blockIdx.x;
for (int i = threadIdx.x; i < block_stride; i += blockDim.x) { for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
int64_t idx = block_idx * block_stride + i; int64_t idx = block_idx * block_stride + i;
dst_cache[idx] = fp8::scaled_convert<Tout, Tin, kv_dt>(src_cache[idx], kv_scale); dst_cache[idx] =
fp8::scaled_convert<Tout, Tin, kv_dt>(src_cache[idx], kv_scale);
} }
} }
@ -351,23 +330,16 @@ __global__ void convert_fp8_kernel(
#define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE) \ #define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE) \
vllm::convert_fp8_kernel<Tout, Tin, KV_DTYPE><<<grid, block, 0, stream>>>( \ vllm::convert_fp8_kernel<Tout, Tin, KV_DTYPE><<<grid, block, 0, stream>>>( \
reinterpret_cast<Tin*>(src_cache.data_ptr()), \ reinterpret_cast<Tin*>(src_cache.data_ptr()), \
reinterpret_cast<Tout*>(dst_cache.data_ptr()), \ reinterpret_cast<Tout*>(dst_cache.data_ptr()), kv_scale, block_stride);
kv_scale, \
block_stride);
// Only for testing. // Only for testing.
void convert_fp8( void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
torch::Tensor& dst_cache, const float kv_scale, const std::string& kv_cache_dtype) {
torch::Tensor& src_cache,
const float kv_scale,
const std::string& kv_cache_dtype)
{
torch::Device src_device = src_cache.device(); torch::Device src_device = src_cache.device();
torch::Device dst_device = dst_cache.device(); torch::Device dst_device = dst_cache.device();
TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU") TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU")
TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU") TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU")
TORCH_CHECK( TORCH_CHECK(src_device.index() == dst_device.index(),
src_device.index() == dst_device.index(),
"src and dst must be on the same GPU"); "src and dst must be on the same GPU");
at::cuda::OptionalCUDAGuard device_guard(src_device); at::cuda::OptionalCUDAGuard device_guard(src_device);
@ -398,13 +370,15 @@ void convert_fp8(
} else if (src_cache.dtype() == at::ScalarType::Half) { } else if (src_cache.dtype() == at::ScalarType::Half) {
CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kFp8E4M3); CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
} else if (src_cache.dtype() == at::ScalarType::BFloat16) { } else if (src_cache.dtype() == at::ScalarType::BFloat16) {
CALL_CONVERT_FP8(uint8_t, __nv_bfloat16, vllm::Fp8KVCacheDataType::kFp8E4M3); CALL_CONVERT_FP8(uint8_t, __nv_bfloat16,
vllm::Fp8KVCacheDataType::kFp8E4M3);
} else if (dst_cache.dtype() == at::ScalarType::Float) { } else if (dst_cache.dtype() == at::ScalarType::Float) {
CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
} else if (dst_cache.dtype() == at::ScalarType::Half) { } else if (dst_cache.dtype() == at::ScalarType::Half) {
CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) { } else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); CALL_CONVERT_FP8(__nv_bfloat16, uint8_t,
vllm::Fp8KVCacheDataType::kFp8E4M3);
} }
} else { } else {
TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype); TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype);

View File

@ -81,12 +81,10 @@ void silu_and_mul(torch::Tensor &out, torch::Tensor &input) {
int num_tokens = input.numel() / input.size(-1); int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1) / 2; int d = input.size(-1) / 2;
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "silu_and_mul_impl", [&] {
input.scalar_type(), "silu_and_mul_impl", [&] {
CPU_KERNEL_GUARD_IN(silu_and_mul_impl) CPU_KERNEL_GUARD_IN(silu_and_mul_impl)
activation_kernel<scalar_t, silu_act, true>(num_tokens, d, activation_kernel<scalar_t, silu_act, true>(
input.data_ptr<scalar_t>(), num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
out.data_ptr<scalar_t>());
CPU_KERNEL_GUARD_OUT(silu_and_mul_impl) CPU_KERNEL_GUARD_OUT(silu_and_mul_impl)
}); });
} }
@ -97,12 +95,10 @@ void gelu_and_mul(torch::Tensor &out, // [..., d]
int num_tokens = input.numel() / input.size(-1); int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1) / 2; int d = input.size(-1) / 2;
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_and_mul_impl", [&] {
input.scalar_type(), "gelu_and_mul_impl", [&] {
CPU_KERNEL_GUARD_IN(gelu_and_mul_impl) CPU_KERNEL_GUARD_IN(gelu_and_mul_impl)
activation_kernel<scalar_t, gelu_act, true>(num_tokens, d, activation_kernel<scalar_t, gelu_act, true>(
input.data_ptr<scalar_t>(), num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
out.data_ptr<scalar_t>());
CPU_KERNEL_GUARD_OUT(gelu_and_mul_impl) CPU_KERNEL_GUARD_OUT(gelu_and_mul_impl)
}); });
} }

View File

@ -2,7 +2,8 @@
namespace { namespace {
template <typename scalar_t> struct KernelVecType { template <typename scalar_t>
struct KernelVecType {
using q_load_vec_type = void; using q_load_vec_type = void;
using q_vec_type = void; using q_vec_type = void;
using k_load_vec_type = void; using k_load_vec_type = void;
@ -11,7 +12,8 @@ template <typename scalar_t> struct KernelVecType {
using v_load_vec_type = void; using v_load_vec_type = void;
}; };
template <> struct KernelVecType<float> { template <>
struct KernelVecType<float> {
using q_load_vec_type = vec_op::FP32Vec4; using q_load_vec_type = vec_op::FP32Vec4;
using q_vec_type = vec_op::FP32Vec16; using q_vec_type = vec_op::FP32Vec16;
using k_load_vec_type = vec_op::FP32Vec16; using k_load_vec_type = vec_op::FP32Vec16;
@ -21,7 +23,8 @@ template <> struct KernelVecType<float> {
}; };
#ifdef __AVX512BF16__ #ifdef __AVX512BF16__
template <> struct KernelVecType<c10::BFloat16> { template <>
struct KernelVecType<c10::BFloat16> {
using q_load_vec_type = vec_op::BF16Vec8; using q_load_vec_type = vec_op::BF16Vec8;
using q_vec_type = vec_op::BF16Vec32; using q_vec_type = vec_op::BF16Vec32;
using k_load_vec_type = vec_op::BF16Vec32; using k_load_vec_type = vec_op::BF16Vec32;
@ -30,7 +33,8 @@ template <> struct KernelVecType<c10::BFloat16> {
using v_load_vec_type = vec_op::BF16Vec16; using v_load_vec_type = vec_op::BF16Vec16;
}; };
#else #else
template <> struct KernelVecType<c10::BFloat16> { template <>
struct KernelVecType<c10::BFloat16> {
using q_load_vec_type = vec_op::BF16Vec8; using q_load_vec_type = vec_op::BF16Vec8;
using q_vec_type = vec_op::FP32Vec16; using q_vec_type = vec_op::FP32Vec16;
using k_load_vec_type = vec_op::BF16Vec16; using k_load_vec_type = vec_op::BF16Vec16;
@ -67,9 +71,10 @@ FORCE_INLINE std::pair<T, T> reduceSoftmax(T *data, const int size,
} }
template <typename T> template <typename T>
FORCE_INLINE std::pair<T, T> FORCE_INLINE std::pair<T, T> reduceSoftmaxAlibi(T* data, const int size,
reduceSoftmaxAlibi(T *data, const int size, const int capacity, const int capacity,
const float alibi_slope, const int start_index, const float alibi_slope,
const int start_index,
const int seq_len) { const int seq_len) {
data[0] += alibi_slope * (start_index - seq_len + 1); data[0] += alibi_slope * (start_index - seq_len + 1);
T max = data[0]; T max = data[0];
@ -215,16 +220,16 @@ FORCE_INLINE void reduceValueBlock(const float *prob, const scalar_t *v_block,
namespace { namespace {
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE> template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE>
struct paged_attention_v1_impl { struct paged_attention_v1_impl {
static void static void call(
call(scalar_t *__restrict__ out, // [num_seqs, num_heads, head_size] scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x] // head_size/x, block_size, x]
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size] // head_size, block_size]
const int num_kv_heads, const float scale, const int num_kv_heads, const float scale,
const int const int* __restrict__ block_tables, // [num_seqs,
*__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] // max_num_blocks_per_seq]
const int* __restrict__ seq_lens, // [num_seqs] const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
@ -257,8 +262,7 @@ struct paged_attention_v1_impl {
const int64_t kv_head_idx = head_idx / num_queries_per_kv; const int64_t kv_head_idx = head_idx / num_queries_per_kv;
const scalar_t* __restrict__ q_vec_ptr = const scalar_t* __restrict__ q_vec_ptr =
q + seq_idx * q_stride + head_idx * HEAD_SIZE; q + seq_idx * q_stride + head_idx * HEAD_SIZE;
const int last_block_token_num = const int last_block_token_num = seq_len - (block_num - 1) * BLOCK_SIZE;
seq_len - (block_num - 1) * BLOCK_SIZE;
float* __restrict__ thread_block_logits = float* __restrict__ thread_block_logits =
logits + omp_get_thread_num() * max_seq_len_padded; logits + omp_get_thread_num() * max_seq_len_padded;
@ -282,8 +286,7 @@ struct paged_attention_v1_impl {
block_num * BLOCK_SIZE, alibi_slopes[head_idx], 0, block_num * BLOCK_SIZE, alibi_slopes[head_idx], 0,
seq_len); seq_len);
} else { } else {
reduceSoftmax(thread_block_logits, seq_len, reduceSoftmax(thread_block_logits, seq_len, block_num * BLOCK_SIZE);
block_num * BLOCK_SIZE);
} }
// Compute value // Compute value
@ -348,8 +351,8 @@ template <typename T, int BLOCK_SIZE>
void paged_attention_v1_impl_launcher( void paged_attention_v1_impl_launcher(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale, torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor &block_tables, torch::Tensor &seq_lens, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
int max_seq_len, const c10::optional<torch::Tensor> &alibi_slopes) { const c10::optional<torch::Tensor>& alibi_slopes) {
int num_seqs = query.size(0); int num_seqs = query.size(0);
int num_heads = query.size(1); int num_heads = query.size(1);
int head_size = query.size(2); int head_size = query.size(2);
@ -415,9 +418,8 @@ void paged_attention_v1_impl_launcher(
void paged_attention_v1(torch::Tensor& out, torch::Tensor& query, void paged_attention_v1(torch::Tensor& out, torch::Tensor& query,
torch::Tensor& key_cache, torch::Tensor& value_cache, torch::Tensor& key_cache, torch::Tensor& value_cache,
int num_kv_heads, float scale, int num_kv_heads, float scale,
torch::Tensor &block_tables, torch::Tensor& block_tables, torch::Tensor& seq_lens,
torch::Tensor &seq_lens, int block_size, int block_size, int max_seq_len,
int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, float kv_scale) { const std::string& kv_cache_dtype, float kv_scale) {
TORCH_CHECK(kv_scale == 1.0f); TORCH_CHECK(kv_scale == 1.0f);
@ -435,9 +437,10 @@ template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, int PARTITION_SIZE>
struct paged_attention_v2_impl { struct paged_attention_v2_impl {
static void call( static void call(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
float *__restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ exp_sums, // [num_seqs, num_heads,
float // max_num_partitions]
*__restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions]
scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
// max_num_partitions, head_size] // max_num_partitions, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
@ -446,8 +449,8 @@ struct paged_attention_v2_impl {
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size] // head_size, block_size]
const int num_kv_heads, const float scale, const int num_kv_heads, const float scale,
const int const int* __restrict__ block_tables, // [num_seqs,
*__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] // max_num_blocks_per_seq]
const int* __restrict__ seq_lens, // [num_seqs] const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
@ -468,8 +471,7 @@ struct paged_attention_v2_impl {
const int seq_len = seq_lens[seq_idx]; const int seq_len = seq_lens[seq_idx];
const int start_token_idx = partition_idx * PARTITION_SIZE; const int start_token_idx = partition_idx * PARTITION_SIZE;
if (start_token_idx >= seq_len) if (start_token_idx >= seq_len) continue;
continue;
const int partition_num = const int partition_num =
(seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
@ -477,8 +479,7 @@ struct paged_attention_v2_impl {
const int token_num = const int token_num =
(std::min(seq_len, start_token_idx + PARTITION_SIZE) - (std::min(seq_len, start_token_idx + PARTITION_SIZE) -
start_token_idx); start_token_idx);
const int block_num = const int block_num = (token_num + BLOCK_SIZE - 1) / BLOCK_SIZE;
(token_num + BLOCK_SIZE - 1) / BLOCK_SIZE;
const int last_block_token_num = const int last_block_token_num =
token_num - (block_num - 1) * BLOCK_SIZE; token_num - (block_num - 1) * BLOCK_SIZE;
const int* seq_block_table = block_tables + const int* seq_block_table = block_tables +
@ -510,8 +511,8 @@ struct paged_attention_v2_impl {
logits, token_num, block_num * BLOCK_SIZE, logits, token_num, block_num * BLOCK_SIZE,
alibi_slopes[head_idx], start_token_idx, seq_len); alibi_slopes[head_idx], start_token_idx, seq_len);
} else { } else {
max_and_sum = reduceSoftmax(logits, token_num, max_and_sum =
block_num * BLOCK_SIZE); reduceSoftmax(logits, token_num, block_num * BLOCK_SIZE);
} }
auto&& [max_logit, exp_sum] = max_and_sum; auto&& [max_logit, exp_sum] = max_and_sum;
@ -587,8 +588,7 @@ struct paged_attention_v2_impl {
const int partition_num = const int partition_num =
(seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
if (partition_num == 1) if (partition_num == 1) continue;
continue;
reducePartitonSoftmax( reducePartitonSoftmax(
max_logits + seq_idx * num_heads * max_num_partitions + max_logits + seq_idx * num_heads * max_num_partitions +
@ -603,8 +603,8 @@ struct paged_attention_v2_impl {
using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type; using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type;
static_assert(v_load_vec_type::get_elem_num() == BLOCK_SIZE); static_assert(v_load_vec_type::get_elem_num() == BLOCK_SIZE);
constexpr int head_elem_num_per_group = constexpr int head_elem_num_per_group =
16; // Note: didn't align with the cacheline size, due to some HEAD_SIZE 16; // Note: didn't align with the cacheline size, due to some
// didn't align with 64 bytes // HEAD_SIZE didn't align with 64 bytes
static_assert(HEAD_SIZE % head_elem_num_per_group == 0); static_assert(HEAD_SIZE % head_elem_num_per_group == 0);
constexpr int head_group_num = HEAD_SIZE / head_elem_num_per_group; constexpr int head_group_num = HEAD_SIZE / head_elem_num_per_group;
const float* __restrict__ rescale_factors = exp_sums; const float* __restrict__ rescale_factors = exp_sums;
@ -616,8 +616,7 @@ struct paged_attention_v2_impl {
const int partition_num = const int partition_num =
(seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
if (partition_num == 1) if (partition_num == 1) continue;
continue;
const float* __restrict__ seq_head_rescale_factors = const float* __restrict__ seq_head_rescale_factors =
rescale_factors + seq_idx * num_heads * max_num_partitions + rescale_factors + seq_idx * num_heads * max_num_partitions +
@ -713,8 +712,8 @@ void paged_attention_v2_impl_launcher(
#define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ #define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
paged_attention_v2_impl_launcher<T, BLOCK_SIZE>( \ paged_attention_v2_impl_launcher<T, BLOCK_SIZE>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, seq_lens, block_size, \ num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, \
max_seq_len, alibi_slopes); alibi_slopes);
#define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ #define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
switch (block_size) { \ switch (block_size) { \

View File

@ -5,17 +5,18 @@
namespace { namespace {
template <typename scalar_t> template <typename scalar_t>
void copy_blocks_cpu_impl( void copy_blocks_cpu_impl(std::vector<torch::Tensor>& key_caches,
std::vector<torch::Tensor> &key_caches,
std::vector<torch::Tensor>& value_caches, std::vector<torch::Tensor>& value_caches,
const torch::Tensor& mapping_pairs, const torch::Tensor& mapping_pairs,
const int element_num_per_block, const int layer_num) { const int element_num_per_block,
const int layer_num) {
const size_t pair_num = mapping_pairs.size(0); const size_t pair_num = mapping_pairs.size(0);
const size_t block_bytes = sizeof(scalar_t) * element_num_per_block; const size_t block_bytes = sizeof(scalar_t) * element_num_per_block;
#pragma omp parallel for collapse(2) #pragma omp parallel for collapse(2)
for (int layer = 0; layer < layer_num; ++layer) { for (int layer = 0; layer < layer_num; ++layer) {
for (size_t pair = 0; pair < pair_num; ++pair) { for (size_t pair = 0; pair < pair_num; ++pair) {
int64_t source_offset = element_num_per_block * mapping_pairs[pair][0].item<int64_t>(); int64_t source_offset =
element_num_per_block * mapping_pairs[pair][0].item<int64_t>();
int64_t target_offset = int64_t target_offset =
element_num_per_block * mapping_pairs[pair][1].item<int64_t>(); element_num_per_block * mapping_pairs[pair][1].item<int64_t>();
scalar_t* key_cache_ptr = key_caches[layer].data_ptr<scalar_t>(); scalar_t* key_cache_ptr = key_caches[layer].data_ptr<scalar_t>();

View File

@ -87,8 +87,8 @@ void fused_add_rms_norm_impl(scalar_t *__restrict__ input,
} }
} // namespace } // namespace
void rms_norm(torch::Tensor &out, torch::Tensor &input, void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
torch::Tensor &weight, float epsilon) { float epsilon) {
int hidden_size = input.size(-1); int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size; int num_tokens = input.numel() / hidden_size;

View File

@ -4,16 +4,16 @@
namespace { namespace {
template <typename scalar_t> template <typename scalar_t>
void rotary_embedding_impl( void rotary_embedding_impl(
const int64_t const int64_t* __restrict__ positions, // [batch_size, seq_len] or
*__restrict__ positions, // [batch_size, seq_len] or [num_tokens] // [num_tokens]
scalar_t scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads,
*__restrict__ query, /// [batch_size, seq_len, num_heads, head_size] or /// head_size] or [num_tokens, num_heads,
/// [num_tokens, num_heads, head_size] /// head_size]
scalar_t scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
*__restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or // head_size] or [num_tokens, num_kv_heads,
// [num_tokens, num_kv_heads, head_size] // head_size]
const scalar_t const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
*__restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] // 2]
const int rot_dim, const int64_t query_stride, const int64_t key_stride, const int rot_dim, const int64_t query_stride, const int64_t key_stride,
const int num_heads, const int num_kv_heads, const int head_size, const int num_heads, const int num_kv_heads, const int head_size,
const int num_tokens) { const int num_tokens) {
@ -94,16 +94,16 @@ void rotary_embedding_impl(
template <typename scalar_t> template <typename scalar_t>
void rotary_embedding_gptj_impl( void rotary_embedding_gptj_impl(
const int64_t const int64_t* __restrict__ positions, // [batch_size, seq_len] or
*__restrict__ positions, // [batch_size, seq_len] or [num_tokens] // [num_tokens]
scalar_t scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads,
*__restrict__ query, /// [batch_size, seq_len, num_heads, head_size] or /// head_size] or [num_tokens, num_heads,
/// [num_tokens, num_heads, head_size] /// head_size]
scalar_t scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
*__restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or // head_size] or [num_tokens, num_kv_heads,
// [num_tokens, num_kv_heads, head_size] // head_size]
const scalar_t const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
*__restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] // 2]
const int rot_dim, const int64_t query_stride, const int64_t key_stride, const int rot_dim, const int64_t query_stride, const int64_t key_stride,
const int num_heads, const int num_kv_heads, const int head_size, const int num_heads, const int num_kv_heads, const int head_size,
const int num_tokens) { const int num_tokens) {

View File

@ -8,66 +8,37 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
pybind11::module ops = m.def_submodule("ops", "vLLM custom operators"); pybind11::module ops = m.def_submodule("ops", "vLLM custom operators");
// Attention ops // Attention ops
ops.def( ops.def("paged_attention_v1", &paged_attention_v1,
"paged_attention_v1", "Compute the attention between an input query and the cached "
&paged_attention_v1, "keys/values using PagedAttention.");
"Compute the attention between an input query and the cached keys/values using PagedAttention."); ops.def("paged_attention_v2", &paged_attention_v2, "PagedAttention V2.");
ops.def(
"paged_attention_v2",
&paged_attention_v2,
"PagedAttention V2.");
// Activation ops // Activation ops
ops.def( ops.def("silu_and_mul", &silu_and_mul, "Activation function used in SwiGLU.");
"silu_and_mul", ops.def("gelu_and_mul", &gelu_and_mul,
&silu_and_mul,
"Activation function used in SwiGLU.");
ops.def(
"gelu_and_mul",
&gelu_and_mul,
"Activation function used in GeGLU with `none` approximation."); "Activation function used in GeGLU with `none` approximation.");
ops.def( ops.def("gelu_tanh_and_mul", &gelu_tanh_and_mul,
"gelu_tanh_and_mul",
&gelu_tanh_and_mul,
"Activation function used in GeGLU with `tanh` approximation."); "Activation function used in GeGLU with `tanh` approximation.");
ops.def( ops.def("gelu_new", &gelu_new, "GELU implementation used in GPT-2.");
"gelu_new", ops.def("gelu_fast", &gelu_fast, "Approximate GELU implementation.");
&gelu_new,
"GELU implementation used in GPT-2.");
ops.def(
"gelu_fast",
&gelu_fast,
"Approximate GELU implementation.");
// Layernorm // Layernorm
ops.def( ops.def("rms_norm", &rms_norm,
"rms_norm",
&rms_norm,
"Apply Root Mean Square (RMS) Normalization to the input tensor."); "Apply Root Mean Square (RMS) Normalization to the input tensor.");
ops.def( ops.def("fused_add_rms_norm", &fused_add_rms_norm,
"fused_add_rms_norm",
&fused_add_rms_norm,
"In-place fused Add and RMS Normalization"); "In-place fused Add and RMS Normalization");
// Rotary embedding // Rotary embedding
ops.def( ops.def("rotary_embedding", &rotary_embedding,
"rotary_embedding",
&rotary_embedding,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key"); "Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
// Cache ops // Cache ops
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
cache_ops.def( cache_ops.def("swap_blocks", &swap_blocks,
"swap_blocks",
&swap_blocks,
"Swap in (out) the cache blocks from src to dst"); "Swap in (out) the cache blocks from src to dst");
cache_ops.def( cache_ops.def("copy_blocks", &copy_blocks,
"copy_blocks",
&copy_blocks,
"Copy the cache blocks from src to dst"); "Copy the cache blocks from src to dst");
cache_ops.def( cache_ops.def("reshape_and_cache", &reshape_and_cache,
"reshape_and_cache",
&reshape_and_cache,
"Reshape the key and value tensors and cache them"); "Reshape the key and value tensors and cache them");
} }

View File

@ -17,7 +17,8 @@
#endif #endif
#ifndef USE_ROCM #ifndef USE_ROCM
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask) #define VLLM_SHFL_XOR_SYNC(var, lane_mask) \
__shfl_xor_sync(uint32_t(-1), var, lane_mask)
#else #else
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask) #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask)
#endif #endif
@ -29,7 +30,8 @@
#endif #endif
#ifndef USE_ROCM #ifndef USE_ROCM
#define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down_sync(uint32_t(-1), var, lane_delta) #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) \
__shfl_down_sync(uint32_t(-1), var, lane_delta)
#else #else
#define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down(var, lane_delta) #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down(var, lane_delta)
#endif #endif
@ -41,4 +43,3 @@
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL) hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL)
#endif #endif

View File

@ -2,9 +2,6 @@
#include <torch/extension.h> #include <torch/extension.h>
int get_device_attribute( int get_device_attribute(int attribute, int device_id);
int attribute,
int device_id);
int get_max_shared_memory_per_block_device_attribute( int get_max_shared_memory_per_block_device_attribute(int device_id);
int device_id);

View File

@ -2,25 +2,19 @@
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#include <hip/hip_runtime_api.h> #include <hip/hip_runtime_api.h>
#endif #endif
int get_device_attribute( int get_device_attribute(int attribute, int device_id) {
int attribute,
int device_id)
{
int device, value; int device, value;
if (device_id < 0) { if (device_id < 0) {
cudaGetDevice(&device); cudaGetDevice(&device);
} } else {
else {
device = device_id; device = device_id;
} }
cudaDeviceGetAttribute(&value, static_cast<cudaDeviceAttr>(attribute), device); cudaDeviceGetAttribute(&value, static_cast<cudaDeviceAttr>(attribute),
device);
return value; return value;
} }
int get_max_shared_memory_per_block_device_attribute(int device_id) {
int get_max_shared_memory_per_block_device_attribute(
int device_id)
{
int attribute; int attribute;
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
// cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74 // cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74

View File

@ -80,8 +80,7 @@ void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out,
} }
case at::ScalarType::Half: { case at::ScalarType::Half: {
fa->allreduce<half>(stream, reinterpret_cast<half*>(inp.data_ptr()), fa->allreduce<half>(stream, reinterpret_cast<half*>(inp.data_ptr()),
reinterpret_cast<half *>(out.data_ptr()), reinterpret_cast<half*>(out.data_ptr()), out.numel());
out.numel());
break; break;
} }
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))

View File

@ -137,8 +137,7 @@ DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg,
// Latency = 1 p2p write // Latency = 1 p2p write
sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1; sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1;
// wait until we got true from all ranks // wait until we got true from all ranks
while (!self_sg->start[blockIdx.x][threadIdx.x]) while (!self_sg->start[blockIdx.x][threadIdx.x]);
;
} }
__syncthreads(); __syncthreads();
} }
@ -162,8 +161,7 @@ DINLINE void end_sync(const RankSignals &sg, volatile Signal *self_sg,
// Latency = 1 p2p write // Latency = 1 p2p write
sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1; sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1;
// wait until we got true from all ranks // wait until we got true from all ranks
while (!self_sg->end[blockIdx.x][threadIdx.x]) while (!self_sg->end[blockIdx.x][threadIdx.x]);
;
} }
if constexpr (!final_sync) __syncthreads(); if constexpr (!final_sync) __syncthreads();
} }
@ -192,8 +190,7 @@ __global__ void __launch_bounds__(512, 1)
// do the actual reduction // do the actual reduction
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
idx += gridDim.x * blockDim.x) { idx += gridDim.x * blockDim.x) {
((P *)result)[idx] = ((P*)result)[idx] = packed_reduce<P, ngpus, A>((const P**)&dp.ptrs[0], idx);
packed_reduce<P, ngpus, A>((const P **)&dp.ptrs[0], idx);
} }
end_sync<ngpus, true>(sg, self_sg, rank); end_sync<ngpus, true>(sg, self_sg, rank);
} }

View File

@ -12,8 +12,7 @@
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \ #define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
@ -22,8 +21,8 @@
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \ #define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \ AT_DISPATCH_SWITCH(TYPE, NAME, \
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__)) VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \ #define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
@ -33,5 +32,4 @@
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ #define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))

View File

@ -23,9 +23,7 @@ __global__ void rms_norm_kernel(
scalar_t* __restrict__ out, // [..., hidden_size] scalar_t* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size] const scalar_t* __restrict__ input, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size] const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const float epsilon, const int num_tokens, const int hidden_size) {
const int num_tokens,
const int hidden_size) {
__shared__ float s_variance; __shared__ float s_variance;
float variance = 0.0f; float variance = 0.0f;
@ -41,11 +39,11 @@ __global__ void rms_norm_kernel(
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)input[blockIdx.x * hidden_size + idx]; float x = (float)input[blockIdx.x * hidden_size + idx];
out[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx]; out[blockIdx.x * hidden_size + idx] =
((scalar_t)(x * s_variance)) * weight[idx];
} }
} }
/* Converter structs for the conversion from torch types to HIP/CUDA types, /* Converter structs for the conversion from torch types to HIP/CUDA types,
and the associated type conversions within HIP/CUDA. These helpers need and the associated type conversions within HIP/CUDA. These helpers need
to be implemented for now because the relevant type conversion to be implemented for now because the relevant type conversion
@ -57,7 +55,9 @@ __global__ void rms_norm_kernel(
If true, the struct should be fully defined as shown in the examples below. If true, the struct should be fully defined as shown in the examples below.
*/ */
template <typename torch_type> template <typename torch_type>
struct _typeConvert { static constexpr bool exists = false; }; struct _typeConvert {
static constexpr bool exists = false;
};
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)) #if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
// CUDA < 12.0 runs into issues with packed type conversion // CUDA < 12.0 runs into issues with packed type conversion
@ -68,9 +68,15 @@ struct _typeConvert<c10::Half> {
using packed_hip_type = __half2; using packed_hip_type = __half2;
__device__ static inline float convert(hip_type x) { return __half2float(x); } __device__ static inline float convert(hip_type x) { return __half2float(x); }
__device__ static inline float2 convert(packed_hip_type x) { return __half22float2(x); } __device__ static inline float2 convert(packed_hip_type x) {
__device__ static inline hip_type convert(float x) { return __float2half_rn(x); } return __half22float2(x);
__device__ static inline packed_hip_type convert(float2 x) { return __float22half2_rn(x); } }
__device__ static inline hip_type convert(float x) {
return __float2half_rn(x);
}
__device__ static inline packed_hip_type convert(float2 x) {
return __float22half2_rn(x);
}
}; };
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
@ -82,13 +88,22 @@ struct _typeConvert<c10::BFloat16> {
using hip_type = __nv_bfloat16; using hip_type = __nv_bfloat16;
using packed_hip_type = __nv_bfloat162; using packed_hip_type = __nv_bfloat162;
__device__ static inline float convert(hip_type x) { return __bfloat162float(x); } __device__ static inline float convert(hip_type x) {
__device__ static inline float2 convert(packed_hip_type x) { return __bfloat1622float2(x); } return __bfloat162float(x);
__device__ static inline hip_type convert(float x) { return __float2bfloat16(x); } }
__device__ static inline packed_hip_type convert(float2 x) { return __float22bfloat162_rn(x); } __device__ static inline float2 convert(packed_hip_type x) {
return __bfloat1622float2(x);
}
__device__ static inline hip_type convert(float x) {
return __float2bfloat16(x);
}
__device__ static inline packed_hip_type convert(float2 x) {
return __float22bfloat162_rn(x);
}
}; };
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)) #endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >=
// 12000))
/* Vector POD struct to generate vectorized and packed FP16/BF16 ops /* Vector POD struct to generate vectorized and packed FP16/BF16 ops
for appropriate specializations of fused_add_rms_norm_kernel. for appropriate specializations of fused_add_rms_norm_kernel.
@ -117,8 +132,7 @@ struct alignas(16) _f16Vec {
} }
} else { } else {
#pragma unroll #pragma unroll
for (int i = 0; i < width; ++i) for (int i = 0; i < width; ++i) data[i] += other.data[i];
data[i] += other.data[i];
} }
return *this; return *this;
} }
@ -134,8 +148,7 @@ struct alignas(16) _f16Vec {
} }
} else { } else {
#pragma unroll #pragma unroll
for (int i = 0; i < width; ++i) for (int i = 0; i < width; ++i) data[i] *= other.data[i];
data[i] *= other.data[i];
} }
return *this; return *this;
} }
@ -185,14 +198,12 @@ struct alignas(16) _f16Vec {
packed and vectorized operations, which help with the packed and vectorized operations, which help with the
memory latency bottleneck. */ memory latency bottleneck. */
template <typename scalar_t, int width> template <typename scalar_t, int width>
__global__ std::enable_if_t< __global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
(width > 0) && _typeConvert<scalar_t>::exists> fused_add_rms_norm_kernel( fused_add_rms_norm_kernel(
scalar_t* __restrict__ input, // [..., hidden_size] scalar_t* __restrict__ input, // [..., hidden_size]
scalar_t* __restrict__ residual, // [..., hidden_size] scalar_t* __restrict__ residual, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size] const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const float epsilon, const int num_tokens, const int hidden_size) {
const int num_tokens,
const int hidden_size) {
// Sanity checks on our vector struct and type-punned pointer arithmetic // Sanity checks on our vector struct and type-punned pointer arithmetic
static_assert(std::is_pod_v<_f16Vec<scalar_t, width>>); static_assert(std::is_pod_v<_f16Vec<scalar_t, width>>);
static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width); static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width);
@ -203,9 +214,12 @@ __global__ std::enable_if_t<
/* These and the argument pointers are all declared `restrict` as they are /* These and the argument pointers are all declared `restrict` as they are
not aliased in practice. Argument pointers should not be dereferenced not aliased in practice. Argument pointers should not be dereferenced
in this kernel as that would be undefined behavior */ in this kernel as that would be undefined behavior */
auto* __restrict__ input_v = reinterpret_cast<_f16Vec<scalar_t, width>*>(input); auto* __restrict__ input_v =
auto* __restrict__ residual_v = reinterpret_cast<_f16Vec<scalar_t, width>*>(residual); reinterpret_cast<_f16Vec<scalar_t, width>*>(input);
auto* __restrict__ weight_v = reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight); auto* __restrict__ residual_v =
reinterpret_cast<_f16Vec<scalar_t, width>*>(residual);
auto* __restrict__ weight_v =
reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight);
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
int id = blockIdx.x * vec_hidden_size + idx; int id = blockIdx.x * vec_hidden_size + idx;
@ -218,7 +232,8 @@ __global__ std::enable_if_t<
calculation of max_block_size in fused_add_rms_norm */ calculation of max_block_size in fused_add_rms_norm */
if (num_tokens < 256) { if (num_tokens < 256) {
variance = blockReduceSum<float, 1024>(variance); variance = blockReduceSum<float, 1024>(variance);
} else variance = blockReduceSum<float, 256>(variance); } else
variance = blockReduceSum<float, 256>(variance);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon); s_variance = rsqrtf(variance / hidden_size + epsilon);
} }
@ -233,19 +248,16 @@ __global__ std::enable_if_t<
} }
} }
/* Generic fused_add_rms_norm_kernel /* Generic fused_add_rms_norm_kernel
The width field is not used here but necessary for other specializations. The width field is not used here but necessary for other specializations.
*/ */
template <typename scalar_t, int width> template <typename scalar_t, int width>
__global__ std::enable_if_t< __global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
(width == 0) || !_typeConvert<scalar_t>::exists> fused_add_rms_norm_kernel( fused_add_rms_norm_kernel(
scalar_t* __restrict__ input, // [..., hidden_size] scalar_t* __restrict__ input, // [..., hidden_size]
scalar_t* __restrict__ residual, // [..., hidden_size] scalar_t* __restrict__ residual, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size] const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const float epsilon, const int num_tokens, const int hidden_size) {
const int num_tokens,
const int hidden_size) {
__shared__ float s_variance; __shared__ float s_variance;
float variance = 0.0f; float variance = 0.0f;
@ -260,7 +272,8 @@ __global__ std::enable_if_t<
calculation of max_block_size in fused_add_rms_norm */ calculation of max_block_size in fused_add_rms_norm */
if (num_tokens < 256) { if (num_tokens < 256) {
variance = blockReduceSum<float, 1024>(variance); variance = blockReduceSum<float, 1024>(variance);
} else variance = blockReduceSum<float, 256>(variance); } else
variance = blockReduceSum<float, 256>(variance);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon); s_variance = rsqrtf(variance / hidden_size + epsilon);
} }
@ -268,14 +281,14 @@ __global__ std::enable_if_t<
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)residual[blockIdx.x * hidden_size + idx]; float x = (float)residual[blockIdx.x * hidden_size + idx];
input[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx]; input[blockIdx.x * hidden_size + idx] =
((scalar_t)(x * s_variance)) * weight[idx];
} }
} }
} // namespace vllm } // namespace vllm
void rms_norm( void rms_norm(torch::Tensor& out, // [..., hidden_size]
torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size] torch::Tensor& input, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size] torch::Tensor& weight, // [hidden_size]
float epsilon) { float epsilon) {
@ -286,37 +299,24 @@ void rms_norm(
dim3 block(std::min(hidden_size, 1024)); dim3 block(std::min(hidden_size, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
input.scalar_type(),
"rms_norm_kernel",
[&] {
vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>( vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
weight.data_ptr<scalar_t>(),
epsilon,
num_tokens,
hidden_size);
}); });
} }
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \ #define LAUNCH_FUSED_ADD_RMS_NORM(width) \
VLLM_DISPATCH_FLOATING_TYPES( \ VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), \ input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
"fused_add_rms_norm_kernel", \ vllm::fused_add_rms_norm_kernel<scalar_t, width> \
[&] { \ <<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(), \
vllm::fused_add_rms_norm_kernel \
<scalar_t, width><<<grid, block, 0, stream>>>( \
input.data_ptr<scalar_t>(), \
residual.data_ptr<scalar_t>(), \ residual.data_ptr<scalar_t>(), \
weight.data_ptr<scalar_t>(), \ weight.data_ptr<scalar_t>(), epsilon, \
epsilon, \ num_tokens, hidden_size); \
num_tokens, \
hidden_size); \
}); });
void fused_add_rms_norm( void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size]
torch::Tensor& residual, // [..., hidden_size] torch::Tensor& residual, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size] torch::Tensor& weight, // [hidden_size]
float epsilon) { float epsilon) {
@ -342,8 +342,8 @@ void fused_add_rms_norm(
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr()); auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr()); auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr());
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr()); auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
bool ptrs_are_aligned = inp_ptr % 16 == 0 && res_ptr % 16 == 0 \ bool ptrs_are_aligned =
&& wt_ptr % 16 == 0; inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
if (ptrs_are_aligned && hidden_size % 8 == 0) { if (ptrs_are_aligned && hidden_size % 8 == 0) {
LAUNCH_FUSED_ADD_RMS_NORM(8); LAUNCH_FUSED_ADD_RMS_NORM(8);
} else { } else {

View File

@ -3,5 +3,6 @@
#include <torch/extension.h> #include <torch/extension.h>
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("topk_softmax", &topk_softmax, "Apply topk softmax to the gating outputs."); m.def("topk_softmax", &topk_softmax,
"Apply topk softmax to the gating outputs.");
} }

View File

@ -2,8 +2,6 @@
#include <torch/extension.h> #include <torch/extension.h>
void topk_softmax( void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices,
torch::Tensor& topk_weights,
torch::Tensor& topk_indices,
torch::Tensor& token_expert_indices, torch::Tensor& token_expert_indices,
torch::Tensor& gating_output); torch::Tensor& gating_output);

View File

@ -12,11 +12,12 @@
namespace vllm { namespace vllm {
namespace { namespace {
__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, int32_t col) { __device__ __forceinline__ int32_t index(int32_t total_col, int32_t row,
int32_t col) {
// don't worry about overflow because num_experts is relatively small // don't worry about overflow because num_experts is relatively small
return row * total_col + col; return row * total_col + col;
} }
} } // namespace
template <typename scalar_t> template <typename scalar_t>
__global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
@ -24,15 +25,17 @@ __global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids,
int32_t* expert_ids, int32_t* expert_ids,
int32_t* total_tokens_post_pad, int32_t* total_tokens_post_pad,
int32_t num_experts, int32_t num_experts,
int32_t block_size, int32_t block_size, size_t numel) {
size_t numel) {
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
const size_t start_idx = threadIdx.x * tokens_per_thread; const size_t start_idx = threadIdx.x * tokens_per_thread;
extern __shared__ int32_t shared_mem[]; extern __shared__ int32_t shared_mem[];
int32_t* tokens_cnts = shared_mem; // 2d tensor with shape (num_experts + 1, num_experts) int32_t* tokens_cnts =
int32_t* cumsum = shared_mem + (num_experts + 1) * num_experts; // 1d tensor with shape (num_experts + 1) shared_mem; // 2d tensor with shape (num_experts + 1, num_experts)
int32_t* cumsum =
shared_mem + (num_experts + 1) *
num_experts; // 1d tensor with shape (num_experts + 1)
for (int i = 0; i < num_experts; ++i) { for (int i = 0; i < num_experts; ++i) {
tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0; tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
@ -40,8 +43,8 @@ __global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids,
/** /**
* In the first step we compute token_cnts[thread_index + 1][expert_index], * In the first step we compute token_cnts[thread_index + 1][expert_index],
* which counts how many tokens in the token shard of thread_index are assigned * which counts how many tokens in the token shard of thread_index are
* to expert expert_index. * assigned to expert expert_index.
*/ */
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])]; ++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])];
@ -52,7 +55,8 @@ __global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids,
// For each expert we accumulate the token counts from the different threads. // For each expert we accumulate the token counts from the different threads.
tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0; tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
for (int i = 1; i <= blockDim.x; ++i) { for (int i = 1; i <= blockDim.x; ++i) {
tokens_cnts[index(num_experts, i, threadIdx.x)] += tokens_cnts[index(num_experts, i-1, threadIdx.x)]; tokens_cnts[index(num_experts, i, threadIdx.x)] +=
tokens_cnts[index(num_experts, i - 1, threadIdx.x)];
} }
__syncthreads(); __syncthreads();
@ -61,7 +65,10 @@ __global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids,
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
cumsum[0] = 0; cumsum[0] = 0;
for (int i = 1; i <= num_experts; ++i) { for (int i = 1; i <= num_experts; ++i) {
cumsum[i] = cumsum[i-1] + CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], block_size) * block_size; cumsum[i] = cumsum[i - 1] +
CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)],
block_size) *
block_size;
} }
*total_tokens_post_pad = cumsum[num_experts]; *total_tokens_post_pad = cumsum[num_experts];
} }
@ -69,57 +76,59 @@ __global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids,
__syncthreads(); __syncthreads();
/** /**
* For each expert, each thread processes the tokens of the corresponding blocks * For each expert, each thread processes the tokens of the corresponding
* and stores the corresponding expert_id for each block. * blocks and stores the corresponding expert_id for each block.
*/ */
for (int i = cumsum[threadIdx.x];i < cumsum[threadIdx.x + 1];i += block_size) { for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
i += block_size) {
expert_ids[i / block_size] = threadIdx.x; expert_ids[i / block_size] = threadIdx.x;
} }
/** /**
* Each thread processes a token shard, calculating the index of each token after * Each thread processes a token shard, calculating the index of each token
* sorting by expert number. Given the example topk_ids = [0,1,2,1,2,3,0,3,4] and * after sorting by expert number. Given the example topk_ids =
* block_size = 4, then the output would be [0, 6, *, *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], * [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *,
* where * represents a padding value(preset in python). * *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a
* padding value(preset in python).
*/ */
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
int32_t expert_id = topk_ids[i]; int32_t expert_id = topk_ids[i];
/** The cumsum[expert_id] stores the starting index of the tokens that the /** The cumsum[expert_id] stores the starting index of the tokens that the
* expert with expert_id needs to process, and tokens_cnts[threadIdx.x][expert_id] * expert with expert_id needs to process, and
* stores the indices of the tokens processed by the expert with expert_id within * tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens
* the current thread's token shard. * processed by the expert with expert_id within the current thread's token
* shard.
*/ */
int32_t rank_post_pad = tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + cumsum[expert_id]; int32_t rank_post_pad =
tokens_cnts[index(num_experts, threadIdx.x, expert_id)] +
cumsum[expert_id];
sorted_token_ids[rank_post_pad] = i; sorted_token_ids[rank_post_pad] = i;
++tokens_cnts[index(num_experts, threadIdx.x, expert_id)]; ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)];
} }
} }
} } // namespace vllm
void moe_align_block_size( void moe_align_block_size(torch::Tensor topk_ids, int num_experts,
torch::Tensor topk_ids, int block_size, torch::Tensor sorted_token_ids,
int num_experts,
int block_size,
torch::Tensor sorted_token_ids,
torch::Tensor experts_ids, torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad) { torch::Tensor num_tokens_post_pad) {
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_INTEGRAL_TYPES( VLLM_DISPATCH_INTEGRAL_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
// calc needed amount of shared mem for `tokens_cnts` and `cumsum` tensors // calc needed amount of shared mem for `tokens_cnts` and `cumsum`
const int32_t shared_mem = ((num_experts + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t); // tensors
const int32_t shared_mem =
((num_experts + 1) * num_experts + (num_experts + 1)) *
sizeof(int32_t);
// set dynamic shared mem // set dynamic shared mem
auto kernel = vllm::moe_align_block_size_kernel<scalar_t>; auto kernel = vllm::moe_align_block_size_kernel<scalar_t>;
AT_CUDA_CHECK( AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize((void *)kernel, shared_mem)); (void*)kernel, shared_mem));
kernel<<<1, num_experts, shared_mem, stream>>>( kernel<<<1, num_experts, shared_mem, stream>>>(
topk_ids.data_ptr<scalar_t>(), topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
sorted_token_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(), experts_ids.data_ptr<int32_t>(),
num_tokens_post_pad.data_ptr<int32_t>(), num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
num_experts,
block_size,
topk_ids.numel()); topk_ids.numel());
}); });
} }

View File

@ -2,204 +2,115 @@
#include <torch/extension.h> #include <torch/extension.h>
void paged_attention_v1( void paged_attention_v1(torch::Tensor& out, torch::Tensor& query,
torch::Tensor& out, torch::Tensor& key_cache, torch::Tensor& value_cache,
torch::Tensor& query, int num_kv_heads, float scale,
torch::Tensor& key_cache, torch::Tensor& block_tables, torch::Tensor& seq_lens,
torch::Tensor& value_cache, int block_size, int max_seq_len,
int num_kv_heads, const c10::optional<torch::Tensor>& alibi_slopes,
float scale, const std::string& kv_cache_dtype, float kv_scale);
torch::Tensor& block_tables,
torch::Tensor& seq_lens, void paged_attention_v2(torch::Tensor& out, torch::Tensor& exp_sums,
int block_size, torch::Tensor& max_logits, torch::Tensor& tmp_out,
torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads,
float scale, torch::Tensor& block_tables,
torch::Tensor& seq_lens, int block_size,
int max_seq_len, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, const std::string& kv_cache_dtype, float kv_scale);
float kv_scale);
void paged_attention_v2( void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
torch::Tensor& out,
torch::Tensor& exp_sums,
torch::Tensor& max_logits,
torch::Tensor& tmp_out,
torch::Tensor& query,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
int num_kv_heads,
float scale,
torch::Tensor& block_tables,
torch::Tensor& seq_lens,
int block_size,
int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype,
float kv_scale);
void rms_norm(
torch::Tensor& out,
torch::Tensor& input,
torch::Tensor& weight,
float epsilon); float epsilon);
void fused_add_rms_norm( void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
torch::Tensor& input, torch::Tensor& weight, float epsilon);
torch::Tensor& residual,
torch::Tensor& weight,
float epsilon);
void rotary_embedding( void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
torch::Tensor& positions, torch::Tensor& key, int head_size,
torch::Tensor& query, torch::Tensor& cos_sin_cache, bool is_neox);
torch::Tensor& key,
int head_size,
torch::Tensor& cos_sin_cache,
bool is_neox);
void batched_rotary_embedding( void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
torch::Tensor& positions, torch::Tensor& key, int head_size,
torch::Tensor& query, torch::Tensor& cos_sin_cache, bool is_neox,
torch::Tensor& key,
int head_size,
torch::Tensor& cos_sin_cache,
bool is_neox,
int rot_dim, int rot_dim,
torch::Tensor& cos_sin_cache_offsets); torch::Tensor& cos_sin_cache_offsets);
void silu_and_mul( void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
torch::Tensor& out,
torch::Tensor& input);
void gelu_and_mul( void gelu_and_mul(torch::Tensor& out, torch::Tensor& input);
torch::Tensor& out,
torch::Tensor& input);
void gelu_tanh_and_mul( void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);
torch::Tensor& out,
torch::Tensor& input);
void gelu_new( void gelu_new(torch::Tensor& out, torch::Tensor& input);
torch::Tensor& out,
torch::Tensor& input);
void gelu_fast( void gelu_fast(torch::Tensor& out, torch::Tensor& input);
torch::Tensor& out,
torch::Tensor& input);
#ifndef USE_ROCM #ifndef USE_ROCM
torch::Tensor aqlm_gemm( torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
const torch::Tensor& input,
const torch::Tensor& codes,
const torch::Tensor& codebooks, const torch::Tensor& codebooks,
const torch::Tensor& scales, const torch::Tensor& scales,
const torch::Tensor& codebook_partition_sizes, const torch::Tensor& codebook_partition_sizes,
const std::optional<torch::Tensor>& bias const std::optional<torch::Tensor>& bias);
);
torch::Tensor aqlm_dequant( torch::Tensor aqlm_dequant(const torch::Tensor& codes,
const torch::Tensor& codes,
const torch::Tensor& codebooks, const torch::Tensor& codebooks,
const torch::Tensor& codebook_partition_sizes const torch::Tensor& codebook_partition_sizes);
);
torch::Tensor awq_gemm( torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
torch::Tensor _in_feats, torch::Tensor _scaling_factors, torch::Tensor _zeros,
torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
int split_k_iters); int split_k_iters);
torch::Tensor awq_dequantize( torch::Tensor awq_dequantize(torch::Tensor _kernel,
torch::Tensor _kernel,
torch::Tensor _scaling_factors, torch::Tensor _scaling_factors,
torch::Tensor _zeros, torch::Tensor _zeros, int split_k_iters, int thx,
int split_k_iters,
int thx,
int thy); int thy);
torch::Tensor marlin_gemm( torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& a, torch::Tensor& b_scales, torch::Tensor& workspace,
torch::Tensor& b_q_weight, int64_t size_m, int64_t size_n, int64_t size_k);
torch::Tensor& b_scales,
torch::Tensor& workspace,
int64_t size_m,
int64_t size_n,
int64_t size_k);
torch::Tensor gptq_marlin_24_gemm( torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor &a,
torch::Tensor &b_q_weight,
torch::Tensor& b_meta, torch::Tensor& b_meta,
torch::Tensor& b_scales, torch::Tensor& b_scales,
torch::Tensor &workspace, torch::Tensor& workspace, int64_t num_bits,
int64_t num_bits, int64_t size_m, int64_t size_n,
int64_t size_m,
int64_t size_n,
int64_t size_k); int64_t size_k);
torch::Tensor gptq_marlin_gemm( torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor &a, torch::Tensor& b_scales, torch::Tensor& g_idx,
torch::Tensor &b_q_weight, torch::Tensor& perm, torch::Tensor& workspace,
torch::Tensor &b_scales, int64_t num_bits, int64_t size_m, int64_t size_n,
torch::Tensor &g_idx, int64_t size_k, bool is_k_full);
torch::Tensor &perm,
torch::Tensor &workspace,
int64_t num_bits,
int64_t size_m,
int64_t size_n,
int64_t size_k,
bool is_k_full);
torch::Tensor gptq_marlin_repack( torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
torch::Tensor &b_q_weight, int64_t size_k, int64_t size_n,
torch::Tensor &perm,
int64_t size_k,
int64_t size_n,
int64_t num_bits); int64_t num_bits);
int cutlass_scaled_mm_dq( int cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor& out, torch::Tensor const& b, torch::Tensor const& a_scales,
torch::Tensor const &a,
torch::Tensor const &b,
torch::Tensor const &a_scales,
torch::Tensor const& b_scales); torch::Tensor const& b_scales);
#endif #endif
void squeezellm_gemm( void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor vec,
torch::Tensor mat,
torch::Tensor mul,
torch::Tensor lookup_table); torch::Tensor lookup_table);
torch::Tensor gptq_gemm( torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
torch::Tensor a,
torch::Tensor b_q_weight,
torch::Tensor b_gptq_qzeros, torch::Tensor b_gptq_qzeros,
torch::Tensor b_gptq_scales, torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
torch::Tensor b_g_idx, bool use_exllama, int bit);
bool use_exllama,
int bit);
void gptq_shuffle( void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int bit);
torch::Tensor q_weight,
torch::Tensor q_perm,
int bit);
void static_scaled_fp8_quant( void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
torch::Tensor& out,
torch::Tensor& input,
torch::Tensor& scale); torch::Tensor& scale);
void dynamic_scaled_fp8_quant( void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
torch::Tensor& out,
torch::Tensor& input,
torch::Tensor& scale); torch::Tensor& scale);
void moe_align_block_size( void moe_align_block_size(torch::Tensor topk_ids, int num_experts,
torch::Tensor topk_ids, int block_size, torch::Tensor sorted_token_ids,
int num_experts,
int block_size,
torch::Tensor sorted_token_ids,
torch::Tensor experts_ids, torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad); torch::Tensor num_tokens_post_pad);
@ -219,7 +130,8 @@ int meta_size();
void register_buffer(fptr_t _fa, torch::Tensor& t, void register_buffer(fptr_t _fa, torch::Tensor& t,
const std::vector<std::string>& handles, const std::vector<std::string>& handles,
const std::vector<int64_t>& offsets); const std::vector<int64_t>& offsets);
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa); std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(
fptr_t _fa);
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles, void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
const std::vector<std::vector<int64_t>>& offsets); const std::vector<std::vector<int64_t>>& offsets);
#endif #endif

View File

@ -9,12 +9,8 @@ namespace vllm {
template <typename scalar_t, bool IS_NEOX> template <typename scalar_t, bool IS_NEOX>
inline __device__ void apply_token_rotary_embedding( inline __device__ void apply_token_rotary_embedding(
scalar_t* __restrict__ arr, scalar_t* __restrict__ arr, const scalar_t* __restrict__ cos_ptr,
const scalar_t* __restrict__ cos_ptr, const scalar_t* __restrict__ sin_ptr, int rot_offset, int embed_dim) {
const scalar_t* __restrict__ sin_ptr,
int rot_offset,
int embed_dim)
{
int x_index, y_index; int x_index, y_index;
scalar_t cos, sin; scalar_t cos, sin;
if (IS_NEOX) { if (IS_NEOX) {
@ -39,17 +35,15 @@ inline __device__ void apply_token_rotary_embedding(
template <typename scalar_t, bool IS_NEOX> template <typename scalar_t, bool IS_NEOX>
inline __device__ void apply_rotary_embedding( inline __device__ void apply_rotary_embedding(
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] // head_size] or [num_tokens, num_heads,
const scalar_t* cache_ptr, // head_size]
const int head_size, scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
const int num_heads, // head_size] or [num_tokens, num_kv_heads,
const int num_kv_heads, // head_size]
const int rot_dim, const scalar_t* cache_ptr, const int head_size, const int num_heads,
const int token_idx, const int num_kv_heads, const int rot_dim, const int token_idx,
const int64_t query_stride, const int64_t query_stride, const int64_t key_stride) {
const int64_t key_stride)
{
const int embed_dim = rot_dim / 2; const int embed_dim = rot_dim / 2;
const scalar_t* cos_ptr = cache_ptr; const scalar_t* cos_ptr = cache_ptr;
const scalar_t* sin_ptr = cache_ptr + embed_dim; const scalar_t* sin_ptr = cache_ptr + embed_dim;
@ -59,8 +53,8 @@ inline __device__ void apply_rotary_embedding(
const int head_idx = i / embed_dim; const int head_idx = i / embed_dim;
const int64_t token_head = token_idx * query_stride + head_idx * head_size; const int64_t token_head = token_idx * query_stride + head_idx * head_size;
const int rot_offset = i % embed_dim; const int rot_offset = i % embed_dim;
apply_token_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr, apply_token_rotary_embedding<scalar_t, IS_NEOX>(
sin_ptr, rot_offset, embed_dim); query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
} }
const int nk = num_kv_heads * embed_dim; const int nk = num_kv_heads * embed_dim;
@ -68,59 +62,71 @@ inline __device__ void apply_rotary_embedding(
const int head_idx = i / embed_dim; const int head_idx = i / embed_dim;
const int64_t token_head = token_idx * key_stride + head_idx * head_size; const int64_t token_head = token_idx * key_stride + head_idx * head_size;
const int rot_offset = i % embed_dim; const int rot_offset = i % embed_dim;
apply_token_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr, apply_token_rotary_embedding<scalar_t, IS_NEOX>(
sin_ptr, rot_offset, embed_dim); key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
} }
} }
template <typename scalar_t, bool IS_NEOX> template <typename scalar_t, bool IS_NEOX>
__global__ void rotary_embedding_kernel( __global__ void rotary_embedding_kernel(
const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] const int64_t* __restrict__ positions, // [batch_size, seq_len] or
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] // [num_tokens]
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] // head_size] or [num_tokens, num_heads,
const int rot_dim, // head_size]
const int64_t query_stride, scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
const int64_t key_stride, // head_size] or [num_tokens, num_kv_heads,
const int num_heads, // head_size]
const int num_kv_heads, const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
const int head_size) { // 2]
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
const int num_heads, const int num_kv_heads, const int head_size) {
// Each thread block is responsible for one token. // Each thread block is responsible for one token.
const int token_idx = blockIdx.x; const int token_idx = blockIdx.x;
int64_t pos = positions[token_idx]; int64_t pos = positions[token_idx];
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
apply_rotary_embedding<scalar_t, IS_NEOX>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride); apply_rotary_embedding<scalar_t, IS_NEOX>(
query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
token_idx, query_stride, key_stride);
} }
template <typename scalar_t, bool IS_NEOX> template <typename scalar_t, bool IS_NEOX>
__global__ void batched_rotary_embedding_kernel( __global__ void batched_rotary_embedding_kernel(
const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] const int64_t* __restrict__ positions, // [batch_size, seq_len] or
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] // [num_tokens]
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] // head_size] or [num_tokens, num_heads,
const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len] or [num_tokens] // head_size]
const int rot_dim, scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
const int64_t query_stride, // head_size] or [num_tokens, num_kv_heads,
const int64_t key_stride, // head_size]
const int num_heads, const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
const int num_kv_heads, // 2]
const int head_size) { const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len]
// or [num_tokens]
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
const int num_heads, const int num_kv_heads, const int head_size) {
// Each thread block is responsible for one token. // Each thread block is responsible for one token.
const int token_idx = blockIdx.x; const int token_idx = blockIdx.x;
int64_t pos = positions[token_idx]; int64_t pos = positions[token_idx];
int64_t cos_sin_cache_offset = cos_sin_cache_offsets[token_idx]; int64_t cos_sin_cache_offset = cos_sin_cache_offsets[token_idx];
const scalar_t* cache_ptr = cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim; const scalar_t* cache_ptr =
cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim;
apply_rotary_embedding<scalar_t, IS_NEOX>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride); apply_rotary_embedding<scalar_t, IS_NEOX>(
query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
token_idx, query_stride, key_stride);
} }
} // namespace vllm } // namespace vllm
void rotary_embedding( void rotary_embedding(
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size] torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size] // [num_tokens, num_heads * head_size]
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
// [num_tokens, num_kv_heads * head_size]
int head_size, int head_size,
torch::Tensor& cos_sin_cache, // [max_position, rot_dim] torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
bool is_neox) { bool is_neox) {
@ -135,33 +141,18 @@ void rotary_embedding(
dim3 block(std::min(num_heads * rot_dim / 2, 512)); dim3 block(std::min(num_heads * rot_dim / 2, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] {
query.scalar_type(),
"rotary_embedding",
[&] {
if (is_neox) { if (is_neox) {
vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>( vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(), positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
query.data_ptr<scalar_t>(), key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(), rot_dim,
key.data_ptr<scalar_t>(), query_stride, key_stride, num_heads, num_kv_heads, head_size);
cos_sin_cache.data_ptr<scalar_t>(),
rot_dim,
query_stride,
key_stride,
num_heads,
num_kv_heads,
head_size);
} else { } else {
vllm::rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>( vllm::rotary_embedding_kernel<scalar_t, false>
positions.data_ptr<int64_t>(), <<<grid, block, 0, stream>>>(
query.data_ptr<scalar_t>(), positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(), key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
cos_sin_cache.data_ptr<scalar_t>(), rot_dim, query_stride, key_stride, num_heads, num_kv_heads,
rot_dim,
query_stride,
key_stride,
num_heads,
num_kv_heads,
head_size); head_size);
} }
}); });
@ -173,12 +164,13 @@ and process in batched manner.
*/ */
void batched_rotary_embedding( void batched_rotary_embedding(
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size] torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size] // [num_tokens, num_heads * head_size]
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
// [num_tokens, num_kv_heads * head_size]
int head_size, int head_size,
torch::Tensor& cos_sin_cache, // [max_position, rot_dim] torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
bool is_neox, bool is_neox, int rot_dim,
int rot_dim,
torch::Tensor& cos_sin_cache_offsets // [num_tokens] torch::Tensor& cos_sin_cache_offsets // [num_tokens]
) { ) {
int64_t num_tokens = cos_sin_cache_offsets.size(0); int64_t num_tokens = cos_sin_cache_offsets.size(0);
@ -191,36 +183,21 @@ void batched_rotary_embedding(
dim3 block(std::min(num_heads * rot_dim / 2, 512)); dim3 block(std::min(num_heads * rot_dim / 2, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] {
query.scalar_type(),
"rotary_embedding",
[&] {
if (is_neox) { if (is_neox) {
vllm::batched_rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>( vllm::batched_rotary_embedding_kernel<scalar_t, true>
positions.data_ptr<int64_t>(), <<<grid, block, 0, stream>>>(
query.data_ptr<scalar_t>(), positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(), key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
cos_sin_cache.data_ptr<scalar_t>(), cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
cos_sin_cache_offsets.data_ptr<int64_t>(), key_stride, num_heads, num_kv_heads, head_size);
rot_dim,
query_stride,
key_stride,
num_heads,
num_kv_heads,
head_size);
} else { } else {
vllm::batched_rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>( vllm::batched_rotary_embedding_kernel<scalar_t, false>
positions.data_ptr<int64_t>(), <<<grid, block, 0, stream>>>(
query.data_ptr<scalar_t>(), positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(), key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
cos_sin_cache.data_ptr<scalar_t>(), cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
cos_sin_cache_offsets.data_ptr<int64_t>(), key_stride, num_heads, num_kv_heads, head_size);
rot_dim,
query_stride,
key_stride,
num_heads,
num_kv_heads,
head_size);
} }
}); });
} }

View File

@ -8,114 +8,85 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
pybind11::module ops = m.def_submodule("ops", "vLLM custom operators"); pybind11::module ops = m.def_submodule("ops", "vLLM custom operators");
// Attention ops // Attention ops
ops.def( ops.def("paged_attention_v1", &paged_attention_v1,
"paged_attention_v1", "Compute the attention between an input query and the cached "
&paged_attention_v1, "keys/values using PagedAttention.");
"Compute the attention between an input query and the cached keys/values using PagedAttention."); ops.def("paged_attention_v2", &paged_attention_v2, "PagedAttention V2.");
ops.def(
"paged_attention_v2",
&paged_attention_v2,
"PagedAttention V2.");
// Activation ops // Activation ops
ops.def( ops.def("silu_and_mul", &silu_and_mul, "Activation function used in SwiGLU.");
"silu_and_mul", ops.def("gelu_and_mul", &gelu_and_mul,
&silu_and_mul,
"Activation function used in SwiGLU.");
ops.def(
"gelu_and_mul",
&gelu_and_mul,
"Activation function used in GeGLU with `none` approximation."); "Activation function used in GeGLU with `none` approximation.");
ops.def( ops.def("gelu_tanh_and_mul", &gelu_tanh_and_mul,
"gelu_tanh_and_mul",
&gelu_tanh_and_mul,
"Activation function used in GeGLU with `tanh` approximation."); "Activation function used in GeGLU with `tanh` approximation.");
ops.def( ops.def("gelu_new", &gelu_new, "GELU implementation used in GPT-2.");
"gelu_new", ops.def("gelu_fast", &gelu_fast, "Approximate GELU implementation.");
&gelu_new,
"GELU implementation used in GPT-2.");
ops.def(
"gelu_fast",
&gelu_fast,
"Approximate GELU implementation.");
// Layernorm // Layernorm
ops.def( ops.def("rms_norm", &rms_norm,
"rms_norm",
&rms_norm,
"Apply Root Mean Square (RMS) Normalization to the input tensor."); "Apply Root Mean Square (RMS) Normalization to the input tensor.");
ops.def( ops.def("fused_add_rms_norm", &fused_add_rms_norm,
"fused_add_rms_norm",
&fused_add_rms_norm,
"In-place fused Add and RMS Normalization"); "In-place fused Add and RMS Normalization");
// Rotary embedding // Rotary embedding
ops.def( ops.def("rotary_embedding", &rotary_embedding,
"rotary_embedding",
&rotary_embedding,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key"); "Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
ops.def( ops.def("batched_rotary_embedding", &batched_rotary_embedding,
"batched_rotary_embedding", "Apply GPT-NeoX or GPT-J style rotary embedding to query and key "
&batched_rotary_embedding, "(supports multiple loras)");
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key (supports multiple loras)");
// Quantization ops // Quantization ops
#ifndef USE_ROCM #ifndef USE_ROCM
ops.def("aqlm_gemm", &aqlm_gemm, "Quantized GEMM for AQLM"); ops.def("aqlm_gemm", &aqlm_gemm, "Quantized GEMM for AQLM");
ops.def("aqlm_dequant", &aqlm_dequant, "Decompression method for AQLM"); ops.def("aqlm_dequant", &aqlm_dequant, "Decompression method for AQLM");
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ"); ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
ops.def("marlin_gemm", &marlin_gemm, "Marlin (Dense) Optimized Quantized GEMM for GPTQ"); ops.def("marlin_gemm", &marlin_gemm,
ops.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm, "Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ"); "Marlin (Dense) Optimized Quantized GEMM for GPTQ");
ops.def("gptq_marlin_gemm", &gptq_marlin_gemm, "gptq_marlin Optimized Quantized GEMM for GPTQ"); ops.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm,
ops.def("gptq_marlin_repack", &gptq_marlin_repack, "gptq_marlin repack from GPTQ"); "Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ");
ops.def("gptq_marlin_gemm", &gptq_marlin_gemm,
"gptq_marlin Optimized Quantized GEMM for GPTQ");
ops.def("gptq_marlin_repack", &gptq_marlin_repack,
"gptq_marlin repack from GPTQ");
ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ"); ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
ops.def("cutlass_scaled_mm_dq", &cutlass_scaled_mm_dq, "CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column quantization."); ops.def("cutlass_scaled_mm_dq", &cutlass_scaled_mm_dq,
"CUTLASS w8a8 GEMM, supporting symmetric per-tensor or "
"per-row/column quantization.");
#endif #endif
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ"); ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ"); ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM"); ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
ops.def("static_scaled_fp8_quant", &static_scaled_fp8_quant, "Compute FP8 quantized tensor for given scaling factor"); ops.def("static_scaled_fp8_quant", &static_scaled_fp8_quant,
ops.def("dynamic_scaled_fp8_quant", &dynamic_scaled_fp8_quant, "Compute FP8 quantized tensor and scaling factor"); "Compute FP8 quantized tensor for given scaling factor");
ops.def( ops.def("dynamic_scaled_fp8_quant", &dynamic_scaled_fp8_quant,
"moe_align_block_size", "Compute FP8 quantized tensor and scaling factor");
&moe_align_block_size, ops.def("moe_align_block_size", &moe_align_block_size,
"Aligning the number of tokens to be processed by each expert such that it is divisible by the block size."); "Aligning the number of tokens to be processed by each expert such "
"that it is divisible by the block size.");
// Cache ops // Cache ops
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
cache_ops.def( cache_ops.def("swap_blocks", &swap_blocks,
"swap_blocks",
&swap_blocks,
"Swap in (out) the cache blocks from src to dst"); "Swap in (out) the cache blocks from src to dst");
cache_ops.def( cache_ops.def("copy_blocks", &copy_blocks,
"copy_blocks",
&copy_blocks,
"Copy the cache blocks from src to dst"); "Copy the cache blocks from src to dst");
cache_ops.def( cache_ops.def("reshape_and_cache", &reshape_and_cache,
"reshape_and_cache",
&reshape_and_cache,
"Reshape the key and value tensors and cache them"); "Reshape the key and value tensors and cache them");
cache_ops.def( cache_ops.def("reshape_and_cache_flash", &reshape_and_cache_flash,
"reshape_and_cache_flash",
&reshape_and_cache_flash,
"Reshape the key and value tensors and cache them"); "Reshape the key and value tensors and cache them");
cache_ops.def( cache_ops.def("convert_fp8", &convert_fp8,
"convert_fp8",
&convert_fp8,
"Convert the key and value cache to fp8 data type"); "Convert the key and value cache to fp8 data type");
// Cuda utils // Cuda utils
pybind11::module cuda_utils = m.def_submodule("cuda_utils", "vLLM cuda utils"); pybind11::module cuda_utils =
cuda_utils.def( m.def_submodule("cuda_utils", "vLLM cuda utils");
"get_device_attribute", cuda_utils.def("get_device_attribute", &get_device_attribute,
&get_device_attribute,
"Gets the specified device attribute."); "Gets the specified device attribute.");
cuda_utils.def( cuda_utils.def("get_max_shared_memory_per_block_device_attribute",
"get_max_shared_memory_per_block_device_attribute",
&get_max_shared_memory_per_block_device_attribute, &get_max_shared_memory_per_block_device_attribute,
"Gets the maximum shared memory per block device attribute."); "Gets the maximum shared memory per block device attribute.");
@ -134,5 +105,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
custom_ar.def("register_graph_buffers", &register_graph_buffers, custom_ar.def("register_graph_buffers", &register_graph_buffers,
"register_graph_buffers"); "register_graph_buffers");
#endif #endif
} }

View File

@ -25,30 +25,26 @@
#include <iostream> #include <iostream>
#include <cstdlib> #include <cstdlib>
namespace vllm { namespace vllm {
namespace aqlm { namespace aqlm {
__global__ void Code1x16MatVec( __global__ void Code1x16MatVec(
const int4* __restrict__ A, const int4* __restrict__ A, const int4* __restrict__ B,
const int4* __restrict__ B, int4* __restrict__ C, const int4* __restrict__ codebook, const int prob_m,
int4* __restrict__ C,
const int4* __restrict__ codebook,
const int prob_m,
const int prob_k, const int prob_k,
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long. const int4 codebook_a_sizes, // cumulative sizes of A spanning each
// codebook, at most 3 long.
const int codebook_stride // as int4. const int codebook_stride // as int4.
) { ) {
int a_gl_stride = prob_k / 8 / 8; int a_gl_stride = prob_k / 8 / 8;
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
bool pred = a_gl_rd < prob_m; bool pred = a_gl_rd < prob_m;
if (pred) if (pred) {
{ // advance to the correct codebook, this easy because we only multiply one
// advance to the correct codebook, this easy because we only multiply one column of the codebook. // column of the codebook.
auto codebook_size = &codebook_a_sizes.x; auto codebook_size = &codebook_a_sizes.x;
while (a_gl_rd >= *codebook_size) while (a_gl_rd >= *codebook_size) {
{
codebook += codebook_stride; codebook += codebook_stride;
++codebook_size; ++codebook_size;
} }
@ -67,8 +63,7 @@ __global__ void Code1x16MatVec(
// We pad shared memory to avoid bank conflicts during reads // We pad shared memory to avoid bank conflicts during reads
__syncthreads(); __syncthreads();
for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) { for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) {
if (b_gl_rd + i < prob_k / 8) if (b_gl_rd + i < prob_k / 8) sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i];
sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i];
} }
__syncthreads(); __syncthreads();
b_gl_rd += 32 * 8; b_gl_rd += 32 * 8;
@ -79,19 +74,16 @@ __global__ void Code1x16MatVec(
#pragma unroll #pragma unroll
for (int i = 0; i < 8; i++) { for (int i = 0; i < 8; i++) {
uint32_t dec[4]; uint32_t dec[4];
// We bypass the L1 cache to avoid massive amounts of memory streaming that doesn't // We bypass the L1 cache to avoid massive amounts of memory streaming
// actually help us; this brings > 2x speedup. // that doesn't actually help us; this brings > 2x speedup.
asm volatile ( asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
"ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
: "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3]) : "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3])
: "l"((void*) &codebook[enc[i]]) : "l"((void*)&codebook[enc[i]]));
);
half2* a = reinterpret_cast<half2*>(&dec); half2* a = reinterpret_cast<half2*>(&dec);
half2* b = reinterpret_cast<half2*>(&sh_b[b_sh_rd]); half2* b = reinterpret_cast<half2*>(&sh_b[b_sh_rd]);
half2 res2 = {}; half2 res2 = {};
#pragma unroll #pragma unroll
for (int j = 0; j < 4; j++) for (int j = 0; j < 4; j++) res2 = __hfma2(a[j], b[j], res2);
res2 = __hfma2(a[j], b[j], res2);
res += __half2float(res2.x) + __half2float(res2.y); res += __half2float(res2.x) + __half2float(res2.y);
b_sh_rd++; b_sh_rd++;
} }
@ -101,21 +93,18 @@ __global__ void Code1x16MatVec(
if (pred) { if (pred) {
#pragma unroll #pragma unroll
for (int i = 16; i > 0; i /= 2) for (int i = 16; i > 0; i /= 2) res += __shfl_down_sync(0xffffffff, res, i);
res += __shfl_down_sync(0xffffffff, res, i);
if (threadIdx.x % 32 == 0) if (threadIdx.x % 32 == 0)
reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res); reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res);
} }
} }
__global__ void Code2x8MatVec( __global__ void Code2x8MatVec(
const int4* __restrict__ A, const int4* __restrict__ A, const int4* __restrict__ B,
const int4* __restrict__ B, int4* __restrict__ C, const int4* __restrict__ codebook, int prob_m,
int4* __restrict__ C,
const int4* __restrict__ codebook,
int prob_m,
int prob_k, int prob_k,
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long. const int4 codebook_a_sizes, // cumulative sizes of A spanning each
// codebook, at most 3 long.
const int codebook_stride // as int4. const int codebook_stride // as int4.
) { ) {
@ -123,12 +112,11 @@ __global__ void Code2x8MatVec(
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
bool pred = a_gl_rd < prob_m; bool pred = a_gl_rd < prob_m;
if (pred) if (pred) {
{ // advance to the correct codebook, this easy because we only multiply one
// advance to the correct codebook, this easy because we only multiply one column of the codebook. // column of the codebook.
auto codebook_size = &codebook_a_sizes.x; auto codebook_size = &codebook_a_sizes.x;
while (a_gl_rd >= *codebook_size) while (a_gl_rd >= *codebook_size) {
{
codebook += codebook_stride; codebook += codebook_stride;
++codebook_size; ++codebook_size;
} }
@ -149,8 +137,7 @@ __global__ void Code2x8MatVec(
for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) { for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) {
int4 dec = codebook[i]; int4 dec = codebook[i];
#pragma unroll #pragma unroll
for (int j = 0; j < 8; j++) for (int j = 0; j < 8; j++) sh_code[8 * i + (j + lane) % 8] = dec;
sh_code[8 * i + (j + lane) % 8] = dec;
} }
__syncthreads(); __syncthreads();
@ -161,8 +148,7 @@ __global__ void Code2x8MatVec(
// We pad shared memory to avoid bank conflicts during reads // We pad shared memory to avoid bank conflicts during reads
__syncthreads(); __syncthreads();
for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) { for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) {
if (b_gl_rd + i < prob_k / 8) if (b_gl_rd + i < prob_k / 8) sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i];
sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i];
} }
__syncthreads(); __syncthreads();
b_gl_rd += 32 * 8; b_gl_rd += 32 * 8;
@ -172,8 +158,10 @@ __global__ void Code2x8MatVec(
const uint8_t* enc = reinterpret_cast<const uint8_t*>(&A[a_gl_rd]); const uint8_t* enc = reinterpret_cast<const uint8_t*>(&A[a_gl_rd]);
#pragma unroll #pragma unroll
for (int i = 0; i < 8; i++) { for (int i = 0; i < 8; i++) {
half2* a0 = reinterpret_cast<half2*>(&sh_code0[8 * enc[2 * i + 0] + lane]); half2* a0 =
half2* a1 = reinterpret_cast<half2*>(&sh_code1[8 * enc[2 * i + 1] + lane]); reinterpret_cast<half2*>(&sh_code0[8 * enc[2 * i + 0] + lane]);
half2* a1 =
reinterpret_cast<half2*>(&sh_code1[8 * enc[2 * i + 1] + lane]);
half2* b = reinterpret_cast<half2*>(&sh_b[b_sh_rd]); half2* b = reinterpret_cast<half2*>(&sh_b[b_sh_rd]);
half2 res2 = {}; half2 res2 = {};
#pragma unroll #pragma unroll
@ -188,33 +176,28 @@ __global__ void Code2x8MatVec(
if (pred) { if (pred) {
#pragma unroll #pragma unroll
for (int i = 16; i > 0; i /= 2) for (int i = 16; i > 0; i /= 2) res += __shfl_down_sync(0xffffffff, res, i);
res += __shfl_down_sync(0xffffffff, res, i);
if (threadIdx.x % 32 == 0) if (threadIdx.x % 32 == 0)
reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res); reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res);
} }
} }
__global__ void Code1x16Dequant( __global__ void Code1x16Dequant(
const int4* __restrict__ A, const int4* __restrict__ A, int4* __restrict__ C,
int4* __restrict__ C, const int4* __restrict__ codebook, int prob_m, int prob_k,
const int4* __restrict__ codebook, const int4 codebook_a_sizes, // cumulative sizes of A spanning each
int prob_m, // codebook, at most 3 long, sums to m.
int prob_k,
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long, sums to m.
const int codebook_stride // as int4 const int codebook_stride // as int4
) { ) {
int a_gl_stride = prob_k / 8 / 8; int a_gl_stride = prob_k / 8 / 8;
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
bool pred = a_gl_rd < prob_m; bool pred = a_gl_rd < prob_m;
if (pred) if (pred) {
{ // advance to the correct codebook, this easy because we only multiply one
// advance to the correct codebook, this easy because we only multiply one column of the codebook. // column of the codebook.
auto codebook_size = &codebook_a_sizes.x; auto codebook_size = &codebook_a_sizes.x;
while (a_gl_rd >= *codebook_size) while (a_gl_rd >= *codebook_size) {
{
codebook += codebook_stride; codebook += codebook_stride;
++codebook_size; ++codebook_size;
} }
@ -235,13 +218,11 @@ __global__ void Code1x16Dequant(
for (int i = 0; i < 8; i++) { for (int i = 0; i < 8; i++) {
int4 chunk; int4 chunk;
auto dec = reinterpret_cast<uint32_t*>(&chunk); auto dec = reinterpret_cast<uint32_t*>(&chunk);
// We bypass the L1 cache to avoid massive amounts of memory streaming that doesn't // We bypass the L1 cache to avoid massive amounts of memory streaming
// actually help us; this brings > 2x speedup. // that doesn't actually help us; this brings > 2x speedup.
asm volatile ( asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
"ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
: "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3]) : "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3])
: "l"((void*) &codebook[enc[i]]) : "l"((void*)&codebook[enc[i]]));
);
C[a_gl_rd * 8 + i] = chunk; C[a_gl_rd * 8 + i] = chunk;
} }
@ -250,26 +231,23 @@ __global__ void Code1x16Dequant(
} }
} }
__global__ void Code2x8Dequant( __global__ void Code2x8Dequant(
const int4* __restrict__ A, const int4* __restrict__ A, int4* __restrict__ C,
int4* __restrict__ C, const int4* __restrict__ codebook, int prob_m, int prob_k,
const int4* __restrict__ codebook, const int4
int prob_m, codebook_a_sizes, // cumulative sizes of A spanning each codebook, at
int prob_k, // most 3 long, corresponds to cols.
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long, corresponds to cols.
const int codebook_stride // as int4 const int codebook_stride // as int4
) { ) {
int a_gl_stride = prob_k / 8 / 8; int a_gl_stride = prob_k / 8 / 8;
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
bool pred = a_gl_rd < prob_m; bool pred = a_gl_rd < prob_m;
if (pred) if (pred) {
{ // advance to the correct codebook, this easy because we only multiply one
// advance to the correct codebook, this easy because we only multiply one column of the codebook. // column of the codebook.
auto codebook_size = &codebook_a_sizes.x; auto codebook_size = &codebook_a_sizes.x;
while (a_gl_rd >= *codebook_size) while (a_gl_rd >= *codebook_size) {
{
codebook += codebook_stride; codebook += codebook_stride;
++codebook_size; ++codebook_size;
} }
@ -291,8 +269,7 @@ __global__ void Code2x8Dequant(
for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) { for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) {
int4 dec = codebook[i]; int4 dec = codebook[i];
#pragma unroll #pragma unroll
for (int j = 0; j < 8; j++) for (int j = 0; j < 8; j++) sh_code[8 * i + (j + lane) % 8] = dec;
sh_code[8 * i + (j + lane) % 8] = dec;
} }
__syncthreads(); __syncthreads();
@ -305,8 +282,10 @@ __global__ void Code2x8Dequant(
#pragma unroll #pragma unroll
for (int i = 0; i < 8; i++) { for (int i = 0; i < 8; i++) {
int4 chunk; int4 chunk;
half2* a0 = reinterpret_cast<half2*>(&sh_code0[8 * enc[2 * i + 0] + lane]); half2* a0 =
half2* a1 = reinterpret_cast<half2*>(&sh_code1[8 * enc[2 * i + 1] + lane]); reinterpret_cast<half2*>(&sh_code0[8 * enc[2 * i + 0] + lane]);
half2* a1 =
reinterpret_cast<half2*>(&sh_code1[8 * enc[2 * i + 1] + lane]);
#pragma unroll #pragma unroll
for (int j = 0; j < 4; j++) for (int j = 0; j < 4; j++)
reinterpret_cast<half2*>(&chunk)[j] = __hadd2(a0[j], a1[j]); reinterpret_cast<half2*>(&chunk)[j] = __hadd2(a0[j], a1[j]);
@ -317,22 +296,15 @@ __global__ void Code2x8Dequant(
} }
} }
inline int ceildiv(int a, int b) { inline int ceildiv(int a, int b) { return (a + b - 1) / b; }
return (a + b - 1) / b;
}
const int THREAD_M = 16; const int THREAD_M = 16;
void code1x16_matvec_cuda( void code1x16_matvec_cuda(const void* __restrict__ A,
const void* __restrict__ A, const void* __restrict__ B, void* __restrict__ C,
const void* __restrict__ B, const void* __restrict__ codebook, int prob_m,
void* __restrict__ C, int prob_k, const int4 codebook_a_sizes,
const void* __restrict__ codebook, const int codebook_stride) {
int prob_m,
int prob_k,
const int4 codebook_a_sizes,
const int codebook_stride
) {
int sms; int sms;
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0); cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
int waves = 0; int waves = 0;
@ -346,27 +318,15 @@ void code1x16_matvec_cuda(
int threads = 32 * thread_m; int threads = 32 * thread_m;
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
Code1x16MatVec<<<blocks, threads, 16 * 32 * 9, stream>>>( Code1x16MatVec<<<blocks, threads, 16 * 32 * 9, stream>>>(
(const int4*) A, (const int4*)A, (const int4*)B, (int4*)C, (const int4*)codebook, prob_m,
(const int4*) B, prob_k, codebook_a_sizes, codebook_stride);
(int4*) C,
(const int4*) codebook,
prob_m,
prob_k,
codebook_a_sizes,
codebook_stride
);
} }
void code2x8_matvec_cuda( void code2x8_matvec_cuda(const void* __restrict__ A, const void* __restrict__ B,
const void* __restrict__ A,
const void* __restrict__ B,
void* __restrict__ C, void* __restrict__ C,
const void* __restrict__ codebook, const void* __restrict__ codebook, int prob_m,
int prob_m, int prob_k, const int4 codebook_a_sizes,
int prob_k, const int codebook_stride) {
const int4 codebook_a_sizes,
const int codebook_stride
) {
int sms; int sms;
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0); cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
int waves = 0; int waves = 0;
@ -379,29 +339,19 @@ void code2x8_matvec_cuda(
int blocks = ceildiv(prob_m, thread_m); int blocks = ceildiv(prob_m, thread_m);
int threads = 32 * thread_m; int threads = 32 * thread_m;
int shared = 16 * (2 * 256 * 8 + 32 * 9); int shared = 16 * (2 * 256 * 8 + 32 * 9);
cudaFuncSetAttribute( cudaFuncSetAttribute(Code2x8MatVec,
Code2x8MatVec, cudaFuncAttributeMaxDynamicSharedMemorySize, shared cudaFuncAttributeMaxDynamicSharedMemorySize, shared);
);
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
Code2x8MatVec<<<blocks, threads, shared, stream>>>( Code2x8MatVec<<<blocks, threads, shared, stream>>>(
(const int4*) A, (const int4*)A, (const int4*)B, (int4*)C, (const int4*)codebook, prob_m,
(const int4*) B, prob_k, codebook_a_sizes, codebook_stride);
(int4*) C,
(const int4*) codebook,
prob_m,
prob_k,
codebook_a_sizes,
codebook_stride
);
} }
void code1x16_dequant_cuda( void code1x16_dequant_cuda(
const void* __restrict__ A, const void* __restrict__ A, void* __restrict__ C,
void* __restrict__ C, const void* __restrict__ codebook, int prob_m, int prob_k,
const void* __restrict__ codebook, const int4 codebook_a_sizes, // cumulative sizes of A spanning each
int prob_m, // codebook, at most 3 long.
int prob_k,
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long.
const int codebook_stride // as int4. const int codebook_stride // as int4.
) { ) {
int sms; int sms;
@ -417,24 +367,20 @@ void code1x16_dequant_cuda(
int threads = 32 * thread_m; int threads = 32 * thread_m;
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
Code1x16Dequant<<<blocks, threads, 0, stream>>>( Code1x16Dequant<<<blocks, threads, 0, stream>>>(
(const int4*) A, (const int4*)A, (int4*)C, (const int4*)codebook, prob_m, prob_k,
(int4*) C, codebook_a_sizes, // cumulative sizes of A spanning each codebook, at
(const int4*) codebook, // most 3 long.
prob_m,
prob_k,
codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long.
codebook_stride // as int4. codebook_stride // as int4.
); );
} }
// Dequantizes the code and codebook into weights. // Dequantizes the code and codebook into weights.
void code2x8_dequant_cuda( void code2x8_dequant_cuda(
const void* __restrict__ A, const void* __restrict__ A, void* __restrict__ C,
void* __restrict__ C, const void* __restrict__ codebook, int prob_m, int prob_k,
const void* __restrict__ codebook, const int4
int prob_m, codebook_a_sizes, // cumulative sizes of A spanning each codebook, at
int prob_k, // most 3 long, corresponds to cols.
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long, corresponds to cols.
const int codebook_stride // as int4 const int codebook_stride // as int4
) { ) {
int sms; int sms;
@ -451,50 +397,33 @@ void code2x8_dequant_cuda(
int shared = 16 * (2 * 256 * 8 + 32 * 9); int shared = 16 * (2 * 256 * 8 + 32 * 9);
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cudaFuncSetAttribute( cudaFuncSetAttribute(Code2x8Dequant,
Code2x8Dequant, cudaFuncAttributeMaxDynamicSharedMemorySize, shared cudaFuncAttributeMaxDynamicSharedMemorySize, shared);
);
Code2x8Dequant<<<blocks, threads, shared, stream>>>( Code2x8Dequant<<<blocks, threads, shared, stream>>>(
(const int4*) A, (const int4*)A, (int4*)C, (const int4*)codebook, prob_m, prob_k,
(int4*) C, codebook_a_sizes, codebook_stride);
(const int4*) codebook,
prob_m,
prob_k,
codebook_a_sizes,
codebook_stride
);
} }
int codebook_stride(const torch::Tensor& codebooks) int codebook_stride(const torch::Tensor& codebooks) {
{
return codebooks.stride(0) * codebooks.element_size() / sizeof(int4); return codebooks.stride(0) * codebooks.element_size() / sizeof(int4);
} }
void code1x16_matvec( void code1x16_matvec(
const torch::Tensor& A, const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& C,
const torch::Tensor& B,
torch::Tensor& C,
const torch::Tensor& codebook, const torch::Tensor& codebook,
const int4 codebook_a_sizes // cumulative sizes of A spanning each codebook, at most 3 long. const int4 codebook_a_sizes // cumulative sizes of A spanning each
// codebook, at most 3 long.
) { ) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
int prob_m = C.size(0); int prob_m = C.size(0);
int prob_k = B.size(0); int prob_k = B.size(0);
code1x16_matvec_cuda( code1x16_matvec_cuda(A.data_ptr(), B.data_ptr(), C.data_ptr(),
A.data_ptr(), codebook.data_ptr(), prob_m, prob_k, codebook_a_sizes,
B.data_ptr(), codebook_stride(codebook));
C.data_ptr(),
codebook.data_ptr(),
prob_m,
prob_k,
codebook_a_sizes,
codebook_stride(codebook)
);
} }
torch::Tensor code1x16_matmat( torch::Tensor code1x16_matmat(const torch::Tensor& input,
const torch::Tensor& input,
const torch::Tensor& codes, const torch::Tensor& codes,
const torch::Tensor& codebooks, const torch::Tensor& codebooks,
const torch::Tensor& scales, const torch::Tensor& scales,
@ -503,22 +432,15 @@ torch::Tensor code1x16_matmat(
auto input_sizes = input.sizes(); auto input_sizes = input.sizes();
auto out_features = codes.size(0) * codebooks.size(2); auto out_features = codes.size(0) * codebooks.size(2);
auto flat_input = input.reshape({-1, input.size(-1)}); auto flat_input = input.reshape({-1, input.size(-1)});
auto flat_output = torch::empty({flat_input.size(0), out_features}, auto flat_output = torch::empty(
torch::TensorOptions() {flat_input.size(0), out_features},
.dtype(input.dtype()) torch::TensorOptions().dtype(input.dtype()).device(input.device()));
.device(input.device())
);
for (int i = 0; i < flat_input.size(0); ++i) { for (int i = 0; i < flat_input.size(0); ++i) {
auto input_vec = flat_input.index({i}); auto input_vec = flat_input.index({i});
auto output_vec = flat_output.index({i}); auto output_vec = flat_output.index({i});
code1x16_matvec( code1x16_matvec(codes.squeeze(2), input_vec, output_vec, codebooks,
codes.squeeze(2), codebook_a_sizes);
input_vec,
output_vec,
codebooks,
codebook_a_sizes
);
} }
flat_output *= scales.flatten().unsqueeze(0); flat_output *= scales.flatten().unsqueeze(0);
@ -533,55 +455,35 @@ torch::Tensor code1x16_matmat(
return output; return output;
} }
void code2x8_matvec( void code2x8_matvec(const torch::Tensor& A, const torch::Tensor& B,
const torch::Tensor& A, torch::Tensor& C, const torch::Tensor& codebook,
const torch::Tensor& B, const int4 codebook_a_sizes) {
torch::Tensor& C,
const torch::Tensor& codebook,
const int4 codebook_a_sizes
) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
int prob_m = C.size(0); int prob_m = C.size(0);
int prob_k = B.size(0); int prob_k = B.size(0);
code2x8_matvec_cuda( code2x8_matvec_cuda(A.data_ptr(), B.data_ptr(), C.data_ptr(),
A.data_ptr(), codebook.data_ptr(), prob_m, prob_k, codebook_a_sizes,
B.data_ptr(), 2 * codebook_stride(codebook));
C.data_ptr(),
codebook.data_ptr(),
prob_m,
prob_k,
codebook_a_sizes,
2 * codebook_stride(codebook)
);
} }
torch::Tensor code2x8_matmat( torch::Tensor code2x8_matmat(const torch::Tensor& input,
const torch::Tensor& input,
const torch::Tensor& codes, const torch::Tensor& codes,
const torch::Tensor& codebooks, const torch::Tensor& codebooks,
const torch::Tensor& scales, const torch::Tensor& scales,
const int4 codebook_a_sizes, const int4 codebook_a_sizes,
const std::optional<torch::Tensor>& bias const std::optional<torch::Tensor>& bias) {
) {
auto input_sizes = input.sizes(); auto input_sizes = input.sizes();
auto out_features = codes.size(0) * codebooks.size(2); auto out_features = codes.size(0) * codebooks.size(2);
auto flat_input = input.reshape({-1, input.size(-1)}); auto flat_input = input.reshape({-1, input.size(-1)});
auto flat_output = torch::empty({flat_input.size(0), out_features}, auto flat_output = torch::empty(
torch::TensorOptions() {flat_input.size(0), out_features},
.dtype(input.dtype()) torch::TensorOptions().dtype(input.dtype()).device(input.device()));
.device(input.device())
);
for (int i = 0; i < flat_input.size(0); ++i) { for (int i = 0; i < flat_input.size(0); ++i) {
auto input_vec = flat_input.index({i}); auto input_vec = flat_input.index({i});
auto output_vec = flat_output.index({i}); auto output_vec = flat_output.index({i});
code2x8_matvec( code2x8_matvec(codes.squeeze(2), input_vec, output_vec, codebooks,
codes.squeeze(2), codebook_a_sizes);
input_vec,
output_vec,
codebooks,
codebook_a_sizes
);
} }
flat_output *= scales.flatten().unsqueeze(0); flat_output *= scales.flatten().unsqueeze(0);
if (bias.has_value()) { if (bias.has_value()) {
@ -596,21 +498,18 @@ torch::Tensor code2x8_matmat(
} }
// Accumulate the partition sizes. // Accumulate the partition sizes.
int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes) int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes) {
{
int4 cumulative_sizes; int4 cumulative_sizes;
auto cumulative_size = &cumulative_sizes.x; auto cumulative_size = &cumulative_sizes.x;
int i = 0; int i = 0;
int last = 0; int last = 0;
assert(codebook_partition_sizes.size(0) <= 4); assert(codebook_partition_sizes.size(0) <= 4);
for (; i < codebook_partition_sizes.size(0); ++i, ++cumulative_size) for (; i < codebook_partition_sizes.size(0); ++i, ++cumulative_size) {
{
*cumulative_size = codebook_partition_sizes[i].item<int>() + last; *cumulative_size = codebook_partition_sizes[i].item<int>() + last;
last = *cumulative_size; last = *cumulative_size;
} }
// fill in the rest with unreachable. // fill in the rest with unreachable.
for (; i < 4; ++i, ++cumulative_size) for (; i < 4; ++i, ++cumulative_size) {
{
*cumulative_size = last * 10; *cumulative_size = last * 10;
} }
return cumulative_sizes; return cumulative_sizes;
@ -619,41 +518,36 @@ int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes)
} // namespace aqlm } // namespace aqlm
} // namespace vllm } // namespace vllm
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
torch::Tensor aqlm_gemm(
const torch::Tensor& input,
const torch::Tensor& codes,
const torch::Tensor& codebooks, const torch::Tensor& codebooks,
const torch::Tensor& scales, const torch::Tensor& scales,
const torch::Tensor& codebook_partition_sizes, const torch::Tensor& codebook_partition_sizes,
const std::optional<torch::Tensor>& bias const std::optional<torch::Tensor>& bias) {
) int4 cumulative_sizes =
{ vllm::aqlm::accumulate_sizes(codebook_partition_sizes);
int4 cumulative_sizes = vllm::aqlm::accumulate_sizes(codebook_partition_sizes);
int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0); int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0);
int const entries = codebooks.size(1); int const entries = codebooks.size(1);
if (nbooks == 1 && entries == (1 << 16)) if (nbooks == 1 && entries == (1 << 16)) {
{ return vllm::aqlm::code1x16_matmat(input, codes, codebooks, scales,
return vllm::aqlm::code1x16_matmat(input, codes, codebooks, scales, cumulative_sizes, bias); cumulative_sizes, bias);
} }
if (nbooks == 2 && entries == (1 << 8)) if (nbooks == 2 && entries == (1 << 8)) {
{ return vllm::aqlm::code2x8_matmat(input, codes, codebooks, scales,
return vllm::aqlm::code2x8_matmat(input, codes, codebooks, scales, cumulative_sizes, bias); cumulative_sizes, bias);
} }
TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, " entries is not currently supported.") TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries,
" entries is not currently supported.")
return {}; return {};
} }
torch::Tensor aqlm_dequant( torch::Tensor aqlm_dequant(const torch::Tensor& codes,
const torch::Tensor& codes,
const torch::Tensor& codebooks, const torch::Tensor& codebooks,
const torch::Tensor& codebook_partition_sizes const torch::Tensor& codebook_partition_sizes) {
) int4 cumulative_sizes =
{ vllm::aqlm::accumulate_sizes(codebook_partition_sizes);
int4 cumulative_sizes = vllm::aqlm::accumulate_sizes(codebook_partition_sizes);
int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0); int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0);
int const entries = codebooks.size(1); int const entries = codebooks.size(1);
@ -670,43 +564,35 @@ torch::Tensor aqlm_dequant(
auto weights = torch::empty({out_features, in_features}, auto weights = torch::empty({out_features, in_features},
torch::TensorOptions() torch::TensorOptions()
.dtype(codebooks.dtype()) .dtype(codebooks.dtype())
.device(codebooks.device()) .device(codebooks.device()));
);
if (nbooks == 1 && entries == (1 << 16)) if (nbooks == 1 && entries == (1 << 16)) {
{ vllm::aqlm::code1x16_dequant_cuda(codes.data_ptr(), weights.data_ptr(),
vllm::aqlm::code1x16_dequant_cuda( codebooks.data_ptr(), out_features,
codes.data_ptr(), in_features, cumulative_sizes,
weights.data_ptr(),
codebooks.data_ptr(),
out_features,
in_features,
cumulative_sizes,
vllm::aqlm::codebook_stride(codebooks)); vllm::aqlm::codebook_stride(codebooks));
// if you wanted to flip to scaling the weights, (though it's 30%-ish slower and not consistent with gemv implementation.) // if you wanted to flip to scaling the weights, (though it's 30%-ish slower
// weights *= scales.index({"...", 0, 0}); // and not consistent with gemv implementation.) weights *=
// scales.index({"...", 0, 0});
return weights; return weights;
} }
if (nbooks == 2 && entries == (1 << 8)) if (nbooks == 2 && entries == (1 << 8)) {
{ vllm::aqlm::code2x8_dequant_cuda(codes.data_ptr(), weights.data_ptr(),
vllm::aqlm::code2x8_dequant_cuda( codebooks.data_ptr(), out_features,
codes.data_ptr(), in_features, cumulative_sizes,
weights.data_ptr(),
codebooks.data_ptr(),
out_features,
in_features,
cumulative_sizes,
vllm::aqlm::codebook_stride(codebooks)); vllm::aqlm::codebook_stride(codebooks));
// if you wanted to flip to scaling the weights, (though it's 30%-ish slower and not consistent with gemv implementation) // if you wanted to flip to scaling the weights, (though it's 30%-ish slower
// weights *= scales.index({"...", 0, 0}); // and not consistent with gemv implementation) weights *=
// scales.index({"...", 0, 0});
return weights; return weights;
} }
TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, " entries is not currently supported.") TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries,
" entries is not currently supported.")
return {}; return {};
} }

View File

@ -1,11 +1,11 @@
/* /*
Adapted from https://github.com/mit-han-lab/llm-awq Adapted from https://github.com/mit-han-lab/llm-awq
Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h Modified from NVIDIA FasterTransformer:
https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
@article{lin2023awq, @article{lin2023awq,
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration}, title={AWQ: Activation-aware Weight Quantization for LLM Compression and
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song}, Acceleration}, author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang,
journal={arXiv}, Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023}
year={2023}
} }
*/ */
@ -14,8 +14,7 @@ Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransfor
namespace vllm { namespace vllm {
namespace awq { namespace awq {
__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) {
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
assert(false); assert(false);
#else #else
@ -30,33 +29,40 @@ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
static constexpr uint32_t TOP_MASK = 0x00f000f0; static constexpr uint32_t TOP_MASK = 0x00f000f0;
static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400; static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
// Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing // Note that the entire sequence only requires 1 shift instruction. This is
// format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions. // thanks to the register packing format and the fact that we force our
// In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and // integers to be unsigned, and account for this in the fp16 subtractions. In
// elt_67 to fp16 without having to shift them to the bottom bits before hand. // addition, I exploit the fact that sub and fma have the same throughput in
// order to convert elt_23 and elt_67 to fp16 without having to shift them to
// the bottom bits before hand.
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW
// immediately before required. // dependency if we issue immediately before required.
const uint32_t top_i4s = i4s >> 8; const uint32_t top_i4s = i4s >> 8;
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[0]) : "=r"(h[0])
: "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM),
"n"(immLut));
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400 // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[1]) : "=r"(h[1])
: "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM),
"n"(immLut));
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400 // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[2]) : "=r"(h[2])
: "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM),
"n"(immLut));
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400 // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[3]) : "=r"(h[3])
: "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM),
"n"(immLut));
// I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the // I use inline PTX below because I am not sure if the compiler will emit
// half2 ctor. In this case, I chose performance reliability over code readability. // float2half instructions if I use the half2 ctor. In this case, I chose
// performance reliability over code readability.
// This is the half2 {1032, 1032} represented as an integer. // This is the half2 {1032, 1032} represented as an integer.
// static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408; // static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
@ -71,13 +77,21 @@ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
// Finally, we construct the output numbers. // Finally, we construct the output numbers.
// Convert elt_01 // Convert elt_01
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM)); asm volatile("sub.f16x2 %0, %1, %2;\n"
: "=r"(h[0])
: "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_23 // Convert elt_23
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
: "=r"(h[1])
: "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
// Convert elt_45 // Convert elt_45
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM)); asm volatile("sub.f16x2 %0, %1, %2;\n"
: "=r"(h[2])
: "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_67 // Convert elt_67
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
: "=r"(h[3])
: "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
return result; return result;
#endif #endif

View File

@ -1,14 +1,12 @@
/* /*
Adapted from https://github.com/mit-han-lab/llm-awq Adapted from https://github.com/mit-han-lab/llm-awq
@article{lin2023awq, @article{lin2023awq,
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration}, title={AWQ: Activation-aware Weight Quantization for LLM Compression and
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song}, Acceleration}, author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang,
journal={arXiv}, Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023}
year={2023}
} }
*/ */
#include <torch/extension.h> #include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
@ -20,26 +18,20 @@ namespace vllm {
namespace awq { namespace awq {
// Pack two half values. // Pack two half values.
static inline __device__ __host__ unsigned static inline __device__ __host__ unsigned __pack_half2(const half x,
__pack_half2(const half x, const half y) { const half y) {
unsigned v0 = *((unsigned short*)&x); unsigned v0 = *((unsigned short*)&x);
unsigned v1 = *((unsigned short*)&y); unsigned v1 = *((unsigned short*)&y);
return (v1 << 16) | v0; return (v1 << 16) | v0;
} }
template <int N> template <int N>
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32( __global__ void __launch_bounds__(64)
int G, gemm_forward_4bit_cuda_m16nXk32(int G, int split_k_iters,
int split_k_iters, half* __restrict__ A, int* __restrict__ B,
half* __restrict__ A,
int* __restrict__ B,
half* __restrict__ scaling_factors, half* __restrict__ scaling_factors,
int* __restrict__ zeros, int* __restrict__ zeros, int M, int IC,
int M, int OC, half* __restrict__ C) {
int IC,
int OC,
half* __restrict__ C)
{
// Only support matrix n = 64 or 128 // Only support matrix n = 64 or 128
assert(N == 64 || N == 128); assert(N == 64 || N == 128);
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
@ -70,43 +62,46 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32(
static constexpr int row_stride = 2 * 32 * 8 / N; static constexpr int row_stride = 2 * 32 * 8 / N;
bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < N; bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < N;
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id bool ld_A_flag =
(blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp +
threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id
// bool wb_C_flag = (threadIdx.x / 4) < M; // bool wb_C_flag = (threadIdx.x / 4) < M;
half* A_ptr = A half* A_ptr =
+ (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC A +
+ (((int)threadIdx.x) % (32 / 8)) * 8; (((int)blockIdx_y) / j_factors1 * 16 +
(((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) *
IC +
(((int)threadIdx.x) % (32 / 8)) * 8;
int* B_ptr = B int* B_ptr = B + ((int)threadIdx.y) * (OC / 8) * (256 / N) +
+ ((int)threadIdx.y) * (OC / 8) * (256 / N) (((int)threadIdx.x) / (N / 8)) * (OC / 8) +
+ (((int)threadIdx.x) / (N / 8)) * (OC / 8) (((int)blockIdx_y) % j_factors1) * (N / 8) +
+ (((int)blockIdx_y) % j_factors1) * (N / 8) (((int)threadIdx.x) % (N / 8)) * 1;
+ (((int)threadIdx.x) % (N / 8)) * 1;
// Why * 1 in the above line? // Why * 1 in the above line?
half* A_shared_ptr = A_shared half* A_shared_ptr = A_shared +
+ ((int)threadIdx.y) * row_stride_warp * (32 + 8) ((int)threadIdx.y) * row_stride_warp * (32 + 8) +
+ (((int)threadIdx.x) / (32 / 8)) * (32 + 8) (((int)threadIdx.x) / (32 / 8)) * (32 + 8) +
+ (((int)threadIdx.x) % (32 / 8) ) * 8; (((int)threadIdx.x) % (32 / 8)) * 8;
half* B_shared_ptr = B_shared half* B_shared_ptr = B_shared +
+ ((int)threadIdx.y) * (row_stride / 2) * (N + 8) ((int)threadIdx.y) * (row_stride / 2) * (N + 8) +
+ (((int)threadIdx.x) / (N / 8)) * (N + 8) (((int)threadIdx.x) / (N / 8)) * (N + 8) +
+ (((int)threadIdx.x) % (N / 8)) * 8; (((int)threadIdx.x) % (N / 8)) * 8;
int* zeros_ptr = zeros int* zeros_ptr = zeros + (((int)blockIdx_y) % j_factors1) * (N / 8) +
+ (((int)blockIdx_y) % j_factors1) * (N / 8) ((int)threadIdx.x) % (N / 8);
+ ((int)threadIdx.x) % (N / 8);
half* scaling_factors_ptr = scaling_factors half* scaling_factors_ptr = scaling_factors +
+ (((int)blockIdx_y) % j_factors1) * N (((int)blockIdx_y) % j_factors1) * N +
+ (((int)threadIdx.x) % (N / 8)) * 8; (((int)threadIdx.x) % (N / 8)) * 8;
half* C_ptr = C half* C_ptr =
+ static_cast<long long>(blockIdx_z) * M * OC // blockIdz.x -> split_k dim C +
+ (((int)blockIdx_y) % j_factors1) * N static_cast<long long>(blockIdx_z) * M * OC // blockIdz.x -> split_k dim
+ ((int)threadIdx.y) * (N / 2) + (((int)blockIdx_y) % j_factors1) * N + ((int)threadIdx.y) * (N / 2) +
+ (((int)threadIdx.x) % 4) * 2; (((int)threadIdx.x) % 4) * 2;
// preload s.f. and zeros // preload s.f. and zeros
int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters; int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
@ -115,57 +110,83 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32(
int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z; int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
__syncthreads(); __syncthreads();
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
if (ld_A_flag) if (ld_A_flag) {
{
*(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32)); *(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
} } else {
else
{
*(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0); *(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
} }
// for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) { // for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8)); uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded); uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC)); uint4 B_loaded_scale =
*(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
/* /*
if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){ if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 &&
printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w); threadIdx.y == 0){ printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x,
B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x,
B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
} }
*/ */
// uint4 B_loaded_scale = make_uint4(0, 0, 0, 0); // uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8); int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < N / 16; ++ax0_ax1_fused_0) { for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < N / 16; ++ax0_ax1_fused_0) {
// B: 32 x 136 (128+8) float16 // B: 32 x 136 (128+8) float16
// each warp: 32 x 4 // each warp: 32 x 4
// each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4 // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus
// *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8))); // zero -> WB UINT4
// row stride in shared memory: (NWARPS * 32 * 8 / cta_N) // *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) *
uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8)); // 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15)
// * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 *
// 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) *
// 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) *
// 8))); row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
uint32_t B_loaded =
*(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
//uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8); // uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N /
// 8)) * 8);
// uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8); // uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x
// % (cta_N / 8)) * 8);
// - zero and * scale // - zero and * scale
// TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale. // TODO (Haotian): can save 4 assembly instructions if sormulate as deq =
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); // q * scale - zero * scale.
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); asm volatile("sub.f16x2 %0, %1, %2;\n"
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); : "=r"(B_loaded_fp16.x)
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); : "=r"(B_loaded_fp16.x)
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); asm volatile("sub.f16x2 %0, %1, %2;\n"
: "=r"(B_loaded_fp16.y)
: "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
: "=r"(B_loaded_fp16.y)
: "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n"
: "=r"(B_loaded_fp16.z)
: "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
: "=r"(B_loaded_fp16.z)
: "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n"
: "=r"(B_loaded_fp16.w)
: "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
: "=r"(B_loaded_fp16.w)
: "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
/* /*
if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){ if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 ==
printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w); 0 && threadIdx.x == 17 && threadIdx.y == 0){ printf("[x] %X %X %X %X\n",
B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
} }
*/ */
// write back // write back
*(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) = B_loaded_fp16; *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) =
B_loaded_fp16;
} }
__syncthreads(); __syncthreads();
@ -173,34 +194,43 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32(
{ {
unsigned int addr; unsigned int addr;
__asm__ __volatile__( __asm__ __volatile__(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, "
"addr; }\n"
: "=r"(addr) : "=r"(addr)
: "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8)))) : "l"((void*)((&(A_shared[(k_0_1 * 16)])) +
); (((((int)threadIdx.x) & 15) * 40) +
((((int)threadIdx.x) >> 4) * 8)))));
__asm__ __volatile__( __asm__ __volatile__(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16" "ldmatrix.sync.aligned.m8n8.x4.shared.b16"
"{%0, %1, %2, %3}, [%4];\n" "{%0, %1, %2, %3}, [%4];\n"
: "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3]) : "=r"(((unsigned*)(A_shared_warp + 0))[0]),
: "r"(addr) "=r"(((unsigned*)(A_shared_warp + 0))[1]),
); "=r"(((unsigned*)(A_shared_warp + 0))[2]),
"=r"(((unsigned*)(A_shared_warp + 0))[3])
: "r"(addr));
} }
for (int ax1_0 = 0; ax1_0 < N / 32; ++ax1_0) { for (int ax1_0 = 0; ax1_0 < N / 32; ++ax1_0) {
{ {
unsigned int addr; unsigned int addr;
__asm__ __volatile__( __asm__ __volatile__(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, "
"addr; }\n"
: "=r"(addr) : "=r"(addr)
: "l"((void *)((&(B_shared[(((k_0_1 * (N * 16 + 128)) + (((int)threadIdx.y) * (N / 2))) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * (N + 8)) + ((((int)threadIdx.x) >> 4) * 8)))) : "l"((void*)((&(B_shared[(((k_0_1 * (N * 16 + 128)) +
); (((int)threadIdx.y) * (N / 2))) +
(ax1_0 * 16))])) +
(((((int)threadIdx.x) & 15) * (N + 8)) +
((((int)threadIdx.x) >> 4) * 8)))));
__asm__ __volatile__( __asm__ __volatile__(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
"{%0, %1, %2, %3}, [%4];\n" "{%0, %1, %2, %3}, [%4];\n"
: "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3]) : "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[0]),
: "r"(addr) "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[1]),
); "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[2]),
"=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[3])
: "r"(addr));
} }
} }
for (int j_0_4 = 0; j_0_4 < N / 32; ++j_0_4) { for (int j_0_4 = 0; j_0_4 < N / 32; ++j_0_4) {
@ -209,48 +239,110 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32(
__asm__ __volatile__( __asm__ __volatile__(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
"=f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
"=f"(((float*)(C_warp + (j_0_4 * 8)))[3])
: "r"(((unsigned*)(A_shared_warp + 0))[0]),
"r"(((unsigned*)(A_shared_warp + 0))[1]),
"r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[0]),
"f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
"f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
"f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
"f"(((float*)(C_warp + (j_0_4 * 8)))[3]));
} }
{ {
__asm__ __volatile__( __asm__ __volatile__(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])
: "r"(((unsigned*)(A_shared_warp + 0))[0]),
"r"(((unsigned*)(A_shared_warp + 0))[1]),
"r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]),
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]));
} }
{ {
__asm__ __volatile__( __asm__ __volatile__(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
: "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
"=f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
"=f"(((float*)(C_warp + (j_0_4 * 8)))[3])
: "r"(((unsigned*)(A_shared_warp + 0))[2]),
"r"(((unsigned*)(A_shared_warp + 0))[3]),
"r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[1]),
"f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
"f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
"f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
"f"(((float*)(C_warp + (j_0_4 * 8)))[3]));
} }
{ {
__asm__ __volatile__( __asm__ __volatile__(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
: "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])
: "r"(((unsigned*)(A_shared_warp + 0))[2]),
"r"(((unsigned*)(A_shared_warp + 0))[3]),
"r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]),
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]));
} }
#else #else
{ {
__asm__ __volatile__( __asm__ __volatile__(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, "
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) "%13};\n"
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
"=f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
"=f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
"=f"(((float*)(C_warp + (j_0_4 * 8)))[3])
: "r"(((unsigned*)(A_shared_warp + 0))[0]),
"r"(((unsigned*)(A_shared_warp + 0))[1]),
"r"(((unsigned*)(A_shared_warp + 0))[2]),
"r"(((unsigned*)(A_shared_warp + 0))[3]),
"r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[0]),
"r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[1]),
"f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
"f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
"f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
"f"(((float*)(C_warp + (j_0_4 * 8)))[3]));
} }
{ {
__asm__ __volatile__( __asm__ __volatile__(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, "
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) "%13};\n"
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])
: "r"(((unsigned*)(A_shared_warp + 0))[0]),
"r"(((unsigned*)(A_shared_warp + 0))[1]),
"r"(((unsigned*)(A_shared_warp + 0))[2]),
"r"(((unsigned*)(A_shared_warp + 0))[3]),
"r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]),
"r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]),
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]));
} }
#endif #endif
@ -261,24 +353,20 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32(
// TODO: Shang: Hoist loop invariance. // TODO: Shang: Hoist loop invariance.
for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) { for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) {
for (int local_id = 0; local_id < 8; ++local_id) { for (int local_id = 0; local_id < 8; ++local_id) {
int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8; int row_offset = (((int)blockIdx_y) / j_factors1) * 16 +
if (row_offset < M) ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
{ if (row_offset < M) {
*(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]); *(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 +
local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
} }
} }
} }
#endif #endif
} }
__global__ void __launch_bounds__(64) dequantize_weights( __global__ void __launch_bounds__(64)
int* __restrict__ B, dequantize_weights(int* __restrict__ B, half* __restrict__ scaling_factors,
half* __restrict__ scaling_factors, int* __restrict__ zeros, half* __restrict__ C, int G) {
int* __restrict__ zeros,
half* __restrict__ C,
int G
)
{
int j_factors1 = 4; int j_factors1 = 4;
int row_stride2 = 4; int row_stride2 = 4;
int split_k_iters = 1; int split_k_iters = 1;
@ -310,14 +398,30 @@ __global__ void __launch_bounds__(64) dequantize_weights(
uint32_t B_loaded = *(uint32_t*)B_ptr2; uint32_t B_loaded = *(uint32_t*)B_ptr2;
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); asm volatile("sub.f16x2 %0, %1, %2;\n"
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); : "=r"(B_loaded_fp16.x)
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); : "=r"(B_loaded_fp16.x)
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); asm volatile("sub.f16x2 %0, %1, %2;\n"
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); : "=r"(B_loaded_fp16.y)
: "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
: "=r"(B_loaded_fp16.y)
: "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n"
: "=r"(B_loaded_fp16.z)
: "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
: "=r"(B_loaded_fp16.z)
: "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n"
: "=r"(B_loaded_fp16.w)
: "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
: "=r"(B_loaded_fp16.w)
: "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
*(uint4*)B_shared_ptr2 = B_loaded_fp16; *(uint4*)B_shared_ptr2 = B_loaded_fp16;
@ -329,14 +433,10 @@ __global__ void __launch_bounds__(64) dequantize_weights(
} // namespace awq } // namespace awq
} // namespace vllm } // namespace vllm
torch::Tensor awq_dequantize( torch::Tensor awq_dequantize(torch::Tensor _kernel,
torch::Tensor _kernel,
torch::Tensor _scaling_factors, torch::Tensor _scaling_factors,
torch::Tensor _zeros, torch::Tensor _zeros, int split_k_iters, int thx,
int split_k_iters, int thy) {
int thx,
int thy)
{
int in_c = _kernel.size(0); int in_c = _kernel.size(0);
int qout_c = _kernel.size(1); int qout_c = _kernel.size(1);
int out_c = qout_c * 8; int out_c = qout_c * 8;
@ -362,12 +462,15 @@ torch::Tensor awq_dequantize(
const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors)); const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors));
auto options = torch::TensorOptions().dtype(_scaling_factors.dtype()).device(_scaling_factors.device()); auto options = torch::TensorOptions()
.dtype(_scaling_factors.dtype())
.device(_scaling_factors.device());
at::Tensor _de_kernel = torch::empty({in_c, out_c}, options); at::Tensor _de_kernel = torch::empty({in_c, out_c}, options);
auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>()); auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
auto de_kernel = reinterpret_cast<half*>(_de_kernel.data_ptr<at::Half>()); auto de_kernel = reinterpret_cast<half*>(_de_kernel.data_ptr<at::Half>());
auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>()); auto scaling_factors =
reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>()); auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
dim3 num_blocks(x_blocks, y_blocks); dim3 num_blocks(x_blocks, y_blocks);
@ -386,26 +489,26 @@ torch::Tensor awq_dequantize(
// zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b] // zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b]
// assume that batch_size < 16 for now // assume that batch_size < 16 for now
torch::Tensor awq_gemm( torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
torch::Tensor _in_feats, torch::Tensor _scaling_factors, torch::Tensor _zeros,
torch::Tensor _kernel, int split_k_iters) {
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
int split_k_iters)
{
int num_in_feats = _in_feats.size(0); int num_in_feats = _in_feats.size(0);
int num_in_channels = _in_feats.size(1); int num_in_channels = _in_feats.size(1);
const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats)); const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats));
auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device()); auto options = torch::TensorOptions()
at::Tensor _out_feats = torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options); .dtype(_in_feats.dtype())
.device(_in_feats.device());
at::Tensor _out_feats =
torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options);
int num_out_feats = _out_feats.size(-2); int num_out_feats = _out_feats.size(-2);
int num_out_channels = _out_feats.size(-1); int num_out_channels = _out_feats.size(-1);
auto in_feats = reinterpret_cast<half*>(_in_feats.data_ptr<at::Half>()); auto in_feats = reinterpret_cast<half*>(_in_feats.data_ptr<at::Half>());
auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>()); auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>()); auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());
auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>()); auto scaling_factors =
reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>()); auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
int group_size = num_in_channels / _scaling_factors.size(0); int group_size = num_in_channels / _scaling_factors.size(0);
@ -419,28 +522,28 @@ torch::Tensor awq_gemm(
throw std::invalid_argument("OC is not multiple of Group size"); throw std::invalid_argument("OC is not multiple of Group size");
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (num_out_channels % 128 == 0) if (num_out_channels % 128 == 0) {
{
int j_factors1 = num_out_channels / 128 / 1; int j_factors1 = num_out_channels / 128 / 1;
dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
// threadIdx.x: 32 // threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2] // threadIdx.y: i_factors[2] * j_factors[2]
dim3 threads_per_block(32, 2); dim3 threads_per_block(32, 2);
vllm::awq::gemm_forward_4bit_cuda_m16nXk32<128><<<num_blocks, threads_per_block, 0, stream>>>( vllm::awq::gemm_forward_4bit_cuda_m16nXk32<128>
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, <<<num_blocks, threads_per_block, 0, stream>>>(
num_out_channels, out_feats); group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros,
} num_in_feats, num_in_channels, num_out_channels, out_feats);
else if (num_out_channels % 64 == 0) } else if (num_out_channels % 64 == 0) {
{
int j_factors1 = num_out_channels / 64 / 1; int j_factors1 = num_out_channels / 64 / 1;
dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 *
split_k_iters);
// threadIdx.x: 32 // threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2] // threadIdx.y: i_factors[2] * j_factors[2]
dim3 threads_per_block(32, 2); dim3 threads_per_block(32, 2);
vllm::awq::gemm_forward_4bit_cuda_m16nXk32<64><<<num_blocks, threads_per_block, 0, stream>>>( vllm::awq::gemm_forward_4bit_cuda_m16nXk32<64>
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, <<<num_blocks, threads_per_block, 0, stream>>>(
num_out_channels, out_feats); group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros,
num_in_feats, num_in_channels, num_out_channels, out_feats);
} }
return _out_feats.sum(0); return _out_feats.sum(0);
} }

View File

@ -43,7 +43,8 @@ void cutlass_scaled_mm_dq(torch::Tensor &c, torch::Tensor const &a,
// Check for strides and alignment // Check for strides and alignment
TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
TORCH_CHECK(b.stride(0) == 1); // Column-major TORCH_CHECK(b.stride(0) == 1); // Column-major
TORCH_CHECK(c.stride(0) % 16 == 0 && b.stride(1) % 16 == 0); // 16 Byte Alignment TORCH_CHECK(c.stride(0) % 16 == 0 &&
b.stride(1) % 16 == 0); // 16 Byte Alignment
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
at::cuda::OptionalCUDAGuard const device_guard(device_of(a)); at::cuda::OptionalCUDAGuard const device_guard(device_of(a));

View File

@ -11,33 +11,26 @@
#include "hip_float8_impl.h" #include "hip_float8_impl.h"
struct alignas(1) hip_fp8 struct alignas(1) hip_fp8 {
{ struct from_bits_t {};
struct from_bits_t HIP_FP8_HOST_DEVICE static constexpr from_bits_t from_bits() {
{ return from_bits_t();
}; }
HIP_FP8_HOST_DEVICE static constexpr from_bits_t from_bits() { return from_bits_t(); }
uint8_t data; uint8_t data;
hip_fp8() = default; hip_fp8() = default;
HIP_FP8_HOST_DEVICE constexpr hip_fp8(const hip_fp8&) = default; HIP_FP8_HOST_DEVICE constexpr hip_fp8(const hip_fp8&) = default;
HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v) = delete; HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v) = delete;
explicit HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v, from_bits_t) explicit HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v, from_bits_t)
: data(v) : data(v) {}
{
}
#ifdef __HIP__MI300__ #ifdef __HIP__MI300__
// NOTE: ON-DEVICE... always optimal bias // NOTE: ON-DEVICE... always optimal bias
explicit HIP_FP8_DEVICE hip_fp8(float v) explicit HIP_FP8_DEVICE hip_fp8(float v)
: data(hip_fp8_impl::to_fp8_from_fp32(v)) : data(hip_fp8_impl::to_fp8_from_fp32(v)) {}
{
}
explicit HIP_FP8_DEVICE hip_fp8(_Float16 v) explicit HIP_FP8_DEVICE hip_fp8(_Float16 v)
: hip_fp8(static_cast<float>(v)) : hip_fp8(static_cast<float>(v)) {}
{
}
// Host only implementation using s/w simulation // Host only implementation using s/w simulation
explicit HIP_FP8_HOST explicit HIP_FP8_HOST
@ -45,25 +38,24 @@ struct alignas(1) hip_fp8
// both Host and DEVICE for non-MI300 using s/w simulation // both Host and DEVICE for non-MI300 using s/w simulation
explicit HIP_FP8_HOST_DEVICE explicit HIP_FP8_HOST_DEVICE
#endif // __HIP__MI300__ #endif // __HIP__MI300__
hip_fp8(float v) hip_fp8(float v) {
{ data = hip_fp8_impl::to_float8<4, 3, float, true /*negative_zero_nan*/,
data = hip_fp8_impl::to_float8<4, 3, float, true /*negative_zero_nan*/, true /*clip*/>(v); true /*clip*/>(v);
} }
explicit HIP_FP8_HOST_DEVICE hip_fp8(double v) explicit HIP_FP8_HOST_DEVICE hip_fp8(double v)
: hip_fp8(static_cast<float>(v)) : hip_fp8(static_cast<float>(v)) {}
{
}
#ifdef __HIP__MI300__ #ifdef __HIP__MI300__
// upcast using device specific intrinsic // upcast using device specific intrinsic
explicit inline HIP_FP8_DEVICE operator float() const explicit inline HIP_FP8_DEVICE operator float() const {
{
float fval; float fval;
uint32_t i32val = static_cast<uint32_t>(data); uint32_t i32val = static_cast<uint32_t>(data);
// upcast // upcast
asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val)); asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0"
: "=v"(fval)
: "v"(i32val));
return fval; return fval;
} }
@ -73,95 +65,73 @@ struct alignas(1) hip_fp8
explicit inline HIP_FP8_HOST_DEVICE operator float() const explicit inline HIP_FP8_HOST_DEVICE operator float() const
#endif // __HIP__MI300__ #endif // __HIP__MI300__
{ {
return hip_fp8_impl::from_float8<4, 3, float, true /*negative_zero_nan*/>(data); return hip_fp8_impl::from_float8<4, 3, float, true /*negative_zero_nan*/>(
data);
} }
}; };
namespace std namespace std {
{ inline hip_fp8 sin(hip_fp8 a) { return hip_fp8(sinf(float(a))); }
inline hip_fp8 sin(hip_fp8 a) inline hip_fp8 cos(hip_fp8 a) { return hip_fp8(cosf(float(a))); }
{ HIP_FP8_HOST_DEVICE constexpr hip_fp8 real(const hip_fp8& a) { return a; }
return hip_fp8(sinf(float(a)));
}
inline hip_fp8 cos(hip_fp8 a)
{
return hip_fp8(cosf(float(a)));
}
HIP_FP8_HOST_DEVICE constexpr hip_fp8 real(const hip_fp8& a)
{
return a;
}
} // namespace std } // namespace std
// Special operator overloading // Special operator overloading
inline std::ostream& operator<<(std::ostream& os, const hip_fp8& f8) inline std::ostream& operator<<(std::ostream& os, const hip_fp8& f8) {
{
return os << float(f8); return os << float(f8);
} }
// all + operator overloading with mixed types // all + operator overloading with mixed types
// mixed types, always converts to f32, does computation in f32, and returns float // mixed types, always converts to f32, does computation in f32, and returns
inline HIP_FP8_HOST_DEVICE float operator+(const float fa, hip_fp8 b) // float
{ inline HIP_FP8_HOST_DEVICE float operator+(const float fa, hip_fp8 b) {
return (fa + float(b)); return (fa + float(b));
} }
inline HIP_FP8_HOST_DEVICE float operator+(hip_fp8 a, const float fb) inline HIP_FP8_HOST_DEVICE float operator+(hip_fp8 a, const float fb) {
{
return (float(a) + fb); return (float(a) + fb);
} }
inline HIP_FP8_HOST_DEVICE hip_fp8 operator+(hip_fp8 a, hip_fp8 b) inline HIP_FP8_HOST_DEVICE hip_fp8 operator+(hip_fp8 a, hip_fp8 b) {
{
return hip_fp8(float(a) + float(b)); return hip_fp8(float(a) + float(b));
} }
inline HIP_FP8_HOST_DEVICE hip_fp8& operator+=(hip_fp8& a, hip_fp8 b) inline HIP_FP8_HOST_DEVICE hip_fp8& operator+=(hip_fp8& a, hip_fp8 b) {
{
return a = hip_fp8(float(a) + float(b)); return a = hip_fp8(float(a) + float(b));
} }
// overloading multiplication, always returns float, // overloading multiplication, always returns float,
inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, hip_fp8 b) inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, hip_fp8 b) {
{
return float(a) * float(b); return float(a) * float(b);
} }
inline HIP_FP8_HOST_DEVICE float operator*(float a, hip_fp8 b) inline HIP_FP8_HOST_DEVICE float operator*(float a, hip_fp8 b) {
{
return (a * float(b)); return (a * float(b));
} }
inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, float b) inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, float b) {
{
return (float(a) * b); return (float(a) * b);
} }
inline HIP_FP8_HOST_DEVICE float operator*(int32_t a, hip_fp8 b) inline HIP_FP8_HOST_DEVICE float operator*(int32_t a, hip_fp8 b) {
{
return ((float)a * float(b)); return ((float)a * float(b));
} }
inline HIP_FP8_HOST_DEVICE float operator*(double a, hip_fp8 b) inline HIP_FP8_HOST_DEVICE float operator*(double a, hip_fp8 b) {
{
return ((float)a * float(b)); return ((float)a * float(b));
} }
// overloading for compare // overloading for compare
inline HIP_FP8_HOST_DEVICE bool operator==(hip_fp8 a, hip_fp8 b) inline HIP_FP8_HOST_DEVICE bool operator==(hip_fp8 a, hip_fp8 b) {
{
return (a.data == b.data); return (a.data == b.data);
} }
inline HIP_FP8_HOST_DEVICE bool operator!=(hip_fp8 a, hip_fp8 b) inline HIP_FP8_HOST_DEVICE bool operator!=(hip_fp8 a, hip_fp8 b) {
{
return (a.data != b.data); return (a.data != b.data);
} }
inline HIP_FP8_HOST_DEVICE bool operator>=(hip_fp8 a, hip_fp8 b) inline HIP_FP8_HOST_DEVICE bool operator>=(hip_fp8 a, hip_fp8 b) {
{
return static_cast<float>(a) >= static_cast<float>(b); return static_cast<float>(a) >= static_cast<float>(b);
} }
inline HIP_FP8_HOST_DEVICE bool operator>(hip_fp8 a, hip_fp8 b) inline HIP_FP8_HOST_DEVICE bool operator>(hip_fp8 a, hip_fp8 b) {
{
return static_cast<float>(a) > static_cast<float>(b); return static_cast<float>(a) > static_cast<float>(b);
} }

View File

@ -1,6 +1,7 @@
#pragma once #pragma once
#if defined(__HIPCC__) && (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) #if defined(__HIPCC__) && \
(defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
#define __HIP__MI300__ #define __HIP__MI300__
#endif #endif
@ -14,12 +15,10 @@
#define HIP_FP8_DEVICE #define HIP_FP8_DEVICE
#endif #endif
namespace hip_fp8_impl namespace hip_fp8_impl {
{
#ifdef __HIP__MI300__ #ifdef __HIP__MI300__
HIP_FP8_DEVICE uint8_t to_fp8_from_fp32(float v) HIP_FP8_DEVICE uint8_t to_fp8_from_fp32(float v) {
{
uint8_t i8data; uint8_t i8data;
union { union {
float fval; float fval;
@ -30,7 +29,8 @@ HIP_FP8_DEVICE uint8_t to_fp8_from_fp32(float v)
uint32_t ival = 0; uint32_t ival = 0;
val.fval = v; val.fval = v;
if ((val.i32val & 0x7F800000) != 0x7F800000) { /// propagate NAN/INF, no clipping if ((val.i32val & 0x7F800000) !=
0x7F800000) { /// propagate NAN/INF, no clipping
val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0); val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0);
} }
@ -43,20 +43,14 @@ HIP_FP8_DEVICE uint8_t to_fp8_from_fp32(float v)
} }
#endif // __HIP__MI300__ #endif // __HIP__MI300__
HIP_FP8_HOST inline int clz(uint32_t x) HIP_FP8_HOST inline int clz(uint32_t x) { return __builtin_clz(x); }
{
return __builtin_clz(x);
}
#if defined(__HIPCC__) || defined(__CUDA_ARCH__) #if defined(__HIPCC__) || defined(__CUDA_ARCH__)
HIP_FP8_DEVICE inline int clz(uint32_t x) HIP_FP8_DEVICE inline int clz(uint32_t x) { return __clz(x); }
{
return __clz(x);
}
#endif #endif
template <int we, int wm, typename T, bool negative_zero_nan, bool clip> template <int we, int wm, typename T, bool negative_zero_nan, bool clip>
HIP_FP8_HOST_DEVICE uint8_t to_float8(T _x, bool stoch = false, uint32_t rng = 0) HIP_FP8_HOST_DEVICE uint8_t to_float8(T _x, bool stoch = false,
{ uint32_t rng = 0) {
#ifdef __HIPCC__ #ifdef __HIPCC__
constexpr bool is_half = std::is_same<T, _Float16>::value; constexpr bool is_half = std::is_same<T, _Float16>::value;
#else #else
@ -130,7 +124,8 @@ HIP_FP8_HOST_DEVICE uint8_t to_float8(T _x, bool stoch = false, uint32_t rng = 0
// For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent // For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent
// bits // bits
const int f8_bias = (1 << (we - 1)) - 1 + (negative_zero_nan ? 1 : 0); const int f8_bias = (1 << (we - 1)) - 1 + (negative_zero_nan ? 1 : 0);
const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal const int f8_denormal_act_exponent =
1 - f8_bias; // actual exponent of f8 denormal
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias) // act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
// f8_exponent is the converted f8 exponent with bias encoding // f8_exponent is the converted f8 exponent with bias encoding
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent, // exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
@ -146,7 +141,9 @@ are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8
(NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */ (NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */
act_exponent = exponent - bias + 1; act_exponent = exponent - bias + 1;
exponent_diff = f8_denormal_act_exponent - act_exponent; // actual exponent is exponent-bias+1 as it is denormal exponent_diff =
f8_denormal_act_exponent -
act_exponent; // actual exponent is exponent-bias+1 as it is denormal
} else { // fp32/fp16 is normal with implicit 1 } else { // fp32/fp16 is normal with implicit 1
act_exponent = exponent - bias; act_exponent = exponent - bias;
if (act_exponent <= f8_denormal_act_exponent) { if (act_exponent <= f8_denormal_act_exponent) {
@ -157,9 +154,9 @@ where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */ So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
exponent_diff = f8_denormal_act_exponent - act_exponent; exponent_diff = f8_denormal_act_exponent - act_exponent;
} else { // both fp32/fp16 and f8 are in normal range } else { // both fp32/fp16 and f8 are in normal range
exponent_diff = 0; // exponent_diff=0 does not mean there is no difference exponent_diff = 0; // exponent_diff=0 does not mean there is no
// for this case, // difference for this case, act_exponent could be
// act_exponent could be larger. Just that it does not need shift mantissa // larger. Just that it does not need shift mantissa
} }
mantissa += (1 << mfmt); // Add the implicit 1 into mantissa mantissa += (1 << mfmt); // Add the implicit 1 into mantissa
} }
@ -181,13 +178,16 @@ where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8
bool implicit_one = mantissa & (1 << mfmt); bool implicit_one = mantissa & (1 << mfmt);
// if there is no implicit 1, it means the f8 is denormal and need to adjust // if there is no implicit 1, it means the f8 is denormal and need to adjust
// to denorm exponent // to denorm exponent
f8_exponent = (act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1); f8_exponent = (act_exponent + exponent_diff) /*actual f8 exponent*/ +
f8_bias - (implicit_one ? 0 : 1);
// Now we have the exponent and mantissa adjusted // Now we have the exponent and mantissa adjusted
uint32_t drop_mask = (1 << (mfmt - wm)) - 1; uint32_t drop_mask = (1 << (mfmt - wm)) - 1;
bool odd = mantissa & (1 << (mfmt - wm)); // if the least significant bit that bool odd = mantissa & (1 << (mfmt - wm)); // if the least significant bit
// is not truncated is 1 // that is not truncated is 1
mantissa += (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & drop_mask; mantissa +=
(stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) &
drop_mask;
// Now we deal with overflow // Now we deal with overflow
if (f8_exponent == 0) { if (f8_exponent == 0) {
@ -222,8 +222,7 @@ where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8
} }
template <int we, int wm, typename T = float, bool negative_zero_nan = true> template <int we, int wm, typename T = float, bool negative_zero_nan = true>
inline HIP_FP8_HOST_DEVICE T from_float8(uint8_t x) inline HIP_FP8_HOST_DEVICE T from_float8(uint8_t x) {
{
#ifdef __HIPCC__ #ifdef __HIPCC__
constexpr bool is_half = std::is_same<T, _Float16>::value; constexpr bool is_half = std::is_same<T, _Float16>::value;
#else #else
@ -285,7 +284,8 @@ inline HIP_FP8_HOST_DEVICE T from_float8(uint8_t x)
return reinterpret_cast<const T&>(retval); return reinterpret_cast<const T&>(retval);
} }
const int exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0); const int exp_low_cutoff =
(1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0);
// subnormal input // subnormal input
if (exponent == 0) { if (exponent == 0) {

View File

@ -9,29 +9,27 @@
#include "../../../attention/dtype_float32.cuh" #include "../../../attention/dtype_float32.cuh"
#include "../../../attention/dtype_bfloat16.cuh" #include "../../../attention/dtype_bfloat16.cuh"
namespace vllm namespace vllm {
{
#ifdef USE_ROCM #ifdef USE_ROCM
namespace fp8 { namespace fp8 {
#ifdef ENABLE_FP8 #ifdef ENABLE_FP8
template <typename Tout, typename Tin> template <typename Tout, typename Tin>
__inline__ __device__ Tout vec_conversion(const Tin& x) __inline__ __device__ Tout vec_conversion(const Tin& x) {
{
return x; return x;
} }
template <typename Tout, typename Tin> template <typename Tout, typename Tin>
__inline__ __device__ Tout scaled_vec_conversion(const Tin& x, const float scale) __inline__ __device__ Tout scaled_vec_conversion(const Tin& x,
{ const float scale) {
return x; return x;
} }
// fp8 -> half // fp8 -> half
template <> template <>
__inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>(const uint8_t& a) __inline__ __device__ uint16_t
{ vec_conversion<uint16_t, uint8_t>(const uint8_t& a) {
hip_fp8 f8{a, hip_fp8::from_bits()}; hip_fp8 f8{a, hip_fp8::from_bits()};
__half_raw res; __half_raw res;
res.data = static_cast<float>(f8); res.data = static_cast<float>(f8);
@ -40,9 +38,10 @@ __inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>(const uint8_t&
// fp8x2 -> half2 // fp8x2 -> half2
template <> template <>
__inline__ __device__ uint32_t vec_conversion<uint32_t, uint16_t>(const uint16_t& a) __inline__ __device__ uint32_t
{ vec_conversion<uint32_t, uint16_t>(const uint16_t& a) {
#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) #if defined(__HIP__MI300__) && \
defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
union { union {
__half2_raw h2r; __half2_raw h2r;
@ -65,8 +64,7 @@ __inline__ __device__ uint32_t vec_conversion<uint32_t, uint16_t>(const uint16_t
// fp8x4 -> half2x2 // fp8x4 -> half2x2
template <> template <>
__inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a) __inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a) {
{
union { union {
uint2 u32x2; uint2 u32x2;
uint32_t u32[2]; uint32_t u32[2];
@ -78,8 +76,7 @@ __inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a)
// fp8x8 -> half2x4 // fp8x8 -> half2x4
template <> template <>
__inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a) __inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a) {
{
union { union {
uint4 u64x2; uint4 u64x2;
uint2 u64[2]; uint2 u64[2];
@ -93,8 +90,8 @@ using __nv_bfloat16 = __hip_bfloat16;
// fp8 -> __nv_bfloat16 // fp8 -> __nv_bfloat16
template <> template <>
__inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) __inline__ __device__ __nv_bfloat16
{ vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) {
hip_fp8 f8{a, hip_fp8::from_bits()}; hip_fp8 f8{a, hip_fp8::from_bits()};
float f{f8}; float f{f8};
return __float2bfloat16(f); return __float2bfloat16(f);
@ -104,8 +101,8 @@ using __nv_bfloat162 = __hip_bfloat162;
// fp8x2 -> __nv_bfloat162 // fp8x2 -> __nv_bfloat162
template <> template <>
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a) __inline__ __device__ __nv_bfloat162
{ vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a) {
__nv_bfloat162 res; __nv_bfloat162 res;
res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a); res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a);
res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U)); res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U));
@ -114,8 +111,8 @@ __inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(co
// fp8x4 -> bf16_4_t // fp8x4 -> bf16_4_t
template <> template <>
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a) __inline__ __device__ bf16_4_t
{ vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a) {
bf16_4_t res; bf16_4_t res;
res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a); res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a);
res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U)); res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U));
@ -124,8 +121,7 @@ __inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, uint32_t>(const uint32_t
// fp8x8 -> bf16_8_t // fp8x8 -> bf16_8_t
template <> template <>
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a) __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a) {
{
bf16_4_t tmp1, tmp2; bf16_4_t tmp1, tmp2;
tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x); tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x);
tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y); tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y);
@ -139,17 +135,17 @@ __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a)
// fp8 -> float // fp8 -> float
template <> template <>
__inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a) __inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a) {
{
hip_fp8 fp8{a, hip_fp8::from_bits()}; hip_fp8 fp8{a, hip_fp8::from_bits()};
return static_cast<float>(fp8); return static_cast<float>(fp8);
} }
// fp8x2 -> float2 // fp8x2 -> float2
template <> template <>
__inline__ __device__ float2 vec_conversion<float2, uint16_t>(const uint16_t& a) __inline__ __device__ float2
{ vec_conversion<float2, uint16_t>(const uint16_t& a) {
#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) #if defined(__HIP__MI300__) && \
defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
float2 res; float2 res;
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
res.x = f2[0]; res.x = f2[0];
@ -165,8 +161,8 @@ __inline__ __device__ float2 vec_conversion<float2, uint16_t>(const uint16_t& a)
// fp8x4 -> float4 // fp8x4 -> float4
template <> template <>
__inline__ __device__ Float4_ vec_conversion<Float4_, uint32_t>(const uint32_t& a) __inline__ __device__ Float4_
{ vec_conversion<Float4_, uint32_t>(const uint32_t& a) {
Float4_ res; Float4_ res;
res.x = vec_conversion<float2, uint16_t>((uint16_t)a); res.x = vec_conversion<float2, uint16_t>((uint16_t)a);
res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U)); res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U));
@ -175,8 +171,7 @@ __inline__ __device__ Float4_ vec_conversion<Float4_, uint32_t>(const uint32_t&
// fp8x8 -> float8 // fp8x8 -> float8
template <> template <>
__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a) __inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a) {
{
Float4_ tmp1, tmp2; Float4_ tmp1, tmp2;
tmp1 = vec_conversion<Float4_, uint32_t>(a.x); tmp1 = vec_conversion<Float4_, uint32_t>(a.x);
tmp2 = vec_conversion<Float4_, uint32_t>(a.y); tmp2 = vec_conversion<Float4_, uint32_t>(a.y);
@ -190,8 +185,8 @@ __inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a)
// half -> fp8 // half -> fp8
template <> template <>
__inline__ __device__ uint8_t vec_conversion<uint8_t, uint16_t>(const uint16_t& a) __inline__ __device__ uint8_t
{ vec_conversion<uint8_t, uint16_t>(const uint16_t& a) {
__half_raw tmp; __half_raw tmp;
tmp.x = a; tmp.x = a;
@ -201,24 +196,23 @@ __inline__ __device__ uint8_t vec_conversion<uint8_t, uint16_t>(const uint16_t&
// bf16 -> fp8 // bf16 -> fp8
template <> template <>
__inline__ __device__ uint8_t vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a) __inline__ __device__ uint8_t
{ vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a) {
hip_fp8 res{__bfloat162float(a)}; hip_fp8 res{__bfloat162float(a)};
return res.data; return res.data;
} }
// float -> fp8 // float -> fp8
template <> template <>
__inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a) __inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a) {
{
hip_fp8 f8(a); hip_fp8 f8(a);
return f8.data; return f8.data;
} }
// fp8x4 -> float4 // fp8x4 -> float4
template <> template <>
__inline__ __device__ float4 vec_conversion<float4, uint32_t>(const uint32_t& a) __inline__ __device__ float4
{ vec_conversion<float4, uint32_t>(const uint32_t& a) {
Float4_ tmp = vec_conversion<Float4_, uint32_t>(a); Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
return res; return res;
@ -226,8 +220,8 @@ __inline__ __device__ float4 vec_conversion<float4, uint32_t>(const uint32_t& a)
// float2 -> half2 // float2 -> half2
template <> template <>
__inline__ __device__ uint32_t vec_conversion<uint32_t, float2>(const float2& a) __inline__ __device__ uint32_t
{ vec_conversion<uint32_t, float2>(const float2& a) {
union { union {
half2 float16; half2 float16;
uint32_t uint32; uint32_t uint32;
@ -239,8 +233,7 @@ __inline__ __device__ uint32_t vec_conversion<uint32_t, float2>(const float2& a)
// Float4 -> half2x2 // Float4 -> half2x2
template <> template <>
__inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a) __inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a) {
{
uint2 b; uint2 b;
float2 val; float2 val;
val.x = a.x.x; val.x = a.x.x;
@ -255,8 +248,7 @@ __inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a)
// Float4 -> float4 // Float4 -> float4
template <> template <>
__inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a) __inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a) {
{
float4 b; float4 b;
b.x = a.x.x; b.x = a.x.x;
b.y = a.x.y; b.y = a.x.y;
@ -267,8 +259,7 @@ __inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a)
// Float8 -> half2x4 // Float8 -> half2x4
template <> template <>
__inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a) __inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a) {
{
uint4 b; uint4 b;
b.x = vec_conversion<uint32_t, float2>(a.x); b.x = vec_conversion<uint32_t, float2>(a.x);
b.y = vec_conversion<uint32_t, float2>(a.y); b.y = vec_conversion<uint32_t, float2>(a.y);
@ -279,16 +270,16 @@ __inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a)
// float2 -> bfloat162 // float2 -> bfloat162
template <> template <>
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(const float2& a) __inline__ __device__ __nv_bfloat162
{ vec_conversion<__nv_bfloat162, float2>(const float2& a) {
__nv_bfloat162 b = __float22bfloat162_rn(a); __nv_bfloat162 b = __float22bfloat162_rn(a);
return b; return b;
} }
// Float4 -> bfloat162x2 // Float4 -> bfloat162x2
template <> template <>
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, Float4_>(const Float4_& a) __inline__ __device__ bf16_4_t
{ vec_conversion<bf16_4_t, Float4_>(const Float4_& a) {
bf16_4_t b; bf16_4_t b;
b.x = __float22bfloat162_rn(a.x); b.x = __float22bfloat162_rn(a.x);
b.y = __float22bfloat162_rn(a.y); b.y = __float22bfloat162_rn(a.y);
@ -297,8 +288,8 @@ __inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, Float4_>(const Float4_&
// Float8 -> bfloat162x4 // Float8 -> bfloat162x4
template <> template <>
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(const Float8_& a) __inline__ __device__ bf16_8_t
{ vec_conversion<bf16_8_t, Float8_>(const Float8_& a) {
bf16_8_t b; bf16_8_t b;
b.x = __float22bfloat162_rn(a.x); b.x = __float22bfloat162_rn(a.x);
b.y = __float22bfloat162_rn(a.y); b.y = __float22bfloat162_rn(a.y);
@ -307,20 +298,19 @@ __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(const Float8_&
return b; return b;
} }
/* Scaled and vectorized conversions, for data exchange between high and low
precision domains
/* Scaled and vectorized conversions, for data exchange between high and low precision domains Convention of the scale in API, e.g: FP8_data = Quantization(
High_Precision_data / scale ) s.t. Quantize(HP / scale) => FP8 Dequant(FP8) *
Convention of the scale in API, e.g: FP8_data = Quantization( High_Precision_data / scale ) scale => HP
s.t.
Quantize(HP / scale) => FP8
Dequant(FP8) * scale => HP
*/ */
// fp8 -> half // fp8 -> half
template <> template <>
__inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, const float scale) __inline__ __device__ uint16_t
{ scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, const float scale) {
hip_fp8 f8{a, hip_fp8::from_bits()}; hip_fp8 f8{a, hip_fp8::from_bits()};
__half_raw res; __half_raw res;
res.data = static_cast<float>(f8) * scale; res.data = static_cast<float>(f8) * scale;
@ -329,9 +319,10 @@ __inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, uint8_t>(const ui
// fp8x2 -> half2 // fp8x2 -> half2
template <> template <>
__inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>(const uint16_t& a, const float scale) __inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>(
{ const uint16_t& a, const float scale) {
#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) #if defined(__HIP__MI300__) && \
defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
union { union {
__half2_raw h2r; __half2_raw h2r;
@ -346,29 +337,32 @@ __inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>(const u
uint32_t u32; uint32_t u32;
} tmp; } tmp;
tmp.u16[0] = scaled_vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a), scale); tmp.u16[0] =
tmp.u16[1] = scaled_vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a >> 8U), scale); scaled_vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a), scale);
tmp.u16[1] = scaled_vec_conversion<uint16_t, uint8_t>(
static_cast<uint8_t>(a >> 8U), scale);
return tmp.u32; return tmp.u32;
#endif #endif
} }
// fp8x4 -> half2x2 // fp8x4 -> half2x2
template <> template <>
__inline__ __device__ uint2 scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, const float scale) __inline__ __device__ uint2
{ scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, const float scale) {
union { union {
uint2 u32x2; uint2 u32x2;
uint32_t u32[2]; uint32_t u32[2];
} tmp; } tmp;
tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale); tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale);
tmp.u32[1] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale); tmp.u32[1] =
scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale);
return tmp.u32x2; return tmp.u32x2;
} }
// fp8x8 -> half2x4 // fp8x8 -> half2x4
template <> template <>
__inline__ __device__ uint4 scaled_vec_conversion<uint4, uint2>(const uint2& a, const float scale) __inline__ __device__ uint4
{ scaled_vec_conversion<uint4, uint2>(const uint2& a, const float scale) {
union { union {
uint4 u64x2; uint4 u64x2;
uint2 u64[2]; uint2 u64[2];
@ -382,8 +376,9 @@ using __nv_bfloat16 = __hip_bfloat16;
// fp8 -> __nv_bfloat16 // fp8 -> __nv_bfloat16
template <> template <>
__inline__ __device__ __nv_bfloat16 scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, const float scale) __inline__ __device__ __nv_bfloat16
{ scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a,
const float scale) {
hip_fp8 f8{a, hip_fp8::from_bits()}; hip_fp8 f8{a, hip_fp8::from_bits()};
float f{f8}; float f{f8};
return __float2bfloat16(f * scale); return __float2bfloat16(f * scale);
@ -393,28 +388,31 @@ using __nv_bfloat162 = __hip_bfloat162;
// fp8x2 -> __nv_bfloat162 // fp8x2 -> __nv_bfloat162
template <> template <>
__inline__ __device__ __nv_bfloat162 scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a, const float scale) __inline__ __device__ __nv_bfloat162
{ scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a,
const float scale) {
__nv_bfloat162 res; __nv_bfloat162 res;
res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale); res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale);
res.y = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale); res.y =
scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale);
return res; return res;
} }
// fp8x4 -> bf16_4_t // fp8x4 -> bf16_4_t
template <> template <>
__inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a, const float scale) __inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(
{ const uint32_t& a, const float scale) {
bf16_4_t res; bf16_4_t res;
res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale); res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale);
res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), scale); res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U),
scale);
return res; return res;
} }
// fp8x8 -> bf16_8_t // fp8x8 -> bf16_8_t
template <> template <>
__inline__ __device__ bf16_8_t scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, const float scale) __inline__ __device__ bf16_8_t
{ scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, const float scale) {
bf16_4_t tmp1, tmp2; bf16_4_t tmp1, tmp2;
tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale); tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale);
tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale); tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale);
@ -428,17 +426,18 @@ __inline__ __device__ bf16_8_t scaled_vec_conversion<bf16_8_t, uint2>(const uint
// fp8 -> float // fp8 -> float
template <> template <>
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>(const uint8_t& a, const float scale) __inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
{ const uint8_t& a, const float scale) {
hip_fp8 fp8{a, hip_fp8::from_bits()}; hip_fp8 fp8{a, hip_fp8::from_bits()};
return static_cast<float>(fp8) * scale; return static_cast<float>(fp8) * scale;
} }
// fp8x2 -> float2 // fp8x2 -> float2
template <> template <>
__inline__ __device__ float2 scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, const float scale) __inline__ __device__ float2
{ scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, const float scale) {
#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) #if defined(__HIP__MI300__) && \
defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
float2 res; float2 res;
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
res.x = f2[0] * scale; res.x = f2[0] * scale;
@ -447,15 +446,16 @@ __inline__ __device__ float2 scaled_vec_conversion<float2, uint16_t>(const uint1
#else #else
float2 res; float2 res;
res.x = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a), scale); res.x = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a), scale);
res.y = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U), scale); res.y = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U),
scale);
return res; return res;
#endif #endif
} }
// fp8x4 -> float4 // fp8x4 -> float4
template <> template <>
__inline__ __device__ Float4_ scaled_vec_conversion<Float4_, uint32_t>(const uint32_t& a, const float scale) __inline__ __device__ Float4_
{ scaled_vec_conversion<Float4_, uint32_t>(const uint32_t& a, const float scale) {
Float4_ res; Float4_ res;
res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale); res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale);
res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale); res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale);
@ -464,8 +464,8 @@ __inline__ __device__ Float4_ scaled_vec_conversion<Float4_, uint32_t>(const uin
// fp8x8 -> float8 // fp8x8 -> float8
template <> template <>
__inline__ __device__ Float8_ scaled_vec_conversion<Float8_, uint2>(const uint2& a, const float scale) __inline__ __device__ Float8_
{ scaled_vec_conversion<Float8_, uint2>(const uint2& a, const float scale) {
Float4_ tmp1, tmp2; Float4_ tmp1, tmp2;
tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale); tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale);
tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale); tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale);
@ -477,15 +477,14 @@ __inline__ __device__ Float8_ scaled_vec_conversion<Float8_, uint2>(const uint2&
return res; return res;
} }
/* Quantize(HP / scale) => FP8 */ /* Quantize(HP / scale) => FP8 */
// TODO(Hai): vectorized to add // TODO(Hai): vectorized to add
// half -> fp8 // half -> fp8
template <> template <>
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, const float scale) __inline__ __device__ uint8_t
{ scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, const float scale) {
__half_raw tmp; __half_raw tmp;
tmp.x = a; tmp.x = a;
@ -495,24 +494,24 @@ __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, uint16_t>(const uin
// bf16 -> fp8 // bf16 -> fp8
template <> template <>
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a, const float scale) __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
{ const __nv_bfloat16& a, const float scale) {
hip_fp8 res{__bfloat162float(a) / scale}; hip_fp8 res{__bfloat162float(a) / scale};
return res.data; return res.data;
} }
// float -> fp8 // float -> fp8
template <> template <>
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, float>(const float& a, const float scale) __inline__ __device__ uint8_t
{ scaled_vec_conversion<uint8_t, float>(const float& a, const float scale) {
hip_fp8 f8(a / scale); hip_fp8 f8(a / scale);
return f8.data; return f8.data;
} }
// fp8x4 -> float4 // fp8x4 -> float4
template <> template <>
__inline__ __device__ float4 scaled_vec_conversion<float4, uint32_t>(const uint32_t& a, const float scale) __inline__ __device__ float4
{ scaled_vec_conversion<float4, uint32_t>(const uint32_t& a, const float scale) {
Float4_ tmp = scaled_vec_conversion<Float4_, uint32_t>(a, scale); Float4_ tmp = scaled_vec_conversion<Float4_, uint32_t>(a, scale);
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
return res; return res;
@ -539,9 +538,10 @@ __inline__ __device__ Tout scaled_convert(const Tin &x, const float scale) {
assert(false); assert(false);
} }
// The following macro is used to dispatch the conversion function based on the // The following macro is used to dispatch the conversion function based on
// data type of the key and value cache. The FN is a macro that calls a function // the data type of the key and value cache. The FN is a macro that calls a
// with template<typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>. // function with template<typename scalar_t, typename cache_t,
// Fp8KVCacheDataType kv_dt>.
#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \ #define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
if (KV_DTYPE == "auto") { \ if (KV_DTYPE == "auto") { \
if (SRC_DTYPE == at::ScalarType::Float) { \ if (SRC_DTYPE == at::ScalarType::Float) { \
@ -562,13 +562,14 @@ __inline__ __device__ Tout scaled_convert(const Tin &x, const float scale) {
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else { \ } else { \
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \ TORCH_CHECK(false, \
"Unsupported input type of kv cache: ", SRC_DTYPE); \
} \ } \
} else { \ } else { \
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \ TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
} \ } \
} }
} // fp8 } // namespace fp8
#endif // USE_ROCM #endif // USE_ROCM
} // namespace vllm } // namespace vllm

View File

@ -11,8 +11,10 @@ namespace vllm {
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) { __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
float old; float old;
old = (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) : old = (value >= 0)
__uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value))); ? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
: __uint_as_float(
atomicMin((unsigned int*)addr, __float_as_uint(value)));
return old; return old;
} }
@ -20,7 +22,8 @@ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
#define FP8_E4M3_MAX std::numeric_limits<c10::Float8_e4m3fn>::max() #define FP8_E4M3_MAX std::numeric_limits<c10::Float8_e4m3fn>::max()
template <typename scalar_t> template <typename scalar_t>
__device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(const scalar_t val, const float scale) { __device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(
const scalar_t val, const float scale) {
float x = static_cast<float>(val) / scale; float x = static_cast<float>(val) / scale;
float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX)); float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
return static_cast<c10::Float8_e4m3fn>(r); return static_cast<c10::Float8_e4m3fn>(r);
@ -33,8 +36,7 @@ __device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(const scalar
// a value <= 0.0 and we need to wait for all thread blocks to // a value <= 0.0 and we need to wait for all thread blocks to
// finish before consuming *scale. // finish before consuming *scale.
template <typename scalar_t> template <typename scalar_t>
__global__ void segmented_max_reduction( __global__ void segmented_max_reduction(float* __restrict__ scale,
float* __restrict__ scale,
const scalar_t* __restrict__ input, const scalar_t* __restrict__ input,
int64_t num_elems) { int64_t num_elems) {
__shared__ float cache[1024]; __shared__ float cache[1024];
@ -64,13 +66,13 @@ __global__ void segmented_max_reduction(
// Finally, since cache[0] contains the maximum for this thread block, // Finally, since cache[0] contains the maximum for this thread block,
// atomically write the max to the target location // atomically write the max to the target location
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
atomicMaxFloat(scale, cache[0] / std::numeric_limits<c10::Float8_e4m3fn>::max()); atomicMaxFloat(scale,
cache[0] / std::numeric_limits<c10::Float8_e4m3fn>::max());
} }
} }
template <typename scalar_t> template <typename scalar_t>
__global__ void scaled_fp8_quant_kernel( __global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out,
c10::Float8_e4m3fn* __restrict__ out,
const scalar_t* __restrict__ input, const scalar_t* __restrict__ input,
const float* __restrict__ scale, const float* __restrict__ scale,
int64_t num_elems) { int64_t num_elems) {
@ -83,8 +85,7 @@ __global__ void scaled_fp8_quant_kernel(
} // namespace vllm } // namespace vllm
void static_scaled_fp8_quant( void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
torch::Tensor& out, // [..., d]
torch::Tensor& input, // [..., d] torch::Tensor& input, // [..., d]
torch::Tensor& scale) // [1] torch::Tensor& scale) // [1]
{ {
@ -95,19 +96,14 @@ void static_scaled_fp8_quant(
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
"scaled_fp8_quant_kernel",
[&] {
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>( vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<c10::Float8_e4m3fn>(), out.data_ptr<c10::Float8_e4m3fn>(), input.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(), scale.data_ptr<float>(), num_elems);
scale.data_ptr<float>(),
num_elems);
}); });
} }
void dynamic_scaled_fp8_quant( void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
torch::Tensor& out, // [..., d]
torch::Tensor& input, // [..., d] torch::Tensor& input, // [..., d]
torch::Tensor& scale) // [1] torch::Tensor& scale) // [1]
{ {
@ -118,18 +114,11 @@ void dynamic_scaled_fp8_quant(
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
"scaled_fp8_quant_kernel",
[&] {
vllm::segmented_max_reduction<scalar_t><<<grid, block, 0, stream>>>( vllm::segmented_max_reduction<scalar_t><<<grid, block, 0, stream>>>(
scale.data_ptr<float>(), scale.data_ptr<float>(), input.data_ptr<scalar_t>(), num_elems);
input.data_ptr<scalar_t>(),
num_elems);
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>( vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<c10::Float8_e4m3fn>(), out.data_ptr<c10::Float8_e4m3fn>(), input.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(), scale.data_ptr<float>(), num_elems);
scale.data_ptr<float>(),
num_elems);
}); });
} }

View File

@ -406,7 +406,6 @@ template <>
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>( __inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
const uint8_t& a, const float scale, const uint8_t& a, const float scale,
const __nv_fp8_interpretation_t fp8_type) { const __nv_fp8_interpretation_t fp8_type) {
// fp8 -> half // fp8 -> half
__half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type); __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
uint16_t tmp = res.x; uint16_t tmp = res.x;
@ -523,9 +522,10 @@ __inline__ __device__ Tout scaled_convert(const Tin &x, const float scale) {
assert(false); assert(false);
} }
// The following macro is used to dispatch the conversion function based on the // The following macro is used to dispatch the conversion function based on
// data type of the key and value cache. The FN is a macro that calls a function // the data type of the key and value cache. The FN is a macro that calls a
// with template<typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>. // function with template<typename scalar_t, typename cache_t,
// Fp8KVCacheDataType kv_dt>.
#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \ #define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
if (KV_DTYPE == "auto") { \ if (KV_DTYPE == "auto") { \
if (SRC_DTYPE == at::ScalarType::Float) { \ if (SRC_DTYPE == at::ScalarType::Float) { \
@ -546,7 +546,8 @@ __inline__ __device__ Tout scaled_convert(const Tin &x, const float scale) {
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else { \ } else { \
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \ TORCH_CHECK(false, \
"Unsupported input type of kv cache: ", SRC_DTYPE); \
} \ } \
} else if (KV_DTYPE == "fp8_e5m2") { \ } else if (KV_DTYPE == "fp8_e5m2") { \
if (SRC_DTYPE == at::ScalarType::Float) { \ if (SRC_DTYPE == at::ScalarType::Float) { \
@ -556,7 +557,8 @@ __inline__ __device__ Tout scaled_convert(const Tin &x, const float scale) {
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \ FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
} else { \ } else { \
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \ TORCH_CHECK(false, \
"Unsupported input type of kv cache: ", SRC_DTYPE); \
} \ } \
} else { \ } else { \
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \ TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \

View File

@ -9,40 +9,36 @@ namespace vllm {
namespace gptq { namespace gptq {
// atomicAdd for half types, to support CC < 7.x // atomicAdd for half types, to support CC < 7.x
__device__ __forceinline__ void atomicAdd_half(half* address, half val) __device__ __forceinline__ void atomicAdd_half(half* address, half val) {
{ unsigned int* address_as_ui =
unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2)); (unsigned int*)((char*)address - ((size_t)address & 2));
unsigned int old = *address_as_ui; unsigned int old = *address_as_ui;
unsigned int assumed; unsigned int assumed;
do do {
{
assumed = old; assumed = old;
__half_raw hsum; __half_raw hsum;
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
half tmpres = __hadd(hsum, val); half tmpres = __hadd(hsum, val);
hsum = __half_raw(tmpres); hsum = __half_raw(tmpres);
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x; old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16)
: (old & 0xffff0000) | hsum.x;
old = atomicCAS(address_as_ui, assumed, old); old = atomicCAS(address_as_ui, assumed, old);
} } while (assumed != old);
while (assumed != old);
} }
// atomicAdd for half2 types // atomicAdd for half2 types
__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) __device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) {
{
unsigned int* address_as_ui = (unsigned int*)address; unsigned int* address_as_ui = (unsigned int*)address;
unsigned int old = *address_as_ui; unsigned int old = *address_as_ui;
unsigned int assumed; unsigned int assumed;
do do {
{
assumed = old; assumed = old;
half2 old_val = *((half2*)&old); half2 old_val = *((half2*)&old);
half2 new_val = __hadd2(old_val, val); half2 new_val = __hadd2(old_val, val);
old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val)); old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val));
} } while (assumed != old);
while (assumed != old);
} }
// //
@ -50,10 +46,14 @@ __device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
#if defined(__CUDA_ARCH__) || defined(USE_ROCM) #if defined(__CUDA_ARCH__) || defined(USE_ROCM)
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM) #if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); } __device__ __forceinline__ void atomicAdd(half* address, half val) {
atomicAdd_half(address, val);
}
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM) #if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); } __device__ __forceinline__ void atomicAdd(half2* address, half2 val) {
atomicAdd_half2(address, val);
}
#endif #endif
#endif #endif

View File

@ -1,5 +1,6 @@
/* /*
Adapted from https://github.com/turboderp/exllamav2 and https://github.com/turboderp/exllama Adapted from https://github.com/turboderp/exllamav2 and
https://github.com/turboderp/exllama
*/ */
#ifndef _matrix_view_cuh #ifndef _matrix_view_cuh
@ -13,24 +14,31 @@ Adapted from https://github.com/turboderp/exllamav2 and https://github.com/turbo
namespace vllm { namespace vllm {
namespace gptq { namespace gptq {
class MatrixView_half class MatrixView_half {
{
public: public:
const half* data; const half* data;
const int height; const int height;
const int width; const int width;
__device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width) __device__ __forceinline__ MatrixView_half(const half* data, const int height,
: data(data), height(height), width(width) const int width)
{ } : data(data), height(height), width(width) {}
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } __device__ __forceinline__ half item(int row, int column) const {
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } return data[row * width + column];
__device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); } }
__device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; } __device__ __forceinline__ half2 item_half2(int row, int column) const {
return ((half2*)data)[(row * width + column) / 2];
}
__device__ __forceinline__ half2 item_half2half2(int row, int column) const {
return __half2half2(data[row * width + column]);
}
__device__ __forceinline__ const half* item_ptr(int row, int column) const {
return &data[row * width + column];
}
__device__ __forceinline__ void item4(half (&items)[4], int row, int column) const __device__ __forceinline__ void item4(half (&items)[4], int row,
{ int column) const {
half2* ptr = (half2*)item_ptr(row, column); half2* ptr = (half2*)item_ptr(row, column);
half2 i01 = ptr[0]; half2 i01 = ptr[0];
half2 i23 = ptr[1]; half2 i23 = ptr[1];
@ -39,8 +47,8 @@ public:
items[2] = __low2half(i23); items[2] = __low2half(i23);
items[3] = __high2half(i23); items[3] = __high2half(i23);
} }
__device__ __forceinline__ void item4_f(float (&items)[4], int row, int column) const __device__ __forceinline__ void item4_f(float (&items)[4], int row,
{ int column) const {
half2* ptr = (half2*)item_ptr(row, column); half2* ptr = (half2*)item_ptr(row, column);
half2 i01 = ptr[0]; half2 i01 = ptr[0];
half2 i23 = ptr[1]; half2 i23 = ptr[1];
@ -50,8 +58,8 @@ public:
items[3] = __half2float(__high2half(i23)); items[3] = __half2float(__high2half(i23));
} }
__device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, int column) const __device__ __forceinline__ void item4_h2(half2 (&items)[4], int row,
{ int column) const {
half2* ptr = (half2*)item_ptr(row, column); half2* ptr = (half2*)item_ptr(row, column);
half2 i01 = ptr[0]; half2 i01 = ptr[0];
half2 i23 = ptr[1]; half2 i23 = ptr[1];
@ -62,25 +70,34 @@ public:
} }
}; };
class MatrixView_half_rw class MatrixView_half_rw {
{
public: public:
half* data; half* data;
const int height; const int height;
const int width; const int width;
__device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width) __device__ __forceinline__ MatrixView_half_rw(half* data, const int height,
: data(data), height(height), width(width) const int width)
{ } : data(data), height(height), width(width) {}
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } __device__ __forceinline__ half item(int row, int column) const {
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } return data[row * width + column];
__device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; } }
__device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; } __device__ __forceinline__ half2 item_half2(int row, int column) const {
__device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; } return ((half2*)data)[(row * width + column) / 2];
}
__device__ __forceinline__ half* item_ptr(int row, int column) {
return &data[row * width + column];
}
__device__ __forceinline__ void set(int row, int column, half value) {
data[row * width + column] = value;
}
__device__ __forceinline__ void set_half2(int row, int column, half2 value) {
((half2*)data)[(row * width + column) / 2] = value;
}
__device__ __forceinline__ void set4(int row, int column, half v0, half v1, half v2, half v3) __device__ __forceinline__ void set4(int row, int column, half v0, half v1,
{ half v2, half v3) {
half2 v01 = __halves2half2(v0, v1); half2 v01 = __halves2half2(v0, v1);
half2 v23 = __halves2half2(v2, v3); half2 v23 = __halves2half2(v2, v3);
half2* ptr = (half2*)item_ptr(row, column); half2* ptr = (half2*)item_ptr(row, column);
@ -89,33 +106,32 @@ public:
} }
}; };
class MatrixView_q4_row class MatrixView_q4_row {
{
public: public:
const uint32_t* data; const uint32_t* data;
const int height; const int height;
const int width; const int width;
__device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width) __device__ __forceinline__ MatrixView_q4_row(const uint32_t* data,
: data(data), height(height), width(width) const int height,
{ } const int width)
: data(data), height(height), width(width) {}
__device__ __forceinline__ int item(int row, int column) const __device__ __forceinline__ int item(int row, int column) const {
{
int shift = (column & 0x07) * 4; int shift = (column & 0x07) * 4;
return (data[row * width / 8 + column / 8] >> shift) & 0x0f; return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
} }
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const __device__ __forceinline__ void item2(int (&items)[2], int row,
{ int column) const {
int shift = (column & 0x07) * 4; int shift = (column & 0x07) * 4;
uint32_t d = data[row * width / 8 + column / 8] >> shift; uint32_t d = data[row * width / 8 + column / 8] >> shift;
items[0] = d & 0x0f; items[0] = d & 0x0f;
items[1] = (d >> 4) & 0x0f; items[1] = (d >> 4) & 0x0f;
} }
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const __device__ __forceinline__ void item4(int (&items)[4], int row,
{ int column) const {
int shift = (column & 0x07) * 4; int shift = (column & 0x07) * 4;
uint32_t d = data[row * width / 8 + column / 8] >> shift; uint32_t d = data[row * width / 8 + column / 8] >> shift;
items[0] = d & 0x0f; items[0] = d & 0x0f;
@ -125,54 +141,57 @@ public:
} }
}; };
class MatrixView_q4_column class MatrixView_q4_column {
{
public: public:
const uint32_t* data; const uint32_t* data;
const int height; const int height;
const int width; const int width;
__device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width) __device__ __forceinline__ MatrixView_q4_column(const uint32_t* data,
: data(data), height(height), width(width) const int height,
{ } const int width)
: data(data), height(height), width(width) {}
__device__ __forceinline__ int item(int row, int column) const __device__ __forceinline__ int item(int row, int column) const {
{
int shift = (row & 0x07) * 4; int shift = (row & 0x07) * 4;
return (data[row / 8 * width + column] >> shift) & 0x0f; return (data[row / 8 * width + column] >> shift) & 0x0f;
} }
__device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; } __device__ __forceinline__ uint32_t item_uint32_t(int row, int column) {
__device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; } return data[row / 8 * width + column];
}
__device__ __forceinline__ const uint32_t* item_uint32_ptr(int row,
int column) {
return &data[row / 8 * width + column];
}
}; };
class MatrixView_q2_row class MatrixView_q2_row {
{
public: public:
const uint32_t* data; const uint32_t* data;
const int height; const int height;
const int width; const int width;
__device__ __forceinline__ MatrixView_q2_row(const uint32_t* data, const int height, const int width) __device__ __forceinline__ MatrixView_q2_row(const uint32_t* data,
: data(data), height(height), width(width) const int height,
{ } const int width)
: data(data), height(height), width(width) {}
__device__ __forceinline__ int item(int row, int column) const __device__ __forceinline__ int item(int row, int column) const {
{
int shift = (column & 0x0f) * 2; int shift = (column & 0x0f) * 2;
return (data[row * width / 16 + column / 16] >> shift) & 0x03; return (data[row * width / 16 + column / 16] >> shift) & 0x03;
} }
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const __device__ __forceinline__ void item2(int (&items)[2], int row,
{ int column) const {
int shift = (column & 0x0f) * 2; int shift = (column & 0x0f) * 2;
uint32_t d = data[row * width / 16 + column / 16] >> shift; uint32_t d = data[row * width / 16 + column / 16] >> shift;
items[0] = d & 0x03; items[0] = d & 0x03;
items[1] = (d >> 2) & 0x03; items[1] = (d >> 2) & 0x03;
} }
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const __device__ __forceinline__ void item4(int (&items)[4], int row,
{ int column) const {
int shift = (column & 0x0f) * 2; int shift = (column & 0x0f) * 2;
uint32_t d = data[row * width / 16 + column / 16] >> shift; uint32_t d = data[row * width / 16 + column / 16] >> shift;
items[0] = d & 0x03; items[0] = d & 0x03;
@ -182,26 +201,27 @@ public:
} }
}; };
class MatrixView_q3_row class MatrixView_q3_row {
{
public: public:
const uint32_t* data; const uint32_t* data;
const int height; const int height;
const int width; const int width;
__device__ __forceinline__ MatrixView_q3_row(const uint32_t* data, const int height, const int width) __device__ __forceinline__ MatrixView_q3_row(const uint32_t* data,
: data(data), height(height), width(width) const int height,
{ } const int width)
: data(data), height(height), width(width) {}
__device__ __forceinline__ int item(int row, int column) const __device__ __forceinline__ int item(int row, int column) const {
{
int z_w = column * 3 / 32; int z_w = column * 3 / 32;
int z_mod = column & 0x1f; int z_mod = column & 0x1f;
if (z_mod == 10) { if (z_mod == 10) {
return (data[row * width * 3 / 32 + z_w] >> 30) | ((data[row * width * 3 / 32 + (z_w + 1)] << 2) & 0x4); return (data[row * width * 3 / 32 + z_w] >> 30) |
((data[row * width * 3 / 32 + (z_w + 1)] << 2) & 0x4);
} else if (z_mod == 21) { } else if (z_mod == 21) {
return (data[row * width * 3 / 32 + z_w] >> 31) | ((data[row * width * 3 / 32 + (z_w + 1)] << 1) & 0x6); return (data[row * width * 3 / 32 + z_w] >> 31) |
((data[row * width * 3 / 32 + (z_w + 1)] << 1) & 0x6);
} else if (z_mod < 10) { } else if (z_mod < 10) {
return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3)) & 0x07; return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3)) & 0x07;
} else if (z_mod < 21) { } else if (z_mod < 21) {
@ -211,18 +231,20 @@ public:
} }
} }
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const __device__ __forceinline__ void item4(int (&items)[4], int row,
{ int column) const {
int shift = (column & 0x1f); int shift = (column & 0x1f);
uint32_t d; uint32_t d;
if (shift <= 4) { if (shift <= 4) {
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3); d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3);
} else if (shift == 8) { } else if (shift == 8) {
d = (data[row * width / 32 * 3 + column * 3 / 32] >> 24) | ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0x0f) << 8); d = (data[row * width / 32 * 3 + column * 3 / 32] >> 24) |
((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0x0f) << 8);
} else if (shift <= 16) { } else if (shift <= 16) {
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 32); d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 32);
} else if (shift == 20) { } else if (shift == 20) {
d = (data[row * width / 32 * 3 + column * 3 / 32] >> 28) | ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0xff) << 4); d = (data[row * width / 32 * 3 + column * 3 / 32] >> 28) |
((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0xff) << 4);
} else { } else {
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 64); d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 64);
} }
@ -233,33 +255,32 @@ public:
} }
}; };
class MatrixView_q8_row class MatrixView_q8_row {
{
public: public:
const uint32_t* data; const uint32_t* data;
const int height; const int height;
const int width; const int width;
__device__ __forceinline__ MatrixView_q8_row(const uint32_t* data, const int height, const int width) __device__ __forceinline__ MatrixView_q8_row(const uint32_t* data,
: data(data), height(height), width(width) const int height,
{ } const int width)
: data(data), height(height), width(width) {}
__device__ __forceinline__ int item(int row, int column) const __device__ __forceinline__ int item(int row, int column) const {
{
int shift = (column & 0x03) * 8; int shift = (column & 0x03) * 8;
return (data[row * width / 4 + column / 4] >> shift) & 0xff; return (data[row * width / 4 + column / 4] >> shift) & 0xff;
} }
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const __device__ __forceinline__ void item2(int (&items)[2], int row,
{ int column) const {
int shift = (column & 0x03) * 8; int shift = (column & 0x03) * 8;
uint32_t d = data[row * width / 4 + column / 4] >> shift; uint32_t d = data[row * width / 4 + column / 4] >> shift;
items[0] = d & 0xff; items[0] = d & 0xff;
items[1] = (d >> 8) & 0xff; items[1] = (d >> 8) & 0xff;
} }
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const __device__ __forceinline__ void item4(int (&items)[4], int row,
{ int column) const {
int shift = (column & 0x03) * 2; int shift = (column & 0x03) * 2;
uint32_t d = data[row * width / 4 + column / 4] >> shift; uint32_t d = data[row * width / 4 + column / 4] >> shift;
items[0] = d & 0xff; items[0] = d & 0xff;

File diff suppressed because it is too large Load Diff

View File

@ -14,18 +14,12 @@ namespace gptq {
// //
// ffddbb99 77553311 eeccaa88 66442200 // ffddbb99 77553311 eeccaa88 66442200
__forceinline__ __device__ void shuffle_2bit_16 __forceinline__ __device__ void shuffle_2bit_16(uint32_t* q, int stride) {
(
uint32_t* q,
int stride
)
{
uint32_t qa = q[0]; uint32_t qa = q[0];
uint32_t qb = 0; uint32_t qb = 0;
#pragma unroll #pragma unroll
for (int i = 0; i < 8; i++) for (int i = 0; i < 8; i++) {
{
uint32_t qa0 = qa & 0x03; uint32_t qa0 = qa & 0x03;
uint32_t qa1 = (qa & 0x0c) >> 2; uint32_t qa1 = (qa & 0x0c) >> 2;
qa >>= 4; qa >>= 4;
@ -35,14 +29,9 @@ __forceinline__ __device__ void shuffle_2bit_16
q[0] = qb; q[0] = qb;
} }
__forceinline__ __device__ void dequant_2bit_16 __forceinline__ __device__ void dequant_2bit_16(const uint32_t q_0,
( half2 (&dq)[8], int stride,
const uint32_t q_0, const uint32_t zero) {
half2 (&dq)[8],
int stride,
const uint32_t zero
)
{
const uint32_t c0 = 0x64006400; const uint32_t c0 = 0x64006400;
const half y4_ = __float2half_rn(1.0f / 4.0f); const half y4_ = __float2half_rn(1.0f / 4.0f);
const half y16_ = __float2half_rn(1.0f / 16.0f); const half y16_ = __float2half_rn(1.0f / 16.0f);

View File

@ -11,12 +11,7 @@ namespace gptq {
// vjjjhhhf ffdddbbb uiiiggge eecccaaa // vjjjhhhf ffdddbbb uiiiggge eecccaaa
// vtttrrrp ppnnnlll usssqqqo oommmkkk // vtttrrrp ppnnnlll usssqqqo oommmkkk
__forceinline__ __device__ void shuffle_3bit_32 __forceinline__ __device__ void shuffle_3bit_32(uint32_t* q, int stride) {
(
uint32_t* q,
int stride
)
{
uint32_t qa = q[0 * stride]; uint32_t qa = q[0 * stride];
uint32_t qb = q[1 * stride]; uint32_t qb = q[1 * stride];
uint32_t qc = q[2 * stride]; uint32_t qc = q[2 * stride];
@ -40,9 +35,27 @@ __forceinline__ __device__ void shuffle_3bit_32
uint32_t zb = 0; uint32_t zb = 0;
uint32_t zc = 0; uint32_t zc = 0;
for (int i = 0; i < 5; i++) { uint32_t t0 = qa & 0x07; uint32_t t1 = (qa & 0x38) >> 3; qa >>= 6; za |= (t0 << (i * 3)); za |= (t1 << (i * 3 + 16)); } for (int i = 0; i < 5; i++) {
for (int i = 0; i < 5; i++) { uint32_t t0 = qb & 0x07; uint32_t t1 = (qb & 0x38) >> 3; qb >>= 6; zb |= (t0 << (i * 3)); zb |= (t1 << (i * 3 + 16)); } uint32_t t0 = qa & 0x07;
for (int i = 0; i < 5; i++) { uint32_t t0 = qc & 0x07; uint32_t t1 = (qc & 0x38) >> 3; qc >>= 6; zc |= (t0 << (i * 3)); zc |= (t1 << (i * 3 + 16)); } uint32_t t1 = (qa & 0x38) >> 3;
qa >>= 6;
za |= (t0 << (i * 3));
za |= (t1 << (i * 3 + 16));
}
for (int i = 0; i < 5; i++) {
uint32_t t0 = qb & 0x07;
uint32_t t1 = (qb & 0x38) >> 3;
qb >>= 6;
zb |= (t0 << (i * 3));
zb |= (t1 << (i * 3 + 16));
}
for (int i = 0; i < 5; i++) {
uint32_t t0 = qc & 0x07;
uint32_t t1 = (qc & 0x38) >> 3;
qc >>= 6;
zc |= (t0 << (i * 3));
zc |= (t1 << (i * 3 + 16));
}
// za: 9997775 55333111 8886664 44222000 // za: 9997775 55333111 8886664 44222000
// zb: jjjhhhf ffdddbbb iiiggge eecccaaa // zb: jjjhhhf ffdddbbb iiiggge eecccaaa
@ -65,16 +78,11 @@ __forceinline__ __device__ void shuffle_3bit_32
q[2 * stride] = zc; q[2 * stride] = zc;
} }
__forceinline__ __device__ void dequant_3bit_32 __forceinline__ __device__ void dequant_3bit_32(const uint32_t q_0,
(
const uint32_t q_0,
const uint32_t q_1, const uint32_t q_1,
const uint32_t q_2, const uint32_t q_2,
half2 (&dq)[16], half2 (&dq)[16], int stride,
int stride, const uint32_t zero) {
const uint32_t zero
)
{
const uint32_t c0 = 0x64006400; const uint32_t c0 = 0x64006400;
const half y8_ = __float2half_rn(1.0f / 8.0f); const half y8_ = __float2half_rn(1.0f / 8.0f);
const half y64_ = __float2half_rn(1.0f / 64.0f); const half y64_ = __float2half_rn(1.0f / 64.0f);

View File

@ -13,18 +13,12 @@ namespace gptq {
// //
// 77775555 33331111 66664444 22220000 // 77775555 33331111 66664444 22220000
__forceinline__ __device__ void shuffle_4bit_8 __forceinline__ __device__ void shuffle_4bit_8(uint32_t* q, int stride) {
(
uint32_t* q,
int stride
)
{
uint32_t qa = q[0]; uint32_t qa = q[0];
uint32_t qb = 0; uint32_t qb = 0;
#pragma unroll #pragma unroll
for (int i = 0; i < 4; i++) for (int i = 0; i < 4; i++) {
{
uint32_t qa0 = qa & 0x0f; uint32_t qa0 = qa & 0x0f;
uint32_t qa1 = (qa & 0xf0) >> 4; uint32_t qa1 = (qa & 0xf0) >> 4;
qa >>= 8; qa >>= 8;
@ -34,14 +28,9 @@ __forceinline__ __device__ void shuffle_4bit_8
q[0] = qb; q[0] = qb;
} }
__forceinline__ __device__ void dequant_4bit_8 __forceinline__ __device__ void dequant_4bit_8(const uint32_t q_0,
( half2 (&dq)[4], int stride,
const uint32_t q_0, const uint32_t zero) {
half2 (&dq)[4],
int stride,
const uint32_t zero
)
{
const uint32_t c0 = 0x64006400; const uint32_t c0 = 0x64006400;
const half y16_ = __float2half_rn(1.0f / 16.0f); const half y16_ = __float2half_rn(1.0f / 16.0f);
const half2 y16 = __halves2half2(y16_, y16_); const half2 y16 = __halves2half2(y16_, y16_);
@ -63,14 +52,9 @@ __forceinline__ __device__ void dequant_4bit_8
dq[3] = __hfma2(q3.as_half2, y16, z16); dq[3] = __hfma2(q3.as_half2, y16, z16);
} }
__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale __forceinline__ __device__ void dequant_4bit_8_prep_zero_scale(
( const uint32_t zero, const half scale, half2 (&z1z16)[2],
const uint32_t zero, half2 (&y1y16)[2]) {
const half scale,
half2 (&z1z16)[2],
half2 (&y1y16)[2]
)
{
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero); half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero)); half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
@ -86,13 +70,9 @@ __forceinline__ __device__ void dequant_4bit_8_prep_zero_scale
y1y16[1] = __hmul2(scale2, __half2half2(y16)); y1y16[1] = __hmul2(scale2, __half2half2(y16));
} }
__forceinline__ __device__ void dequant_4bit_8_prep_zero __forceinline__ __device__ void dequant_4bit_8_prep_zero(const uint32_t zero,
(
const uint32_t zero,
half2 (&z1z16)[2], half2 (&z1z16)[2],
half2(&y1y16)[2] half2 (&y1y16)[2]) {
)
{
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero); half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero)); half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
@ -106,39 +86,38 @@ __forceinline__ __device__ void dequant_4bit_8_prep_zero
y1y16[1] = __half2half2(y16); y1y16[1] = __half2half2(y16);
} }
__forceinline__ __device__ void dequant_4bit_8_gptq(const uint32_t q_0,
__forceinline__ __device__ void dequant_4bit_8_gptq
(
const uint32_t q_0,
half2 (&dq)[4], half2 (&dq)[4],
half2 (&z1z16)[2], half2 (&z1z16)[2],
half2 (&y1y16)[2], half2 (&y1y16)[2],
int stride, int stride, bool scaled) {
bool scaled
)
{
const uint32_t c0 = 0x64006400; const uint32_t c0 = 0x64006400;
uint32_t qa = q_0; uint32_t qa = q_0;
half2_uint32 q0((qa & 0x000f000f) | c0); // half2( q[0] + 1024, q[1] + 1024 ) half2_uint32 q0((qa & 0x000f000f) |
half2_uint32 q1((qa & 0x00f000f0) | c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 ) c0); // half2( q[0] + 1024, q[1] + 1024 )
half2_uint32 q1((qa & 0x00f000f0) |
c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 )
qa >>= 8; qa >>= 8;
half2_uint32 q2((qa & 0x000f000f) | c0); // half2( q[4] + 1024, q[5] + 1024 ) half2_uint32 q2((qa & 0x000f000f) |
half2_uint32 q3((qa & 0x00f000f0) | c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 ) c0); // half2( q[4] + 1024, q[5] + 1024 )
half2_uint32 q3((qa & 0x00f000f0) |
c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 )
if (scaled) if (scaled) {
{ dq[0] = __hfma2(q0.as_half2, y1y16[0],
dq[0] = __hfma2(q0.as_half2, y1y16[0], z1z16[0]); // half2( q[0] * s - z * s, q[1] * s - z * s) z1z16[0]); // half2( q[0] * s - z * s, q[1] * s - z * s)
dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] * s - z * s, q[3] * s - z * s) dq[1] = __hfma2(q1.as_half2, y1y16[1],
z1z16[1]); // half2( q[2] * s - z * s, q[3] * s - z * s)
dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]); dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]);
dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]);
} } else {
else
{
dq[0] = __hadd2(q0.as_half2, z1z16[0]); // half2( q[0] - z, q[1] - z ) dq[0] = __hadd2(q0.as_half2, z1z16[0]); // half2( q[0] - z, q[1] - z )
dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] - z, q[3] - z ) dq[1] = __hfma2(q1.as_half2, y1y16[1],
z1z16[1]); // half2( q[2] - z, q[3] - z )
dq[2] = __hadd2(q2.as_half2, z1z16[0]); // half2( q[4] - z, q[5] - z ) dq[2] = __hadd2(q2.as_half2, z1z16[0]); // half2( q[4] - z, q[5] - z )
dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); // half2( q[6] - z, q[7] - z ) dq[3] = __hfma2(q3.as_half2, y1y16[1],
z1z16[1]); // half2( q[6] - z, q[7] - z )
} }
} }
} // namespace gptq } // namespace gptq

View File

@ -10,28 +10,18 @@ Copied from https://github.com/turboderp/exllamav2
namespace vllm { namespace vllm {
namespace gptq { namespace gptq {
__forceinline__ __device__ void shuffle_8bit_4 __forceinline__ __device__ void shuffle_8bit_4(uint32_t* q, int stride) {}
(
uint32_t* q,
int stride
)
{
}
__forceinline__ __device__ void dequant_8bit_8 __forceinline__ __device__ void dequant_8bit_8(const uint32_t q_0,
(
const uint32_t q_0,
const uint32_t q_1, const uint32_t q_1,
half2 (&dq)[4], half2 (&dq)[4], int stride,
int stride, const uint32_t zero) {
const uint32_t zero
)
{
half dqh[8]; half dqh[8];
for (int i = 0; i < 4; i++) dqh[i] = dq_ns(exb(q_0, i * 8, 0xff), zero); for (int i = 0; i < 4; i++) dqh[i] = dq_ns(exb(q_0, i * 8, 0xff), zero);
for (int i = 0; i < 4; i++) dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), zero); for (int i = 0; i < 4; i++) dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), zero);
for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]); for (int i = 0; i < 4; i++)
dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
} }
} // namespace gptq } // namespace gptq

View File

@ -8,16 +8,14 @@ Copied from https://github.com/turboderp/exllamav2
namespace vllm { namespace vllm {
namespace gptq { namespace gptq {
union half2_uint32 union half2_uint32 {
{
uint32_t as_uint32; uint32_t as_uint32;
half2 as_half2; half2 as_half2;
__device__ half2_uint32(uint32_t val) : as_uint32(val) {} __device__ half2_uint32(uint32_t val) : as_uint32(val) {}
__device__ half2_uint32(half2 val) : as_half2(val) {} __device__ half2_uint32(half2 val) : as_half2(val) {}
}; };
union half_uint16 union half_uint16 {
{
uint16_t as_uint16; uint16_t as_uint16;
half as_half; half as_half;
__device__ half_uint16(uint16_t val) : as_uint16(val) {} __device__ half_uint16(uint16_t val) : as_uint16(val) {}
@ -26,32 +24,30 @@ union half_uint16
// Max_scale premultiplied by 1/256 // Max_scale premultiplied by 1/256
__forceinline__ __device__ half dq_scale(const int qs, const half max_scale) __forceinline__ __device__ half dq_scale(const int qs, const half max_scale) {
{
int qs_i = qs + 1; int qs_i = qs + 1;
half qs_h = __int2half_rn(qs_i * qs_i); half qs_h = __int2half_rn(qs_i * qs_i);
qs_h = __hmul(qs_h, max_scale); qs_h = __hmul(qs_h, max_scale);
return qs_h; return qs_h;
} }
__forceinline__ __device__ half dq(const int q, const int qzero, const half scale) __forceinline__ __device__ half dq(const int q, const int qzero,
{ const half scale) {
return __hmul(__int2half_rn(q - qzero), scale); return __hmul(__int2half_rn(q - qzero), scale);
} }
__forceinline__ __device__ half dq_ns(const int q, const int qzero) __forceinline__ __device__ half dq_ns(const int q, const int qzero) {
{
// return __hsub(__int2half_rn(q), __int2half_rn(qzero)); // return __hsub(__int2half_rn(q), __int2half_rn(qzero));
return __int2half_rn(q - qzero); return __int2half_rn(q - qzero);
} }
__forceinline__ __device__ int exb(const uint32_t q, const int shift, const int mask) __forceinline__ __device__ int exb(const uint32_t q, const int shift,
{ const int mask) {
return (int)((q >> shift) & mask); return (int)((q >> shift) & mask);
} }
__forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, const int shift, const int mask) __forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0,
{ const int shift, const int mask) {
return (int)(__funnelshift_rc(q0, q1, shift) & mask); return (int)(__funnelshift_rc(q0, q1, shift) & mask);
} }

View File

@ -22,11 +22,15 @@
#include "gptq_marlin.cuh" #include "gptq_marlin.cuh"
#include "gptq_marlin_dtypes.cuh" #include "gptq_marlin_dtypes.cuh"
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) static_assert(\ #define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
std::is_same<scalar_t, half>::value || std::is_same<scalar_t, nv_bfloat16>::value, \ static_assert(std::is_same<scalar_t, half>::value || \
std::is_same<scalar_t, nv_bfloat16>::value, \
"only float16 and bfloat16 is supported"); "only float16 and bfloat16 is supported");
template <typename T> inline std::string str(T x) { return std::to_string(x); } template <typename T>
inline std::string str(T x) {
return std::to_string(x);
}
namespace gptq_marlin { namespace gptq_marlin {
@ -41,17 +45,18 @@ template <typename scalar_t, // compute dtype, half or nv_float16
const int num_bits, // number of bits used for weights const int num_bits, // number of bits used for weights
const int threads, // number of threads in a threadblock const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the threadblock // dimension (batchsize) of the
// threadblock
const int thread_n_blocks, // same for n dimension (output) const int thread_n_blocks, // same for n dimension (output)
const int thread_k_blocks, // same for k dimension (reduction) const int thread_k_blocks, // same for k dimension (reduction)
const int stages, // number of stages for the async global->shared const int stages, // number of stages for the async global->shared
// fetch pipeline // fetch pipeline
const bool has_act_order, // whether act_order is enabled const bool has_act_order, // whether act_order is enabled
const int group_blocks = -1 // number of consecutive 16x16 blocks with const int group_blocks = -1 // number of consecutive 16x16 blocks
// a separate quantization scale // with a separate quantization scale
> >
__global__ void __global__ void Marlin(
Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk const int4* __restrict__ A, // fp16 input matrix of shape mxk
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
int4* __restrict__ C, // fp16 output buffer of shape mxn int4* __restrict__ C, // fp16 output buffer of shape mxn
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
@ -88,17 +93,19 @@ __device__ inline void mma(const typename ScalarType<scalar_t>::FragA &a_frag,
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b); const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
float* c = reinterpret_cast<float*>(&frag_c); float* c = reinterpret_cast<float*>(&frag_c);
if constexpr (std::is_same<scalar_t, half>::value) { if constexpr (std::is_same<scalar_t, half>::value) {
asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) { } else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else { } else {
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
} }
@ -107,7 +114,8 @@ __device__ inline void mma(const typename ScalarType<scalar_t>::FragA &a_frag,
// Instruction for loading a full 16x16 matrix fragment of operand A from shared // Instruction for loading a full 16x16 matrix fragment of operand A from shared
// memory, directly in tensor core layout. // memory, directly in tensor core layout.
template <typename scalar_t> template <typename scalar_t>
__device__ inline void ldsm4(typename ScalarType<scalar_t>::FragA &frag_a, const void *smem_ptr) { __device__ inline void ldsm4(typename ScalarType<scalar_t>::FragA& frag_a,
const void* smem_ptr) {
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a); uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr)); uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
@ -118,7 +126,8 @@ __device__ inline void ldsm4(typename ScalarType<scalar_t>::FragA &frag_a, const
// Lookup-table based 3-input logical operation; explicitly used for // Lookup-table based 3-input logical operation; explicitly used for
// dequantization as the compiler does not seem to automatically recognize it in // dequantization as the compiler does not seem to automatically recognize it in
// all cases. // all cases.
template <int lut> __device__ inline int lop3(int a, int b, int c) { template <int lut>
__device__ inline int lop3(int a, int b, int c) {
int res; int res;
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(res) : "=r"(res)
@ -140,8 +149,10 @@ __device__ inline uint32_t prmt(uint32_t a) {
// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 // Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
// values. We mostly follow the strategy in the link below, with some small // values. We mostly follow the strategy in the link below, with some small
// changes: // changes:
// - FP16: https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 // - FP16:
// - BF16: https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385 // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287
// - BF16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385
template <typename scalar_t> template <typename scalar_t>
__device__ inline typename ScalarType<scalar_t>::FragB dequant_4bit(int q) { __device__ inline typename ScalarType<scalar_t>::FragB dequant_4bit(int q) {
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
@ -170,7 +181,8 @@ __device__ inline typename ScalarType<half>::FragB dequant_4bit<half>(int q) {
} }
template <> template <>
__device__ inline typename ScalarType<nv_bfloat16>::FragB dequant_4bit<nv_bfloat16>(int q) { __device__ inline typename ScalarType<nv_bfloat16>::FragB
dequant_4bit<nv_bfloat16>(int q) {
static constexpr uint32_t MASK = 0x000f000f; static constexpr uint32_t MASK = 0x000f000f;
static constexpr uint32_t EX = 0x43004300; static constexpr uint32_t EX = 0x43004300;
@ -193,10 +205,12 @@ __device__ inline typename ScalarType<nv_bfloat16>::FragB dequant_4bit<nv_bfloat
return frag_b; return frag_b;
} }
// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or bf16 // Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or
// Reference: // bf16 Reference:
// - FP16: https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 // - FP16:
// - BF16: https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175 // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
// - BF16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
template <typename scalar_t> template <typename scalar_t>
__device__ inline typename ScalarType<scalar_t>::FragB dequant_8bit(int q) { __device__ inline typename ScalarType<scalar_t>::FragB dequant_8bit(int q) {
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
@ -222,11 +236,13 @@ __device__ inline typename ScalarType<half>::FragB dequant_8bit<half>(int q) {
} }
template <> template <>
__device__ inline typename ScalarType<nv_bfloat16>::FragB dequant_8bit<nv_bfloat16>(int q) { __device__ inline typename ScalarType<nv_bfloat16>::FragB
dequant_8bit<nv_bfloat16>(int q) {
typename ScalarType<nv_bfloat16>::FragB frag_b; typename ScalarType<nv_bfloat16>::FragB frag_b;
float fp32_intermediates[4]; float fp32_intermediates[4];
uint32_t * fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates); uint32_t* fp32_intermediates_casted =
reinterpret_cast<uint32_t*>(fp32_intermediates);
static constexpr uint32_t fp32_base = 0x4B000000; static constexpr uint32_t fp32_base = 0x4B000000;
fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
@ -240,8 +256,10 @@ __device__ inline typename ScalarType<nv_bfloat16>::FragB dequant_8bit<nv_bfloat
fp32_intermediates[3] -= 8388736.f; fp32_intermediates[3] -= 8388736.f;
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(&frag_b); uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(&frag_b);
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632); bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], fp32_intermediates_casted[3], 0x7632); fp32_intermediates_casted[1], 0x7632);
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],
fp32_intermediates_casted[3], 0x7632);
return frag_b; return frag_b;
} }
@ -250,9 +268,11 @@ __device__ inline typename ScalarType<nv_bfloat16>::FragB dequant_8bit<nv_bfloat
// only for grouped quantization. // only for grouped quantization.
template <typename scalar_t> template <typename scalar_t>
__device__ inline void scale(typename ScalarType<scalar_t>::FragB& frag_b, __device__ inline void scale(typename ScalarType<scalar_t>::FragB& frag_b,
typename ScalarType<scalar_t>::FragS &frag_s, int i) { typename ScalarType<scalar_t>::FragS& frag_s,
int i) {
using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2; using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
scalar_t2 s = ScalarType<scalar_t>::num2num2(reinterpret_cast<scalar_t *>(&frag_s)[i]); scalar_t2 s =
ScalarType<scalar_t>::num2num2(reinterpret_cast<scalar_t*>(&frag_s)[i]);
frag_b[0] = __hmul2(frag_b[0], s); frag_b[0] = __hmul2(frag_b[0], s);
frag_b[1] = __hmul2(frag_b[1], s); frag_b[1] = __hmul2(frag_b[1], s);
} }
@ -280,7 +300,8 @@ __device__ inline void scale4(typename ScalarType<scalar_t>::FragB &frag_b,
// Given 2 floats multiply by 2 scales (halves) // Given 2 floats multiply by 2 scales (halves)
template <typename scalar_t> template <typename scalar_t>
__device__ inline void scale_float(float *c, typename ScalarType<scalar_t>::FragS &s) { __device__ inline void scale_float(float* c,
typename ScalarType<scalar_t>::FragS& s) {
scalar_t* s_ptr = reinterpret_cast<scalar_t*>(&s); scalar_t* s_ptr = reinterpret_cast<scalar_t*>(&s);
c[0] = __fmul_rn(c[0], ScalarType<scalar_t>::num2float(s_ptr[0])); c[0] = __fmul_rn(c[0], ScalarType<scalar_t>::num2float(s_ptr[0]));
c[1] = __fmul_rn(c[1], ScalarType<scalar_t>::num2float(s_ptr[1])); c[1] = __fmul_rn(c[1], ScalarType<scalar_t>::num2float(s_ptr[1]));
@ -325,7 +346,6 @@ __global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr,
int const* __restrict__ perm_int_ptr, int const* __restrict__ perm_int_ptr,
int4* __restrict__ out_int4_ptr, int size_m, int4* __restrict__ out_int4_ptr, int size_m,
int size_k, int block_rows) { int size_k, int block_rows) {
int start_row = block_rows * blockIdx.x; int start_row = block_rows * blockIdx.x;
int finish_row = start_row + block_rows; int finish_row = start_row + block_rows;
if (finish_row > size_m) { if (finish_row > size_m) {
@ -341,8 +361,7 @@ __global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr,
int offset = row * row_stride; int offset = row * row_stride;
half const *a_row_half = half const* a_row_half = reinterpret_cast<half const*>(a_int4_ptr + offset);
reinterpret_cast<half const *>(a_int4_ptr + offset);
half* out_half = reinterpret_cast<half*>(out_int4_ptr + offset); half* out_half = reinterpret_cast<half*>(out_int4_ptr + offset);
int base_k = 0; int base_k = 0;
@ -378,17 +397,18 @@ template <typename scalar_t, // compute dtype, half or nv_float16
const int num_bits, // number of bits used for weights const int num_bits, // number of bits used for weights
const int threads, // number of threads in a threadblock const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the threadblock // dimension (batchsize) of the
// threadblock
const int thread_n_blocks, // same for n dimension (output) const int thread_n_blocks, // same for n dimension (output)
const int thread_k_blocks, // same for k dimension (reduction) const int thread_k_blocks, // same for k dimension (reduction)
const int stages, // number of stages for the async global->shared const int stages, // number of stages for the async global->shared
// fetch pipeline // fetch pipeline
const bool has_act_order, // whether act_order is enabled const bool has_act_order, // whether act_order is enabled
const int group_blocks = -1 // number of consecutive 16x16 blocks with const int group_blocks = -1 // number of consecutive 16x16 blocks
// a separate quantization scale // with a separate quantization scale
> >
__global__ void __global__ void Marlin(
Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk const int4* __restrict__ A, // fp16 input matrix of shape mxk
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
int4* __restrict__ C, // fp16 output buffer of shape mxn int4* __restrict__ C, // fp16 output buffer of shape mxn
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
@ -465,27 +485,22 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
auto init_slice = [&]() { auto init_slice = [&]() {
slice_iters = slice_iters =
iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);
if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0;
slice_iters = 0; if (slice_iters == 0) return;
if (slice_iters == 0) if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row;
return;
if (slice_row + slice_iters > k_tiles)
slice_iters = k_tiles - slice_row;
slice_count = 1; slice_count = 1;
slice_idx = 0; slice_idx = 0;
int col_first = iters * div_ceil(k_tiles * slice_col_par, iters); int col_first = iters * div_ceil(k_tiles * slice_col_par, iters);
if (col_first <= k_tiles * (slice_col_par + 1)) { if (col_first <= k_tiles * (slice_col_par + 1)) {
int col_off = col_first - k_tiles * slice_col_par; int col_off = col_first - k_tiles * slice_col_par;
slice_count = div_ceil(k_tiles - col_off, iters); slice_count = div_ceil(k_tiles - col_off, iters);
if (col_off > 0) if (col_off > 0) slice_count++;
slice_count++;
int delta_first = iters * blockIdx.x - col_first; int delta_first = iters * blockIdx.x - col_first;
if (delta_first < 0 || (col_off == 0 && delta_first == 0)) if (delta_first < 0 || (col_off == 0 && delta_first == 0))
slice_idx = slice_count - 1; slice_idx = slice_count - 1;
else { else {
slice_idx = slice_count - 1 - delta_first / iters; slice_idx = slice_count - 1 - delta_first / iters;
if (col_off > 0) if (col_off > 0) slice_idx--;
slice_idx--;
} }
} }
if (slice_col == n_tiles) { if (slice_col == n_tiles) {
@ -785,7 +800,8 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
int4* sh_a_stage = sh_a + a_sh_stage * pipe; int4* sh_a_stage = sh_a + a_sh_stage * pipe;
#pragma unroll #pragma unroll
for (int i = 0; i < thread_m_blocks; i++) for (int i = 0; i < thread_m_blocks; i++)
ldsm4<scalar_t>(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); ldsm4<scalar_t>(frag_a[k % 2][i],
&sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
int4* sh_b_stage = sh_b + b_sh_stage * pipe; int4* sh_b_stage = sh_b + b_sh_stage * pipe;
#pragma unroll #pragma unroll
@ -906,7 +922,6 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
#pragma unroll #pragma unroll
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
int actual_k = cur_k + k_frag_offsets[i]; int actual_k = cur_k + k_frag_offsets[i];
int group_id = sh_g_idx_int_ptr[actual_k]; int group_id = sh_g_idx_int_ptr[actual_k];
@ -943,8 +958,9 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
// Apply scale to frag_b0 // Apply scale to frag_b0
if constexpr (has_act_order) { if constexpr (has_act_order) {
scale4<scalar_t>(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], scale4<scalar_t>(frag_b0, act_frag_s[k % 2][0][j],
act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 0); act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j],
act_frag_s[k % 2][3][j], 0);
} else { } else {
if constexpr (group_blocks != -1) { if constexpr (group_blocks != -1) {
scale<scalar_t>(frag_b0, frag_s[k % 2][j], 0); scale<scalar_t>(frag_b0, frag_s[k % 2][j], 0);
@ -953,8 +969,9 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
// Apply scale to frag_b1 // Apply scale to frag_b1
if constexpr (has_act_order) { if constexpr (has_act_order) {
scale4<scalar_t>(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], scale4<scalar_t>(frag_b1, act_frag_s[k % 2][0][j],
act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 1); act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j],
act_frag_s[k % 2][3][j], 1);
} else { } else {
if constexpr (group_blocks != -1) { if constexpr (group_blocks != -1) {
@ -997,8 +1014,8 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
int red_sh_wr = int red_sh_wr =
red_sh_delta * j + (red_sh_rd - red_sh_stride * i); red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
if (i < red_off) { if (i < red_off) {
float *c_rd = reinterpret_cast<float *>( float* c_rd =
&sh[red_sh_delta * j + red_sh_rd]); reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]);
float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]); float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]);
#pragma unroll #pragma unroll
for (int k = 0; k < 4; k++) for (int k = 0; k < 4; k++)
@ -1049,16 +1066,16 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
int row = (threadIdx.x % 32) / 4; int row = (threadIdx.x % 32) / 4;
if (!first) { if (!first) {
// Interestingly, doing direct global accesses here really seems to mess up the // Interestingly, doing direct global accesses here really seems to mess up
// compiler and lead to slowdowns, hence we also use async-copies even though // the compiler and lead to slowdowns, hence we also use async-copies even
// these fetches are not actually asynchronous. // though these fetches are not actually asynchronous.
#pragma unroll #pragma unroll
for (int i = 0; i < thread_m_blocks * 4; i++) { for (int i = 0; i < thread_m_blocks * 4; i++) {
cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], cp_async4_pred(
&sh[c_sh_wr + c_sh_wr_delta * i],
&C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
c_gl_wr_delta_i * (i % 2)], c_gl_wr_delta_i * (i % 2)],
i < (thread_m_blocks - 1) * 4 || i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m);
8 * (i / 2) + row < prob_m);
} }
cp_async_fence(); cp_async_fence();
cp_async_wait<0>(); cp_async_wait<0>();
@ -1116,7 +1133,8 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
// We first reorder in shared memory to guarantee the most efficient final // We first reorder in shared memory to guarantee the most efficient final
// global write patterns // global write patterns
auto write = [&](int idx, float c0, float c1, FragS& s) { auto write = [&](int idx, float c0, float c1, FragS& s) {
scalar_t2 res = Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); scalar_t2 res =
Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1));
// For per-column quantization we finally apply the scale here (only for // For per-column quantization we finally apply the scale here (only for
// 4-bit) // 4-bit)
@ -1286,14 +1304,18 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
for (int i = 0; i < thread_m_blocks; i++) { for (int i = 0; i < thread_m_blocks; i++) {
#pragma unroll #pragma unroll
for (int j = 0; j < 4; j++) { for (int j = 0; j < 4; j++) {
scale_float<scalar_t>(reinterpret_cast<float *>(&frag_c[i][j][0][0]), scale_float<scalar_t>(
reinterpret_cast<float*>(&frag_c[i][j][0][0]),
frag_s[j / 2][2 * (j % 2) + 0]); frag_s[j / 2][2 * (j % 2) + 0]);
scale_float<scalar_t>(reinterpret_cast<float *>(&frag_c[i][j][0][2]), scale_float<scalar_t>(
reinterpret_cast<float*>(&frag_c[i][j][0][2]),
frag_s[j / 2][2 * (j % 2) + 0]); frag_s[j / 2][2 * (j % 2) + 0]);
scale_float<scalar_t>(reinterpret_cast<float *>(&frag_c[i][j][1][0]), scale_float<scalar_t>(
reinterpret_cast<float*>(&frag_c[i][j][1][0]),
frag_s[j / 2][2 * (j % 2) + 1]); frag_s[j / 2][2 * (j % 2) + 1]);
scale_float<scalar_t>(reinterpret_cast<float *>(&frag_c[i][j][1][2]), scale_float<scalar_t>(
reinterpret_cast<float*>(&frag_c[i][j][1][2]),
frag_s[j / 2][2 * (j % 2) + 1]); frag_s[j / 2][2 * (j % 2) + 1]);
} }
} }
@ -1320,8 +1342,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
if (slice_col == 0) { if (slice_col == 0) {
#pragma unroll #pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++) for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
B_ptr[i] -= b_gl_stride;
} }
// Update slice k/n for scales loading // Update slice k/n for scales loading
@ -1341,20 +1362,21 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
} }
} }
#define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ #define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS) \ THREAD_K_BLOCKS, HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS) \
else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \ else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \
thread_n_blocks == THREAD_N_BLOCKS && \ thread_n_blocks == THREAD_N_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \ thread_k_blocks == THREAD_K_BLOCKS && \
has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \ has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \
num_threads == NUM_THREADS) { \ num_threads == NUM_THREADS) { \
cudaFuncSetAttribute( \ cudaFuncSetAttribute( \
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \
THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, GROUP_BLOCKS>, \ THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
GROUP_BLOCKS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \
THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, GROUP_BLOCKS> \ THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \ GROUP_BLOCKS><<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
A_ptr, B_ptr, C_ptr, s_ptr, g_idx_ptr, num_groups, prob_m, prob_n, \ A_ptr, B_ptr, C_ptr, s_ptr, g_idx_ptr, num_groups, prob_m, prob_n, \
prob_k, locks); \ prob_k, locks); \
} }
@ -1673,8 +1695,7 @@ void marlin_mm_f16i4(const void *A, const void *B, void *C, void *s,
// Note that parallel > 1 currently only works for inputs without any // Note that parallel > 1 currently only works for inputs without any
// padding // padding
par = (16 * thread_m_blocks - pad) / (16 * exec_cfg.max_m_blocks); par = (16 * thread_m_blocks - pad) / (16 * exec_cfg.max_m_blocks);
if (par > max_par) if (par > max_par) par = max_par;
par = max_par;
prob_m = (16 * exec_cfg.max_m_blocks) * par; prob_m = (16 * exec_cfg.max_m_blocks) * par;
i += exec_cfg.max_m_blocks * (par - 1); i += exec_cfg.max_m_blocks * (par - 1);
thread_m_blocks = exec_cfg.max_m_blocks; thread_m_blocks = exec_cfg.max_m_blocks;
@ -1824,18 +1845,21 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
int dev = a.get_device(); int dev = a.get_device();
if (a.scalar_type() == at::ScalarType::Half) { if (a.scalar_type() == at::ScalarType::Half) {
gptq_marlin::marlin_mm_f16i4<half>( gptq_marlin::marlin_mm_f16i4<half>(
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(), b_scales.data_ptr<at::Half>(), a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::Half>(), size_m, size_n, b_scales.data_ptr<at::Half>(), g_idx.data_ptr(), perm.data_ptr(),
size_k, workspace.data_ptr(), num_bits, has_act_order, is_k_full, a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), workspace.data_ptr(), num_bits, has_act_order, is_k_full, num_groups,
thread_k, thread_n, sms, gptq_marlin::max_par); group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k,
thread_n, sms, gptq_marlin::max_par);
} else if (a.scalar_type() == at::ScalarType::BFloat16) { } else if (a.scalar_type() == at::ScalarType::BFloat16) {
gptq_marlin::marlin_mm_f16i4<nv_bfloat16>( gptq_marlin::marlin_mm_f16i4<nv_bfloat16>(
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(), c.data_ptr<at::BFloat16>(), b_scales.data_ptr<at::BFloat16>(), a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(), size_m, size_n, c.data_ptr<at::BFloat16>(), b_scales.data_ptr<at::BFloat16>(),
size_k, workspace.data_ptr(), num_bits, has_act_order, is_k_full, g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(),
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), size_m, size_n, size_k, workspace.data_ptr(), num_bits, has_act_order,
thread_k, thread_n, sms, gptq_marlin::max_par); is_k_full, num_groups, group_size, dev,
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
gptq_marlin::max_par);
} else { } else {
TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16"); TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16");
} }

View File

@ -11,12 +11,13 @@
namespace gptq_marlin { namespace gptq_marlin {
// 8 warps are a good choice since every SM has 4 schedulers and having more than 1 warp per // 8 warps are a good choice since every SM has 4 schedulers and having more
// schedule allows some more latency hiding. At the same time, we want relatively few warps to have // than 1 warp per schedule allows some more latency hiding. At the same time,
// many registers per warp and small tiles. // we want relatively few warps to have many registers per warp and small tiles.
static constexpr int default_threads = 256; static constexpr int default_threads = 256;
static constexpr int pipe_stages = 4; // 4 pipeline stages fit into shared memory static constexpr int pipe_stages =
4; // 4 pipeline stages fit into shared memory
static constexpr int min_thread_n = 64; static constexpr int min_thread_n = 64;
static constexpr int min_thread_k = 64; static constexpr int min_thread_k = 64;
@ -38,10 +39,12 @@ constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
// No support for async // No support for async
#else #else
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) { __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
bool pred = true) {
const int BYTES = 16; const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr)); uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile("{\n" asm volatile(
"{\n"
" .reg .pred p;\n" " .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n" " setp.ne.b32 p, %0, 0;\n"
" @p cp.async.cg.shared.global [%1], [%2], %3;\n" " @p cp.async.cg.shared.global [%1], [%2], %3;\n"
@ -52,13 +55,16 @@ __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { __device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
const int BYTES = 16; const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr)); uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile("{\n" asm volatile(
"{\n"
" cp.async.cg.shared.global [%0], [%1], %2;\n" " cp.async.cg.shared.global [%0], [%1], %2;\n"
"}\n" ::"r"(smem), "}\n" ::"r"(smem),
"l"(glob_ptr), "n"(BYTES)); "l"(glob_ptr), "n"(BYTES));
} }
__device__ inline void cp_async_fence() { asm volatile("cp.async.commit_group;\n" ::); } __device__ inline void cp_async_fence() {
asm volatile("cp.async.commit_group;\n" ::);
}
template <int n> template <int n>
__device__ inline void cp_async_wait() { __device__ inline void cp_async_wait() {

View File

@ -5,12 +5,10 @@
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_bf16.h> #include <cuda_bf16.h>
namespace gptq_marlin { namespace gptq_marlin {
template <typename scalar_t> template <typename scalar_t>
class ScalarType { class ScalarType {};
};
template <> template <>
class ScalarType<half> { class ScalarType<half> {
@ -26,13 +24,21 @@ public:
using FragC = Vec<float, 4>; using FragC = Vec<float, 4>;
using FragS = Vec<half2, 1>; using FragS = Vec<half2, 1>;
static __device__ float inline num2float(const half x) { return __half2float(x); } static __device__ float inline num2float(const half x) {
return __half2float(x);
}
static __device__ half2 inline num2num2(const half x) { return __half2half2(x); } static __device__ half2 inline num2num2(const half x) {
return __half2half2(x);
}
static __device__ half2 inline nums2num2(const half x1, const half x2) { return __halves2half2(x1, x2); } static __device__ half2 inline nums2num2(const half x1, const half x2) {
return __halves2half2(x1, x2);
}
static __host__ __device__ half inline float2num(const float x) { return __float2half(x); } static __host__ __device__ half inline float2num(const float x) {
return __float2half(x);
}
}; };
template <> template <>
@ -47,16 +53,25 @@ public:
using FragS = Vec<nv_bfloat162, 1>; using FragS = Vec<nv_bfloat162, 1>;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
static __device__ float inline num2float(const nv_bfloat16 x) { return __bfloat162float(x); } static __device__ float inline num2float(const nv_bfloat16 x) {
return __bfloat162float(x);
}
static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) { return __bfloat162bfloat162(x); } static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) {
return __bfloat162bfloat162(x);
}
static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1, const nv_bfloat16 x2) { return __halves2bfloat162(x1, x2); } static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1,
const nv_bfloat16 x2) {
return __halves2bfloat162(x1, x2);
}
static __host__ __device__ nv_bfloat16 inline float2num(const float x) { return __float2bfloat16(x); } static __host__ __device__ nv_bfloat16 inline float2num(const float x) {
return __float2bfloat16(x);
}
#endif #endif
}; };
} } // namespace gptq_marlin
#endif #endif

View File

@ -12,10 +12,10 @@ static constexpr int tile_n_size = tile_k_size * 4;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
template <int const num_threads, int const num_bits, bool const has_perm> template <int const num_threads, int const num_bits, bool const has_perm>
__global__ void __global__ void marlin_repack_kernel(
marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, uint32_t const* __restrict__ b_q_weight_ptr,
uint32_t const *__restrict__ perm_ptr, uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,
uint32_t *__restrict__ out_ptr, int size_k, int size_n) {} int size_k, int size_n) {}
} // namespace gptq_marlin } // namespace gptq_marlin
@ -30,10 +30,10 @@ torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
#else #else
template <int const num_threads, int const num_bits, bool const has_perm> template <int const num_threads, int const num_bits, bool const has_perm>
__global__ void __global__ void marlin_repack_kernel(
marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, uint32_t const* __restrict__ b_q_weight_ptr,
uint32_t const *__restrict__ perm_ptr, uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,
uint32_t *__restrict__ out_ptr, int size_k, int size_n) { int size_k, int size_n) {
constexpr int pack_factor = 32 / num_bits; constexpr int pack_factor = 32 / num_bits;
int k_tiles = size_k / tile_k_size; int k_tiles = size_k / tile_k_size;
@ -176,7 +176,6 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
} }
} else { } else {
uint32_t b1_vals[tile_ints]; uint32_t b1_vals[tile_ints];
uint32_t b2_vals[tile_ints]; uint32_t b2_vals[tile_ints];
@ -320,8 +319,7 @@ torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
// Get ptrs // Get ptrs
uint32_t const* b_q_weight_ptr = uint32_t const* b_q_weight_ptr =
reinterpret_cast<uint32_t const*>(b_q_weight.data_ptr()); reinterpret_cast<uint32_t const*>(b_q_weight.data_ptr());
uint32_t const *perm_ptr = uint32_t const* perm_ptr = reinterpret_cast<uint32_t const*>(perm.data_ptr());
reinterpret_cast<uint32_t const *>(perm.data_ptr());
uint32_t* out_ptr = reinterpret_cast<uint32_t*>(out.data_ptr()); uint32_t* out_ptr = reinterpret_cast<uint32_t*>(out.data_ptr());
// Get dev info // Get dev info

View File

@ -25,7 +25,10 @@
#include <iostream> #include <iostream>
template <typename T> inline std::string str(T x) { return std::to_string(x); } template <typename T>
inline std::string str(T x) {
return std::to_string(x);
}
namespace marlin { namespace marlin {
@ -38,7 +41,8 @@ constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; }
// corresponding index accesses must be compile-time constants, which is why we // corresponding index accesses must be compile-time constants, which is why we
// extensively use `#pragma unroll` throughout the kernel code to guarantee // extensively use `#pragma unroll` throughout the kernel code to guarantee
// this. // this.
template <typename T, int n> struct Vec { template <typename T, int n>
struct Vec {
T elems[n]; T elems[n];
__device__ T& operator[](int i) { return elems[i]; } __device__ T& operator[](int i) { return elems[i]; }
}; };
@ -59,7 +63,8 @@ __device__ inline void cp_async4_pred(void *smem_ptr, const void *glob_ptr,
bool pred = true) { bool pred = true) {
const int BYTES = 16; const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr)); uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile("{\n" asm volatile(
"{\n"
" .reg .pred p;\n" " .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n" " setp.ne.b32 p, %0, 0;\n"
" @p cp.async.cg.shared.global [%1], [%2], %3;\n" " @p cp.async.cg.shared.global [%1], [%2], %3;\n"
@ -71,9 +76,11 @@ __device__ inline void cp_async4_pred(void *smem_ptr, const void *glob_ptr,
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { __device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
const int BYTES = 16; const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr)); uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile("{\n" asm volatile(
"{\n"
" cp.async.cg.shared.global [%0], [%1], %2;\n" " cp.async.cg.shared.global [%0], [%1], %2;\n"
"}\n" :: "r"(smem), "l"(glob_ptr), "n"(BYTES)); "}\n" ::"r"(smem),
"l"(glob_ptr), "n"(BYTES));
} }
// Async copy fence. // Async copy fence.
@ -82,7 +89,8 @@ __device__ inline void cp_async_fence() {
} }
// Wait until at most `n` async copy stages are still pending. // Wait until at most `n` async copy stages are still pending.
template <int n> __device__ inline void cp_async_wait() { template <int n>
__device__ inline void cp_async_wait() {
asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
} }
@ -93,11 +101,12 @@ __device__ inline void mma(const FragA &a_frag, const FragB &frag_b,
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag); const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b); const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
float* c = reinterpret_cast<float*>(&frag_c); float* c = reinterpret_cast<float*>(&frag_c);
asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} }
// Instruction for loading a full 16x16 matrix fragment of operand A from shared // Instruction for loading a full 16x16 matrix fragment of operand A from shared
@ -113,7 +122,8 @@ __device__ inline void ldsm4(FragA &frag_a, const void *smem_ptr) {
// Lookup-table based 3-input logical operation; explicitly used for // Lookup-table based 3-input logical operation; explicitly used for
// dequantization as the compiler does not seem to automatically recognize it in // dequantization as the compiler does not seem to automatically recognize it in
// all cases. // all cases.
template <int lut> __device__ inline int lop3(int a, int b, int c) { template <int lut>
__device__ inline int lop3(int a, int b, int c) {
int res; int res;
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(res) : "=r"(res)
@ -189,20 +199,21 @@ __device__ inline void barrier_release(int *lock, bool reset = false) {
template <const int threads, // number of threads in a threadblock template <const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the threadblock // dimension (batchsize) of the
// threadblock
const int thread_n_blocks, // same for n dimension (output) const int thread_n_blocks, // same for n dimension (output)
const int thread_k_blocks, // same for k dimension (reduction) const int thread_k_blocks, // same for k dimension (reduction)
const int stages, // number of stages for the async global->shared const int stages, // number of stages for the async global->shared
// fetch pipeline // fetch pipeline
const int group_blocks = -1 // number of consecutive 16x16 blocks with const int group_blocks = -1 // number of consecutive 16x16 blocks
// a separate quantization scale // with a separate quantization scale
> >
__global__ void __global__ void Marlin(
Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk const int4* __restrict__ A, // fp16 input matrix of shape mxk
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
int4* __restrict__ C, // fp16 output buffer of shape mxn int4* __restrict__ C, // fp16 output buffer of shape mxn
const int4 const int4* __restrict__ s, // fp16 quantization scales of shape
*__restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn // (k/groupsize)xn
int prob_m, // batch dimension m int prob_m, // batch dimension m
int prob_n, // output dimension n int prob_n, // output dimension n
int prob_k, // reduction dimension k int prob_k, // reduction dimension k
@ -261,27 +272,22 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
auto init_slice = [&]() { auto init_slice = [&]() {
slice_iters = slice_iters =
iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);
if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0;
slice_iters = 0; if (slice_iters == 0) return;
if (slice_iters == 0) if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row;
return;
if (slice_row + slice_iters > k_tiles)
slice_iters = k_tiles - slice_row;
slice_count = 1; slice_count = 1;
slice_idx = 0; slice_idx = 0;
int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); int col_first = iters * ceildiv(k_tiles * slice_col_par, iters);
if (col_first <= k_tiles * (slice_col_par + 1)) { if (col_first <= k_tiles * (slice_col_par + 1)) {
int col_off = col_first - k_tiles * slice_col_par; int col_off = col_first - k_tiles * slice_col_par;
slice_count = ceildiv(k_tiles - col_off, iters); slice_count = ceildiv(k_tiles - col_off, iters);
if (col_off > 0) if (col_off > 0) slice_count++;
slice_count++;
int delta_first = iters * blockIdx.x - col_first; int delta_first = iters * blockIdx.x - col_first;
if (delta_first < 0 || (col_off == 0 && delta_first == 0)) if (delta_first < 0 || (col_off == 0 && delta_first == 0))
slice_idx = slice_count - 1; slice_idx = slice_count - 1;
else { else {
slice_idx = slice_count - 1 - delta_first / iters; slice_idx = slice_count - 1 - delta_first / iters;
if (col_off > 0) if (col_off > 0) slice_idx--;
slice_idx--;
} }
} }
if (slice_col == n_tiles) { if (slice_col == n_tiles) {
@ -305,7 +311,8 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
a_gl_stride * a_gl_stride *
(threads / a_gl_rd_delta_o); // between subsequent accesses within a tile (threads / a_gl_rd_delta_o); // between subsequent accesses within a tile
constexpr int a_sh_wr_delta = constexpr int a_sh_wr_delta =
a_sh_stride * (threads / a_gl_rd_delta_o); // between shared memory writes a_sh_stride *
(threads / a_gl_rd_delta_o); // between shared memory writes
constexpr int a_sh_rd_delta_o = constexpr int a_sh_rd_delta_o =
2 * ((threads / 32) / 2 * ((threads / 32) /
(thread_n_blocks / 4)); // between shared memory tile reads (thread_n_blocks / 4)); // between shared memory tile reads
@ -447,8 +454,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
// Only fetch scales if this tile starts a new group // Only fetch scales if this tile starts a new group
if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) { if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) {
int4* sh_s_stage = sh_s + s_sh_stage * pipe; int4* sh_s_stage = sh_s + s_sh_stage * pipe;
if (s_sh_wr_pred) if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
s_gl_rd += s_gl_rd_delta; s_gl_rd += s_gl_rd_delta;
} }
} }
@ -500,11 +506,9 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
FragB frag_b0 = dequant(b_quant); FragB frag_b0 = dequant(b_quant);
// If there are no groups, we can just scale the final output once and can // If there are no groups, we can just scale the final output once and can
// avoid doing so for each weight. // avoid doing so for each weight.
if (group_blocks != -1) if (group_blocks != -1) scale(frag_b0, frag_s[k % 2][j], 0);
scale(frag_b0, frag_s[k % 2][j], 0);
FragB frag_b1 = dequant(b_quant_shift); FragB frag_b1 = dequant(b_quant_shift);
if (group_blocks != -1) if (group_blocks != -1) scale(frag_b1, frag_s[k % 2][j], 1);
scale(frag_b1, frag_s[k % 2][j], 1);
#pragma unroll #pragma unroll
for (int i = 0; i < thread_m_blocks; i++) { for (int i = 0; i < thread_m_blocks; i++) {
mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]);
@ -540,8 +544,8 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
int red_sh_wr = int red_sh_wr =
red_sh_delta * j + (red_sh_rd - red_sh_stride * i); red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
if (i < red_off) { if (i < red_off) {
float *c_rd = reinterpret_cast<float *>( float* c_rd =
&sh[red_sh_delta * j + red_sh_rd]); reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]);
float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]); float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]);
#pragma unroll #pragma unroll
for (int k = 0; k < 4; k++) for (int k = 0; k < 4; k++)
@ -571,9 +575,9 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
}; };
// Since multiple threadblocks may process parts of the same column slice, we // Since multiple threadblocks may process parts of the same column slice, we
// finally have to globally reduce over the results. As the striped partitioning // finally have to globally reduce over the results. As the striped
// minimizes the number of such reductions and our outputs are usually rather // partitioning minimizes the number of such reductions and our outputs are
// small, we perform this reduction serially in L2 cache. // usually rather small, we perform this reduction serially in L2 cache.
auto global_reduce = [&](bool first = false, bool last = false) { auto global_reduce = [&](bool first = false, bool last = false) {
// We are very careful here to reduce directly in the output buffer to // We are very careful here to reduce directly in the output buffer to
// maximize L2 cache utilization in this step. To do this, we write out // maximize L2 cache utilization in this step. To do this, we write out
@ -592,16 +596,16 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
int row = (threadIdx.x % 32) / 4; int row = (threadIdx.x % 32) / 4;
if (!first) { if (!first) {
// Interestingly, doing direct global accesses here really seems to mess up the // Interestingly, doing direct global accesses here really seems to mess up
// compiler and lead to slowdowns, hence we also use async-copies even though // the compiler and lead to slowdowns, hence we also use async-copies even
// these fetches are not actually asynchronous. // though these fetches are not actually asynchronous.
#pragma unroll #pragma unroll
for (int i = 0; i < thread_m_blocks * 4; i++) { for (int i = 0; i < thread_m_blocks * 4; i++) {
cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], cp_async4_pred(
&sh[c_sh_wr + c_sh_wr_delta * i],
&C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
c_gl_wr_delta_i * (i % 2)], c_gl_wr_delta_i * (i % 2)],
i < (thread_m_blocks - 1) * 4 || i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m);
8 * (i / 2) + row < prob_m);
} }
cp_async_fence(); cp_async_fence();
cp_async_wait<0>(); cp_async_wait<0>();
@ -700,8 +704,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
// Start global fetch and register load pipelines. // Start global fetch and register load pipelines.
auto start_pipes = [&]() { auto start_pipes = [&]() {
#pragma unroll #pragma unroll
for (int i = 0; i < stages - 1; i++) for (int i = 0; i < stages - 1; i++) fetch_to_shared(i, i, i < slice_iters);
fetch_to_shared(i, i, i < slice_iters);
zero_accums(); zero_accums();
wait_for_stage(); wait_for_stage();
fetch_to_registers(0, 0); fetch_to_registers(0, 0);
@ -711,9 +714,9 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
// Main loop. // Main loop.
while (slice_iters) { while (slice_iters) {
// We unroll over both the global fetch and the register load pipeline to ensure // We unroll over both the global fetch and the register load pipeline to
// all shared memory accesses are static. Note that both pipelines have even // ensure all shared memory accesses are static. Note that both pipelines have
// length meaning that the next iteration will always start at index 0. // even length meaning that the next iteration will always start at index 0.
#pragma unroll #pragma unroll
for (int pipe = 0; pipe < stages;) { for (int pipe = 0; pipe < stages;) {
#pragma unroll #pragma unroll
@ -728,8 +731,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
matmul(k); matmul(k);
} }
slice_iters--; slice_iters--;
if (slice_iters == 0) if (slice_iters == 0) break;
break;
} }
a_gl_rd += a_gl_rd_delta_o * stages; a_gl_rd += a_gl_rd_delta_o * stages;
@ -742,8 +744,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
// For per-column scales, we only fetch them here in the final step before // For per-column scales, we only fetch them here in the final step before
// write-out // write-out
if (group_blocks == -1 && last) { if (group_blocks == -1 && last) {
if (s_sh_wr_pred) if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]);
cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]);
cp_async_fence(); cp_async_fence();
} }
thread_block_reduce(); thread_block_reduce();
@ -775,8 +776,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
if (slice_col == 0) { if (slice_col == 0) {
#pragma unroll #pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++) for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
B_ptr[i] -= b_gl_stride;
} }
s_gl_rd = s_sh_stride * slice_col + threadIdx.x; s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
start_pipes(); start_pipes();
@ -789,20 +789,21 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
template <const int threads, // number of threads in a threadblock template <const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the threadblock // dimension (batchsize) of the
// threadblock
const int thread_n_blocks, // same for n dimension (output) const int thread_n_blocks, // same for n dimension (output)
const int thread_k_blocks, // same for k dimension (reduction) const int thread_k_blocks, // same for k dimension (reduction)
const int stages, // number of stages for the async global->shared const int stages, // number of stages for the async global->shared
// fetch pipeline // fetch pipeline
const int group_blocks = -1 // number of consecutive 16x16 blocks with const int group_blocks = -1 // number of consecutive 16x16 blocks
// a separate quantization scale // with a separate quantization scale
> >
__global__ void __global__ void Marlin(
Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk const int4* __restrict__ A, // fp16 input matrix of shape mxk
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
int4* __restrict__ C, // fp16 output buffer of shape mxn int4* __restrict__ C, // fp16 output buffer of shape mxn
const int4 const int4* __restrict__ s, // fp16 quantization scales of shape
*__restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn // (k/groupsize)xn
int prob_m, // batch dimension m int prob_m, // batch dimension m
int prob_n, // output dimension n int prob_n, // output dimension n
int prob_k, // reduction dimension k int prob_k, // reduction dimension k
@ -907,7 +908,6 @@ bool is_valid_config(thread_config_t const &th_config, int prob_m, int prob_n,
} }
thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) { thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) {
if (prob_m <= 16) { if (prob_m <= 16) {
for (auto th_config : small_batch_thread_configs) { for (auto th_config : small_batch_thread_configs) {
if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { if (is_valid_config(th_config, prob_m, prob_n, prob_k)) {
@ -1011,8 +1011,7 @@ void marlin_cuda(const void *A, const void *B, void *C, void *s, int prob_m,
// Note that parallel > 1 currently only works for inputs without any // Note that parallel > 1 currently only works for inputs without any
// padding // padding
par = (16 * thread_m_blocks - pad) / 64; par = (16 * thread_m_blocks - pad) / 64;
if (par > max_par) if (par > max_par) par = max_par;
par = max_par;
prob_m = 64 * par; prob_m = 64 * par;
i += 4 * (par - 1); i += 4 * (par - 1);
thread_m_blocks = 4; thread_m_blocks = 4;
@ -1046,7 +1045,6 @@ void marlin_cuda(const void *A, const void *B, void *C, void *s, int prob_m,
torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& workspace, torch::Tensor& b_scales, torch::Tensor& workspace,
int64_t size_m, int64_t size_n, int64_t size_k) { int64_t size_m, int64_t size_n, int64_t size_k) {
// Verify M // Verify M
TORCH_CHECK(size_m == a.size(0), TORCH_CHECK(size_m == a.size(0),
"Shape mismatch: a.size(0) = " + str(a.size(0)) + "Shape mismatch: a.size(0) = " + str(a.size(0)) +
@ -1074,9 +1072,9 @@ torch::Tensor marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
int actual_size_n = int actual_size_n =
(b_q_weight.size(1) / marlin::tile_size) * marlin::pack_factor_4bit; (b_q_weight.size(1) / marlin::tile_size) * marlin::pack_factor_4bit;
TORCH_CHECK(size_n == actual_size_n, TORCH_CHECK(
"size_n = " + str(size_n) + size_n == actual_size_n,
", actual_size_n = " + str(actual_size_n)); "size_n = " + str(size_n) + ", actual_size_n = " + str(actual_size_n));
// Verify A device and strides // Verify A device and strides
TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); TORCH_CHECK(a.device().is_cuda(), "A is not on GPU");

View File

@ -26,12 +26,14 @@ constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; }
// corresponding index accesses must be compile-time constants, which is why we // corresponding index accesses must be compile-time constants, which is why we
// extensively use `#pragma unroll` throughout the kernel code to guarantee // extensively use `#pragma unroll` throughout the kernel code to guarantee
// this. // this.
template <typename T, int n> struct Vec { template <typename T, int n>
struct Vec {
T elems[n]; T elems[n];
__device__ T& operator[](int i) { return elems[i]; } __device__ T& operator[](int i) { return elems[i]; }
}; };
template <int M_, int N_, int K_> struct ShapeBase { template <int M_, int N_, int K_>
struct ShapeBase {
static constexpr int M = M_, N = N_, K = K_; static constexpr int M = M_, N = N_, K = K_;
}; };

View File

@ -28,7 +28,8 @@ __device__ inline void cp_async4_pred_zfill(void *smem_ptr,
const int BYTES = 16; const int BYTES = 16;
int src_in_bytes = (zfill ? 0 : BYTES); int src_in_bytes = (zfill ? 0 : BYTES);
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr)); uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile("{\n" asm volatile(
"{\n"
" .reg .pred p;\n" " .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n" " setp.ne.b32 p, %0, 0;\n"
" @p cp.async.cg.shared.global [%1], [%2], %3;\n" " @p cp.async.cg.shared.global [%1], [%2], %3;\n"
@ -40,7 +41,8 @@ __device__ inline void cp_async4_pred(void *smem_ptr, const void *glob_ptr,
bool pred = true) { bool pred = true) {
const int BYTES = 16; const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr)); uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile("{\n" asm volatile(
"{\n"
" .reg .pred p;\n" " .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n" " setp.ne.b32 p, %0, 0;\n"
" @p cp.async.cg.shared.global [%1], [%2], %3;\n" " @p cp.async.cg.shared.global [%1], [%2], %3;\n"
@ -52,7 +54,8 @@ __device__ inline void cp_async4_pred(void *smem_ptr, const void *glob_ptr,
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { __device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
const int BYTES = 16; const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr)); uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile("{\n" asm volatile(
"{\n"
" cp.async.cg.shared.global [%0], [%1], %2;\n" " cp.async.cg.shared.global [%0], [%1], %2;\n"
"}\n" ::"r"(smem), "}\n" ::"r"(smem),
"l"(glob_ptr), "n"(BYTES)); "l"(glob_ptr), "n"(BYTES));
@ -64,7 +67,8 @@ __device__ inline void cp_async_fence() {
} }
// Wait until at most `n` async copy stages are still pending. // Wait until at most `n` async copy stages are still pending.
template <int n> __device__ inline void cp_async_wait() { template <int n>
__device__ inline void cp_async_wait() {
asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
} }

View File

@ -31,42 +31,47 @@ __device__ inline void mma_sp(const FragB &a_frag0, const FragB &a_frag1,
const uint32_t* e = reinterpret_cast<const uint32_t*>(&frag_m); const uint32_t* e = reinterpret_cast<const uint32_t*>(&frag_m);
float* c = reinterpret_cast<float*>(&frag_c); float* c = reinterpret_cast<float*>(&frag_c);
if (psel == 0) { if (psel == 0) {
asm volatile("mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " asm volatile(
"mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
"{%12,%13,%14,%15}, %16, 0x0;\n" "{%12,%13,%14,%15}, %16, 0x0;\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]), : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]), "r"(b[2]),
"r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]),
"f"(c[2]), "f"(c[3]), "r"(e[0])); "r"(e[0]));
asm volatile("mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " asm volatile(
"mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
"{%12,%13,%14,%15}, %16, 0x0;\n" "{%12,%13,%14,%15}, %16, 0x0;\n"
: "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7]) : "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7])
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]), : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]), "r"(b[3]),
"r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]), "f"(c[6]), "f"(c[7]),
"f"(c[6]), "f"(c[7]), "r"(e[0])); "r"(e[0]));
} else { } else {
asm volatile("mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " asm volatile(
"mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
"{%12,%13,%14,%15}, %16, 0x1;\n" "{%12,%13,%14,%15}, %16, 0x1;\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]), : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]), "r"(b[2]),
"r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]),
"f"(c[2]), "f"(c[3]), "r"(e[0])); "r"(e[0]));
asm volatile("mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " asm volatile(
"mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
"{%12,%13,%14,%15}, %16, 0x1;\n" "{%12,%13,%14,%15}, %16, 0x1;\n"
: "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7]) : "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7])
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]), : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]), "r"(b[3]),
"r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]), "f"(c[6]), "f"(c[7]),
"f"(c[6]), "f"(c[7]), "r"(e[0])); "r"(e[0]));
} }
} }
// Lookup-table based 3-input logical operation; explicitly used for // Lookup-table based 3-input logical operation; explicitly used for
// dequantization as the compiler does not seem to automatically recognize it in // dequantization as the compiler does not seem to automatically recognize it in
// all cases. // all cases.
template <int lut> __device__ inline int lop3(int a, int b, int c) { template <int lut>
__device__ inline int lop3(int a, int b, int c) {
int res; int res;
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(res) : "=r"(res)

View File

@ -37,7 +37,10 @@
#endif #endif
template <typename T> inline std::string str(T x) { return std::to_string(x); } template <typename T>
inline std::string str(T x) {
return std::to_string(x);
}
namespace marlin_24 { namespace marlin_24 {
@ -57,22 +60,23 @@ static constexpr int max_par = 16;
template <const int num_bits, // weight bits template <const int num_bits, // weight bits
const int threads, // number of threads in a threadblock const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the threadblock // dimension (batchsize) of the
// threadblock
const int thread_n_blocks, // same for n dimension (output) const int thread_n_blocks, // same for n dimension (output)
const int thread_k_blocks, // same for k dimension (reduction) const int thread_k_blocks, // same for k dimension (reduction)
const int stages, // number of stages for the async global->shared const int stages, // number of stages for the async global->shared
// fetch pipeline // fetch pipeline
const int group_blocks = -1 // number of consecutive 16x16 blocks with const int group_blocks = -1 // number of consecutive 16x16 blocks
// a separate quantization scale // with a separate quantization scale
> >
__global__ void Marlin_24( __global__ void Marlin_24(
const int4* __restrict__ A, // fp16 input matrix of shape mxk const int4* __restrict__ A, // fp16 input matrix of shape mxk
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
const int4 const int4* __restrict__ meta, // 2bit metadata information about 2:4
*__restrict__ meta, // 2bit metadata information about 2:4 format on B // format on B
int4* __restrict__ C, // fp16 output buffer of shape mxn int4* __restrict__ C, // fp16 output buffer of shape mxn
const int4 const int4* __restrict__ s, // fp16 quantization scales of shape
*__restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn // (k/groupsize)xn
int prob_m, // batch dimension m int prob_m, // batch dimension m
int prob_n, // output dimension n int prob_n, // output dimension n
int prob_k, // reduction dimension k int prob_k, // reduction dimension k
@ -95,22 +99,23 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
template <const int num_bits, // weight bits template <const int num_bits, // weight bits
const int threads, // number of threads in a threadblock const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the threadblock // dimension (batchsize) of the
// threadblock
const int thread_n_blocks, // same for n dimension (output) const int thread_n_blocks, // same for n dimension (output)
const int thread_k_blocks, // same for k dimension (reduction) const int thread_k_blocks, // same for k dimension (reduction)
const int stages, // number of stages for the async global->shared const int stages, // number of stages for the async global->shared
// fetch pipeline // fetch pipeline
const int group_blocks = -1 // number of consecutive 16x16 blocks with const int group_blocks = -1 // number of consecutive 16x16 blocks
// a separate quantization scale // with a separate quantization scale
> >
__global__ void Marlin_24( __global__ void Marlin_24(
const int4* __restrict__ A, // fp16 input matrix of shape mxk const int4* __restrict__ A, // fp16 input matrix of shape mxk
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
const int4 const int4* __restrict__ meta, // 2bit metadata information about 2:4
*__restrict__ meta, // 2bit metadata information about 2:4 format on B // format on B
int4* __restrict__ C, // fp16 output buffer of shape mxn int4* __restrict__ C, // fp16 output buffer of shape mxn
const int4 const int4* __restrict__ s, // fp16 quantization scales of shape
*__restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn // (k/groupsize)xn
int prob_m, // batch dimension m int prob_m, // batch dimension m
int prob_n, // output dimension n int prob_n, // output dimension n
int prob_k, // reduction dimension k int prob_k, // reduction dimension k
@ -174,27 +179,22 @@ __global__ void Marlin_24(
auto init_slice = [&]() { auto init_slice = [&]() {
slice_iters = slice_iters =
iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);
if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0;
slice_iters = 0; if (slice_iters == 0) return;
if (slice_iters == 0) if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row;
return;
if (slice_row + slice_iters > k_tiles)
slice_iters = k_tiles - slice_row;
slice_count = 1; slice_count = 1;
slice_idx = 0; slice_idx = 0;
int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); int col_first = iters * ceildiv(k_tiles * slice_col_par, iters);
if (col_first <= k_tiles * (slice_col_par + 1)) { if (col_first <= k_tiles * (slice_col_par + 1)) {
int col_off = col_first - k_tiles * slice_col_par; int col_off = col_first - k_tiles * slice_col_par;
slice_count = ceildiv(k_tiles - col_off, iters); slice_count = ceildiv(k_tiles - col_off, iters);
if (col_off > 0) if (col_off > 0) slice_count++;
slice_count++;
int delta_first = iters * blockIdx.x - col_first; int delta_first = iters * blockIdx.x - col_first;
if (delta_first < 0 || (col_off == 0 && delta_first == 0)) if (delta_first < 0 || (col_off == 0 && delta_first == 0))
slice_idx = slice_count - 1; slice_idx = slice_count - 1;
else { else {
slice_idx = slice_count - 1 - delta_first / iters; slice_idx = slice_count - 1 - delta_first / iters;
if (col_off > 0) if (col_off > 0) slice_idx--;
slice_idx--;
} }
} }
if (slice_col == n_tiles) { if (slice_col == n_tiles) {
@ -392,8 +392,7 @@ __global__ void Marlin_24(
for (int i = 0; i < b_sh_wr_iters; i++) { for (int i = 0; i < b_sh_wr_iters; i++) {
#pragma unroll #pragma unroll
for (int j = 0; j < b_thread_vecs; j++) { for (int j = 0; j < b_thread_vecs; j++) {
cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j);
B_ptr[i] + j);
} }
B_ptr[i] += b_gl_rd_delta_o; B_ptr[i] += b_gl_rd_delta_o;
} }
@ -401,15 +400,13 @@ __global__ void Marlin_24(
#pragma unroll #pragma unroll
for (int i = 0; i < m_sh_iters; i++) { for (int i = 0; i < m_sh_iters; i++) {
if (m_sh_wr_pred) if (m_sh_wr_pred)
cp_async4(&sh_meta_stage[m_sh_wr_delta * i + m_sh_wr], cp_async4(&sh_meta_stage[m_sh_wr_delta * i + m_sh_wr], meta_ptr[i]);
meta_ptr[i]);
meta_ptr[i] += m_gl_rd_delta_o; meta_ptr[i] += m_gl_rd_delta_o;
} }
// Only fetch scales if this tile starts a new group // Only fetch scales if this tile starts a new group
if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) { if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) {
int4* sh_s_stage = sh_s + s_sh_stage * pipe; int4* sh_s_stage = sh_s + s_sh_stage * pipe;
if (s_sh_wr_pred) if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
s_gl_rd += s_gl_rd_delta; s_gl_rd += s_gl_rd_delta;
} }
} }
@ -519,8 +516,8 @@ __global__ void Marlin_24(
(threadIdx.x % b_sh_stride_threads); (threadIdx.x % b_sh_stride_threads);
// Parallel logarithmic shared memory reduction. We make sure to avoid any // Parallel logarithmic shared memory reduction. We make sure to avoid any
// unnecessary read or write iterations, e.g., for two warps we write only once // unnecessary read or write iterations, e.g., for two warps we write only
// by warp 1 and read only once by warp 0. // once by warp 1 and read only once by warp 0.
#pragma unroll #pragma unroll
for (int m_block = 0; m_block < thread_m_blocks; m_block++) { for (int m_block = 0; m_block < thread_m_blocks; m_block++) {
#pragma unroll #pragma unroll
@ -531,8 +528,8 @@ __global__ void Marlin_24(
int red_sh_wr = int red_sh_wr =
red_sh_delta * j + (red_sh_rd - red_sh_stride * i); red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
if (i < red_off) { if (i < red_off) {
float *c_rd = reinterpret_cast<float *>( float* c_rd =
&sh[red_sh_delta * j + red_sh_rd]); reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]);
float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]); float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]);
#pragma unroll #pragma unroll
for (int k = 0; k < 4; k++) for (int k = 0; k < 4; k++)
@ -562,9 +559,9 @@ __global__ void Marlin_24(
}; };
// Since multiple threadblocks may process parts of the same column slice, we // Since multiple threadblocks may process parts of the same column slice, we
// finally have to globally reduce over the results. As the striped partitioning // finally have to globally reduce over the results. As the striped
// minimizes the number of such reductions and our outputs are usually rather // partitioning minimizes the number of such reductions and our outputs are
// small, we perform this reduction serially in L2 cache. // usually rather small, we perform this reduction serially in L2 cache.
auto global_reduce = [&](bool first = false, bool last = false) { auto global_reduce = [&](bool first = false, bool last = false) {
// We are very careful here to reduce directly in the output buffer to // We are very careful here to reduce directly in the output buffer to
// maximize L2 cache utilization in this step. To do this, we write out // maximize L2 cache utilization in this step. To do this, we write out
@ -584,9 +581,9 @@ __global__ void Marlin_24(
int col = 2 * ((threadIdx.x % 32) % 4); int col = 2 * ((threadIdx.x % 32) % 4);
if (!first) { if (!first) {
// Interestingly, doing direct global accesses here really seems to mess up the // Interestingly, doing direct global accesses here really seems to mess up
// compiler and lead to slowdowns, hence we also use async-copies even though // the compiler and lead to slowdowns, hence we also use async-copies even
// these fetches are not actually asynchronous. // though these fetches are not actually asynchronous.
#pragma unroll #pragma unroll
for (int i = 0; i < thread_m_blocks * 4; i++) { for (int i = 0; i < thread_m_blocks * 4; i++) {
cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i],
@ -722,8 +719,7 @@ __global__ void Marlin_24(
// Start global fetch and register load pipelines. // Start global fetch and register load pipelines.
auto start_pipes = [&]() { auto start_pipes = [&]() {
#pragma unroll #pragma unroll
for (int i = 0; i < stages - 1; i++) for (int i = 0; i < stages - 1; i++) fetch_to_shared(i, i, i < slice_iters);
fetch_to_shared(i, i, i < slice_iters);
zero_accums(); zero_accums();
wait_for_stage(); wait_for_stage();
fetch_to_registers(0, 0); fetch_to_registers(0, 0);
@ -733,9 +729,9 @@ __global__ void Marlin_24(
// Main loop. // Main loop.
while (slice_iters) { while (slice_iters) {
// We unroll over both the global fetch and the register load pipeline to ensure // We unroll over both the global fetch and the register load pipeline to
// all shared memory accesses are static. Note that both pipelines have even // ensure all shared memory accesses are static. Note that both pipelines have
// length meaning that the next iteration will always start at index 0. // even length meaning that the next iteration will always start at index 0.
#pragma unroll #pragma unroll
for (int pipe = 0; pipe < stages;) { for (int pipe = 0; pipe < stages;) {
fetch_to_shared((pipe + stages - 1) % stages, pipe, fetch_to_shared((pipe + stages - 1) % stages, pipe,
@ -747,8 +743,7 @@ __global__ void Marlin_24(
pipe++; pipe++;
slice_iters--; slice_iters--;
if (slice_iters == 0) if (slice_iters == 0) break;
break;
} }
a_gl_rd += a_gl_rd_delta_o * stages; a_gl_rd += a_gl_rd_delta_o * stages;
@ -762,13 +757,11 @@ __global__ void Marlin_24(
// write-out // write-out
if constexpr (group_blocks == -1) { if constexpr (group_blocks == -1) {
if constexpr (num_bits == 8) { if constexpr (num_bits == 8) {
if (s_sh_wr_pred) if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]);
cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]);
cp_async_fence(); cp_async_fence();
} else { } else {
if (last) { if (last) {
if (s_sh_wr_pred) if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]);
cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]);
cp_async_fence(); cp_async_fence();
} }
} }
@ -851,11 +844,9 @@ __global__ void Marlin_24(
meta_ptr[i] += (m_sh_stride)-m_gl_rd_delta_o * k_tiles; meta_ptr[i] += (m_sh_stride)-m_gl_rd_delta_o * k_tiles;
if (slice_col == 0) { if (slice_col == 0) {
#pragma unroll #pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++) for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
B_ptr[i] -= b_gl_stride;
#pragma unroll #pragma unroll
for (int i = 0; i < m_sh_iters; i++) for (int i = 0; i < m_sh_iters; i++) meta_ptr[i] -= m_gl_stride;
meta_ptr[i] -= m_gl_stride;
} }
s_gl_rd = s_sh_stride * slice_col + threadIdx.x; s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
start_pipes(); start_pipes();
@ -904,8 +895,8 @@ void marlin_cuda_2_4(const void *A, const void *B, const void *meta, void *C,
if (thread_k == -1 || thread_m == -1) { if (thread_k == -1 || thread_m == -1) {
if (prob_n <= 16) { if (prob_n <= 16) {
// For small batchizes, better partitioningif is slightly more important than // For small batchizes, better partitioningif is slightly more important
// better compute utilization // than better compute utilization
thread_k = 128; thread_k = 128;
thread_m = 128; thread_m = 128;
} else { } else {
@ -946,8 +937,7 @@ void marlin_cuda_2_4(const void *A, const void *B, const void *meta, void *C,
// Note that parallel > 1 currently only works for inputs without any // Note that parallel > 1 currently only works for inputs without any
// padding // padding
par = (16 * thread_n_blocks - pad) / 64; par = (16 * thread_n_blocks - pad) / 64;
if (par > max_par) if (par > max_par) par = max_par;
par = max_par;
prob_n = 64 * par; prob_n = 64 * par;
i += 4 * (par - 1); i += 4 * (par - 1);
thread_n_blocks = 4; thread_n_blocks = 4;
@ -1037,9 +1027,9 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
" is not divisible by tile_size = " + str(marlin_24::tile_size)); " is not divisible by tile_size = " + str(marlin_24::tile_size));
int actual_size_n = (b_q_weight.size(1) / marlin_24::tile_size) * pack_factor; int actual_size_n = (b_q_weight.size(1) / marlin_24::tile_size) * pack_factor;
TORCH_CHECK(size_n == actual_size_n, TORCH_CHECK(
"size_n = " + str(size_n) + size_n == actual_size_n,
", actual_size_n = " + str(actual_size_n)); "size_n = " + str(size_n) + ", actual_size_n = " + str(actual_size_n));
// Verify meta // Verify meta
TORCH_CHECK(b_meta.size(0) == size_k / 8 / 2 / 2, TORCH_CHECK(b_meta.size(0) == size_k / 8 / 2 / 2,

View File

@ -32,12 +32,8 @@ __global__ void NUQ4MatMulKernel(
#else #else
float2* __restrict__ mul, float2* __restrict__ mul,
#endif #endif
const __half* __restrict__ lookup_table, const __half* __restrict__ lookup_table, int height, int width, int batch,
int height, int vec_height) {
int width,
int batch,
int vec_height
) {
const int blockwidth2 = BLOCKWIDTH / 2; const int blockwidth2 = BLOCKWIDTH / 2;
@ -80,7 +76,9 @@ __global__ void NUQ4MatMulKernel(
__syncthreads(); __syncthreads();
if (threadIdx.x < blockwidth2) if (threadIdx.x < blockwidth2)
blockvec[threadIdx.x] = vec[b * vec_height / 2 + (row / BLOCKHEIGHT4) * blockwidth2 + threadIdx.x]; blockvec[threadIdx.x] =
vec[b * vec_height / 2 + (row / BLOCKHEIGHT4) * blockwidth2 +
threadIdx.x];
__syncthreads(); __syncthreads();
while (k < blockwidth2) { while (k < blockwidth2) {
@ -143,7 +141,8 @@ __global__ void NUQ4MatMulKernel(
#ifndef USE_ROCM #ifndef USE_ROCM
res = __hadd(__hadd(res2.x, res2.y), res); res = __hadd(__hadd(res2.x, res2.y), res);
#else #else
res = __hadd(__hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)), res); res = __hadd(__hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)),
res);
#endif #endif
i += width; i += width;
@ -183,22 +182,16 @@ __global__ void NUQ4MatMulKernel(
} // namespace vllm } // namespace vllm
// 4-bit matvec kernel (LUT-based) // 4-bit matvec kernel (LUT-based)
void squeezellm_gemm( void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor vec, torch::Tensor lookup_table) {
torch::Tensor mat,
torch::Tensor mul,
torch::Tensor lookup_table
) {
int height = mat.size(0); int height = mat.size(0);
int width = mat.size(1); int width = mat.size(1);
int batch = vec.size(0); int batch = vec.size(0);
int vec_height = vec.size(1); int vec_height = vec.size(1);
dim3 blocks( dim3 blocks((height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
(height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4, (width + BLOCKWIDTH - 1) / BLOCKWIDTH);
(width + BLOCKWIDTH - 1) / BLOCKWIDTH
);
dim3 threads(BLOCKWIDTH); dim3 threads(BLOCKWIDTH);
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
@ -211,14 +204,12 @@ void squeezellm_gemm(
#endif #endif
mat.data_ptr<int>(), mat.data_ptr<int>(),
#ifndef USE_ROCM #ifndef USE_ROCM
(half2*) mul.data<at::Half>(), (half2*)mul.data<at::Half>(), (__half*)lookup_table.data<at::Half>(),
(__half*) lookup_table.data<at::Half>(),
#else #else
(float2*)mul.data_ptr<float>(), (float2*)mul.data_ptr<float>(),
(__half*)lookup_table.data_ptr<at::Half>(), (__half*)lookup_table.data_ptr<at::Half>(),
#endif #endif
height, width, batch, vec_height height, width, batch, vec_height);
);
} }
#undef BLOCKWIDTH #undef BLOCKWIDTH

View File

@ -1,5 +1,6 @@
/* /*
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/reduce_kernel_utils.cuh * Adapted from
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/reduce_kernel_utils.cuh
* Copyright (c) 2023, The vLLM team. * Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
* *
@ -43,17 +44,18 @@ __inline__ __device__ T blockReduceSum(T val) {
static_assert(maxBlockSize <= 1024); static_assert(maxBlockSize <= 1024);
if constexpr (maxBlockSize > WARP_SIZE) { if constexpr (maxBlockSize > WARP_SIZE) {
val = warpReduceSum<T>(val); val = warpReduceSum<T>(val);
// Calculates max number of lanes that need to participate in the last warpReduce // Calculates max number of lanes that need to participate in the last
// warpReduce
constexpr int maxActiveLanes = (maxBlockSize + WARP_SIZE - 1) / WARP_SIZE; constexpr int maxActiveLanes = (maxBlockSize + WARP_SIZE - 1) / WARP_SIZE;
static __shared__ T shared[maxActiveLanes]; static __shared__ T shared[maxActiveLanes];
int lane = threadIdx.x % WARP_SIZE; int lane = threadIdx.x % WARP_SIZE;
int wid = threadIdx.x / WARP_SIZE; int wid = threadIdx.x / WARP_SIZE;
if (lane == 0) if (lane == 0) shared[wid] = val;
shared[wid] = val;
__syncthreads(); __syncthreads();
val = (threadIdx.x < blockDim.x / float(WARP_SIZE)) ? shared[lane] : (T)(0.0f); val = (threadIdx.x < blockDim.x / float(WARP_SIZE)) ? shared[lane]
: (T)(0.0f);
val = warpReduceSum<T, _nextPow2(maxActiveLanes)>(val); val = warpReduceSum<T, _nextPow2(maxActiveLanes)>(val);
} else { } else {
// A single warpReduce is equal to blockReduce // A single warpReduce is equal to blockReduce

View File

@ -26,6 +26,7 @@ RUFF_VERSION=$(ruff --version | awk '{print $2}')
MYPY_VERSION=$(mypy --version | awk '{print $2}') MYPY_VERSION=$(mypy --version | awk '{print $2}')
CODESPELL_VERSION=$(codespell --version) CODESPELL_VERSION=$(codespell --version)
ISORT_VERSION=$(isort --vn) ISORT_VERSION=$(isort --vn)
CLANGFORMAT_VERSION=$(clang-format --version | awk '{print $3}')
# # params: tool name, tool version, required version # # params: tool name, tool version, required version
tool_version_check() { tool_version_check() {
@ -40,6 +41,7 @@ tool_version_check "ruff" $RUFF_VERSION "$(grep "ruff==" requirements-dev.txt |
tool_version_check "mypy" "$MYPY_VERSION" "$(grep mypy requirements-dev.txt | cut -d'=' -f3)" tool_version_check "mypy" "$MYPY_VERSION" "$(grep mypy requirements-dev.txt | cut -d'=' -f3)"
tool_version_check "isort" "$ISORT_VERSION" "$(grep isort requirements-dev.txt | cut -d'=' -f3)" tool_version_check "isort" "$ISORT_VERSION" "$(grep isort requirements-dev.txt | cut -d'=' -f3)"
tool_version_check "codespell" "$CODESPELL_VERSION" "$(grep codespell requirements-dev.txt | cut -d'=' -f3)" tool_version_check "codespell" "$CODESPELL_VERSION" "$(grep codespell requirements-dev.txt | cut -d'=' -f3)"
tool_version_check "clang-format" "$CLANGFORMAT_VERSION" "$(grep clang-format requirements-dev.txt | cut -d'=' -f3)"
YAPF_FLAGS=( YAPF_FLAGS=(
'--recursive' '--recursive'
@ -179,7 +181,6 @@ lint_changed() {
} }
# Run Ruff # Run Ruff
echo 'vLLM ruff:'
### This flag lints individual files. --files *must* be the first command line ### This flag lints individual files. --files *must* be the first command line
### arg to use this option. ### arg to use this option.
if [[ "$1" == '--files' ]]; then if [[ "$1" == '--files' ]]; then
@ -192,6 +193,7 @@ else
# Format only the files that changed in last commit. # Format only the files that changed in last commit.
lint_changed lint_changed
fi fi
echo 'vLLM ruff: Done'
# check spelling of specified files # check spelling of specified files
isort_check() { isort_check() {
@ -233,6 +235,59 @@ else
fi fi
echo 'vLLM isort: Done' echo 'vLLM isort: Done'
# Clang-format section
# Exclude some files for formatting because they are vendored
# NOTE: Keep up to date with .github/workflows/clang-format.yml
CLANG_FORMAT_EXCLUDES=(
'csrc/moe/topk_softmax_kernels.cu'
'csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu'
'csrc/punica/bgmv/bgmv_config.h'
'csrc/punica/bgmv/bgmv_impl.cuh'
'csrc/punica/bgmv/vec_dtypes.cuh'
'csrc/punica/punica_ops.cu'
'csrc/punica/type_convert.h'
)
# Format specified files with clang-format
clang_format() {
clang-format -i "$@"
}
# Format files that differ from main branch with clang-format.
clang_format_changed() {
# The `if` guard ensures that the list of filenames is not empty, which
# could cause clang-format to receive 0 positional arguments, making it hang
# waiting for STDIN.
#
# `diff-filter=ACM` and $MERGEBASE is to ensure we only format files that
# exist on both branches.
MERGEBASE="$(git merge-base origin/main HEAD)"
# Get the list of changed files, excluding the specified ones
changed_files=$(git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.h' '*.cpp' '*.cu' '*.cuh' | grep -vFf <(printf "%s\n" "${CLANG_FORMAT_EXCLUDES[@]}"))
if [ -n "$changed_files" ]; then
echo "$changed_files" | xargs -P 5 clang-format -i
fi
}
# Format all files with clang-format
clang_format_all() {
find csrc/ \( -name '*.h' -o -name '*.cpp' -o -name '*.cu' -o -name '*.cuh' \) -print \
| grep -vFf <(printf "%s\n" "${CLANG_FORMAT_EXCLUDES[@]}") \
| xargs clang-format -i
}
# Run clang-format
if [[ "$1" == '--files' ]]; then
clang_format "${@:2}"
elif [[ "$1" == '--all' ]]; then
clang_format_all
else
clang_format_changed
fi
echo 'vLLM clang-format: Done'
if ! git diff --quiet &>/dev/null; then if ! git diff --quiet &>/dev/null; then
echo 'Reformatted files. Please review and stage the changes.' echo 'Reformatted files. Please review and stage the changes.'
echo 'Changes not staged for commit:' echo 'Changes not staged for commit:'

View File

@ -5,6 +5,7 @@ tomli==2.0.1
ruff==0.1.5 ruff==0.1.5
codespell==2.2.6 codespell==2.2.6
isort==5.13.2 isort==5.13.2
clang-format==18.1.5
# type checking # type checking
mypy==1.9.0 mypy==1.9.0