[CI/Build] Enforce style for C++ and CUDA code with clang-format
(#4722)
This commit is contained in:
parent
9b9a10d6cb
commit
5f6d10c14c
26
.clang-format
Normal file
26
.clang-format
Normal file
@ -0,0 +1,26 @@
|
||||
BasedOnStyle: Google
|
||||
UseTab: Never
|
||||
IndentWidth: 2
|
||||
ColumnLimit: 80
|
||||
|
||||
# Force pointers to the type for C++.
|
||||
DerivePointerAlignment: false
|
||||
PointerAlignment: Left
|
||||
|
||||
# Reordering #include statements can (and currently will) introduce errors
|
||||
SortIncludes: false
|
||||
|
||||
# Style choices
|
||||
AlignConsecutiveAssignments: false
|
||||
AlignConsecutiveDeclarations: false
|
||||
IndentPPDirectives: BeforeHash
|
||||
|
||||
IncludeCategories:
|
||||
- Regex: '^<'
|
||||
Priority: 4
|
||||
- Regex: '^"(llvm|llvm-c|clang|clang-c|mlir|mlir-c)/'
|
||||
Priority: 3
|
||||
- Regex: '^"(qoda|\.\.)/'
|
||||
Priority: 2
|
||||
- Regex: '.*'
|
||||
Priority: 1
|
42
.github/workflows/clang-format.yml
vendored
Normal file
42
.github/workflows/clang-format.yml
vendored
Normal file
@ -0,0 +1,42 @@
|
||||
name: clang-format
|
||||
|
||||
on:
|
||||
# Trigger the workflow on push or pull request,
|
||||
# but only for the main branch
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
|
||||
jobs:
|
||||
clang-format:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.11"]
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install clang-format==18.1.5
|
||||
- name: Running clang-format
|
||||
run: |
|
||||
EXCLUDES=(
|
||||
'csrc/moe/topk_softmax_kernels.cu'
|
||||
'csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu'
|
||||
'csrc/punica/bgmv/bgmv_config.h'
|
||||
'csrc/punica/bgmv/bgmv_impl.cuh'
|
||||
'csrc/punica/bgmv/vec_dtypes.cuh'
|
||||
'csrc/punica/punica_ops.cu'
|
||||
'csrc/punica/type_convert.h'
|
||||
)
|
||||
find csrc/ \( -name '*.h' -o -name '*.cpp' -o -name '*.cu' -o -name '*.cuh' \) -print \
|
||||
| grep -vFf <(printf "%s\n" "${EXCLUDES[@]}") \
|
||||
| xargs clang-format --dry-run --Werror
|
@ -63,31 +63,25 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
||||
VLLM_DISPATCH_FLOATING_TYPES( \
|
||||
input.scalar_type(), \
|
||||
"act_and_mul_kernel", \
|
||||
[&] { \
|
||||
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \
|
||||
out.data_ptr<scalar_t>(), \
|
||||
input.data_ptr<scalar_t>(), \
|
||||
d); \
|
||||
input.scalar_type(), "act_and_mul_kernel", [&] { \
|
||||
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>> \
|
||||
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
|
||||
input.data_ptr<scalar_t>(), d); \
|
||||
});
|
||||
|
||||
void silu_and_mul(
|
||||
torch::Tensor& out, // [..., d]
|
||||
void silu_and_mul(torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input) // [..., 2 * d]
|
||||
{
|
||||
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
|
||||
}
|
||||
|
||||
void gelu_and_mul(
|
||||
torch::Tensor& out, // [..., d]
|
||||
void gelu_and_mul(torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input) // [..., 2 * d]
|
||||
{
|
||||
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel);
|
||||
}
|
||||
|
||||
void gelu_tanh_and_mul(
|
||||
torch::Tensor& out, // [..., d]
|
||||
void gelu_tanh_and_mul(torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input) // [..., 2 * d]
|
||||
{
|
||||
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel);
|
||||
@ -118,14 +112,10 @@ __global__ void activation_kernel(
|
||||
dim3 block(std::min(d, 1024)); \
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
||||
VLLM_DISPATCH_FLOATING_TYPES( \
|
||||
input.scalar_type(), \
|
||||
"activation_kernel", \
|
||||
[&] { \
|
||||
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \
|
||||
out.data_ptr<scalar_t>(), \
|
||||
input.data_ptr<scalar_t>(), \
|
||||
d); \
|
||||
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "activation_kernel", [&] { \
|
||||
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>> \
|
||||
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
|
||||
input.data_ptr<scalar_t>(), d); \
|
||||
});
|
||||
|
||||
namespace vllm {
|
||||
@ -140,21 +130,20 @@ __device__ __forceinline__ T gelu_new_kernel(const T& x) {
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T gelu_fast_kernel(const T& x) {
|
||||
const float f = (float)x;
|
||||
const T t = (T) tanhf(((T) (f * 0.79788456f)) * (((T) 1.0) + (T) (0.044715f * f) * x));
|
||||
const T t =
|
||||
(T)tanhf(((T)(f * 0.79788456f)) * (((T)1.0) + (T)(0.044715f * f) * x));
|
||||
return ((T)0.5) * x * (((T)1.0) + t);
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
void gelu_new(
|
||||
torch::Tensor& out, // [..., d]
|
||||
void gelu_new(torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input) // [..., d]
|
||||
{
|
||||
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
|
||||
}
|
||||
|
||||
void gelu_fast(
|
||||
torch::Tensor& out, // [..., d]
|
||||
void gelu_fast(torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input) // [..., d]
|
||||
{
|
||||
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
|
||||
|
@ -1,5 +1,6 @@
|
||||
/*
|
||||
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
||||
* Adapted from
|
||||
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
||||
* Copyright (c) 2023, The vLLM team.
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
|
@ -1,5 +1,6 @@
|
||||
/*
|
||||
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
||||
* Adapted from
|
||||
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
||||
* Copyright (c) 2023, The vLLM team.
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
@ -82,30 +83,27 @@ inline __device__ float block_sum(float* red_smem, float sum) {
|
||||
|
||||
// TODO(woosuk): Merge the last two dimensions of the grid.
|
||||
// Grid: (num_heads, num_seqs, max_num_partitions).
|
||||
template<
|
||||
typename scalar_t,
|
||||
typename cache_t,
|
||||
int HEAD_SIZE,
|
||||
int BLOCK_SIZE,
|
||||
int NUM_THREADS,
|
||||
vllm::Fp8KVCacheDataType KV_DTYPE,
|
||||
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
|
||||
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
|
||||
int PARTITION_SIZE = 0> // Zero means no partitioning.
|
||||
__device__ void paged_attention_kernel(
|
||||
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
||||
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
|
||||
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
||||
float* __restrict__ max_logits, // [num_seqs, num_heads,
|
||||
// max_num_partitions]
|
||||
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
|
||||
// head_size]
|
||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
||||
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
|
||||
// head_size/x, block_size, x]
|
||||
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
|
||||
// head_size, block_size]
|
||||
const int num_kv_heads, // [num_heads]
|
||||
const float scale,
|
||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int* __restrict__ seq_lens, // [num_seqs]
|
||||
const int max_num_blocks_per_seq,
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
const int q_stride,
|
||||
const int kv_block_stride,
|
||||
const int kv_head_stride,
|
||||
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
||||
const float kv_scale) {
|
||||
const int seq_idx = blockIdx.y;
|
||||
const int partition_idx = blockIdx.z;
|
||||
@ -118,22 +116,29 @@ __device__ void paged_attention_kernel(
|
||||
}
|
||||
|
||||
const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
|
||||
const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks;
|
||||
const int num_blocks_per_partition =
|
||||
USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks;
|
||||
|
||||
// [start_block_idx, end_block_idx) is the range of blocks to process.
|
||||
const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
|
||||
const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks);
|
||||
const int start_block_idx =
|
||||
USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
|
||||
const int end_block_idx =
|
||||
MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks);
|
||||
const int num_blocks = end_block_idx - start_block_idx;
|
||||
|
||||
// [start_token_idx, end_token_idx) is the range of tokens to process.
|
||||
const int start_token_idx = start_block_idx * BLOCK_SIZE;
|
||||
const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len);
|
||||
const int end_token_idx =
|
||||
MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len);
|
||||
const int num_tokens = end_token_idx - start_token_idx;
|
||||
|
||||
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
||||
constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS
|
||||
constexpr int NUM_THREAD_GROUPS =
|
||||
NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE
|
||||
// divides NUM_THREADS
|
||||
assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
|
||||
constexpr int NUM_TOKENS_PER_THREAD_GROUP = DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE);
|
||||
constexpr int NUM_TOKENS_PER_THREAD_GROUP =
|
||||
DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE);
|
||||
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
||||
const int thread_idx = threadIdx.x;
|
||||
const int warp_idx = thread_idx / WARP_SIZE;
|
||||
@ -143,13 +148,14 @@ __device__ void paged_attention_kernel(
|
||||
const int num_heads = gridDim.x;
|
||||
const int num_queries_per_kv = num_heads / num_kv_heads;
|
||||
const int kv_head_idx = head_idx / num_queries_per_kv;
|
||||
const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
|
||||
const float alibi_slope =
|
||||
alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
|
||||
|
||||
// A vector type to store a part of a key or a query.
|
||||
// The vector size is configured in such a way that the threads in a thread group
|
||||
// fetch or compute 16 bytes at a time.
|
||||
// For example, if the size of a thread group is 4 and the data type is half,
|
||||
// then the vector size is 16 / (4 * sizeof(half)) == 2.
|
||||
// The vector size is configured in such a way that the threads in a thread
|
||||
// group fetch or compute 16 bytes at a time. For example, if the size of a
|
||||
// thread group is 4 and the data type is half, then the vector size is 16 /
|
||||
// (4 * sizeof(half)) == 2.
|
||||
constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
|
||||
using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
||||
using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
||||
@ -163,18 +169,21 @@ __device__ void paged_attention_kernel(
|
||||
|
||||
// Load the query to registers.
|
||||
// Each thread in a thread group has a different part of the query.
|
||||
// For example, if the the thread group size is 4, then the first thread in the group
|
||||
// has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ...
|
||||
// th vectors of the query, and so on.
|
||||
// NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous.
|
||||
// For example, if the the thread group size is 4, then the first thread in
|
||||
// the group has 0, 4, 8, ... th vectors of the query, and the second thread
|
||||
// has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because
|
||||
// q is split from a qkv tensor, it may not be contiguous.
|
||||
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
||||
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
|
||||
#pragma unroll
|
||||
for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) {
|
||||
for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD;
|
||||
i += NUM_THREAD_GROUPS) {
|
||||
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
|
||||
q_vecs[thread_group_offset][i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
|
||||
q_vecs[thread_group_offset][i] =
|
||||
*reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
|
||||
}
|
||||
__syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs
|
||||
__syncthreads(); // TODO(naed90): possible speedup if this is replaced with a
|
||||
// memory wall right before we use q_vecs
|
||||
|
||||
// Memory planning.
|
||||
extern __shared__ char shared_mem[];
|
||||
@ -193,44 +202,50 @@ __device__ void paged_attention_kernel(
|
||||
// Each thread group in a warp fetches a key from the block, and computes
|
||||
// dot product with the query.
|
||||
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
|
||||
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
|
||||
// NOTE(woosuk): The block number is stored in int32. However, we cast it to int64
|
||||
// because int32 can lead to overflow when this variable is multiplied by large numbers
|
||||
// (e.g., kv_block_stride).
|
||||
const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
|
||||
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
|
||||
block_idx += NUM_WARPS) {
|
||||
// NOTE(woosuk): The block number is stored in int32. However, we cast it to
|
||||
// int64 because int32 can lead to overflow when this variable is multiplied
|
||||
// by large numbers (e.g., kv_block_stride).
|
||||
const int64_t physical_block_number =
|
||||
static_cast<int64_t>(block_table[block_idx]);
|
||||
|
||||
// Load a key to registers.
|
||||
// Each thread in a thread group has a different part of the key.
|
||||
// For example, if the the thread group size is 4, then the first thread in the group
|
||||
// has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th
|
||||
// vectors of the key, and so on.
|
||||
// For example, if the the thread group size is 4, then the first thread in
|
||||
// the group has 0, 4, 8, ... th vectors of the key, and the second thread
|
||||
// has 1, 5, 9, ... th vectors of the key, and so on.
|
||||
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
|
||||
const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
|
||||
const int physical_block_offset =
|
||||
(thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
|
||||
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
|
||||
K_vec k_vecs[NUM_VECS_PER_THREAD];
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
|
||||
const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride
|
||||
+ kv_head_idx * kv_head_stride
|
||||
+ physical_block_offset * x;
|
||||
const cache_t* k_ptr =
|
||||
k_cache + physical_block_number * kv_block_stride +
|
||||
kv_head_idx * kv_head_stride + physical_block_offset * x;
|
||||
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
|
||||
const int offset1 = (vec_idx * VEC_SIZE) / x;
|
||||
const int offset2 = (vec_idx * VEC_SIZE) % x;
|
||||
|
||||
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
|
||||
k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
||||
k_vecs[j] = *reinterpret_cast<const K_vec*>(
|
||||
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
||||
} else {
|
||||
// Vector conversion from Quant_vec to K_vec.
|
||||
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
|
||||
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
||||
k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(k_vec_quant, kv_scale);
|
||||
k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(
|
||||
k_vec_quant, kv_scale);
|
||||
}
|
||||
}
|
||||
|
||||
// Compute dot product.
|
||||
// This includes a reduction across the threads in the same thread group.
|
||||
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
|
||||
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(
|
||||
q_vecs[thread_group_offset], k_vecs);
|
||||
// Add the ALiBi bias if slopes are given.
|
||||
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0;
|
||||
|
||||
@ -285,13 +300,12 @@ __device__ void paged_attention_kernel(
|
||||
|
||||
// If partitioning is enabled, store the max logit and exp_sum.
|
||||
if (USE_PARTITIONING && thread_idx == 0) {
|
||||
float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions
|
||||
+ head_idx * max_num_partitions
|
||||
+ partition_idx;
|
||||
float* max_logits_ptr = max_logits +
|
||||
seq_idx * num_heads * max_num_partitions +
|
||||
head_idx * max_num_partitions + partition_idx;
|
||||
*max_logits_ptr = qk_max;
|
||||
float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions
|
||||
+ head_idx * max_num_partitions
|
||||
+ partition_idx;
|
||||
float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions +
|
||||
head_idx * max_num_partitions + partition_idx;
|
||||
*exp_sums_ptr = exp_sum;
|
||||
}
|
||||
|
||||
@ -304,7 +318,8 @@ __device__ void paged_attention_kernel(
|
||||
|
||||
constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
|
||||
constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
|
||||
constexpr int NUM_ROWS_PER_THREAD = DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
|
||||
constexpr int NUM_ROWS_PER_THREAD =
|
||||
DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
|
||||
|
||||
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
|
||||
float accs[NUM_ROWS_PER_THREAD];
|
||||
@ -315,18 +330,21 @@ __device__ void paged_attention_kernel(
|
||||
|
||||
scalar_t zero_value;
|
||||
zero(zero_value);
|
||||
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
|
||||
// NOTE(woosuk): The block number is stored in int32. However, we cast it to int64
|
||||
// because int32 can lead to overflow when this variable is multiplied by large numbers
|
||||
// (e.g., kv_block_stride).
|
||||
const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
|
||||
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
|
||||
block_idx += NUM_WARPS) {
|
||||
// NOTE(woosuk): The block number is stored in int32. However, we cast it to
|
||||
// int64 because int32 can lead to overflow when this variable is multiplied
|
||||
// by large numbers (e.g., kv_block_stride).
|
||||
const int64_t physical_block_number =
|
||||
static_cast<int64_t>(block_table[block_idx]);
|
||||
const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
|
||||
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
|
||||
L_vec logits_vec;
|
||||
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx - start_token_idx));
|
||||
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx -
|
||||
start_token_idx));
|
||||
|
||||
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride
|
||||
+ kv_head_idx * kv_head_stride;
|
||||
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
|
||||
kv_head_idx * kv_head_stride;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
||||
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
||||
@ -337,14 +355,17 @@ __device__ void paged_attention_kernel(
|
||||
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
|
||||
v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
|
||||
} else {
|
||||
V_quant_vec v_quant_vec = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
|
||||
V_quant_vec v_quant_vec =
|
||||
*reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
|
||||
// Vector conversion from V_quant_vec to V_vec.
|
||||
v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec, kv_scale);
|
||||
v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec,
|
||||
kv_scale);
|
||||
}
|
||||
if (block_idx == num_seq_blocks - 1) {
|
||||
// NOTE(woosuk): When v_vec contains the tokens that are out of the context,
|
||||
// we should explicitly zero out the values since they may contain NaNs.
|
||||
// See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
|
||||
// NOTE(woosuk): When v_vec contains the tokens that are out of the
|
||||
// context, we should explicitly zero out the values since they may
|
||||
// contain NaNs. See
|
||||
// https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
|
||||
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < V_VEC_SIZE; j++) {
|
||||
@ -367,8 +388,8 @@ __device__ void paged_attention_kernel(
|
||||
accs[i] = acc;
|
||||
}
|
||||
|
||||
// NOTE(woosuk): A barrier is required because the shared memory space for logits
|
||||
// is reused for the output.
|
||||
// NOTE(woosuk): A barrier is required because the shared memory space for
|
||||
// logits is reused for the output.
|
||||
__syncthreads();
|
||||
|
||||
// Perform reduction across warps.
|
||||
@ -405,9 +426,9 @@ __device__ void paged_attention_kernel(
|
||||
|
||||
// Write the final output.
|
||||
if (warp_idx == 0) {
|
||||
scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
|
||||
+ head_idx * max_num_partitions * HEAD_SIZE
|
||||
+ partition_idx * HEAD_SIZE;
|
||||
scalar_t* out_ptr =
|
||||
out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
|
||||
head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
||||
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
||||
@ -419,77 +440,73 @@ __device__ void paged_attention_kernel(
|
||||
}
|
||||
|
||||
// Grid: (num_heads, num_seqs, 1).
|
||||
template<
|
||||
typename scalar_t,
|
||||
typename cache_t,
|
||||
int HEAD_SIZE,
|
||||
int BLOCK_SIZE,
|
||||
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
|
||||
int NUM_THREADS,
|
||||
vllm::Fp8KVCacheDataType KV_DTYPE>
|
||||
__global__ void paged_attention_v1_kernel(
|
||||
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
||||
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
|
||||
// head_size/x, block_size, x]
|
||||
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
|
||||
// head_size, block_size]
|
||||
const int num_kv_heads, // [num_heads]
|
||||
const float scale,
|
||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int* __restrict__ seq_lens, // [num_seqs]
|
||||
const int max_num_blocks_per_seq,
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
const int q_stride,
|
||||
const int kv_block_stride,
|
||||
const int kv_head_stride,
|
||||
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
||||
const float kv_scale) {
|
||||
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_DTYPE>(
|
||||
/* exp_sums */ nullptr, /* max_logits */ nullptr,
|
||||
out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, seq_lens,
|
||||
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_scale);
|
||||
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
|
||||
KV_DTYPE>(
|
||||
/* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache,
|
||||
v_cache, num_kv_heads, scale, block_tables, seq_lens,
|
||||
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride,
|
||||
kv_head_stride, kv_scale);
|
||||
}
|
||||
|
||||
// Grid: (num_heads, num_seqs, max_num_partitions).
|
||||
template<
|
||||
typename scalar_t,
|
||||
typename cache_t,
|
||||
int HEAD_SIZE,
|
||||
int BLOCK_SIZE,
|
||||
int NUM_THREADS,
|
||||
vllm::Fp8KVCacheDataType KV_DTYPE,
|
||||
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
|
||||
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
|
||||
int PARTITION_SIZE>
|
||||
__global__ void paged_attention_v2_kernel(
|
||||
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
||||
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
|
||||
scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
||||
float* __restrict__ max_logits, // [num_seqs, num_heads,
|
||||
// max_num_partitions]
|
||||
scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
|
||||
// max_num_partitions, head_size]
|
||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
||||
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
|
||||
// head_size/x, block_size, x]
|
||||
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
|
||||
// head_size, block_size]
|
||||
const int num_kv_heads, // [num_heads]
|
||||
const float scale,
|
||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int* __restrict__ seq_lens, // [num_seqs]
|
||||
const int max_num_blocks_per_seq,
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
const int q_stride,
|
||||
const int kv_block_stride,
|
||||
const int kv_head_stride,
|
||||
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
||||
const float kv_scale) {
|
||||
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_DTYPE, PARTITION_SIZE>(
|
||||
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
|
||||
KV_DTYPE, PARTITION_SIZE>(
|
||||
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
|
||||
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes,
|
||||
q_stride, kv_block_stride, kv_head_stride, kv_scale);
|
||||
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride,
|
||||
kv_block_stride, kv_head_stride, kv_scale);
|
||||
}
|
||||
|
||||
// Grid: (num_heads, num_seqs).
|
||||
template<
|
||||
typename scalar_t,
|
||||
int HEAD_SIZE,
|
||||
int NUM_THREADS,
|
||||
template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS,
|
||||
int PARTITION_SIZE>
|
||||
__global__ void paged_attention_v2_reduce_kernel(
|
||||
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
||||
const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
||||
const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
|
||||
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
||||
const float* __restrict__ exp_sums, // [num_seqs, num_heads,
|
||||
// max_num_partitions]
|
||||
const float* __restrict__ max_logits, // [num_seqs, num_heads,
|
||||
// max_num_partitions]
|
||||
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
|
||||
// max_num_partitions, head_size]
|
||||
const int* __restrict__ seq_lens, // [num_seqs]
|
||||
const int max_num_partitions) {
|
||||
const int num_heads = gridDim.x;
|
||||
@ -499,9 +516,11 @@ __global__ void paged_attention_v2_reduce_kernel(
|
||||
const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE);
|
||||
if (num_partitions == 1) {
|
||||
// No need to reduce. Only copy tmp_out to out.
|
||||
scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
|
||||
const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
|
||||
+ head_idx * max_num_partitions * HEAD_SIZE;
|
||||
scalar_t* out_ptr =
|
||||
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
|
||||
const scalar_t* tmp_out_ptr =
|
||||
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
|
||||
head_idx * max_num_partitions * HEAD_SIZE;
|
||||
for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) {
|
||||
out_ptr[i] = tmp_out_ptr[i];
|
||||
}
|
||||
@ -520,8 +539,9 @@ __global__ void paged_attention_v2_reduce_kernel(
|
||||
|
||||
// Load max logits to shared memory.
|
||||
float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
|
||||
const float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions
|
||||
+ head_idx * max_num_partitions;
|
||||
const float* max_logits_ptr = max_logits +
|
||||
seq_idx * num_heads * max_num_partitions +
|
||||
head_idx * max_num_partitions;
|
||||
float max_logit = -FLT_MAX;
|
||||
for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
|
||||
const float l = max_logits_ptr[i];
|
||||
@ -550,9 +570,11 @@ __global__ void paged_attention_v2_reduce_kernel(
|
||||
max_logit = VLLM_SHFL_SYNC(max_logit, 0);
|
||||
|
||||
// Load rescaled exp sums to shared memory.
|
||||
float* shared_exp_sums = reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
|
||||
const float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions
|
||||
+ head_idx * max_num_partitions;
|
||||
float* shared_exp_sums =
|
||||
reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
|
||||
const float* exp_sums_ptr = exp_sums +
|
||||
seq_idx * num_heads * max_num_partitions +
|
||||
head_idx * max_num_partitions;
|
||||
float global_exp_sum = 0.0f;
|
||||
for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
|
||||
float l = shared_max_logits[i];
|
||||
@ -565,14 +587,17 @@ __global__ void paged_attention_v2_reduce_kernel(
|
||||
const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
|
||||
|
||||
// Aggregate tmp_out to out.
|
||||
const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
|
||||
+ head_idx * max_num_partitions * HEAD_SIZE;
|
||||
scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
|
||||
const scalar_t* tmp_out_ptr =
|
||||
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
|
||||
head_idx * max_num_partitions * HEAD_SIZE;
|
||||
scalar_t* out_ptr =
|
||||
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
|
||||
#pragma unroll
|
||||
for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
|
||||
float acc = 0.0f;
|
||||
for (int j = 0; j < num_partitions; ++j) {
|
||||
acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * inv_global_exp_sum;
|
||||
acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] *
|
||||
inv_global_exp_sum;
|
||||
}
|
||||
from_float(out_ptr[i], acc);
|
||||
}
|
||||
@ -582,44 +607,25 @@ __global__ void paged_attention_v2_reduce_kernel(
|
||||
|
||||
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
|
||||
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
|
||||
((void*)vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
|
||||
KV_DTYPE>), shared_mem_size); \
|
||||
vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
|
||||
KV_DTYPE><<<grid, block, shared_mem_size, stream>>>( \
|
||||
out_ptr, \
|
||||
query_ptr, \
|
||||
key_cache_ptr, \
|
||||
value_cache_ptr, \
|
||||
num_kv_heads, \
|
||||
scale, \
|
||||
block_tables_ptr, \
|
||||
seq_lens_ptr, \
|
||||
max_num_blocks_per_seq, \
|
||||
alibi_slopes_ptr, \
|
||||
q_stride, \
|
||||
kv_block_stride, \
|
||||
kv_head_stride, \
|
||||
((void*)vllm::paged_attention_v1_kernel< \
|
||||
T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_DTYPE>), \
|
||||
shared_mem_size); \
|
||||
vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
|
||||
NUM_THREADS, KV_DTYPE> \
|
||||
<<<grid, block, shared_mem_size, stream>>>( \
|
||||
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \
|
||||
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
|
||||
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
|
||||
kv_scale);
|
||||
|
||||
// TODO(woosuk): Tune NUM_THREADS.
|
||||
template<
|
||||
typename T,
|
||||
typename CACHE_T,
|
||||
int BLOCK_SIZE,
|
||||
vllm::Fp8KVCacheDataType KV_DTYPE,
|
||||
int NUM_THREADS = 128>
|
||||
template <typename T, typename CACHE_T, int BLOCK_SIZE,
|
||||
vllm::Fp8KVCacheDataType KV_DTYPE, int NUM_THREADS = 128>
|
||||
void paged_attention_v1_launcher(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& query,
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
int num_kv_heads,
|
||||
float scale,
|
||||
torch::Tensor& block_tables,
|
||||
torch::Tensor& seq_lens,
|
||||
int max_seq_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
float kv_scale) {
|
||||
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes, float kv_scale) {
|
||||
int num_seqs = query.size(0);
|
||||
int num_heads = query.size(1);
|
||||
int head_size = query.size(2);
|
||||
@ -632,8 +638,9 @@ void paged_attention_v1_launcher(
|
||||
assert(head_size % thread_group_size == 0);
|
||||
|
||||
// NOTE: alibi_slopes is optional.
|
||||
const float* alibi_slopes_ptr = alibi_slopes ?
|
||||
reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
|
||||
const float* alibi_slopes_ptr =
|
||||
alibi_slopes
|
||||
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
|
||||
: nullptr;
|
||||
|
||||
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
|
||||
@ -644,7 +651,8 @@ void paged_attention_v1_launcher(
|
||||
int* seq_lens_ptr = seq_lens.data_ptr<int>();
|
||||
|
||||
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
||||
int padded_max_seq_len = DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
|
||||
int padded_max_seq_len =
|
||||
DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
|
||||
int logits_size = padded_max_seq_len * sizeof(float);
|
||||
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
|
||||
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
|
||||
@ -685,17 +693,8 @@ void paged_attention_v1_launcher(
|
||||
|
||||
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE) \
|
||||
paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE>( \
|
||||
out, \
|
||||
query, \
|
||||
key_cache, \
|
||||
value_cache, \
|
||||
num_kv_heads, \
|
||||
scale, \
|
||||
block_tables, \
|
||||
seq_lens, \
|
||||
max_seq_len, \
|
||||
alibi_slopes, \
|
||||
kv_scale);
|
||||
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
|
||||
seq_lens, max_seq_len, alibi_slopes, kv_scale);
|
||||
|
||||
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
||||
// 1, 2, 4, 64, 128, 256.
|
||||
@ -718,72 +717,43 @@ void paged_attention_v1_launcher(
|
||||
void paged_attention_v1(
|
||||
torch::Tensor& out, // [num_seqs, num_heads, head_size]
|
||||
torch::Tensor& query, // [num_seqs, num_heads, head_size]
|
||||
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||
torch::Tensor&
|
||||
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
torch::Tensor&
|
||||
value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||
int num_kv_heads, // [num_heads]
|
||||
float scale,
|
||||
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
torch::Tensor& seq_lens, // [num_seqs]
|
||||
int block_size,
|
||||
int max_seq_len,
|
||||
int block_size, int max_seq_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype,
|
||||
float kv_scale) {
|
||||
|
||||
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, CALL_V1_LAUNCHER_BLOCK_SIZE)
|
||||
}
|
||||
const std::string& kv_cache_dtype, float kv_scale){
|
||||
|
||||
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
|
||||
CALL_V1_LAUNCHER_BLOCK_SIZE)}
|
||||
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
|
||||
vllm::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
|
||||
KV_DTYPE, PARTITION_SIZE> \
|
||||
vllm::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
|
||||
NUM_THREADS, KV_DTYPE, PARTITION_SIZE> \
|
||||
<<<grid, block, shared_mem_size, stream>>>( \
|
||||
exp_sums_ptr, \
|
||||
max_logits_ptr, \
|
||||
tmp_out_ptr, \
|
||||
query_ptr, \
|
||||
key_cache_ptr, \
|
||||
value_cache_ptr, \
|
||||
num_kv_heads, \
|
||||
scale, \
|
||||
block_tables_ptr, \
|
||||
seq_lens_ptr, \
|
||||
max_num_blocks_per_seq, \
|
||||
alibi_slopes_ptr, \
|
||||
q_stride, \
|
||||
kv_block_stride, \
|
||||
kv_head_stride, \
|
||||
kv_scale); \
|
||||
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, PARTITION_SIZE> \
|
||||
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
|
||||
value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
|
||||
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
|
||||
kv_block_stride, kv_head_stride, kv_scale); \
|
||||
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \
|
||||
PARTITION_SIZE> \
|
||||
<<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
|
||||
out_ptr, \
|
||||
exp_sums_ptr, \
|
||||
max_logits_ptr, \
|
||||
tmp_out_ptr, \
|
||||
seq_lens_ptr, \
|
||||
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
|
||||
max_num_partitions);
|
||||
|
||||
template<
|
||||
typename T,
|
||||
typename CACHE_T,
|
||||
int BLOCK_SIZE,
|
||||
vllm::Fp8KVCacheDataType KV_DTYPE,
|
||||
int NUM_THREADS = 128,
|
||||
template <typename T, typename CACHE_T, int BLOCK_SIZE,
|
||||
vllm::Fp8KVCacheDataType KV_DTYPE, int NUM_THREADS = 128,
|
||||
int PARTITION_SIZE = 512>
|
||||
void paged_attention_v2_launcher(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& exp_sums,
|
||||
torch::Tensor& max_logits,
|
||||
torch::Tensor& tmp_out,
|
||||
torch::Tensor& query,
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
int num_kv_heads,
|
||||
float scale,
|
||||
torch::Tensor& block_tables,
|
||||
torch::Tensor& seq_lens,
|
||||
int max_seq_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
float kv_scale) {
|
||||
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
|
||||
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes, float kv_scale) {
|
||||
int num_seqs = query.size(0);
|
||||
int num_heads = query.size(1);
|
||||
int head_size = query.size(2);
|
||||
@ -796,8 +766,9 @@ void paged_attention_v2_launcher(
|
||||
assert(head_size % thread_group_size == 0);
|
||||
|
||||
// NOTE: alibi_slopes is optional.
|
||||
const float* alibi_slopes_ptr = alibi_slopes ?
|
||||
reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
|
||||
const float* alibi_slopes_ptr =
|
||||
alibi_slopes
|
||||
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
|
||||
: nullptr;
|
||||
|
||||
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
|
||||
@ -855,19 +826,8 @@ void paged_attention_v2_launcher(
|
||||
|
||||
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE) \
|
||||
paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE>( \
|
||||
out, \
|
||||
exp_sums, \
|
||||
max_logits, \
|
||||
tmp_out, \
|
||||
query, \
|
||||
key_cache, \
|
||||
value_cache, \
|
||||
num_kv_heads, \
|
||||
scale, \
|
||||
block_tables, \
|
||||
seq_lens, \
|
||||
max_seq_len, \
|
||||
alibi_slopes, \
|
||||
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
|
||||
num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
|
||||
kv_scale);
|
||||
|
||||
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
||||
@ -892,20 +852,22 @@ void paged_attention_v2(
|
||||
torch::Tensor& out, // [num_seqs, num_heads, head_size]
|
||||
torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
||||
torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
|
||||
torch::Tensor& tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
||||
torch::Tensor&
|
||||
tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
||||
torch::Tensor& query, // [num_seqs, num_heads, head_size]
|
||||
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||
torch::Tensor&
|
||||
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
torch::Tensor&
|
||||
value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||
int num_kv_heads, // [num_heads]
|
||||
float scale,
|
||||
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
torch::Tensor& seq_lens, // [num_seqs]
|
||||
int block_size,
|
||||
int max_seq_len,
|
||||
int block_size, int max_seq_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype,
|
||||
float kv_scale) {
|
||||
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, CALL_V2_LAUNCHER_BLOCK_SIZE)
|
||||
const std::string& kv_cache_dtype, float kv_scale) {
|
||||
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
|
||||
CALL_V2_LAUNCHER_BLOCK_SIZE)
|
||||
}
|
||||
|
||||
#undef WARP_SIZE
|
||||
|
@ -1,5 +1,6 @@
|
||||
/*
|
||||
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
||||
* Adapted from
|
||||
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
||||
* Copyright (c) 2023, The vLLM team.
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
|
@ -1,6 +1,8 @@
|
||||
/*
|
||||
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
||||
* and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
||||
* Adapted from
|
||||
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
||||
* and
|
||||
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
||||
* Copyright (c) 2023, The vLLM team.
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
@ -280,7 +282,8 @@ inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) {
|
||||
}
|
||||
|
||||
// Vector fused multiply-add.
|
||||
inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) {
|
||||
inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b,
|
||||
__nv_bfloat162 c) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
assert(false);
|
||||
#else
|
||||
@ -288,7 +291,8 @@ inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bf
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c) {
|
||||
inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b,
|
||||
__nv_bfloat162 c) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
assert(false);
|
||||
#else
|
||||
|
@ -1,6 +1,8 @@
|
||||
/*
|
||||
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
||||
* and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
||||
* Adapted from
|
||||
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
||||
* and
|
||||
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
||||
* Copyright (c) 2023, The vLLM team.
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
@ -130,7 +132,9 @@ inline __device__ uint32_t float2_to_half2(float2 f) {
|
||||
} tmp;
|
||||
#ifndef USE_ROCM
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x));
|
||||
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n"
|
||||
: "=r"(tmp.u32)
|
||||
: "f"(f.y), "f"(f.x));
|
||||
#else
|
||||
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
|
||||
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
|
||||
@ -327,9 +331,13 @@ inline __device__ Float8_ mul(uint16_t a, uint4 b) {
|
||||
inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) {
|
||||
uint32_t d;
|
||||
#ifndef USE_ROCM
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
|
||||
: "=r"(d)
|
||||
: "r"(a), "r"(b), "r"(c));
|
||||
#else
|
||||
asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" : "=v"(d) : "v"(a), "v"(b), "v"(c));
|
||||
asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n"
|
||||
: "=v"(d)
|
||||
: "v"(a), "v"(b), "v"(c));
|
||||
#endif
|
||||
return d;
|
||||
}
|
||||
@ -470,13 +478,9 @@ inline __device__ void from_float(uint4& dst, Float8_ src) {
|
||||
}
|
||||
|
||||
// From float16 to float32.
|
||||
inline __device__ float to_float(uint16_t u) {
|
||||
return half_to_float(u);
|
||||
}
|
||||
inline __device__ float to_float(uint16_t u) { return half_to_float(u); }
|
||||
|
||||
inline __device__ float2 to_float(uint32_t u) {
|
||||
return half2_to_float2(u);
|
||||
}
|
||||
inline __device__ float2 to_float(uint32_t u) { return half2_to_float2(u); }
|
||||
|
||||
inline __device__ Float4_ to_float(uint2 u) {
|
||||
Float4_ tmp;
|
||||
@ -495,8 +499,6 @@ inline __device__ Float8_ to_float(uint4 u) {
|
||||
}
|
||||
|
||||
// Zero-out a variable.
|
||||
inline __device__ void zero(uint16_t& dst) {
|
||||
dst = uint16_t(0);
|
||||
}
|
||||
inline __device__ void zero(uint16_t& dst) { dst = uint16_t(0); }
|
||||
|
||||
} // namespace vllm
|
||||
|
@ -1,6 +1,8 @@
|
||||
/*
|
||||
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
||||
* and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
||||
* Adapted from
|
||||
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
||||
* and
|
||||
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
||||
* Copyright (c) 2023, The vLLM team.
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
@ -66,9 +68,7 @@ struct FloatVec<float4> {
|
||||
};
|
||||
|
||||
// Vector addition.
|
||||
inline __device__ float add(float a, float b) {
|
||||
return a + b;
|
||||
}
|
||||
inline __device__ float add(float a, float b) { return a + b; }
|
||||
|
||||
inline __device__ float2 add(float2 a, float2 b) {
|
||||
float2 c;
|
||||
@ -129,9 +129,7 @@ inline __device__ float4 mul(float a, float4 b) {
|
||||
}
|
||||
|
||||
// Vector fused multiply-add.
|
||||
inline __device__ float fma(float a, float b, float c) {
|
||||
return a * b + c;
|
||||
}
|
||||
inline __device__ float fma(float a, float b, float c) { return a * b + c; }
|
||||
|
||||
inline __device__ float2 fma(float2 a, float2 b, float2 c) {
|
||||
float2 d;
|
||||
@ -208,9 +206,7 @@ inline __device__ float sum(Float8_ v) {
|
||||
}
|
||||
|
||||
// Vector dot product.
|
||||
inline __device__ float dot(float a, float b) {
|
||||
return a * b;
|
||||
}
|
||||
inline __device__ float dot(float a, float b) { return a * b; }
|
||||
|
||||
inline __device__ float dot(float2 a, float2 b) {
|
||||
float2 c = mul<float2, float2, float2>(a, b);
|
||||
@ -232,42 +228,24 @@ inline __device__ float dot(Float8_ a, Float8_ b) {
|
||||
}
|
||||
|
||||
// From float to float.
|
||||
inline __device__ void from_float(float& dst, float src) {
|
||||
dst = src;
|
||||
}
|
||||
inline __device__ void from_float(float& dst, float src) { dst = src; }
|
||||
|
||||
inline __device__ void from_float(float2& dst, float2 src) {
|
||||
dst = src;
|
||||
}
|
||||
inline __device__ void from_float(float2& dst, float2 src) { dst = src; }
|
||||
|
||||
inline __device__ void from_float(float4& dst, float4 src) {
|
||||
dst = src;
|
||||
}
|
||||
inline __device__ void from_float(float4& dst, float4 src) { dst = src; }
|
||||
|
||||
// From float to float.
|
||||
inline __device__ float to_float(float u) {
|
||||
return u;
|
||||
}
|
||||
inline __device__ float to_float(float u) { return u; }
|
||||
|
||||
inline __device__ float2 to_float(float2 u) {
|
||||
return u;
|
||||
}
|
||||
inline __device__ float2 to_float(float2 u) { return u; }
|
||||
|
||||
inline __device__ float4 to_float(float4 u) {
|
||||
return u;
|
||||
}
|
||||
inline __device__ float4 to_float(float4 u) { return u; }
|
||||
|
||||
inline __device__ Float4_ to_float(Float4_ u) {
|
||||
return u;
|
||||
}
|
||||
inline __device__ Float4_ to_float(Float4_ u) { return u; }
|
||||
|
||||
inline __device__ Float8_ to_float(Float8_ u) {
|
||||
return u;
|
||||
}
|
||||
inline __device__ Float8_ to_float(Float8_ u) { return u; }
|
||||
|
||||
// Zero-out a variable.
|
||||
inline __device__ void zero(float& dst) {
|
||||
dst = 0.f;
|
||||
}
|
||||
inline __device__ void zero(float& dst) { dst = 0.f; }
|
||||
|
||||
} // namespace vllm
|
||||
|
28
csrc/cache.h
28
csrc/cache.h
@ -5,36 +5,24 @@
|
||||
#include <map>
|
||||
#include <vector>
|
||||
|
||||
void swap_blocks(
|
||||
torch::Tensor& src,
|
||||
torch::Tensor& dst,
|
||||
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
|
||||
const torch::Tensor& block_mapping);
|
||||
|
||||
void copy_blocks(
|
||||
std::vector<torch::Tensor>& key_caches,
|
||||
void copy_blocks(std::vector<torch::Tensor>& key_caches,
|
||||
std::vector<torch::Tensor>& value_caches,
|
||||
const torch::Tensor& block_mapping);
|
||||
|
||||
void reshape_and_cache(
|
||||
torch::Tensor& key,
|
||||
torch::Tensor& value,
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
|
||||
torch::Tensor& key_cache, torch::Tensor& value_cache,
|
||||
torch::Tensor& slot_mapping,
|
||||
const std::string& kv_cache_dtype,
|
||||
const float kv_scale);
|
||||
const std::string& kv_cache_dtype, const float kv_scale);
|
||||
|
||||
void reshape_and_cache_flash(
|
||||
torch::Tensor& key,
|
||||
torch::Tensor& value,
|
||||
void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
torch::Tensor& slot_mapping,
|
||||
const std::string& kv_cache_dtype);
|
||||
|
||||
// Just for unittest
|
||||
void convert_fp8(
|
||||
torch::Tensor& dst_cache,
|
||||
torch::Tensor& src_cache,
|
||||
const float scale,
|
||||
const std::string& kv_cache_dtype);
|
||||
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
|
||||
const float scale, const std::string& kv_cache_dtype);
|
||||
|
@ -21,16 +21,13 @@
|
||||
typedef __hip_bfloat16 __nv_bfloat16;
|
||||
#endif
|
||||
|
||||
void swap_blocks(
|
||||
torch::Tensor& src,
|
||||
torch::Tensor& dst,
|
||||
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
|
||||
const torch::Tensor& block_mapping) {
|
||||
torch::Device src_device = src.device();
|
||||
torch::Device dst_device = dst.device();
|
||||
cudaMemcpyKind memcpy_type;
|
||||
if (src_device.is_cuda() && dst_device.is_cuda()) {
|
||||
TORCH_CHECK(
|
||||
src_device.index() == dst_device.index(),
|
||||
TORCH_CHECK(src_device.index() == dst_device.index(),
|
||||
"src and dst must be on the same GPU");
|
||||
memcpy_type = cudaMemcpyDeviceToDevice;
|
||||
} else if (src_device.is_cuda() && dst_device.is_cpu()) {
|
||||
@ -50,7 +47,8 @@ void swap_blocks(
|
||||
char* dst_ptr = static_cast<char*>(dst.data_ptr());
|
||||
|
||||
const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
|
||||
const at::cuda::OptionalCUDAGuard device_guard(src_device.is_cuda() ? src_device : dst_device);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(
|
||||
src_device.is_cuda() ? src_device : dst_device);
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
// NOTE(woosuk): This can be slow if the number of blocks is large.
|
||||
const int64_t num_blocks = block_mapping.size(0);
|
||||
@ -59,12 +57,8 @@ void swap_blocks(
|
||||
int64_t dst_block_number = block_mapping[i][1].item<int64_t>();
|
||||
int64_t src_offset = src_block_number * block_size_in_bytes;
|
||||
int64_t dst_offset = dst_block_number * block_size_in_bytes;
|
||||
cudaMemcpyAsync(
|
||||
dst_ptr + dst_offset,
|
||||
src_ptr + src_offset,
|
||||
block_size_in_bytes,
|
||||
memcpy_type,
|
||||
stream);
|
||||
cudaMemcpyAsync(dst_ptr + dst_offset, src_ptr + src_offset,
|
||||
block_size_in_bytes, memcpy_type, stream);
|
||||
}
|
||||
}
|
||||
|
||||
@ -72,8 +66,7 @@ namespace vllm {
|
||||
|
||||
// Grid: (num_layers, num_pairs)
|
||||
template <typename scalar_t>
|
||||
__global__ void copy_blocks_kernel(
|
||||
int64_t* key_cache_ptrs,
|
||||
__global__ void copy_blocks_kernel(int64_t* key_cache_ptrs,
|
||||
int64_t* value_cache_ptrs,
|
||||
const int64_t* __restrict__ block_mapping,
|
||||
const int numel_per_block) {
|
||||
@ -81,7 +74,8 @@ __global__ void copy_blocks_kernel(
|
||||
const int pair_idx = blockIdx.y;
|
||||
|
||||
scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]);
|
||||
scalar_t* value_cache = reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
|
||||
scalar_t* value_cache =
|
||||
reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
|
||||
int64_t src_block_number = block_mapping[2 * pair_idx];
|
||||
int64_t dst_block_number = block_mapping[2 * pair_idx + 1];
|
||||
|
||||
@ -101,8 +95,7 @@ __global__ void copy_blocks_kernel(
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
void copy_blocks(
|
||||
std::vector<torch::Tensor>& key_caches,
|
||||
void copy_blocks(std::vector<torch::Tensor>& key_caches,
|
||||
std::vector<torch::Tensor>& value_caches,
|
||||
const torch::Tensor& block_mapping) {
|
||||
int num_layers = key_caches.size();
|
||||
@ -118,8 +111,10 @@ void copy_blocks(
|
||||
int64_t key_cache_ptrs[num_layers];
|
||||
int64_t value_cache_ptrs[num_layers];
|
||||
for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
|
||||
key_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
|
||||
value_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
|
||||
key_cache_ptrs[layer_idx] =
|
||||
reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
|
||||
value_cache_ptrs[layer_idx] =
|
||||
reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
|
||||
}
|
||||
|
||||
// block_mapping is a 2D tensor with shape (num_pairs, 2).
|
||||
@ -127,10 +122,12 @@ void copy_blocks(
|
||||
|
||||
// Move the data structures to the GPU.
|
||||
// NOTE: This synchronizes the CPU and GPU.
|
||||
torch::Tensor key_cache_ptrs_tensor = torch::from_blob(
|
||||
key_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
|
||||
torch::Tensor value_cache_ptrs_tensor = torch::from_blob(
|
||||
value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
|
||||
torch::Tensor key_cache_ptrs_tensor =
|
||||
torch::from_blob(key_cache_ptrs, {num_layers}, torch::kInt64)
|
||||
.to(cache_device);
|
||||
torch::Tensor value_cache_ptrs_tensor =
|
||||
torch::from_blob(value_cache_ptrs, {num_layers}, torch::kInt64)
|
||||
.to(cache_device);
|
||||
|
||||
// Launch the kernel.
|
||||
const int numel_per_block = key_caches[0][0].numel();
|
||||
@ -143,8 +140,7 @@ void copy_blocks(
|
||||
vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
key_cache_ptrs_tensor.data_ptr<int64_t>(),
|
||||
value_cache_ptrs_tensor.data_ptr<int64_t>(),
|
||||
block_mapping.data_ptr<int64_t>(),
|
||||
numel_per_block);
|
||||
block_mapping.data_ptr<int64_t>(), numel_per_block);
|
||||
}));
|
||||
}
|
||||
|
||||
@ -154,15 +150,13 @@ template<typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
|
||||
__global__ void reshape_and_cache_kernel(
|
||||
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
||||
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
|
||||
cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||
cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x,
|
||||
// block_size, x]
|
||||
cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size,
|
||||
// block_size]
|
||||
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||
const int key_stride,
|
||||
const int value_stride,
|
||||
const int num_heads,
|
||||
const int head_size,
|
||||
const int block_size,
|
||||
const int x,
|
||||
const int key_stride, const int value_stride, const int num_heads,
|
||||
const int head_size, const int block_size, const int x,
|
||||
const float kv_scale) {
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
const int64_t slot_idx = slot_mapping[token_idx];
|
||||
@ -184,23 +178,24 @@ __global__ void reshape_and_cache_kernel(
|
||||
const int x_idx = head_offset / x;
|
||||
const int x_offset = head_offset % x;
|
||||
|
||||
const int64_t tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
|
||||
+ head_idx * (head_size / x) * block_size * x
|
||||
+ x_idx * block_size * x
|
||||
+ block_offset * x
|
||||
+ x_offset;
|
||||
const int64_t tgt_value_idx = block_idx * num_heads * head_size * block_size
|
||||
+ head_idx * head_size * block_size
|
||||
+ head_offset * block_size
|
||||
+ block_offset;
|
||||
const int64_t tgt_key_idx =
|
||||
block_idx * num_heads * (head_size / x) * block_size * x +
|
||||
head_idx * (head_size / x) * block_size * x + x_idx * block_size * x +
|
||||
block_offset * x + x_offset;
|
||||
const int64_t tgt_value_idx =
|
||||
block_idx * num_heads * head_size * block_size +
|
||||
head_idx * head_size * block_size + head_offset * block_size +
|
||||
block_offset;
|
||||
scalar_t tgt_key = key[src_key_idx];
|
||||
scalar_t tgt_value = value[src_value_idx];
|
||||
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
|
||||
key_cache[tgt_key_idx] = tgt_key;
|
||||
value_cache[tgt_value_idx] = tgt_value;
|
||||
} else {
|
||||
key_cache[tgt_key_idx] = fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, kv_scale);
|
||||
value_cache[tgt_value_idx] = fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, kv_scale);
|
||||
key_cache[tgt_key_idx] =
|
||||
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, kv_scale);
|
||||
value_cache[tgt_value_idx] =
|
||||
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, kv_scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -209,15 +204,13 @@ template<typename scalar_t>
|
||||
__global__ void reshape_and_cache_flash_kernel(
|
||||
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
||||
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
|
||||
scalar_t* __restrict__ k_cache, // [num_blocks, block_size, num_heads, head_size]
|
||||
scalar_t* __restrict__ v_cache, // [num_blocks, block_size, num_heads, head_size]
|
||||
scalar_t* __restrict__ k_cache, // [num_blocks, block_size, num_heads,
|
||||
// head_size]
|
||||
scalar_t* __restrict__ v_cache, // [num_blocks, block_size, num_heads,
|
||||
// head_size]
|
||||
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||
const int block_stride,
|
||||
const int key_stride,
|
||||
const int value_stride,
|
||||
const int num_heads,
|
||||
const int head_size,
|
||||
const int block_size) {
|
||||
const int block_stride, const int key_stride, const int value_stride,
|
||||
const int num_heads, const int head_size, const int block_size) {
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
const int64_t slot_idx = slot_mapping[token_idx];
|
||||
// NOTE: slot_idx can be -1 if the token is padded
|
||||
@ -232,10 +225,9 @@ __global__ void reshape_and_cache_flash_kernel(
|
||||
const int64_t src_value_idx = token_idx * value_stride + i;
|
||||
const int head_idx = i / head_size;
|
||||
const int head_offset = i % head_size;
|
||||
const int64_t tgt_value_idx = block_idx * block_stride
|
||||
+ block_offset * num_heads * head_size
|
||||
+ head_idx * head_size
|
||||
+ head_offset;
|
||||
const int64_t tgt_value_idx = block_idx * block_stride +
|
||||
block_offset * num_heads * head_size +
|
||||
head_idx * head_size + head_offset;
|
||||
k_cache[tgt_value_idx] = key[src_key_idx];
|
||||
v_cache[tgt_value_idx] = value[src_value_idx];
|
||||
}
|
||||
@ -246,29 +238,24 @@ __global__ void reshape_and_cache_flash_kernel(
|
||||
// CACHE_T is the data type of key and value tensors.
|
||||
// KV_DTYPE is the real data type of kv-cache.
|
||||
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \
|
||||
vllm::reshape_and_cache_kernel<KV_T, CACHE_T, KV_DTYPE><<<grid, block, 0, stream>>>( \
|
||||
vllm::reshape_and_cache_kernel<KV_T, CACHE_T, KV_DTYPE> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<KV_T*>(key.data_ptr()), \
|
||||
reinterpret_cast<KV_T*>(value.data_ptr()), \
|
||||
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
|
||||
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
|
||||
slot_mapping.data_ptr<int64_t>(), \
|
||||
key_stride, \
|
||||
value_stride, \
|
||||
num_heads, \
|
||||
head_size, \
|
||||
block_size, \
|
||||
x, \
|
||||
kv_scale);
|
||||
slot_mapping.data_ptr<int64_t>(), key_stride, value_stride, \
|
||||
num_heads, head_size, block_size, x, kv_scale);
|
||||
|
||||
void reshape_and_cache(
|
||||
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& value, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||
torch::Tensor&
|
||||
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
torch::Tensor&
|
||||
value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||
torch::Tensor& slot_mapping, // [num_tokens]
|
||||
const std::string& kv_cache_dtype,
|
||||
const float kv_scale)
|
||||
{
|
||||
const std::string& kv_cache_dtype, const float kv_scale) {
|
||||
int num_tokens = key.size(0);
|
||||
int num_heads = key.size(1);
|
||||
int head_size = key.size(2);
|
||||
@ -283,7 +270,8 @@ void reshape_and_cache(
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype, CALL_RESHAPE_AND_CACHE)
|
||||
DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
|
||||
CALL_RESHAPE_AND_CACHE)
|
||||
}
|
||||
|
||||
void reshape_and_cache_flash(
|
||||
@ -292,8 +280,7 @@ void reshape_and_cache_flash(
|
||||
torch::Tensor& k_cache, // [num_blocks, block_size, num_heads, head_size]
|
||||
torch::Tensor& v_cache, // [num_blocks, block_size, num_heads, head_size]
|
||||
torch::Tensor& slot_mapping, // [num_tokens]
|
||||
const std::string& kv_cache_dtype)
|
||||
{
|
||||
const std::string& kv_cache_dtype) {
|
||||
// FIXME: only support auto datatype, does not support fp8
|
||||
if (kv_cache_dtype != "auto") {
|
||||
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
|
||||
@ -313,36 +300,28 @@ void reshape_and_cache_flash(
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
key.scalar_type(),
|
||||
"reshape_and_cache_flash",
|
||||
[&] {
|
||||
vllm::reshape_and_cache_flash_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
key.data_ptr<scalar_t>(),
|
||||
value.data_ptr<scalar_t>(),
|
||||
k_cache.data_ptr<scalar_t>(),
|
||||
v_cache.data_ptr<scalar_t>(),
|
||||
slot_mapping.data_ptr<int64_t>(),
|
||||
block_stride,
|
||||
key_stride,
|
||||
value_stride,
|
||||
num_heads,
|
||||
head_size,
|
||||
block_size);
|
||||
key.scalar_type(), "reshape_and_cache_flash", [&] {
|
||||
vllm::reshape_and_cache_flash_kernel<scalar_t>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
|
||||
k_cache.data_ptr<scalar_t>(), v_cache.data_ptr<scalar_t>(),
|
||||
slot_mapping.data_ptr<int64_t>(), block_stride, key_stride,
|
||||
value_stride, num_heads, head_size, block_size);
|
||||
});
|
||||
}
|
||||
|
||||
namespace vllm {
|
||||
|
||||
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
||||
__global__ void convert_fp8_kernel(
|
||||
const Tin* __restrict__ src_cache,
|
||||
__global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache,
|
||||
Tout* __restrict__ dst_cache,
|
||||
const float kv_scale,
|
||||
const int64_t block_stride) {
|
||||
const int64_t block_idx = blockIdx.x;
|
||||
for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
|
||||
int64_t idx = block_idx * block_stride + i;
|
||||
dst_cache[idx] = fp8::scaled_convert<Tout, Tin, kv_dt>(src_cache[idx], kv_scale);
|
||||
dst_cache[idx] =
|
||||
fp8::scaled_convert<Tout, Tin, kv_dt>(src_cache[idx], kv_scale);
|
||||
}
|
||||
}
|
||||
|
||||
@ -351,23 +330,16 @@ __global__ void convert_fp8_kernel(
|
||||
#define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE) \
|
||||
vllm::convert_fp8_kernel<Tout, Tin, KV_DTYPE><<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<Tin*>(src_cache.data_ptr()), \
|
||||
reinterpret_cast<Tout*>(dst_cache.data_ptr()), \
|
||||
kv_scale, \
|
||||
block_stride);
|
||||
reinterpret_cast<Tout*>(dst_cache.data_ptr()), kv_scale, block_stride);
|
||||
|
||||
// Only for testing.
|
||||
void convert_fp8(
|
||||
torch::Tensor& dst_cache,
|
||||
torch::Tensor& src_cache,
|
||||
const float kv_scale,
|
||||
const std::string& kv_cache_dtype)
|
||||
{
|
||||
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
|
||||
const float kv_scale, const std::string& kv_cache_dtype) {
|
||||
torch::Device src_device = src_cache.device();
|
||||
torch::Device dst_device = dst_cache.device();
|
||||
TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU")
|
||||
TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU")
|
||||
TORCH_CHECK(
|
||||
src_device.index() == dst_device.index(),
|
||||
TORCH_CHECK(src_device.index() == dst_device.index(),
|
||||
"src and dst must be on the same GPU");
|
||||
at::cuda::OptionalCUDAGuard device_guard(src_device);
|
||||
|
||||
@ -398,13 +370,15 @@ void convert_fp8(
|
||||
} else if (src_cache.dtype() == at::ScalarType::Half) {
|
||||
CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
|
||||
} else if (src_cache.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_CONVERT_FP8(uint8_t, __nv_bfloat16, vllm::Fp8KVCacheDataType::kFp8E4M3);
|
||||
CALL_CONVERT_FP8(uint8_t, __nv_bfloat16,
|
||||
vllm::Fp8KVCacheDataType::kFp8E4M3);
|
||||
} else if (dst_cache.dtype() == at::ScalarType::Float) {
|
||||
CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
|
||||
} else if (dst_cache.dtype() == at::ScalarType::Half) {
|
||||
CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
|
||||
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
|
||||
CALL_CONVERT_FP8(__nv_bfloat16, uint8_t,
|
||||
vllm::Fp8KVCacheDataType::kFp8E4M3);
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype);
|
||||
|
@ -81,12 +81,10 @@ void silu_and_mul(torch::Tensor &out, torch::Tensor &input) {
|
||||
int num_tokens = input.numel() / input.size(-1);
|
||||
int d = input.size(-1) / 2;
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "silu_and_mul_impl", [&] {
|
||||
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "silu_and_mul_impl", [&] {
|
||||
CPU_KERNEL_GUARD_IN(silu_and_mul_impl)
|
||||
activation_kernel<scalar_t, silu_act, true>(num_tokens, d,
|
||||
input.data_ptr<scalar_t>(),
|
||||
out.data_ptr<scalar_t>());
|
||||
activation_kernel<scalar_t, silu_act, true>(
|
||||
num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
|
||||
CPU_KERNEL_GUARD_OUT(silu_and_mul_impl)
|
||||
});
|
||||
}
|
||||
@ -97,12 +95,10 @@ void gelu_and_mul(torch::Tensor &out, // [..., d]
|
||||
int num_tokens = input.numel() / input.size(-1);
|
||||
int d = input.size(-1) / 2;
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "gelu_and_mul_impl", [&] {
|
||||
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_and_mul_impl", [&] {
|
||||
CPU_KERNEL_GUARD_IN(gelu_and_mul_impl)
|
||||
activation_kernel<scalar_t, gelu_act, true>(num_tokens, d,
|
||||
input.data_ptr<scalar_t>(),
|
||||
out.data_ptr<scalar_t>());
|
||||
activation_kernel<scalar_t, gelu_act, true>(
|
||||
num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
|
||||
CPU_KERNEL_GUARD_OUT(gelu_and_mul_impl)
|
||||
});
|
||||
}
|
||||
|
@ -2,7 +2,8 @@
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename scalar_t> struct KernelVecType {
|
||||
template <typename scalar_t>
|
||||
struct KernelVecType {
|
||||
using q_load_vec_type = void;
|
||||
using q_vec_type = void;
|
||||
using k_load_vec_type = void;
|
||||
@ -11,7 +12,8 @@ template <typename scalar_t> struct KernelVecType {
|
||||
using v_load_vec_type = void;
|
||||
};
|
||||
|
||||
template <> struct KernelVecType<float> {
|
||||
template <>
|
||||
struct KernelVecType<float> {
|
||||
using q_load_vec_type = vec_op::FP32Vec4;
|
||||
using q_vec_type = vec_op::FP32Vec16;
|
||||
using k_load_vec_type = vec_op::FP32Vec16;
|
||||
@ -21,7 +23,8 @@ template <> struct KernelVecType<float> {
|
||||
};
|
||||
|
||||
#ifdef __AVX512BF16__
|
||||
template <> struct KernelVecType<c10::BFloat16> {
|
||||
template <>
|
||||
struct KernelVecType<c10::BFloat16> {
|
||||
using q_load_vec_type = vec_op::BF16Vec8;
|
||||
using q_vec_type = vec_op::BF16Vec32;
|
||||
using k_load_vec_type = vec_op::BF16Vec32;
|
||||
@ -30,7 +33,8 @@ template <> struct KernelVecType<c10::BFloat16> {
|
||||
using v_load_vec_type = vec_op::BF16Vec16;
|
||||
};
|
||||
#else
|
||||
template <> struct KernelVecType<c10::BFloat16> {
|
||||
template <>
|
||||
struct KernelVecType<c10::BFloat16> {
|
||||
using q_load_vec_type = vec_op::BF16Vec8;
|
||||
using q_vec_type = vec_op::FP32Vec16;
|
||||
using k_load_vec_type = vec_op::BF16Vec16;
|
||||
@ -67,9 +71,10 @@ FORCE_INLINE std::pair<T, T> reduceSoftmax(T *data, const int size,
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
FORCE_INLINE std::pair<T, T>
|
||||
reduceSoftmaxAlibi(T *data, const int size, const int capacity,
|
||||
const float alibi_slope, const int start_index,
|
||||
FORCE_INLINE std::pair<T, T> reduceSoftmaxAlibi(T* data, const int size,
|
||||
const int capacity,
|
||||
const float alibi_slope,
|
||||
const int start_index,
|
||||
const int seq_len) {
|
||||
data[0] += alibi_slope * (start_index - seq_len + 1);
|
||||
T max = data[0];
|
||||
@ -215,16 +220,16 @@ FORCE_INLINE void reduceValueBlock(const float *prob, const scalar_t *v_block,
|
||||
namespace {
|
||||
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE>
|
||||
struct paged_attention_v1_impl {
|
||||
static void
|
||||
call(scalar_t *__restrict__ out, // [num_seqs, num_heads, head_size]
|
||||
static void call(
|
||||
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
|
||||
// head_size/x, block_size, x]
|
||||
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
|
||||
// head_size, block_size]
|
||||
const int num_kv_heads, const float scale,
|
||||
const int
|
||||
*__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int* __restrict__ block_tables, // [num_seqs,
|
||||
// max_num_blocks_per_seq]
|
||||
const int* __restrict__ seq_lens, // [num_seqs]
|
||||
const int max_num_blocks_per_seq,
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
@ -257,8 +262,7 @@ struct paged_attention_v1_impl {
|
||||
const int64_t kv_head_idx = head_idx / num_queries_per_kv;
|
||||
const scalar_t* __restrict__ q_vec_ptr =
|
||||
q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
||||
const int last_block_token_num =
|
||||
seq_len - (block_num - 1) * BLOCK_SIZE;
|
||||
const int last_block_token_num = seq_len - (block_num - 1) * BLOCK_SIZE;
|
||||
float* __restrict__ thread_block_logits =
|
||||
logits + omp_get_thread_num() * max_seq_len_padded;
|
||||
|
||||
@ -282,8 +286,7 @@ struct paged_attention_v1_impl {
|
||||
block_num * BLOCK_SIZE, alibi_slopes[head_idx], 0,
|
||||
seq_len);
|
||||
} else {
|
||||
reduceSoftmax(thread_block_logits, seq_len,
|
||||
block_num * BLOCK_SIZE);
|
||||
reduceSoftmax(thread_block_logits, seq_len, block_num * BLOCK_SIZE);
|
||||
}
|
||||
|
||||
// Compute value
|
||||
@ -348,8 +351,8 @@ template <typename T, int BLOCK_SIZE>
|
||||
void paged_attention_v1_impl_launcher(
|
||||
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
||||
torch::Tensor &block_tables, torch::Tensor &seq_lens,
|
||||
int max_seq_len, const c10::optional<torch::Tensor> &alibi_slopes) {
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes) {
|
||||
int num_seqs = query.size(0);
|
||||
int num_heads = query.size(1);
|
||||
int head_size = query.size(2);
|
||||
@ -415,9 +418,8 @@ void paged_attention_v1_impl_launcher(
|
||||
void paged_attention_v1(torch::Tensor& out, torch::Tensor& query,
|
||||
torch::Tensor& key_cache, torch::Tensor& value_cache,
|
||||
int num_kv_heads, float scale,
|
||||
torch::Tensor &block_tables,
|
||||
torch::Tensor &seq_lens, int block_size,
|
||||
int max_seq_len,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens,
|
||||
int block_size, int max_seq_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, float kv_scale) {
|
||||
TORCH_CHECK(kv_scale == 1.0f);
|
||||
@ -435,9 +437,10 @@ template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, int PARTITION_SIZE>
|
||||
struct paged_attention_v2_impl {
|
||||
static void call(
|
||||
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
||||
float *__restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
||||
float
|
||||
*__restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
|
||||
float* __restrict__ exp_sums, // [num_seqs, num_heads,
|
||||
// max_num_partitions]
|
||||
float* __restrict__ max_logits, // [num_seqs, num_heads,
|
||||
// max_num_partitions]
|
||||
scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
|
||||
// max_num_partitions, head_size]
|
||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||
@ -446,8 +449,8 @@ struct paged_attention_v2_impl {
|
||||
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
|
||||
// head_size, block_size]
|
||||
const int num_kv_heads, const float scale,
|
||||
const int
|
||||
*__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int* __restrict__ block_tables, // [num_seqs,
|
||||
// max_num_blocks_per_seq]
|
||||
const int* __restrict__ seq_lens, // [num_seqs]
|
||||
const int max_num_blocks_per_seq,
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
@ -468,8 +471,7 @@ struct paged_attention_v2_impl {
|
||||
const int seq_len = seq_lens[seq_idx];
|
||||
const int start_token_idx = partition_idx * PARTITION_SIZE;
|
||||
|
||||
if (start_token_idx >= seq_len)
|
||||
continue;
|
||||
if (start_token_idx >= seq_len) continue;
|
||||
|
||||
const int partition_num =
|
||||
(seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
|
||||
@ -477,8 +479,7 @@ struct paged_attention_v2_impl {
|
||||
const int token_num =
|
||||
(std::min(seq_len, start_token_idx + PARTITION_SIZE) -
|
||||
start_token_idx);
|
||||
const int block_num =
|
||||
(token_num + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
const int block_num = (token_num + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
const int last_block_token_num =
|
||||
token_num - (block_num - 1) * BLOCK_SIZE;
|
||||
const int* seq_block_table = block_tables +
|
||||
@ -510,8 +511,8 @@ struct paged_attention_v2_impl {
|
||||
logits, token_num, block_num * BLOCK_SIZE,
|
||||
alibi_slopes[head_idx], start_token_idx, seq_len);
|
||||
} else {
|
||||
max_and_sum = reduceSoftmax(logits, token_num,
|
||||
block_num * BLOCK_SIZE);
|
||||
max_and_sum =
|
||||
reduceSoftmax(logits, token_num, block_num * BLOCK_SIZE);
|
||||
}
|
||||
|
||||
auto&& [max_logit, exp_sum] = max_and_sum;
|
||||
@ -587,8 +588,7 @@ struct paged_attention_v2_impl {
|
||||
const int partition_num =
|
||||
(seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
|
||||
|
||||
if (partition_num == 1)
|
||||
continue;
|
||||
if (partition_num == 1) continue;
|
||||
|
||||
reducePartitonSoftmax(
|
||||
max_logits + seq_idx * num_heads * max_num_partitions +
|
||||
@ -603,8 +603,8 @@ struct paged_attention_v2_impl {
|
||||
using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type;
|
||||
static_assert(v_load_vec_type::get_elem_num() == BLOCK_SIZE);
|
||||
constexpr int head_elem_num_per_group =
|
||||
16; // Note: didn't align with the cacheline size, due to some HEAD_SIZE
|
||||
// didn't align with 64 bytes
|
||||
16; // Note: didn't align with the cacheline size, due to some
|
||||
// HEAD_SIZE didn't align with 64 bytes
|
||||
static_assert(HEAD_SIZE % head_elem_num_per_group == 0);
|
||||
constexpr int head_group_num = HEAD_SIZE / head_elem_num_per_group;
|
||||
const float* __restrict__ rescale_factors = exp_sums;
|
||||
@ -616,8 +616,7 @@ struct paged_attention_v2_impl {
|
||||
const int partition_num =
|
||||
(seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
|
||||
|
||||
if (partition_num == 1)
|
||||
continue;
|
||||
if (partition_num == 1) continue;
|
||||
|
||||
const float* __restrict__ seq_head_rescale_factors =
|
||||
rescale_factors + seq_idx * num_heads * max_num_partitions +
|
||||
@ -713,8 +712,8 @@ void paged_attention_v2_impl_launcher(
|
||||
#define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
|
||||
paged_attention_v2_impl_launcher<T, BLOCK_SIZE>( \
|
||||
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
|
||||
num_kv_heads, scale, block_tables, seq_lens, block_size, \
|
||||
max_seq_len, alibi_slopes);
|
||||
num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, \
|
||||
alibi_slopes);
|
||||
|
||||
#define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
|
||||
switch (block_size) { \
|
||||
|
@ -5,17 +5,18 @@
|
||||
|
||||
namespace {
|
||||
template <typename scalar_t>
|
||||
void copy_blocks_cpu_impl(
|
||||
std::vector<torch::Tensor> &key_caches,
|
||||
void copy_blocks_cpu_impl(std::vector<torch::Tensor>& key_caches,
|
||||
std::vector<torch::Tensor>& value_caches,
|
||||
const torch::Tensor& mapping_pairs,
|
||||
const int element_num_per_block, const int layer_num) {
|
||||
const int element_num_per_block,
|
||||
const int layer_num) {
|
||||
const size_t pair_num = mapping_pairs.size(0);
|
||||
const size_t block_bytes = sizeof(scalar_t) * element_num_per_block;
|
||||
#pragma omp parallel for collapse(2)
|
||||
for (int layer = 0; layer < layer_num; ++layer) {
|
||||
for (size_t pair = 0; pair < pair_num; ++pair) {
|
||||
int64_t source_offset = element_num_per_block * mapping_pairs[pair][0].item<int64_t>();
|
||||
int64_t source_offset =
|
||||
element_num_per_block * mapping_pairs[pair][0].item<int64_t>();
|
||||
int64_t target_offset =
|
||||
element_num_per_block * mapping_pairs[pair][1].item<int64_t>();
|
||||
scalar_t* key_cache_ptr = key_caches[layer].data_ptr<scalar_t>();
|
||||
|
@ -87,8 +87,8 @@ void fused_add_rms_norm_impl(scalar_t *__restrict__ input,
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void rms_norm(torch::Tensor &out, torch::Tensor &input,
|
||||
torch::Tensor &weight, float epsilon) {
|
||||
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
|
||||
float epsilon) {
|
||||
int hidden_size = input.size(-1);
|
||||
int num_tokens = input.numel() / hidden_size;
|
||||
|
||||
|
@ -4,16 +4,16 @@
|
||||
namespace {
|
||||
template <typename scalar_t>
|
||||
void rotary_embedding_impl(
|
||||
const int64_t
|
||||
*__restrict__ positions, // [batch_size, seq_len] or [num_tokens]
|
||||
scalar_t
|
||||
*__restrict__ query, /// [batch_size, seq_len, num_heads, head_size] or
|
||||
/// [num_tokens, num_heads, head_size]
|
||||
scalar_t
|
||||
*__restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or
|
||||
// [num_tokens, num_kv_heads, head_size]
|
||||
const scalar_t
|
||||
*__restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
|
||||
const int64_t* __restrict__ positions, // [batch_size, seq_len] or
|
||||
// [num_tokens]
|
||||
scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads,
|
||||
/// head_size] or [num_tokens, num_heads,
|
||||
/// head_size]
|
||||
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
|
||||
// head_size] or [num_tokens, num_kv_heads,
|
||||
// head_size]
|
||||
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
|
||||
// 2]
|
||||
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
|
||||
const int num_heads, const int num_kv_heads, const int head_size,
|
||||
const int num_tokens) {
|
||||
@ -94,16 +94,16 @@ void rotary_embedding_impl(
|
||||
|
||||
template <typename scalar_t>
|
||||
void rotary_embedding_gptj_impl(
|
||||
const int64_t
|
||||
*__restrict__ positions, // [batch_size, seq_len] or [num_tokens]
|
||||
scalar_t
|
||||
*__restrict__ query, /// [batch_size, seq_len, num_heads, head_size] or
|
||||
/// [num_tokens, num_heads, head_size]
|
||||
scalar_t
|
||||
*__restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or
|
||||
// [num_tokens, num_kv_heads, head_size]
|
||||
const scalar_t
|
||||
*__restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
|
||||
const int64_t* __restrict__ positions, // [batch_size, seq_len] or
|
||||
// [num_tokens]
|
||||
scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads,
|
||||
/// head_size] or [num_tokens, num_heads,
|
||||
/// head_size]
|
||||
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
|
||||
// head_size] or [num_tokens, num_kv_heads,
|
||||
// head_size]
|
||||
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
|
||||
// 2]
|
||||
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
|
||||
const int num_heads, const int num_kv_heads, const int head_size,
|
||||
const int num_tokens) {
|
||||
|
@ -8,66 +8,37 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
pybind11::module ops = m.def_submodule("ops", "vLLM custom operators");
|
||||
|
||||
// Attention ops
|
||||
ops.def(
|
||||
"paged_attention_v1",
|
||||
&paged_attention_v1,
|
||||
"Compute the attention between an input query and the cached keys/values using PagedAttention.");
|
||||
ops.def(
|
||||
"paged_attention_v2",
|
||||
&paged_attention_v2,
|
||||
"PagedAttention V2.");
|
||||
ops.def("paged_attention_v1", &paged_attention_v1,
|
||||
"Compute the attention between an input query and the cached "
|
||||
"keys/values using PagedAttention.");
|
||||
ops.def("paged_attention_v2", &paged_attention_v2, "PagedAttention V2.");
|
||||
|
||||
// Activation ops
|
||||
ops.def(
|
||||
"silu_and_mul",
|
||||
&silu_and_mul,
|
||||
"Activation function used in SwiGLU.");
|
||||
ops.def(
|
||||
"gelu_and_mul",
|
||||
&gelu_and_mul,
|
||||
ops.def("silu_and_mul", &silu_and_mul, "Activation function used in SwiGLU.");
|
||||
ops.def("gelu_and_mul", &gelu_and_mul,
|
||||
"Activation function used in GeGLU with `none` approximation.");
|
||||
ops.def(
|
||||
"gelu_tanh_and_mul",
|
||||
&gelu_tanh_and_mul,
|
||||
ops.def("gelu_tanh_and_mul", &gelu_tanh_and_mul,
|
||||
"Activation function used in GeGLU with `tanh` approximation.");
|
||||
ops.def(
|
||||
"gelu_new",
|
||||
&gelu_new,
|
||||
"GELU implementation used in GPT-2.");
|
||||
ops.def(
|
||||
"gelu_fast",
|
||||
&gelu_fast,
|
||||
"Approximate GELU implementation.");
|
||||
ops.def("gelu_new", &gelu_new, "GELU implementation used in GPT-2.");
|
||||
ops.def("gelu_fast", &gelu_fast, "Approximate GELU implementation.");
|
||||
|
||||
// Layernorm
|
||||
ops.def(
|
||||
"rms_norm",
|
||||
&rms_norm,
|
||||
ops.def("rms_norm", &rms_norm,
|
||||
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
|
||||
|
||||
ops.def(
|
||||
"fused_add_rms_norm",
|
||||
&fused_add_rms_norm,
|
||||
ops.def("fused_add_rms_norm", &fused_add_rms_norm,
|
||||
"In-place fused Add and RMS Normalization");
|
||||
|
||||
// Rotary embedding
|
||||
ops.def(
|
||||
"rotary_embedding",
|
||||
&rotary_embedding,
|
||||
ops.def("rotary_embedding", &rotary_embedding,
|
||||
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
|
||||
|
||||
// Cache ops
|
||||
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
|
||||
cache_ops.def(
|
||||
"swap_blocks",
|
||||
&swap_blocks,
|
||||
cache_ops.def("swap_blocks", &swap_blocks,
|
||||
"Swap in (out) the cache blocks from src to dst");
|
||||
cache_ops.def(
|
||||
"copy_blocks",
|
||||
©_blocks,
|
||||
cache_ops.def("copy_blocks", ©_blocks,
|
||||
"Copy the cache blocks from src to dst");
|
||||
cache_ops.def(
|
||||
"reshape_and_cache",
|
||||
&reshape_and_cache,
|
||||
cache_ops.def("reshape_and_cache", &reshape_and_cache,
|
||||
"Reshape the key and value tensors and cache them");
|
||||
}
|
||||
|
@ -17,7 +17,8 @@
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask)
|
||||
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) \
|
||||
__shfl_xor_sync(uint32_t(-1), var, lane_mask)
|
||||
#else
|
||||
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask)
|
||||
#endif
|
||||
@ -29,7 +30,8 @@
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down_sync(uint32_t(-1), var, lane_delta)
|
||||
#define VLLM_SHFL_DOWN_SYNC(var, lane_delta) \
|
||||
__shfl_down_sync(uint32_t(-1), var, lane_delta)
|
||||
#else
|
||||
#define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down(var, lane_delta)
|
||||
#endif
|
||||
@ -41,4 +43,3 @@
|
||||
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
|
||||
hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL)
|
||||
#endif
|
||||
|
||||
|
@ -2,9 +2,6 @@
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
int get_device_attribute(
|
||||
int attribute,
|
||||
int device_id);
|
||||
int get_device_attribute(int attribute, int device_id);
|
||||
|
||||
int get_max_shared_memory_per_block_device_attribute(
|
||||
int device_id);
|
||||
int get_max_shared_memory_per_block_device_attribute(int device_id);
|
||||
|
@ -2,25 +2,19 @@
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <hip/hip_runtime_api.h>
|
||||
#endif
|
||||
int get_device_attribute(
|
||||
int attribute,
|
||||
int device_id)
|
||||
{
|
||||
int get_device_attribute(int attribute, int device_id) {
|
||||
int device, value;
|
||||
if (device_id < 0) {
|
||||
cudaGetDevice(&device);
|
||||
}
|
||||
else {
|
||||
} else {
|
||||
device = device_id;
|
||||
}
|
||||
cudaDeviceGetAttribute(&value, static_cast<cudaDeviceAttr>(attribute), device);
|
||||
cudaDeviceGetAttribute(&value, static_cast<cudaDeviceAttr>(attribute),
|
||||
device);
|
||||
return value;
|
||||
}
|
||||
|
||||
|
||||
int get_max_shared_memory_per_block_device_attribute(
|
||||
int device_id)
|
||||
{
|
||||
int get_max_shared_memory_per_block_device_attribute(int device_id) {
|
||||
int attribute;
|
||||
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
|
||||
// cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
|
||||
|
@ -80,8 +80,7 @@ void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out,
|
||||
}
|
||||
case at::ScalarType::Half: {
|
||||
fa->allreduce<half>(stream, reinterpret_cast<half*>(inp.data_ptr()),
|
||||
reinterpret_cast<half *>(out.data_ptr()),
|
||||
out.numel());
|
||||
reinterpret_cast<half*>(out.data_ptr()), out.numel());
|
||||
break;
|
||||
}
|
||||
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
||||
|
@ -137,8 +137,7 @@ DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg,
|
||||
// Latency = 1 p2p write
|
||||
sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1;
|
||||
// wait until we got true from all ranks
|
||||
while (!self_sg->start[blockIdx.x][threadIdx.x])
|
||||
;
|
||||
while (!self_sg->start[blockIdx.x][threadIdx.x]);
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
@ -162,8 +161,7 @@ DINLINE void end_sync(const RankSignals &sg, volatile Signal *self_sg,
|
||||
// Latency = 1 p2p write
|
||||
sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1;
|
||||
// wait until we got true from all ranks
|
||||
while (!self_sg->end[blockIdx.x][threadIdx.x])
|
||||
;
|
||||
while (!self_sg->end[blockIdx.x][threadIdx.x]);
|
||||
}
|
||||
if constexpr (!final_sync) __syncthreads();
|
||||
}
|
||||
@ -192,8 +190,7 @@ __global__ void __launch_bounds__(512, 1)
|
||||
// do the actual reduction
|
||||
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
||||
idx += gridDim.x * blockDim.x) {
|
||||
((P *)result)[idx] =
|
||||
packed_reduce<P, ngpus, A>((const P **)&dp.ptrs[0], idx);
|
||||
((P*)result)[idx] = packed_reduce<P, ngpus, A>((const P**)&dp.ptrs[0], idx);
|
||||
}
|
||||
end_sync<ngpus, true>(sg, self_sg, rank);
|
||||
}
|
||||
|
@ -12,8 +12,7 @@
|
||||
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
|
||||
|
||||
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH( \
|
||||
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
||||
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
||||
|
||||
#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
@ -22,8 +21,8 @@
|
||||
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
|
||||
|
||||
#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH( \
|
||||
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))
|
||||
AT_DISPATCH_SWITCH(TYPE, NAME, \
|
||||
VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))
|
||||
|
||||
#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
|
||||
@ -33,5 +32,4 @@
|
||||
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
|
||||
|
||||
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH( \
|
||||
TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
|
||||
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
|
||||
|
@ -23,9 +23,7 @@ __global__ void rms_norm_kernel(
|
||||
scalar_t* __restrict__ out, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||
const float epsilon,
|
||||
const int num_tokens,
|
||||
const int hidden_size) {
|
||||
const float epsilon, const int num_tokens, const int hidden_size) {
|
||||
__shared__ float s_variance;
|
||||
float variance = 0.0f;
|
||||
|
||||
@ -41,11 +39,11 @@ __global__ void rms_norm_kernel(
|
||||
|
||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||
float x = (float)input[blockIdx.x * hidden_size + idx];
|
||||
out[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx];
|
||||
out[blockIdx.x * hidden_size + idx] =
|
||||
((scalar_t)(x * s_variance)) * weight[idx];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/* Converter structs for the conversion from torch types to HIP/CUDA types,
|
||||
and the associated type conversions within HIP/CUDA. These helpers need
|
||||
to be implemented for now because the relevant type conversion
|
||||
@ -57,7 +55,9 @@ __global__ void rms_norm_kernel(
|
||||
If true, the struct should be fully defined as shown in the examples below.
|
||||
*/
|
||||
template <typename torch_type>
|
||||
struct _typeConvert { static constexpr bool exists = false; };
|
||||
struct _typeConvert {
|
||||
static constexpr bool exists = false;
|
||||
};
|
||||
|
||||
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
|
||||
// CUDA < 12.0 runs into issues with packed type conversion
|
||||
@ -68,9 +68,15 @@ struct _typeConvert<c10::Half> {
|
||||
using packed_hip_type = __half2;
|
||||
|
||||
__device__ static inline float convert(hip_type x) { return __half2float(x); }
|
||||
__device__ static inline float2 convert(packed_hip_type x) { return __half22float2(x); }
|
||||
__device__ static inline hip_type convert(float x) { return __float2half_rn(x); }
|
||||
__device__ static inline packed_hip_type convert(float2 x) { return __float22half2_rn(x); }
|
||||
__device__ static inline float2 convert(packed_hip_type x) {
|
||||
return __half22float2(x);
|
||||
}
|
||||
__device__ static inline hip_type convert(float x) {
|
||||
return __float2half_rn(x);
|
||||
}
|
||||
__device__ static inline packed_hip_type convert(float2 x) {
|
||||
return __float22half2_rn(x);
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
@ -82,13 +88,22 @@ struct _typeConvert<c10::BFloat16> {
|
||||
using hip_type = __nv_bfloat16;
|
||||
using packed_hip_type = __nv_bfloat162;
|
||||
|
||||
__device__ static inline float convert(hip_type x) { return __bfloat162float(x); }
|
||||
__device__ static inline float2 convert(packed_hip_type x) { return __bfloat1622float2(x); }
|
||||
__device__ static inline hip_type convert(float x) { return __float2bfloat16(x); }
|
||||
__device__ static inline packed_hip_type convert(float2 x) { return __float22bfloat162_rn(x); }
|
||||
__device__ static inline float convert(hip_type x) {
|
||||
return __bfloat162float(x);
|
||||
}
|
||||
__device__ static inline float2 convert(packed_hip_type x) {
|
||||
return __bfloat1622float2(x);
|
||||
}
|
||||
__device__ static inline hip_type convert(float x) {
|
||||
return __float2bfloat16(x);
|
||||
}
|
||||
__device__ static inline packed_hip_type convert(float2 x) {
|
||||
return __float22bfloat162_rn(x);
|
||||
}
|
||||
};
|
||||
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
|
||||
#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >=
|
||||
// 12000))
|
||||
|
||||
/* Vector POD struct to generate vectorized and packed FP16/BF16 ops
|
||||
for appropriate specializations of fused_add_rms_norm_kernel.
|
||||
@ -117,8 +132,7 @@ struct alignas(16) _f16Vec {
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < width; ++i)
|
||||
data[i] += other.data[i];
|
||||
for (int i = 0; i < width; ++i) data[i] += other.data[i];
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
@ -134,8 +148,7 @@ struct alignas(16) _f16Vec {
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < width; ++i)
|
||||
data[i] *= other.data[i];
|
||||
for (int i = 0; i < width; ++i) data[i] *= other.data[i];
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
@ -185,14 +198,12 @@ struct alignas(16) _f16Vec {
|
||||
packed and vectorized operations, which help with the
|
||||
memory latency bottleneck. */
|
||||
template <typename scalar_t, int width>
|
||||
__global__ std::enable_if_t<
|
||||
(width > 0) && _typeConvert<scalar_t>::exists> fused_add_rms_norm_kernel(
|
||||
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
|
||||
fused_add_rms_norm_kernel(
|
||||
scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
scalar_t* __restrict__ residual, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||
const float epsilon,
|
||||
const int num_tokens,
|
||||
const int hidden_size) {
|
||||
const float epsilon, const int num_tokens, const int hidden_size) {
|
||||
// Sanity checks on our vector struct and type-punned pointer arithmetic
|
||||
static_assert(std::is_pod_v<_f16Vec<scalar_t, width>>);
|
||||
static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width);
|
||||
@ -203,9 +214,12 @@ __global__ std::enable_if_t<
|
||||
/* These and the argument pointers are all declared `restrict` as they are
|
||||
not aliased in practice. Argument pointers should not be dereferenced
|
||||
in this kernel as that would be undefined behavior */
|
||||
auto* __restrict__ input_v = reinterpret_cast<_f16Vec<scalar_t, width>*>(input);
|
||||
auto* __restrict__ residual_v = reinterpret_cast<_f16Vec<scalar_t, width>*>(residual);
|
||||
auto* __restrict__ weight_v = reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight);
|
||||
auto* __restrict__ input_v =
|
||||
reinterpret_cast<_f16Vec<scalar_t, width>*>(input);
|
||||
auto* __restrict__ residual_v =
|
||||
reinterpret_cast<_f16Vec<scalar_t, width>*>(residual);
|
||||
auto* __restrict__ weight_v =
|
||||
reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight);
|
||||
|
||||
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
|
||||
int id = blockIdx.x * vec_hidden_size + idx;
|
||||
@ -218,7 +232,8 @@ __global__ std::enable_if_t<
|
||||
calculation of max_block_size in fused_add_rms_norm */
|
||||
if (num_tokens < 256) {
|
||||
variance = blockReduceSum<float, 1024>(variance);
|
||||
} else variance = blockReduceSum<float, 256>(variance);
|
||||
} else
|
||||
variance = blockReduceSum<float, 256>(variance);
|
||||
if (threadIdx.x == 0) {
|
||||
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
||||
}
|
||||
@ -233,19 +248,16 @@ __global__ std::enable_if_t<
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/* Generic fused_add_rms_norm_kernel
|
||||
The width field is not used here but necessary for other specializations.
|
||||
*/
|
||||
template <typename scalar_t, int width>
|
||||
__global__ std::enable_if_t<
|
||||
(width == 0) || !_typeConvert<scalar_t>::exists> fused_add_rms_norm_kernel(
|
||||
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
|
||||
fused_add_rms_norm_kernel(
|
||||
scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
scalar_t* __restrict__ residual, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||
const float epsilon,
|
||||
const int num_tokens,
|
||||
const int hidden_size) {
|
||||
const float epsilon, const int num_tokens, const int hidden_size) {
|
||||
__shared__ float s_variance;
|
||||
float variance = 0.0f;
|
||||
|
||||
@ -260,7 +272,8 @@ __global__ std::enable_if_t<
|
||||
calculation of max_block_size in fused_add_rms_norm */
|
||||
if (num_tokens < 256) {
|
||||
variance = blockReduceSum<float, 1024>(variance);
|
||||
} else variance = blockReduceSum<float, 256>(variance);
|
||||
} else
|
||||
variance = blockReduceSum<float, 256>(variance);
|
||||
if (threadIdx.x == 0) {
|
||||
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
||||
}
|
||||
@ -268,14 +281,14 @@ __global__ std::enable_if_t<
|
||||
|
||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||
float x = (float)residual[blockIdx.x * hidden_size + idx];
|
||||
input[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx];
|
||||
input[blockIdx.x * hidden_size + idx] =
|
||||
((scalar_t)(x * s_variance)) * weight[idx];
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
void rms_norm(
|
||||
torch::Tensor& out, // [..., hidden_size]
|
||||
void rms_norm(torch::Tensor& out, // [..., hidden_size]
|
||||
torch::Tensor& input, // [..., hidden_size]
|
||||
torch::Tensor& weight, // [hidden_size]
|
||||
float epsilon) {
|
||||
@ -286,37 +299,24 @@ void rms_norm(
|
||||
dim3 block(std::min(hidden_size, 1024));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(),
|
||||
"rms_norm_kernel",
|
||||
[&] {
|
||||
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
|
||||
vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<scalar_t>(),
|
||||
input.data_ptr<scalar_t>(),
|
||||
weight.data_ptr<scalar_t>(),
|
||||
epsilon,
|
||||
num_tokens,
|
||||
hidden_size);
|
||||
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
|
||||
weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
|
||||
});
|
||||
}
|
||||
|
||||
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
|
||||
VLLM_DISPATCH_FLOATING_TYPES( \
|
||||
input.scalar_type(), \
|
||||
"fused_add_rms_norm_kernel", \
|
||||
[&] { \
|
||||
vllm::fused_add_rms_norm_kernel \
|
||||
<scalar_t, width><<<grid, block, 0, stream>>>( \
|
||||
input.data_ptr<scalar_t>(), \
|
||||
input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
|
||||
vllm::fused_add_rms_norm_kernel<scalar_t, width> \
|
||||
<<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(), \
|
||||
residual.data_ptr<scalar_t>(), \
|
||||
weight.data_ptr<scalar_t>(), \
|
||||
epsilon, \
|
||||
num_tokens, \
|
||||
hidden_size); \
|
||||
weight.data_ptr<scalar_t>(), epsilon, \
|
||||
num_tokens, hidden_size); \
|
||||
});
|
||||
|
||||
void fused_add_rms_norm(
|
||||
torch::Tensor& input, // [..., hidden_size]
|
||||
void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
|
||||
torch::Tensor& residual, // [..., hidden_size]
|
||||
torch::Tensor& weight, // [hidden_size]
|
||||
float epsilon) {
|
||||
@ -342,8 +342,8 @@ void fused_add_rms_norm(
|
||||
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
|
||||
auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr());
|
||||
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
|
||||
bool ptrs_are_aligned = inp_ptr % 16 == 0 && res_ptr % 16 == 0 \
|
||||
&& wt_ptr % 16 == 0;
|
||||
bool ptrs_are_aligned =
|
||||
inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
|
||||
if (ptrs_are_aligned && hidden_size % 8 == 0) {
|
||||
LAUNCH_FUSED_ADD_RMS_NORM(8);
|
||||
} else {
|
||||
|
@ -3,5 +3,6 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("topk_softmax", &topk_softmax, "Apply topk softmax to the gating outputs.");
|
||||
m.def("topk_softmax", &topk_softmax,
|
||||
"Apply topk softmax to the gating outputs.");
|
||||
}
|
||||
|
@ -2,8 +2,6 @@
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
void topk_softmax(
|
||||
torch::Tensor& topk_weights,
|
||||
torch::Tensor& topk_indices,
|
||||
void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices,
|
||||
torch::Tensor& token_expert_indices,
|
||||
torch::Tensor& gating_output);
|
||||
|
@ -12,11 +12,12 @@
|
||||
namespace vllm {
|
||||
|
||||
namespace {
|
||||
__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, int32_t col) {
|
||||
__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row,
|
||||
int32_t col) {
|
||||
// don't worry about overflow because num_experts is relatively small
|
||||
return row * total_col + col;
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
|
||||
@ -24,15 +25,17 @@ __global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids,
|
||||
int32_t* expert_ids,
|
||||
int32_t* total_tokens_post_pad,
|
||||
int32_t num_experts,
|
||||
int32_t block_size,
|
||||
size_t numel) {
|
||||
int32_t block_size, size_t numel) {
|
||||
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
|
||||
const size_t start_idx = threadIdx.x * tokens_per_thread;
|
||||
|
||||
extern __shared__ int32_t shared_mem[];
|
||||
|
||||
int32_t* tokens_cnts = shared_mem; // 2d tensor with shape (num_experts + 1, num_experts)
|
||||
int32_t* cumsum = shared_mem + (num_experts + 1) * num_experts; // 1d tensor with shape (num_experts + 1)
|
||||
int32_t* tokens_cnts =
|
||||
shared_mem; // 2d tensor with shape (num_experts + 1, num_experts)
|
||||
int32_t* cumsum =
|
||||
shared_mem + (num_experts + 1) *
|
||||
num_experts; // 1d tensor with shape (num_experts + 1)
|
||||
|
||||
for (int i = 0; i < num_experts; ++i) {
|
||||
tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
|
||||
@ -40,8 +43,8 @@ __global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids,
|
||||
|
||||
/**
|
||||
* In the first step we compute token_cnts[thread_index + 1][expert_index],
|
||||
* which counts how many tokens in the token shard of thread_index are assigned
|
||||
* to expert expert_index.
|
||||
* which counts how many tokens in the token shard of thread_index are
|
||||
* assigned to expert expert_index.
|
||||
*/
|
||||
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
|
||||
++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])];
|
||||
@ -52,7 +55,8 @@ __global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids,
|
||||
// For each expert we accumulate the token counts from the different threads.
|
||||
tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
|
||||
for (int i = 1; i <= blockDim.x; ++i) {
|
||||
tokens_cnts[index(num_experts, i, threadIdx.x)] += tokens_cnts[index(num_experts, i-1, threadIdx.x)];
|
||||
tokens_cnts[index(num_experts, i, threadIdx.x)] +=
|
||||
tokens_cnts[index(num_experts, i - 1, threadIdx.x)];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
@ -61,7 +65,10 @@ __global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids,
|
||||
if (threadIdx.x == 0) {
|
||||
cumsum[0] = 0;
|
||||
for (int i = 1; i <= num_experts; ++i) {
|
||||
cumsum[i] = cumsum[i-1] + CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], block_size) * block_size;
|
||||
cumsum[i] = cumsum[i - 1] +
|
||||
CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)],
|
||||
block_size) *
|
||||
block_size;
|
||||
}
|
||||
*total_tokens_post_pad = cumsum[num_experts];
|
||||
}
|
||||
@ -69,57 +76,59 @@ __global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids,
|
||||
__syncthreads();
|
||||
|
||||
/**
|
||||
* For each expert, each thread processes the tokens of the corresponding blocks
|
||||
* and stores the corresponding expert_id for each block.
|
||||
* For each expert, each thread processes the tokens of the corresponding
|
||||
* blocks and stores the corresponding expert_id for each block.
|
||||
*/
|
||||
for (int i = cumsum[threadIdx.x];i < cumsum[threadIdx.x + 1];i += block_size) {
|
||||
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
|
||||
i += block_size) {
|
||||
expert_ids[i / block_size] = threadIdx.x;
|
||||
}
|
||||
|
||||
/**
|
||||
* Each thread processes a token shard, calculating the index of each token after
|
||||
* sorting by expert number. Given the example topk_ids = [0,1,2,1,2,3,0,3,4] and
|
||||
* block_size = 4, then the output would be [0, 6, *, *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *],
|
||||
* where * represents a padding value(preset in python).
|
||||
* Each thread processes a token shard, calculating the index of each token
|
||||
* after sorting by expert number. Given the example topk_ids =
|
||||
* [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *,
|
||||
* *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a
|
||||
* padding value(preset in python).
|
||||
*/
|
||||
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
|
||||
int32_t expert_id = topk_ids[i];
|
||||
/** The cumsum[expert_id] stores the starting index of the tokens that the
|
||||
* expert with expert_id needs to process, and tokens_cnts[threadIdx.x][expert_id]
|
||||
* stores the indices of the tokens processed by the expert with expert_id within
|
||||
* the current thread's token shard.
|
||||
* expert with expert_id needs to process, and
|
||||
* tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens
|
||||
* processed by the expert with expert_id within the current thread's token
|
||||
* shard.
|
||||
*/
|
||||
int32_t rank_post_pad = tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + cumsum[expert_id];
|
||||
int32_t rank_post_pad =
|
||||
tokens_cnts[index(num_experts, threadIdx.x, expert_id)] +
|
||||
cumsum[expert_id];
|
||||
sorted_token_ids[rank_post_pad] = i;
|
||||
++tokens_cnts[index(num_experts, threadIdx.x, expert_id)];
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace vllm
|
||||
|
||||
void moe_align_block_size(
|
||||
torch::Tensor topk_ids,
|
||||
int num_experts,
|
||||
int block_size,
|
||||
torch::Tensor sorted_token_ids,
|
||||
void moe_align_block_size(torch::Tensor topk_ids, int num_experts,
|
||||
int block_size, torch::Tensor sorted_token_ids,
|
||||
torch::Tensor experts_ids,
|
||||
torch::Tensor num_tokens_post_pad) {
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_INTEGRAL_TYPES(
|
||||
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
|
||||
// calc needed amount of shared mem for `tokens_cnts` and `cumsum` tensors
|
||||
const int32_t shared_mem = ((num_experts + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t);
|
||||
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
|
||||
// tensors
|
||||
const int32_t shared_mem =
|
||||
((num_experts + 1) * num_experts + (num_experts + 1)) *
|
||||
sizeof(int32_t);
|
||||
|
||||
// set dynamic shared mem
|
||||
auto kernel = vllm::moe_align_block_size_kernel<scalar_t>;
|
||||
AT_CUDA_CHECK(
|
||||
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize((void *)kernel, shared_mem));
|
||||
AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
|
||||
(void*)kernel, shared_mem));
|
||||
kernel<<<1, num_experts, shared_mem, stream>>>(
|
||||
topk_ids.data_ptr<scalar_t>(),
|
||||
sorted_token_ids.data_ptr<int32_t>(),
|
||||
topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
|
||||
experts_ids.data_ptr<int32_t>(),
|
||||
num_tokens_post_pad.data_ptr<int32_t>(),
|
||||
num_experts,
|
||||
block_size,
|
||||
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
|
||||
topk_ids.numel());
|
||||
});
|
||||
}
|
||||
|
214
csrc/ops.h
214
csrc/ops.h
@ -2,204 +2,115 @@
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
void paged_attention_v1(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& query,
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
int num_kv_heads,
|
||||
float scale,
|
||||
torch::Tensor& block_tables,
|
||||
torch::Tensor& seq_lens,
|
||||
int block_size,
|
||||
void paged_attention_v1(torch::Tensor& out, torch::Tensor& query,
|
||||
torch::Tensor& key_cache, torch::Tensor& value_cache,
|
||||
int num_kv_heads, float scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens,
|
||||
int block_size, int max_seq_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, float kv_scale);
|
||||
|
||||
void paged_attention_v2(torch::Tensor& out, torch::Tensor& exp_sums,
|
||||
torch::Tensor& max_logits, torch::Tensor& tmp_out,
|
||||
torch::Tensor& query, torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache, int num_kv_heads,
|
||||
float scale, torch::Tensor& block_tables,
|
||||
torch::Tensor& seq_lens, int block_size,
|
||||
int max_seq_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype,
|
||||
float kv_scale);
|
||||
const std::string& kv_cache_dtype, float kv_scale);
|
||||
|
||||
void paged_attention_v2(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& exp_sums,
|
||||
torch::Tensor& max_logits,
|
||||
torch::Tensor& tmp_out,
|
||||
torch::Tensor& query,
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
int num_kv_heads,
|
||||
float scale,
|
||||
torch::Tensor& block_tables,
|
||||
torch::Tensor& seq_lens,
|
||||
int block_size,
|
||||
int max_seq_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype,
|
||||
float kv_scale);
|
||||
|
||||
void rms_norm(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& input,
|
||||
torch::Tensor& weight,
|
||||
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
|
||||
float epsilon);
|
||||
|
||||
void fused_add_rms_norm(
|
||||
torch::Tensor& input,
|
||||
torch::Tensor& residual,
|
||||
torch::Tensor& weight,
|
||||
float epsilon);
|
||||
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
|
||||
torch::Tensor& weight, float epsilon);
|
||||
|
||||
void rotary_embedding(
|
||||
torch::Tensor& positions,
|
||||
torch::Tensor& query,
|
||||
torch::Tensor& key,
|
||||
int head_size,
|
||||
torch::Tensor& cos_sin_cache,
|
||||
bool is_neox);
|
||||
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
|
||||
torch::Tensor& key, int head_size,
|
||||
torch::Tensor& cos_sin_cache, bool is_neox);
|
||||
|
||||
void batched_rotary_embedding(
|
||||
torch::Tensor& positions,
|
||||
torch::Tensor& query,
|
||||
torch::Tensor& key,
|
||||
int head_size,
|
||||
torch::Tensor& cos_sin_cache,
|
||||
bool is_neox,
|
||||
void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
|
||||
torch::Tensor& key, int head_size,
|
||||
torch::Tensor& cos_sin_cache, bool is_neox,
|
||||
int rot_dim,
|
||||
torch::Tensor& cos_sin_cache_offsets);
|
||||
|
||||
void silu_and_mul(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& input);
|
||||
void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
|
||||
|
||||
void gelu_and_mul(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& input);
|
||||
void gelu_and_mul(torch::Tensor& out, torch::Tensor& input);
|
||||
|
||||
void gelu_tanh_and_mul(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& input);
|
||||
void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);
|
||||
|
||||
void gelu_new(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& input);
|
||||
void gelu_new(torch::Tensor& out, torch::Tensor& input);
|
||||
|
||||
void gelu_fast(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& input);
|
||||
void gelu_fast(torch::Tensor& out, torch::Tensor& input);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
torch::Tensor aqlm_gemm(
|
||||
const torch::Tensor& input,
|
||||
const torch::Tensor& codes,
|
||||
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
|
||||
const torch::Tensor& codebooks,
|
||||
const torch::Tensor& scales,
|
||||
const torch::Tensor& codebook_partition_sizes,
|
||||
const std::optional<torch::Tensor>& bias
|
||||
);
|
||||
const std::optional<torch::Tensor>& bias);
|
||||
|
||||
torch::Tensor aqlm_dequant(
|
||||
const torch::Tensor& codes,
|
||||
torch::Tensor aqlm_dequant(const torch::Tensor& codes,
|
||||
const torch::Tensor& codebooks,
|
||||
const torch::Tensor& codebook_partition_sizes
|
||||
);
|
||||
const torch::Tensor& codebook_partition_sizes);
|
||||
|
||||
torch::Tensor awq_gemm(
|
||||
torch::Tensor _in_feats,
|
||||
torch::Tensor _kernel,
|
||||
torch::Tensor _scaling_factors,
|
||||
torch::Tensor _zeros,
|
||||
torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
|
||||
torch::Tensor _scaling_factors, torch::Tensor _zeros,
|
||||
int split_k_iters);
|
||||
|
||||
torch::Tensor awq_dequantize(
|
||||
torch::Tensor _kernel,
|
||||
torch::Tensor awq_dequantize(torch::Tensor _kernel,
|
||||
torch::Tensor _scaling_factors,
|
||||
torch::Tensor _zeros,
|
||||
int split_k_iters,
|
||||
int thx,
|
||||
torch::Tensor _zeros, int split_k_iters, int thx,
|
||||
int thy);
|
||||
|
||||
torch::Tensor marlin_gemm(
|
||||
torch::Tensor& a,
|
||||
torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_scales,
|
||||
torch::Tensor& workspace,
|
||||
int64_t size_m,
|
||||
int64_t size_n,
|
||||
int64_t size_k);
|
||||
torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_scales, torch::Tensor& workspace,
|
||||
int64_t size_m, int64_t size_n, int64_t size_k);
|
||||
|
||||
torch::Tensor gptq_marlin_24_gemm(
|
||||
torch::Tensor &a,
|
||||
torch::Tensor &b_q_weight,
|
||||
torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_meta,
|
||||
torch::Tensor& b_scales,
|
||||
torch::Tensor &workspace,
|
||||
int64_t num_bits,
|
||||
int64_t size_m,
|
||||
int64_t size_n,
|
||||
torch::Tensor& workspace, int64_t num_bits,
|
||||
int64_t size_m, int64_t size_n,
|
||||
int64_t size_k);
|
||||
|
||||
torch::Tensor gptq_marlin_gemm(
|
||||
torch::Tensor &a,
|
||||
torch::Tensor &b_q_weight,
|
||||
torch::Tensor &b_scales,
|
||||
torch::Tensor &g_idx,
|
||||
torch::Tensor &perm,
|
||||
torch::Tensor &workspace,
|
||||
int64_t num_bits,
|
||||
int64_t size_m,
|
||||
int64_t size_n,
|
||||
int64_t size_k,
|
||||
bool is_k_full);
|
||||
torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_scales, torch::Tensor& g_idx,
|
||||
torch::Tensor& perm, torch::Tensor& workspace,
|
||||
int64_t num_bits, int64_t size_m, int64_t size_n,
|
||||
int64_t size_k, bool is_k_full);
|
||||
|
||||
torch::Tensor gptq_marlin_repack(
|
||||
torch::Tensor &b_q_weight,
|
||||
torch::Tensor &perm,
|
||||
int64_t size_k,
|
||||
int64_t size_n,
|
||||
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
||||
int64_t size_k, int64_t size_n,
|
||||
int64_t num_bits);
|
||||
|
||||
int cutlass_scaled_mm_dq(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor const &a,
|
||||
torch::Tensor const &b,
|
||||
torch::Tensor const &a_scales,
|
||||
int cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales);
|
||||
|
||||
#endif
|
||||
|
||||
void squeezellm_gemm(
|
||||
torch::Tensor vec,
|
||||
torch::Tensor mat,
|
||||
torch::Tensor mul,
|
||||
void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
||||
torch::Tensor lookup_table);
|
||||
|
||||
torch::Tensor gptq_gemm(
|
||||
torch::Tensor a,
|
||||
torch::Tensor b_q_weight,
|
||||
torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
|
||||
torch::Tensor b_gptq_qzeros,
|
||||
torch::Tensor b_gptq_scales,
|
||||
torch::Tensor b_g_idx,
|
||||
bool use_exllama,
|
||||
int bit);
|
||||
torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
|
||||
bool use_exllama, int bit);
|
||||
|
||||
void gptq_shuffle(
|
||||
torch::Tensor q_weight,
|
||||
torch::Tensor q_perm,
|
||||
int bit);
|
||||
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int bit);
|
||||
|
||||
void static_scaled_fp8_quant(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& input,
|
||||
void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
|
||||
torch::Tensor& scale);
|
||||
|
||||
void dynamic_scaled_fp8_quant(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& input,
|
||||
void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
|
||||
torch::Tensor& scale);
|
||||
|
||||
void moe_align_block_size(
|
||||
torch::Tensor topk_ids,
|
||||
int num_experts,
|
||||
int block_size,
|
||||
torch::Tensor sorted_token_ids,
|
||||
void moe_align_block_size(torch::Tensor topk_ids, int num_experts,
|
||||
int block_size, torch::Tensor sorted_token_ids,
|
||||
torch::Tensor experts_ids,
|
||||
torch::Tensor num_tokens_post_pad);
|
||||
|
||||
@ -219,7 +130,8 @@ int meta_size();
|
||||
void register_buffer(fptr_t _fa, torch::Tensor& t,
|
||||
const std::vector<std::string>& handles,
|
||||
const std::vector<int64_t>& offsets);
|
||||
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa);
|
||||
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(
|
||||
fptr_t _fa);
|
||||
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
|
||||
const std::vector<std::vector<int64_t>>& offsets);
|
||||
#endif
|
||||
|
@ -9,12 +9,8 @@ namespace vllm {
|
||||
|
||||
template <typename scalar_t, bool IS_NEOX>
|
||||
inline __device__ void apply_token_rotary_embedding(
|
||||
scalar_t* __restrict__ arr,
|
||||
const scalar_t* __restrict__ cos_ptr,
|
||||
const scalar_t* __restrict__ sin_ptr,
|
||||
int rot_offset,
|
||||
int embed_dim)
|
||||
{
|
||||
scalar_t* __restrict__ arr, const scalar_t* __restrict__ cos_ptr,
|
||||
const scalar_t* __restrict__ sin_ptr, int rot_offset, int embed_dim) {
|
||||
int x_index, y_index;
|
||||
scalar_t cos, sin;
|
||||
if (IS_NEOX) {
|
||||
@ -39,17 +35,15 @@ inline __device__ void apply_token_rotary_embedding(
|
||||
|
||||
template <typename scalar_t, bool IS_NEOX>
|
||||
inline __device__ void apply_rotary_embedding(
|
||||
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
|
||||
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
|
||||
const scalar_t* cache_ptr,
|
||||
const int head_size,
|
||||
const int num_heads,
|
||||
const int num_kv_heads,
|
||||
const int rot_dim,
|
||||
const int token_idx,
|
||||
const int64_t query_stride,
|
||||
const int64_t key_stride)
|
||||
{
|
||||
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
|
||||
// head_size] or [num_tokens, num_heads,
|
||||
// head_size]
|
||||
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
|
||||
// head_size] or [num_tokens, num_kv_heads,
|
||||
// head_size]
|
||||
const scalar_t* cache_ptr, const int head_size, const int num_heads,
|
||||
const int num_kv_heads, const int rot_dim, const int token_idx,
|
||||
const int64_t query_stride, const int64_t key_stride) {
|
||||
const int embed_dim = rot_dim / 2;
|
||||
const scalar_t* cos_ptr = cache_ptr;
|
||||
const scalar_t* sin_ptr = cache_ptr + embed_dim;
|
||||
@ -59,8 +53,8 @@ inline __device__ void apply_rotary_embedding(
|
||||
const int head_idx = i / embed_dim;
|
||||
const int64_t token_head = token_idx * query_stride + head_idx * head_size;
|
||||
const int rot_offset = i % embed_dim;
|
||||
apply_token_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr,
|
||||
sin_ptr, rot_offset, embed_dim);
|
||||
apply_token_rotary_embedding<scalar_t, IS_NEOX>(
|
||||
query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
|
||||
}
|
||||
|
||||
const int nk = num_kv_heads * embed_dim;
|
||||
@ -68,59 +62,71 @@ inline __device__ void apply_rotary_embedding(
|
||||
const int head_idx = i / embed_dim;
|
||||
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
|
||||
const int rot_offset = i % embed_dim;
|
||||
apply_token_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr,
|
||||
sin_ptr, rot_offset, embed_dim);
|
||||
apply_token_rotary_embedding<scalar_t, IS_NEOX>(
|
||||
key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, bool IS_NEOX>
|
||||
__global__ void rotary_embedding_kernel(
|
||||
const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens]
|
||||
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
|
||||
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
|
||||
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
|
||||
const int rot_dim,
|
||||
const int64_t query_stride,
|
||||
const int64_t key_stride,
|
||||
const int num_heads,
|
||||
const int num_kv_heads,
|
||||
const int head_size) {
|
||||
const int64_t* __restrict__ positions, // [batch_size, seq_len] or
|
||||
// [num_tokens]
|
||||
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
|
||||
// head_size] or [num_tokens, num_heads,
|
||||
// head_size]
|
||||
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
|
||||
// head_size] or [num_tokens, num_kv_heads,
|
||||
// head_size]
|
||||
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
|
||||
// 2]
|
||||
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
|
||||
const int num_heads, const int num_kv_heads, const int head_size) {
|
||||
// Each thread block is responsible for one token.
|
||||
const int token_idx = blockIdx.x;
|
||||
int64_t pos = positions[token_idx];
|
||||
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
|
||||
|
||||
apply_rotary_embedding<scalar_t, IS_NEOX>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride);
|
||||
apply_rotary_embedding<scalar_t, IS_NEOX>(
|
||||
query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
|
||||
token_idx, query_stride, key_stride);
|
||||
}
|
||||
|
||||
template <typename scalar_t, bool IS_NEOX>
|
||||
__global__ void batched_rotary_embedding_kernel(
|
||||
const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens]
|
||||
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
|
||||
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
|
||||
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
|
||||
const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len] or [num_tokens]
|
||||
const int rot_dim,
|
||||
const int64_t query_stride,
|
||||
const int64_t key_stride,
|
||||
const int num_heads,
|
||||
const int num_kv_heads,
|
||||
const int head_size) {
|
||||
const int64_t* __restrict__ positions, // [batch_size, seq_len] or
|
||||
// [num_tokens]
|
||||
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
|
||||
// head_size] or [num_tokens, num_heads,
|
||||
// head_size]
|
||||
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
|
||||
// head_size] or [num_tokens, num_kv_heads,
|
||||
// head_size]
|
||||
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
|
||||
// 2]
|
||||
const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len]
|
||||
// or [num_tokens]
|
||||
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
|
||||
const int num_heads, const int num_kv_heads, const int head_size) {
|
||||
// Each thread block is responsible for one token.
|
||||
const int token_idx = blockIdx.x;
|
||||
int64_t pos = positions[token_idx];
|
||||
int64_t cos_sin_cache_offset = cos_sin_cache_offsets[token_idx];
|
||||
const scalar_t* cache_ptr = cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim;
|
||||
const scalar_t* cache_ptr =
|
||||
cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim;
|
||||
|
||||
apply_rotary_embedding<scalar_t, IS_NEOX>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride);
|
||||
apply_rotary_embedding<scalar_t, IS_NEOX>(
|
||||
query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
|
||||
token_idx, query_stride, key_stride);
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
void rotary_embedding(
|
||||
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
|
||||
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size]
|
||||
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size]
|
||||
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or
|
||||
// [num_tokens, num_heads * head_size]
|
||||
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
|
||||
// [num_tokens, num_kv_heads * head_size]
|
||||
int head_size,
|
||||
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
||||
bool is_neox) {
|
||||
@ -135,33 +141,18 @@ void rotary_embedding(
|
||||
dim3 block(std::min(num_heads * rot_dim / 2, 512));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
query.scalar_type(),
|
||||
"rotary_embedding",
|
||||
[&] {
|
||||
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] {
|
||||
if (is_neox) {
|
||||
vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
|
||||
positions.data_ptr<int64_t>(),
|
||||
query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(),
|
||||
cos_sin_cache.data_ptr<scalar_t>(),
|
||||
rot_dim,
|
||||
query_stride,
|
||||
key_stride,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_size);
|
||||
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(), rot_dim,
|
||||
query_stride, key_stride, num_heads, num_kv_heads, head_size);
|
||||
} else {
|
||||
vllm::rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>(
|
||||
positions.data_ptr<int64_t>(),
|
||||
query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(),
|
||||
cos_sin_cache.data_ptr<scalar_t>(),
|
||||
rot_dim,
|
||||
query_stride,
|
||||
key_stride,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
vllm::rotary_embedding_kernel<scalar_t, false>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
|
||||
rot_dim, query_stride, key_stride, num_heads, num_kv_heads,
|
||||
head_size);
|
||||
}
|
||||
});
|
||||
@ -173,12 +164,13 @@ and process in batched manner.
|
||||
*/
|
||||
void batched_rotary_embedding(
|
||||
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
|
||||
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size]
|
||||
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size]
|
||||
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or
|
||||
// [num_tokens, num_heads * head_size]
|
||||
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
|
||||
// [num_tokens, num_kv_heads * head_size]
|
||||
int head_size,
|
||||
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
||||
bool is_neox,
|
||||
int rot_dim,
|
||||
bool is_neox, int rot_dim,
|
||||
torch::Tensor& cos_sin_cache_offsets // [num_tokens]
|
||||
) {
|
||||
int64_t num_tokens = cos_sin_cache_offsets.size(0);
|
||||
@ -191,36 +183,21 @@ void batched_rotary_embedding(
|
||||
dim3 block(std::min(num_heads * rot_dim / 2, 512));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
query.scalar_type(),
|
||||
"rotary_embedding",
|
||||
[&] {
|
||||
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] {
|
||||
if (is_neox) {
|
||||
vllm::batched_rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
|
||||
positions.data_ptr<int64_t>(),
|
||||
query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(),
|
||||
cos_sin_cache.data_ptr<scalar_t>(),
|
||||
cos_sin_cache_offsets.data_ptr<int64_t>(),
|
||||
rot_dim,
|
||||
query_stride,
|
||||
key_stride,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_size);
|
||||
vllm::batched_rotary_embedding_kernel<scalar_t, true>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
|
||||
cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
|
||||
key_stride, num_heads, num_kv_heads, head_size);
|
||||
} else {
|
||||
vllm::batched_rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>(
|
||||
positions.data_ptr<int64_t>(),
|
||||
query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(),
|
||||
cos_sin_cache.data_ptr<scalar_t>(),
|
||||
cos_sin_cache_offsets.data_ptr<int64_t>(),
|
||||
rot_dim,
|
||||
query_stride,
|
||||
key_stride,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_size);
|
||||
vllm::batched_rotary_embedding_kernel<scalar_t, false>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
|
||||
cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
|
||||
key_stride, num_heads, num_kv_heads, head_size);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
114
csrc/pybind.cpp
114
csrc/pybind.cpp
@ -8,114 +8,85 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
pybind11::module ops = m.def_submodule("ops", "vLLM custom operators");
|
||||
|
||||
// Attention ops
|
||||
ops.def(
|
||||
"paged_attention_v1",
|
||||
&paged_attention_v1,
|
||||
"Compute the attention between an input query and the cached keys/values using PagedAttention.");
|
||||
ops.def(
|
||||
"paged_attention_v2",
|
||||
&paged_attention_v2,
|
||||
"PagedAttention V2.");
|
||||
ops.def("paged_attention_v1", &paged_attention_v1,
|
||||
"Compute the attention between an input query and the cached "
|
||||
"keys/values using PagedAttention.");
|
||||
ops.def("paged_attention_v2", &paged_attention_v2, "PagedAttention V2.");
|
||||
|
||||
// Activation ops
|
||||
ops.def(
|
||||
"silu_and_mul",
|
||||
&silu_and_mul,
|
||||
"Activation function used in SwiGLU.");
|
||||
ops.def(
|
||||
"gelu_and_mul",
|
||||
&gelu_and_mul,
|
||||
ops.def("silu_and_mul", &silu_and_mul, "Activation function used in SwiGLU.");
|
||||
ops.def("gelu_and_mul", &gelu_and_mul,
|
||||
"Activation function used in GeGLU with `none` approximation.");
|
||||
ops.def(
|
||||
"gelu_tanh_and_mul",
|
||||
&gelu_tanh_and_mul,
|
||||
ops.def("gelu_tanh_and_mul", &gelu_tanh_and_mul,
|
||||
"Activation function used in GeGLU with `tanh` approximation.");
|
||||
ops.def(
|
||||
"gelu_new",
|
||||
&gelu_new,
|
||||
"GELU implementation used in GPT-2.");
|
||||
ops.def(
|
||||
"gelu_fast",
|
||||
&gelu_fast,
|
||||
"Approximate GELU implementation.");
|
||||
ops.def("gelu_new", &gelu_new, "GELU implementation used in GPT-2.");
|
||||
ops.def("gelu_fast", &gelu_fast, "Approximate GELU implementation.");
|
||||
|
||||
// Layernorm
|
||||
ops.def(
|
||||
"rms_norm",
|
||||
&rms_norm,
|
||||
ops.def("rms_norm", &rms_norm,
|
||||
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
|
||||
|
||||
ops.def(
|
||||
"fused_add_rms_norm",
|
||||
&fused_add_rms_norm,
|
||||
ops.def("fused_add_rms_norm", &fused_add_rms_norm,
|
||||
"In-place fused Add and RMS Normalization");
|
||||
|
||||
// Rotary embedding
|
||||
ops.def(
|
||||
"rotary_embedding",
|
||||
&rotary_embedding,
|
||||
ops.def("rotary_embedding", &rotary_embedding,
|
||||
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
|
||||
|
||||
ops.def(
|
||||
"batched_rotary_embedding",
|
||||
&batched_rotary_embedding,
|
||||
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key (supports multiple loras)");
|
||||
ops.def("batched_rotary_embedding", &batched_rotary_embedding,
|
||||
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key "
|
||||
"(supports multiple loras)");
|
||||
|
||||
// Quantization ops
|
||||
#ifndef USE_ROCM
|
||||
ops.def("aqlm_gemm", &aqlm_gemm, "Quantized GEMM for AQLM");
|
||||
ops.def("aqlm_dequant", &aqlm_dequant, "Decompression method for AQLM");
|
||||
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
|
||||
ops.def("marlin_gemm", &marlin_gemm, "Marlin (Dense) Optimized Quantized GEMM for GPTQ");
|
||||
ops.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm, "Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ");
|
||||
ops.def("gptq_marlin_gemm", &gptq_marlin_gemm, "gptq_marlin Optimized Quantized GEMM for GPTQ");
|
||||
ops.def("gptq_marlin_repack", &gptq_marlin_repack, "gptq_marlin repack from GPTQ");
|
||||
ops.def("marlin_gemm", &marlin_gemm,
|
||||
"Marlin (Dense) Optimized Quantized GEMM for GPTQ");
|
||||
ops.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm,
|
||||
"Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ");
|
||||
ops.def("gptq_marlin_gemm", &gptq_marlin_gemm,
|
||||
"gptq_marlin Optimized Quantized GEMM for GPTQ");
|
||||
ops.def("gptq_marlin_repack", &gptq_marlin_repack,
|
||||
"gptq_marlin repack from GPTQ");
|
||||
ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
|
||||
ops.def("cutlass_scaled_mm_dq", &cutlass_scaled_mm_dq, "CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column quantization.");
|
||||
ops.def("cutlass_scaled_mm_dq", &cutlass_scaled_mm_dq,
|
||||
"CUTLASS w8a8 GEMM, supporting symmetric per-tensor or "
|
||||
"per-row/column quantization.");
|
||||
#endif
|
||||
|
||||
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
|
||||
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
|
||||
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
|
||||
ops.def("static_scaled_fp8_quant", &static_scaled_fp8_quant, "Compute FP8 quantized tensor for given scaling factor");
|
||||
ops.def("dynamic_scaled_fp8_quant", &dynamic_scaled_fp8_quant, "Compute FP8 quantized tensor and scaling factor");
|
||||
ops.def(
|
||||
"moe_align_block_size",
|
||||
&moe_align_block_size,
|
||||
"Aligning the number of tokens to be processed by each expert such that it is divisible by the block size.");
|
||||
ops.def("static_scaled_fp8_quant", &static_scaled_fp8_quant,
|
||||
"Compute FP8 quantized tensor for given scaling factor");
|
||||
ops.def("dynamic_scaled_fp8_quant", &dynamic_scaled_fp8_quant,
|
||||
"Compute FP8 quantized tensor and scaling factor");
|
||||
ops.def("moe_align_block_size", &moe_align_block_size,
|
||||
"Aligning the number of tokens to be processed by each expert such "
|
||||
"that it is divisible by the block size.");
|
||||
|
||||
// Cache ops
|
||||
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
|
||||
cache_ops.def(
|
||||
"swap_blocks",
|
||||
&swap_blocks,
|
||||
cache_ops.def("swap_blocks", &swap_blocks,
|
||||
"Swap in (out) the cache blocks from src to dst");
|
||||
cache_ops.def(
|
||||
"copy_blocks",
|
||||
©_blocks,
|
||||
cache_ops.def("copy_blocks", ©_blocks,
|
||||
"Copy the cache blocks from src to dst");
|
||||
cache_ops.def(
|
||||
"reshape_and_cache",
|
||||
&reshape_and_cache,
|
||||
cache_ops.def("reshape_and_cache", &reshape_and_cache,
|
||||
"Reshape the key and value tensors and cache them");
|
||||
cache_ops.def(
|
||||
"reshape_and_cache_flash",
|
||||
&reshape_and_cache_flash,
|
||||
cache_ops.def("reshape_and_cache_flash", &reshape_and_cache_flash,
|
||||
"Reshape the key and value tensors and cache them");
|
||||
cache_ops.def(
|
||||
"convert_fp8",
|
||||
&convert_fp8,
|
||||
cache_ops.def("convert_fp8", &convert_fp8,
|
||||
"Convert the key and value cache to fp8 data type");
|
||||
|
||||
// Cuda utils
|
||||
pybind11::module cuda_utils = m.def_submodule("cuda_utils", "vLLM cuda utils");
|
||||
cuda_utils.def(
|
||||
"get_device_attribute",
|
||||
&get_device_attribute,
|
||||
pybind11::module cuda_utils =
|
||||
m.def_submodule("cuda_utils", "vLLM cuda utils");
|
||||
cuda_utils.def("get_device_attribute", &get_device_attribute,
|
||||
"Gets the specified device attribute.");
|
||||
|
||||
cuda_utils.def(
|
||||
"get_max_shared_memory_per_block_device_attribute",
|
||||
cuda_utils.def("get_max_shared_memory_per_block_device_attribute",
|
||||
&get_max_shared_memory_per_block_device_attribute,
|
||||
"Gets the maximum shared memory per block device attribute.");
|
||||
|
||||
@ -134,5 +105,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
custom_ar.def("register_graph_buffers", ®ister_graph_buffers,
|
||||
"register_graph_buffers");
|
||||
#endif
|
||||
|
||||
}
|
||||
|
@ -25,30 +25,26 @@
|
||||
#include <iostream>
|
||||
#include <cstdlib>
|
||||
|
||||
|
||||
namespace vllm {
|
||||
namespace aqlm {
|
||||
|
||||
__global__ void Code1x16MatVec(
|
||||
const int4* __restrict__ A,
|
||||
const int4* __restrict__ B,
|
||||
int4* __restrict__ C,
|
||||
const int4* __restrict__ codebook,
|
||||
const int prob_m,
|
||||
const int4* __restrict__ A, const int4* __restrict__ B,
|
||||
int4* __restrict__ C, const int4* __restrict__ codebook, const int prob_m,
|
||||
const int prob_k,
|
||||
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long.
|
||||
const int4 codebook_a_sizes, // cumulative sizes of A spanning each
|
||||
// codebook, at most 3 long.
|
||||
const int codebook_stride // as int4.
|
||||
) {
|
||||
int a_gl_stride = prob_k / 8 / 8;
|
||||
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
|
||||
bool pred = a_gl_rd < prob_m;
|
||||
|
||||
if (pred)
|
||||
{
|
||||
// advance to the correct codebook, this easy because we only multiply one column of the codebook.
|
||||
if (pred) {
|
||||
// advance to the correct codebook, this easy because we only multiply one
|
||||
// column of the codebook.
|
||||
auto codebook_size = &codebook_a_sizes.x;
|
||||
while (a_gl_rd >= *codebook_size)
|
||||
{
|
||||
while (a_gl_rd >= *codebook_size) {
|
||||
codebook += codebook_stride;
|
||||
++codebook_size;
|
||||
}
|
||||
@ -67,8 +63,7 @@ __global__ void Code1x16MatVec(
|
||||
// We pad shared memory to avoid bank conflicts during reads
|
||||
__syncthreads();
|
||||
for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) {
|
||||
if (b_gl_rd + i < prob_k / 8)
|
||||
sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i];
|
||||
if (b_gl_rd + i < prob_k / 8) sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i];
|
||||
}
|
||||
__syncthreads();
|
||||
b_gl_rd += 32 * 8;
|
||||
@ -79,19 +74,16 @@ __global__ void Code1x16MatVec(
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) {
|
||||
uint32_t dec[4];
|
||||
// We bypass the L1 cache to avoid massive amounts of memory streaming that doesn't
|
||||
// actually help us; this brings > 2x speedup.
|
||||
asm volatile (
|
||||
"ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
|
||||
// We bypass the L1 cache to avoid massive amounts of memory streaming
|
||||
// that doesn't actually help us; this brings > 2x speedup.
|
||||
asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3])
|
||||
: "l"((void*) &codebook[enc[i]])
|
||||
);
|
||||
: "l"((void*)&codebook[enc[i]]));
|
||||
half2* a = reinterpret_cast<half2*>(&dec);
|
||||
half2* b = reinterpret_cast<half2*>(&sh_b[b_sh_rd]);
|
||||
half2 res2 = {};
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; j++)
|
||||
res2 = __hfma2(a[j], b[j], res2);
|
||||
for (int j = 0; j < 4; j++) res2 = __hfma2(a[j], b[j], res2);
|
||||
res += __half2float(res2.x) + __half2float(res2.y);
|
||||
b_sh_rd++;
|
||||
}
|
||||
@ -101,21 +93,18 @@ __global__ void Code1x16MatVec(
|
||||
|
||||
if (pred) {
|
||||
#pragma unroll
|
||||
for (int i = 16; i > 0; i /= 2)
|
||||
res += __shfl_down_sync(0xffffffff, res, i);
|
||||
for (int i = 16; i > 0; i /= 2) res += __shfl_down_sync(0xffffffff, res, i);
|
||||
if (threadIdx.x % 32 == 0)
|
||||
reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res);
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void Code2x8MatVec(
|
||||
const int4* __restrict__ A,
|
||||
const int4* __restrict__ B,
|
||||
int4* __restrict__ C,
|
||||
const int4* __restrict__ codebook,
|
||||
int prob_m,
|
||||
const int4* __restrict__ A, const int4* __restrict__ B,
|
||||
int4* __restrict__ C, const int4* __restrict__ codebook, int prob_m,
|
||||
int prob_k,
|
||||
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long.
|
||||
const int4 codebook_a_sizes, // cumulative sizes of A spanning each
|
||||
// codebook, at most 3 long.
|
||||
const int codebook_stride // as int4.
|
||||
|
||||
) {
|
||||
@ -123,12 +112,11 @@ __global__ void Code2x8MatVec(
|
||||
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
|
||||
bool pred = a_gl_rd < prob_m;
|
||||
|
||||
if (pred)
|
||||
{
|
||||
// advance to the correct codebook, this easy because we only multiply one column of the codebook.
|
||||
if (pred) {
|
||||
// advance to the correct codebook, this easy because we only multiply one
|
||||
// column of the codebook.
|
||||
auto codebook_size = &codebook_a_sizes.x;
|
||||
while (a_gl_rd >= *codebook_size)
|
||||
{
|
||||
while (a_gl_rd >= *codebook_size) {
|
||||
codebook += codebook_stride;
|
||||
++codebook_size;
|
||||
}
|
||||
@ -149,8 +137,7 @@ __global__ void Code2x8MatVec(
|
||||
for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) {
|
||||
int4 dec = codebook[i];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 8; j++)
|
||||
sh_code[8 * i + (j + lane) % 8] = dec;
|
||||
for (int j = 0; j < 8; j++) sh_code[8 * i + (j + lane) % 8] = dec;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
@ -161,8 +148,7 @@ __global__ void Code2x8MatVec(
|
||||
// We pad shared memory to avoid bank conflicts during reads
|
||||
__syncthreads();
|
||||
for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) {
|
||||
if (b_gl_rd + i < prob_k / 8)
|
||||
sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i];
|
||||
if (b_gl_rd + i < prob_k / 8) sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i];
|
||||
}
|
||||
__syncthreads();
|
||||
b_gl_rd += 32 * 8;
|
||||
@ -172,8 +158,10 @@ __global__ void Code2x8MatVec(
|
||||
const uint8_t* enc = reinterpret_cast<const uint8_t*>(&A[a_gl_rd]);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) {
|
||||
half2* a0 = reinterpret_cast<half2*>(&sh_code0[8 * enc[2 * i + 0] + lane]);
|
||||
half2* a1 = reinterpret_cast<half2*>(&sh_code1[8 * enc[2 * i + 1] + lane]);
|
||||
half2* a0 =
|
||||
reinterpret_cast<half2*>(&sh_code0[8 * enc[2 * i + 0] + lane]);
|
||||
half2* a1 =
|
||||
reinterpret_cast<half2*>(&sh_code1[8 * enc[2 * i + 1] + lane]);
|
||||
half2* b = reinterpret_cast<half2*>(&sh_b[b_sh_rd]);
|
||||
half2 res2 = {};
|
||||
#pragma unroll
|
||||
@ -188,33 +176,28 @@ __global__ void Code2x8MatVec(
|
||||
|
||||
if (pred) {
|
||||
#pragma unroll
|
||||
for (int i = 16; i > 0; i /= 2)
|
||||
res += __shfl_down_sync(0xffffffff, res, i);
|
||||
for (int i = 16; i > 0; i /= 2) res += __shfl_down_sync(0xffffffff, res, i);
|
||||
if (threadIdx.x % 32 == 0)
|
||||
reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
__global__ void Code1x16Dequant(
|
||||
const int4* __restrict__ A,
|
||||
int4* __restrict__ C,
|
||||
const int4* __restrict__ codebook,
|
||||
int prob_m,
|
||||
int prob_k,
|
||||
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long, sums to m.
|
||||
const int4* __restrict__ A, int4* __restrict__ C,
|
||||
const int4* __restrict__ codebook, int prob_m, int prob_k,
|
||||
const int4 codebook_a_sizes, // cumulative sizes of A spanning each
|
||||
// codebook, at most 3 long, sums to m.
|
||||
const int codebook_stride // as int4
|
||||
) {
|
||||
int a_gl_stride = prob_k / 8 / 8;
|
||||
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
|
||||
bool pred = a_gl_rd < prob_m;
|
||||
|
||||
if (pred)
|
||||
{
|
||||
// advance to the correct codebook, this easy because we only multiply one column of the codebook.
|
||||
if (pred) {
|
||||
// advance to the correct codebook, this easy because we only multiply one
|
||||
// column of the codebook.
|
||||
auto codebook_size = &codebook_a_sizes.x;
|
||||
while (a_gl_rd >= *codebook_size)
|
||||
{
|
||||
while (a_gl_rd >= *codebook_size) {
|
||||
codebook += codebook_stride;
|
||||
++codebook_size;
|
||||
}
|
||||
@ -235,13 +218,11 @@ __global__ void Code1x16Dequant(
|
||||
for (int i = 0; i < 8; i++) {
|
||||
int4 chunk;
|
||||
auto dec = reinterpret_cast<uint32_t*>(&chunk);
|
||||
// We bypass the L1 cache to avoid massive amounts of memory streaming that doesn't
|
||||
// actually help us; this brings > 2x speedup.
|
||||
asm volatile (
|
||||
"ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
|
||||
// We bypass the L1 cache to avoid massive amounts of memory streaming
|
||||
// that doesn't actually help us; this brings > 2x speedup.
|
||||
asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3])
|
||||
: "l"((void*) &codebook[enc[i]])
|
||||
);
|
||||
: "l"((void*)&codebook[enc[i]]));
|
||||
|
||||
C[a_gl_rd * 8 + i] = chunk;
|
||||
}
|
||||
@ -250,26 +231,23 @@ __global__ void Code1x16Dequant(
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
__global__ void Code2x8Dequant(
|
||||
const int4* __restrict__ A,
|
||||
int4* __restrict__ C,
|
||||
const int4* __restrict__ codebook,
|
||||
int prob_m,
|
||||
int prob_k,
|
||||
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long, corresponds to cols.
|
||||
const int4* __restrict__ A, int4* __restrict__ C,
|
||||
const int4* __restrict__ codebook, int prob_m, int prob_k,
|
||||
const int4
|
||||
codebook_a_sizes, // cumulative sizes of A spanning each codebook, at
|
||||
// most 3 long, corresponds to cols.
|
||||
const int codebook_stride // as int4
|
||||
) {
|
||||
int a_gl_stride = prob_k / 8 / 8;
|
||||
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
|
||||
bool pred = a_gl_rd < prob_m;
|
||||
|
||||
if (pred)
|
||||
{
|
||||
// advance to the correct codebook, this easy because we only multiply one column of the codebook.
|
||||
if (pred) {
|
||||
// advance to the correct codebook, this easy because we only multiply one
|
||||
// column of the codebook.
|
||||
auto codebook_size = &codebook_a_sizes.x;
|
||||
while (a_gl_rd >= *codebook_size)
|
||||
{
|
||||
while (a_gl_rd >= *codebook_size) {
|
||||
codebook += codebook_stride;
|
||||
++codebook_size;
|
||||
}
|
||||
@ -291,8 +269,7 @@ __global__ void Code2x8Dequant(
|
||||
for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) {
|
||||
int4 dec = codebook[i];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 8; j++)
|
||||
sh_code[8 * i + (j + lane) % 8] = dec;
|
||||
for (int j = 0; j < 8; j++) sh_code[8 * i + (j + lane) % 8] = dec;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
@ -305,8 +282,10 @@ __global__ void Code2x8Dequant(
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) {
|
||||
int4 chunk;
|
||||
half2* a0 = reinterpret_cast<half2*>(&sh_code0[8 * enc[2 * i + 0] + lane]);
|
||||
half2* a1 = reinterpret_cast<half2*>(&sh_code1[8 * enc[2 * i + 1] + lane]);
|
||||
half2* a0 =
|
||||
reinterpret_cast<half2*>(&sh_code0[8 * enc[2 * i + 0] + lane]);
|
||||
half2* a1 =
|
||||
reinterpret_cast<half2*>(&sh_code1[8 * enc[2 * i + 1] + lane]);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; j++)
|
||||
reinterpret_cast<half2*>(&chunk)[j] = __hadd2(a0[j], a1[j]);
|
||||
@ -317,22 +296,15 @@ __global__ void Code2x8Dequant(
|
||||
}
|
||||
}
|
||||
|
||||
inline int ceildiv(int a, int b) {
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
inline int ceildiv(int a, int b) { return (a + b - 1) / b; }
|
||||
|
||||
const int THREAD_M = 16;
|
||||
|
||||
void code1x16_matvec_cuda(
|
||||
const void* __restrict__ A,
|
||||
const void* __restrict__ B,
|
||||
void* __restrict__ C,
|
||||
const void* __restrict__ codebook,
|
||||
int prob_m,
|
||||
int prob_k,
|
||||
const int4 codebook_a_sizes,
|
||||
const int codebook_stride
|
||||
) {
|
||||
void code1x16_matvec_cuda(const void* __restrict__ A,
|
||||
const void* __restrict__ B, void* __restrict__ C,
|
||||
const void* __restrict__ codebook, int prob_m,
|
||||
int prob_k, const int4 codebook_a_sizes,
|
||||
const int codebook_stride) {
|
||||
int sms;
|
||||
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
|
||||
int waves = 0;
|
||||
@ -346,27 +318,15 @@ void code1x16_matvec_cuda(
|
||||
int threads = 32 * thread_m;
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
Code1x16MatVec<<<blocks, threads, 16 * 32 * 9, stream>>>(
|
||||
(const int4*) A,
|
||||
(const int4*) B,
|
||||
(int4*) C,
|
||||
(const int4*) codebook,
|
||||
prob_m,
|
||||
prob_k,
|
||||
codebook_a_sizes,
|
||||
codebook_stride
|
||||
);
|
||||
(const int4*)A, (const int4*)B, (int4*)C, (const int4*)codebook, prob_m,
|
||||
prob_k, codebook_a_sizes, codebook_stride);
|
||||
}
|
||||
|
||||
void code2x8_matvec_cuda(
|
||||
const void* __restrict__ A,
|
||||
const void* __restrict__ B,
|
||||
void code2x8_matvec_cuda(const void* __restrict__ A, const void* __restrict__ B,
|
||||
void* __restrict__ C,
|
||||
const void* __restrict__ codebook,
|
||||
int prob_m,
|
||||
int prob_k,
|
||||
const int4 codebook_a_sizes,
|
||||
const int codebook_stride
|
||||
) {
|
||||
const void* __restrict__ codebook, int prob_m,
|
||||
int prob_k, const int4 codebook_a_sizes,
|
||||
const int codebook_stride) {
|
||||
int sms;
|
||||
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
|
||||
int waves = 0;
|
||||
@ -379,29 +339,19 @@ void code2x8_matvec_cuda(
|
||||
int blocks = ceildiv(prob_m, thread_m);
|
||||
int threads = 32 * thread_m;
|
||||
int shared = 16 * (2 * 256 * 8 + 32 * 9);
|
||||
cudaFuncSetAttribute(
|
||||
Code2x8MatVec, cudaFuncAttributeMaxDynamicSharedMemorySize, shared
|
||||
);
|
||||
cudaFuncSetAttribute(Code2x8MatVec,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, shared);
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
Code2x8MatVec<<<blocks, threads, shared, stream>>>(
|
||||
(const int4*) A,
|
||||
(const int4*) B,
|
||||
(int4*) C,
|
||||
(const int4*) codebook,
|
||||
prob_m,
|
||||
prob_k,
|
||||
codebook_a_sizes,
|
||||
codebook_stride
|
||||
);
|
||||
(const int4*)A, (const int4*)B, (int4*)C, (const int4*)codebook, prob_m,
|
||||
prob_k, codebook_a_sizes, codebook_stride);
|
||||
}
|
||||
|
||||
void code1x16_dequant_cuda(
|
||||
const void* __restrict__ A,
|
||||
void* __restrict__ C,
|
||||
const void* __restrict__ codebook,
|
||||
int prob_m,
|
||||
int prob_k,
|
||||
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long.
|
||||
const void* __restrict__ A, void* __restrict__ C,
|
||||
const void* __restrict__ codebook, int prob_m, int prob_k,
|
||||
const int4 codebook_a_sizes, // cumulative sizes of A spanning each
|
||||
// codebook, at most 3 long.
|
||||
const int codebook_stride // as int4.
|
||||
) {
|
||||
int sms;
|
||||
@ -417,24 +367,20 @@ void code1x16_dequant_cuda(
|
||||
int threads = 32 * thread_m;
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
Code1x16Dequant<<<blocks, threads, 0, stream>>>(
|
||||
(const int4*) A,
|
||||
(int4*) C,
|
||||
(const int4*) codebook,
|
||||
prob_m,
|
||||
prob_k,
|
||||
codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long.
|
||||
(const int4*)A, (int4*)C, (const int4*)codebook, prob_m, prob_k,
|
||||
codebook_a_sizes, // cumulative sizes of A spanning each codebook, at
|
||||
// most 3 long.
|
||||
codebook_stride // as int4.
|
||||
);
|
||||
}
|
||||
|
||||
// Dequantizes the code and codebook into weights.
|
||||
void code2x8_dequant_cuda(
|
||||
const void* __restrict__ A,
|
||||
void* __restrict__ C,
|
||||
const void* __restrict__ codebook,
|
||||
int prob_m,
|
||||
int prob_k,
|
||||
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long, corresponds to cols.
|
||||
const void* __restrict__ A, void* __restrict__ C,
|
||||
const void* __restrict__ codebook, int prob_m, int prob_k,
|
||||
const int4
|
||||
codebook_a_sizes, // cumulative sizes of A spanning each codebook, at
|
||||
// most 3 long, corresponds to cols.
|
||||
const int codebook_stride // as int4
|
||||
) {
|
||||
int sms;
|
||||
@ -451,50 +397,33 @@ void code2x8_dequant_cuda(
|
||||
int shared = 16 * (2 * 256 * 8 + 32 * 9);
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
|
||||
cudaFuncSetAttribute(
|
||||
Code2x8Dequant, cudaFuncAttributeMaxDynamicSharedMemorySize, shared
|
||||
);
|
||||
cudaFuncSetAttribute(Code2x8Dequant,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, shared);
|
||||
Code2x8Dequant<<<blocks, threads, shared, stream>>>(
|
||||
(const int4*) A,
|
||||
(int4*) C,
|
||||
(const int4*) codebook,
|
||||
prob_m,
|
||||
prob_k,
|
||||
codebook_a_sizes,
|
||||
codebook_stride
|
||||
);
|
||||
(const int4*)A, (int4*)C, (const int4*)codebook, prob_m, prob_k,
|
||||
codebook_a_sizes, codebook_stride);
|
||||
}
|
||||
|
||||
int codebook_stride(const torch::Tensor& codebooks)
|
||||
{
|
||||
int codebook_stride(const torch::Tensor& codebooks) {
|
||||
return codebooks.stride(0) * codebooks.element_size() / sizeof(int4);
|
||||
}
|
||||
|
||||
void code1x16_matvec(
|
||||
const torch::Tensor& A,
|
||||
const torch::Tensor& B,
|
||||
torch::Tensor& C,
|
||||
const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& C,
|
||||
const torch::Tensor& codebook,
|
||||
const int4 codebook_a_sizes // cumulative sizes of A spanning each codebook, at most 3 long.
|
||||
const int4 codebook_a_sizes // cumulative sizes of A spanning each
|
||||
// codebook, at most 3 long.
|
||||
) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
|
||||
int prob_m = C.size(0);
|
||||
int prob_k = B.size(0);
|
||||
|
||||
code1x16_matvec_cuda(
|
||||
A.data_ptr(),
|
||||
B.data_ptr(),
|
||||
C.data_ptr(),
|
||||
codebook.data_ptr(),
|
||||
prob_m,
|
||||
prob_k,
|
||||
codebook_a_sizes,
|
||||
codebook_stride(codebook)
|
||||
);
|
||||
code1x16_matvec_cuda(A.data_ptr(), B.data_ptr(), C.data_ptr(),
|
||||
codebook.data_ptr(), prob_m, prob_k, codebook_a_sizes,
|
||||
codebook_stride(codebook));
|
||||
}
|
||||
|
||||
torch::Tensor code1x16_matmat(
|
||||
const torch::Tensor& input,
|
||||
torch::Tensor code1x16_matmat(const torch::Tensor& input,
|
||||
const torch::Tensor& codes,
|
||||
const torch::Tensor& codebooks,
|
||||
const torch::Tensor& scales,
|
||||
@ -503,22 +432,15 @@ torch::Tensor code1x16_matmat(
|
||||
auto input_sizes = input.sizes();
|
||||
auto out_features = codes.size(0) * codebooks.size(2);
|
||||
auto flat_input = input.reshape({-1, input.size(-1)});
|
||||
auto flat_output = torch::empty({flat_input.size(0), out_features},
|
||||
torch::TensorOptions()
|
||||
.dtype(input.dtype())
|
||||
.device(input.device())
|
||||
);
|
||||
auto flat_output = torch::empty(
|
||||
{flat_input.size(0), out_features},
|
||||
torch::TensorOptions().dtype(input.dtype()).device(input.device()));
|
||||
|
||||
for (int i = 0; i < flat_input.size(0); ++i) {
|
||||
auto input_vec = flat_input.index({i});
|
||||
auto output_vec = flat_output.index({i});
|
||||
code1x16_matvec(
|
||||
codes.squeeze(2),
|
||||
input_vec,
|
||||
output_vec,
|
||||
codebooks,
|
||||
codebook_a_sizes
|
||||
);
|
||||
code1x16_matvec(codes.squeeze(2), input_vec, output_vec, codebooks,
|
||||
codebook_a_sizes);
|
||||
}
|
||||
flat_output *= scales.flatten().unsqueeze(0);
|
||||
|
||||
@ -533,55 +455,35 @@ torch::Tensor code1x16_matmat(
|
||||
return output;
|
||||
}
|
||||
|
||||
void code2x8_matvec(
|
||||
const torch::Tensor& A,
|
||||
const torch::Tensor& B,
|
||||
torch::Tensor& C,
|
||||
const torch::Tensor& codebook,
|
||||
const int4 codebook_a_sizes
|
||||
) {
|
||||
void code2x8_matvec(const torch::Tensor& A, const torch::Tensor& B,
|
||||
torch::Tensor& C, const torch::Tensor& codebook,
|
||||
const int4 codebook_a_sizes) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
|
||||
int prob_m = C.size(0);
|
||||
int prob_k = B.size(0);
|
||||
code2x8_matvec_cuda(
|
||||
A.data_ptr(),
|
||||
B.data_ptr(),
|
||||
C.data_ptr(),
|
||||
codebook.data_ptr(),
|
||||
prob_m,
|
||||
prob_k,
|
||||
codebook_a_sizes,
|
||||
2 * codebook_stride(codebook)
|
||||
);
|
||||
code2x8_matvec_cuda(A.data_ptr(), B.data_ptr(), C.data_ptr(),
|
||||
codebook.data_ptr(), prob_m, prob_k, codebook_a_sizes,
|
||||
2 * codebook_stride(codebook));
|
||||
}
|
||||
|
||||
torch::Tensor code2x8_matmat(
|
||||
const torch::Tensor& input,
|
||||
torch::Tensor code2x8_matmat(const torch::Tensor& input,
|
||||
const torch::Tensor& codes,
|
||||
const torch::Tensor& codebooks,
|
||||
const torch::Tensor& scales,
|
||||
const int4 codebook_a_sizes,
|
||||
const std::optional<torch::Tensor>& bias
|
||||
) {
|
||||
const std::optional<torch::Tensor>& bias) {
|
||||
auto input_sizes = input.sizes();
|
||||
auto out_features = codes.size(0) * codebooks.size(2);
|
||||
auto flat_input = input.reshape({-1, input.size(-1)});
|
||||
auto flat_output = torch::empty({flat_input.size(0), out_features},
|
||||
torch::TensorOptions()
|
||||
.dtype(input.dtype())
|
||||
.device(input.device())
|
||||
);
|
||||
auto flat_output = torch::empty(
|
||||
{flat_input.size(0), out_features},
|
||||
torch::TensorOptions().dtype(input.dtype()).device(input.device()));
|
||||
|
||||
for (int i = 0; i < flat_input.size(0); ++i) {
|
||||
auto input_vec = flat_input.index({i});
|
||||
auto output_vec = flat_output.index({i});
|
||||
code2x8_matvec(
|
||||
codes.squeeze(2),
|
||||
input_vec,
|
||||
output_vec,
|
||||
codebooks,
|
||||
codebook_a_sizes
|
||||
);
|
||||
code2x8_matvec(codes.squeeze(2), input_vec, output_vec, codebooks,
|
||||
codebook_a_sizes);
|
||||
}
|
||||
flat_output *= scales.flatten().unsqueeze(0);
|
||||
if (bias.has_value()) {
|
||||
@ -596,21 +498,18 @@ torch::Tensor code2x8_matmat(
|
||||
}
|
||||
|
||||
// Accumulate the partition sizes.
|
||||
int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes)
|
||||
{
|
||||
int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes) {
|
||||
int4 cumulative_sizes;
|
||||
auto cumulative_size = &cumulative_sizes.x;
|
||||
int i = 0;
|
||||
int last = 0;
|
||||
assert(codebook_partition_sizes.size(0) <= 4);
|
||||
for (; i < codebook_partition_sizes.size(0); ++i, ++cumulative_size)
|
||||
{
|
||||
for (; i < codebook_partition_sizes.size(0); ++i, ++cumulative_size) {
|
||||
*cumulative_size = codebook_partition_sizes[i].item<int>() + last;
|
||||
last = *cumulative_size;
|
||||
}
|
||||
// fill in the rest with unreachable.
|
||||
for (; i < 4; ++i, ++cumulative_size)
|
||||
{
|
||||
for (; i < 4; ++i, ++cumulative_size) {
|
||||
*cumulative_size = last * 10;
|
||||
}
|
||||
return cumulative_sizes;
|
||||
@ -619,41 +518,36 @@ int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes)
|
||||
} // namespace aqlm
|
||||
} // namespace vllm
|
||||
|
||||
|
||||
torch::Tensor aqlm_gemm(
|
||||
const torch::Tensor& input,
|
||||
const torch::Tensor& codes,
|
||||
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
|
||||
const torch::Tensor& codebooks,
|
||||
const torch::Tensor& scales,
|
||||
const torch::Tensor& codebook_partition_sizes,
|
||||
const std::optional<torch::Tensor>& bias
|
||||
)
|
||||
{
|
||||
int4 cumulative_sizes = vllm::aqlm::accumulate_sizes(codebook_partition_sizes);
|
||||
const std::optional<torch::Tensor>& bias) {
|
||||
int4 cumulative_sizes =
|
||||
vllm::aqlm::accumulate_sizes(codebook_partition_sizes);
|
||||
|
||||
int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0);
|
||||
int const entries = codebooks.size(1);
|
||||
|
||||
if (nbooks == 1 && entries == (1 << 16))
|
||||
{
|
||||
return vllm::aqlm::code1x16_matmat(input, codes, codebooks, scales, cumulative_sizes, bias);
|
||||
if (nbooks == 1 && entries == (1 << 16)) {
|
||||
return vllm::aqlm::code1x16_matmat(input, codes, codebooks, scales,
|
||||
cumulative_sizes, bias);
|
||||
}
|
||||
if (nbooks == 2 && entries == (1 << 8))
|
||||
{
|
||||
return vllm::aqlm::code2x8_matmat(input, codes, codebooks, scales, cumulative_sizes, bias);
|
||||
if (nbooks == 2 && entries == (1 << 8)) {
|
||||
return vllm::aqlm::code2x8_matmat(input, codes, codebooks, scales,
|
||||
cumulative_sizes, bias);
|
||||
}
|
||||
|
||||
TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, " entries is not currently supported.")
|
||||
TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries,
|
||||
" entries is not currently supported.")
|
||||
return {};
|
||||
}
|
||||
|
||||
torch::Tensor aqlm_dequant(
|
||||
const torch::Tensor& codes,
|
||||
torch::Tensor aqlm_dequant(const torch::Tensor& codes,
|
||||
const torch::Tensor& codebooks,
|
||||
const torch::Tensor& codebook_partition_sizes
|
||||
)
|
||||
{
|
||||
int4 cumulative_sizes = vllm::aqlm::accumulate_sizes(codebook_partition_sizes);
|
||||
const torch::Tensor& codebook_partition_sizes) {
|
||||
int4 cumulative_sizes =
|
||||
vllm::aqlm::accumulate_sizes(codebook_partition_sizes);
|
||||
|
||||
int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0);
|
||||
int const entries = codebooks.size(1);
|
||||
@ -670,43 +564,35 @@ torch::Tensor aqlm_dequant(
|
||||
auto weights = torch::empty({out_features, in_features},
|
||||
torch::TensorOptions()
|
||||
.dtype(codebooks.dtype())
|
||||
.device(codebooks.device())
|
||||
);
|
||||
.device(codebooks.device()));
|
||||
|
||||
if (nbooks == 1 && entries == (1 << 16))
|
||||
{
|
||||
vllm::aqlm::code1x16_dequant_cuda(
|
||||
codes.data_ptr(),
|
||||
weights.data_ptr(),
|
||||
codebooks.data_ptr(),
|
||||
out_features,
|
||||
in_features,
|
||||
cumulative_sizes,
|
||||
if (nbooks == 1 && entries == (1 << 16)) {
|
||||
vllm::aqlm::code1x16_dequant_cuda(codes.data_ptr(), weights.data_ptr(),
|
||||
codebooks.data_ptr(), out_features,
|
||||
in_features, cumulative_sizes,
|
||||
vllm::aqlm::codebook_stride(codebooks));
|
||||
|
||||
// if you wanted to flip to scaling the weights, (though it's 30%-ish slower and not consistent with gemv implementation.)
|
||||
// weights *= scales.index({"...", 0, 0});
|
||||
// if you wanted to flip to scaling the weights, (though it's 30%-ish slower
|
||||
// and not consistent with gemv implementation.) weights *=
|
||||
// scales.index({"...", 0, 0});
|
||||
|
||||
return weights;
|
||||
}
|
||||
|
||||
if (nbooks == 2 && entries == (1 << 8))
|
||||
{
|
||||
vllm::aqlm::code2x8_dequant_cuda(
|
||||
codes.data_ptr(),
|
||||
weights.data_ptr(),
|
||||
codebooks.data_ptr(),
|
||||
out_features,
|
||||
in_features,
|
||||
cumulative_sizes,
|
||||
if (nbooks == 2 && entries == (1 << 8)) {
|
||||
vllm::aqlm::code2x8_dequant_cuda(codes.data_ptr(), weights.data_ptr(),
|
||||
codebooks.data_ptr(), out_features,
|
||||
in_features, cumulative_sizes,
|
||||
vllm::aqlm::codebook_stride(codebooks));
|
||||
|
||||
// if you wanted to flip to scaling the weights, (though it's 30%-ish slower and not consistent with gemv implementation)
|
||||
// weights *= scales.index({"...", 0, 0});
|
||||
// if you wanted to flip to scaling the weights, (though it's 30%-ish slower
|
||||
// and not consistent with gemv implementation) weights *=
|
||||
// scales.index({"...", 0, 0});
|
||||
|
||||
return weights;
|
||||
}
|
||||
|
||||
TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, " entries is not currently supported.")
|
||||
TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries,
|
||||
" entries is not currently supported.")
|
||||
return {};
|
||||
}
|
||||
|
@ -1,11 +1,11 @@
|
||||
/*
|
||||
Adapted from https://github.com/mit-han-lab/llm-awq
|
||||
Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
||||
Modified from NVIDIA FasterTransformer:
|
||||
https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
||||
@article{lin2023awq,
|
||||
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
|
||||
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
|
||||
journal={arXiv},
|
||||
year={2023}
|
||||
title={AWQ: Activation-aware Weight Quantization for LLM Compression and
|
||||
Acceleration}, author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang,
|
||||
Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023}
|
||||
}
|
||||
*/
|
||||
|
||||
@ -14,8 +14,7 @@ Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransfor
|
||||
namespace vllm {
|
||||
namespace awq {
|
||||
|
||||
__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
|
||||
{
|
||||
__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
|
||||
assert(false);
|
||||
#else
|
||||
@ -30,33 +29,40 @@ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
|
||||
static constexpr uint32_t TOP_MASK = 0x00f000f0;
|
||||
static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
|
||||
|
||||
// Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing
|
||||
// format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions.
|
||||
// In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and
|
||||
// elt_67 to fp16 without having to shift them to the bottom bits before hand.
|
||||
// Note that the entire sequence only requires 1 shift instruction. This is
|
||||
// thanks to the register packing format and the fact that we force our
|
||||
// integers to be unsigned, and account for this in the fp16 subtractions. In
|
||||
// addition, I exploit the fact that sub and fma have the same throughput in
|
||||
// order to convert elt_23 and elt_67 to fp16 without having to shift them to
|
||||
// the bottom bits before hand.
|
||||
|
||||
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue
|
||||
// immediately before required.
|
||||
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW
|
||||
// dependency if we issue immediately before required.
|
||||
const uint32_t top_i4s = i4s >> 8;
|
||||
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(h[0])
|
||||
: "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||
: "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM),
|
||||
"n"(immLut));
|
||||
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(h[1])
|
||||
: "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||
: "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM),
|
||||
"n"(immLut));
|
||||
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(h[2])
|
||||
: "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||
: "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM),
|
||||
"n"(immLut));
|
||||
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(h[3])
|
||||
: "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||
: "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM),
|
||||
"n"(immLut));
|
||||
|
||||
// I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the
|
||||
// half2 ctor. In this case, I chose performance reliability over code readability.
|
||||
// I use inline PTX below because I am not sure if the compiler will emit
|
||||
// float2half instructions if I use the half2 ctor. In this case, I chose
|
||||
// performance reliability over code readability.
|
||||
|
||||
// This is the half2 {1032, 1032} represented as an integer.
|
||||
// static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
|
||||
@ -71,13 +77,21 @@ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
|
||||
|
||||
// Finally, we construct the output numbers.
|
||||
// Convert elt_01
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n"
|
||||
: "=r"(h[0])
|
||||
: "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
|
||||
// Convert elt_23
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
|
||||
: "=r"(h[1])
|
||||
: "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
|
||||
// Convert elt_45
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n"
|
||||
: "=r"(h[2])
|
||||
: "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
|
||||
// Convert elt_67
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
|
||||
: "=r"(h[3])
|
||||
: "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
|
||||
|
||||
return result;
|
||||
#endif
|
||||
|
@ -1,14 +1,12 @@
|
||||
/*
|
||||
Adapted from https://github.com/mit-han-lab/llm-awq
|
||||
@article{lin2023awq,
|
||||
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
|
||||
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
|
||||
journal={arXiv},
|
||||
year={2023}
|
||||
title={AWQ: Activation-aware Weight Quantization for LLM Compression and
|
||||
Acceleration}, author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang,
|
||||
Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023}
|
||||
}
|
||||
*/
|
||||
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
@ -20,26 +18,20 @@ namespace vllm {
|
||||
namespace awq {
|
||||
|
||||
// Pack two half values.
|
||||
static inline __device__ __host__ unsigned
|
||||
__pack_half2(const half x, const half y) {
|
||||
static inline __device__ __host__ unsigned __pack_half2(const half x,
|
||||
const half y) {
|
||||
unsigned v0 = *((unsigned short*)&x);
|
||||
unsigned v1 = *((unsigned short*)&y);
|
||||
return (v1 << 16) | v0;
|
||||
}
|
||||
|
||||
template <int N>
|
||||
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32(
|
||||
int G,
|
||||
int split_k_iters,
|
||||
half* __restrict__ A,
|
||||
int* __restrict__ B,
|
||||
__global__ void __launch_bounds__(64)
|
||||
gemm_forward_4bit_cuda_m16nXk32(int G, int split_k_iters,
|
||||
half* __restrict__ A, int* __restrict__ B,
|
||||
half* __restrict__ scaling_factors,
|
||||
int* __restrict__ zeros,
|
||||
int M,
|
||||
int IC,
|
||||
int OC,
|
||||
half* __restrict__ C)
|
||||
{
|
||||
int* __restrict__ zeros, int M, int IC,
|
||||
int OC, half* __restrict__ C) {
|
||||
// Only support matrix n = 64 or 128
|
||||
assert(N == 64 || N == 128);
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
|
||||
@ -70,43 +62,46 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32(
|
||||
static constexpr int row_stride = 2 * 32 * 8 / N;
|
||||
bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < N;
|
||||
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
|
||||
bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id
|
||||
bool ld_A_flag =
|
||||
(blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp +
|
||||
threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id
|
||||
// bool wb_C_flag = (threadIdx.x / 4) < M;
|
||||
|
||||
half* A_ptr = A
|
||||
+ (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC
|
||||
+ (((int)threadIdx.x) % (32 / 8)) * 8;
|
||||
half* A_ptr =
|
||||
A +
|
||||
(((int)blockIdx_y) / j_factors1 * 16 +
|
||||
(((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) *
|
||||
IC +
|
||||
(((int)threadIdx.x) % (32 / 8)) * 8;
|
||||
|
||||
int* B_ptr = B
|
||||
+ ((int)threadIdx.y) * (OC / 8) * (256 / N)
|
||||
+ (((int)threadIdx.x) / (N / 8)) * (OC / 8)
|
||||
+ (((int)blockIdx_y) % j_factors1) * (N / 8)
|
||||
+ (((int)threadIdx.x) % (N / 8)) * 1;
|
||||
int* B_ptr = B + ((int)threadIdx.y) * (OC / 8) * (256 / N) +
|
||||
(((int)threadIdx.x) / (N / 8)) * (OC / 8) +
|
||||
(((int)blockIdx_y) % j_factors1) * (N / 8) +
|
||||
(((int)threadIdx.x) % (N / 8)) * 1;
|
||||
// Why * 1 in the above line?
|
||||
|
||||
half* A_shared_ptr = A_shared
|
||||
+ ((int)threadIdx.y) * row_stride_warp * (32 + 8)
|
||||
+ (((int)threadIdx.x) / (32 / 8)) * (32 + 8)
|
||||
+ (((int)threadIdx.x) % (32 / 8) ) * 8;
|
||||
half* A_shared_ptr = A_shared +
|
||||
((int)threadIdx.y) * row_stride_warp * (32 + 8) +
|
||||
(((int)threadIdx.x) / (32 / 8)) * (32 + 8) +
|
||||
(((int)threadIdx.x) % (32 / 8)) * 8;
|
||||
|
||||
half* B_shared_ptr = B_shared
|
||||
+ ((int)threadIdx.y) * (row_stride / 2) * (N + 8)
|
||||
+ (((int)threadIdx.x) / (N / 8)) * (N + 8)
|
||||
+ (((int)threadIdx.x) % (N / 8)) * 8;
|
||||
half* B_shared_ptr = B_shared +
|
||||
((int)threadIdx.y) * (row_stride / 2) * (N + 8) +
|
||||
(((int)threadIdx.x) / (N / 8)) * (N + 8) +
|
||||
(((int)threadIdx.x) % (N / 8)) * 8;
|
||||
|
||||
int* zeros_ptr = zeros
|
||||
+ (((int)blockIdx_y) % j_factors1) * (N / 8)
|
||||
+ ((int)threadIdx.x) % (N / 8);
|
||||
int* zeros_ptr = zeros + (((int)blockIdx_y) % j_factors1) * (N / 8) +
|
||||
((int)threadIdx.x) % (N / 8);
|
||||
|
||||
half* scaling_factors_ptr = scaling_factors
|
||||
+ (((int)blockIdx_y) % j_factors1) * N
|
||||
+ (((int)threadIdx.x) % (N / 8)) * 8;
|
||||
half* scaling_factors_ptr = scaling_factors +
|
||||
(((int)blockIdx_y) % j_factors1) * N +
|
||||
(((int)threadIdx.x) % (N / 8)) * 8;
|
||||
|
||||
half* C_ptr = C
|
||||
+ static_cast<long long>(blockIdx_z) * M * OC // blockIdz.x -> split_k dim
|
||||
+ (((int)blockIdx_y) % j_factors1) * N
|
||||
+ ((int)threadIdx.y) * (N / 2)
|
||||
+ (((int)threadIdx.x) % 4) * 2;
|
||||
half* C_ptr =
|
||||
C +
|
||||
static_cast<long long>(blockIdx_z) * M * OC // blockIdz.x -> split_k dim
|
||||
+ (((int)blockIdx_y) % j_factors1) * N + ((int)threadIdx.y) * (N / 2) +
|
||||
(((int)threadIdx.x) % 4) * 2;
|
||||
|
||||
// preload s.f. and zeros
|
||||
int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
|
||||
@ -115,57 +110,83 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32(
|
||||
int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
|
||||
__syncthreads();
|
||||
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
|
||||
if (ld_A_flag)
|
||||
{
|
||||
if (ld_A_flag) {
|
||||
*(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
|
||||
}
|
||||
else
|
||||
{
|
||||
} else {
|
||||
*(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
|
||||
}
|
||||
|
||||
// for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
|
||||
uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
|
||||
uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
|
||||
uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
|
||||
uint4 B_loaded_scale =
|
||||
*(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
|
||||
/*
|
||||
if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
|
||||
printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
|
||||
if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 &&
|
||||
threadIdx.y == 0){ printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x,
|
||||
B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x,
|
||||
B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
|
||||
}
|
||||
*/
|
||||
// uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
|
||||
int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
|
||||
|
||||
for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < N / 16; ++ax0_ax1_fused_0) {
|
||||
|
||||
// B: 32 x 136 (128+8) float16
|
||||
// each warp: 32 x 4
|
||||
// each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
|
||||
// *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
|
||||
// row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
|
||||
uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
|
||||
// each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus
|
||||
// zero -> WB UINT4
|
||||
// *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) *
|
||||
// 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15)
|
||||
// * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 *
|
||||
// 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) *
|
||||
// 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) *
|
||||
// 8))); row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
|
||||
uint32_t B_loaded =
|
||||
*(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
|
||||
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
|
||||
//uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8);
|
||||
// uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N /
|
||||
// 8)) * 8);
|
||||
|
||||
// uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8);
|
||||
// uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x
|
||||
// % (cta_N / 8)) * 8);
|
||||
// - zero and * scale
|
||||
// TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale.
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
|
||||
// TODO (Haotian): can save 4 assembly instructions if sormulate as deq =
|
||||
// q * scale - zero * scale.
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n"
|
||||
: "=r"(B_loaded_fp16.x)
|
||||
: "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
|
||||
: "=r"(B_loaded_fp16.x)
|
||||
: "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n"
|
||||
: "=r"(B_loaded_fp16.y)
|
||||
: "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
|
||||
: "=r"(B_loaded_fp16.y)
|
||||
: "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n"
|
||||
: "=r"(B_loaded_fp16.z)
|
||||
: "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
|
||||
: "=r"(B_loaded_fp16.z)
|
||||
: "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n"
|
||||
: "=r"(B_loaded_fp16.w)
|
||||
: "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
|
||||
: "=r"(B_loaded_fp16.w)
|
||||
: "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
|
||||
/*
|
||||
if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){
|
||||
printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
|
||||
if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 ==
|
||||
0 && threadIdx.x == 17 && threadIdx.y == 0){ printf("[x] %X %X %X %X\n",
|
||||
B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
|
||||
}
|
||||
*/
|
||||
|
||||
// write back
|
||||
*(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) = B_loaded_fp16;
|
||||
*(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) =
|
||||
B_loaded_fp16;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
@ -173,34 +194,43 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32(
|
||||
{
|
||||
unsigned int addr;
|
||||
__asm__ __volatile__(
|
||||
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
|
||||
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, "
|
||||
"addr; }\n"
|
||||
: "=r"(addr)
|
||||
: "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))
|
||||
);
|
||||
|
||||
: "l"((void*)((&(A_shared[(k_0_1 * 16)])) +
|
||||
(((((int)threadIdx.x) & 15) * 40) +
|
||||
((((int)threadIdx.x) >> 4) * 8)))));
|
||||
|
||||
__asm__ __volatile__(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
|
||||
"{%0, %1, %2, %3}, [%4];\n"
|
||||
: "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3])
|
||||
: "r"(addr)
|
||||
);
|
||||
: "=r"(((unsigned*)(A_shared_warp + 0))[0]),
|
||||
"=r"(((unsigned*)(A_shared_warp + 0))[1]),
|
||||
"=r"(((unsigned*)(A_shared_warp + 0))[2]),
|
||||
"=r"(((unsigned*)(A_shared_warp + 0))[3])
|
||||
: "r"(addr));
|
||||
}
|
||||
|
||||
for (int ax1_0 = 0; ax1_0 < N / 32; ++ax1_0) {
|
||||
{
|
||||
unsigned int addr;
|
||||
__asm__ __volatile__(
|
||||
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
|
||||
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, "
|
||||
"addr; }\n"
|
||||
: "=r"(addr)
|
||||
: "l"((void *)((&(B_shared[(((k_0_1 * (N * 16 + 128)) + (((int)threadIdx.y) * (N / 2))) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * (N + 8)) + ((((int)threadIdx.x) >> 4) * 8))))
|
||||
);
|
||||
: "l"((void*)((&(B_shared[(((k_0_1 * (N * 16 + 128)) +
|
||||
(((int)threadIdx.y) * (N / 2))) +
|
||||
(ax1_0 * 16))])) +
|
||||
(((((int)threadIdx.x) & 15) * (N + 8)) +
|
||||
((((int)threadIdx.x) >> 4) * 8)))));
|
||||
__asm__ __volatile__(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
|
||||
"{%0, %1, %2, %3}, [%4];\n"
|
||||
: "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3])
|
||||
: "r"(addr)
|
||||
);
|
||||
: "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[0]),
|
||||
"=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[1]),
|
||||
"=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[2]),
|
||||
"=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[3])
|
||||
: "r"(addr));
|
||||
}
|
||||
}
|
||||
for (int j_0_4 = 0; j_0_4 < N / 32; ++j_0_4) {
|
||||
@ -209,48 +239,110 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32(
|
||||
__asm__ __volatile__(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
||||
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
|
||||
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
|
||||
: "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
|
||||
"=f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
|
||||
"=f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
|
||||
"=f"(((float*)(C_warp + (j_0_4 * 8)))[3])
|
||||
: "r"(((unsigned*)(A_shared_warp + 0))[0]),
|
||||
"r"(((unsigned*)(A_shared_warp + 0))[1]),
|
||||
"r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[0]),
|
||||
"f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
|
||||
"f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
|
||||
"f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
|
||||
"f"(((float*)(C_warp + (j_0_4 * 8)))[3]));
|
||||
}
|
||||
|
||||
{
|
||||
__asm__ __volatile__(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
||||
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
||||
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
||||
: "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
|
||||
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
|
||||
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
|
||||
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
||||
: "r"(((unsigned*)(A_shared_warp + 0))[0]),
|
||||
"r"(((unsigned*)(A_shared_warp + 0))[1]),
|
||||
"r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]),
|
||||
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
|
||||
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
|
||||
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
|
||||
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
||||
}
|
||||
|
||||
{
|
||||
__asm__ __volatile__(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
||||
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
|
||||
: "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
|
||||
: "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
|
||||
"=f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
|
||||
"=f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
|
||||
"=f"(((float*)(C_warp + (j_0_4 * 8)))[3])
|
||||
: "r"(((unsigned*)(A_shared_warp + 0))[2]),
|
||||
"r"(((unsigned*)(A_shared_warp + 0))[3]),
|
||||
"r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[1]),
|
||||
"f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
|
||||
"f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
|
||||
"f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
|
||||
"f"(((float*)(C_warp + (j_0_4 * 8)))[3]));
|
||||
}
|
||||
|
||||
{
|
||||
__asm__ __volatile__(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
||||
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
||||
: "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
||||
: "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
|
||||
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
|
||||
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
|
||||
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
||||
: "r"(((unsigned*)(A_shared_warp + 0))[2]),
|
||||
"r"(((unsigned*)(A_shared_warp + 0))[3]),
|
||||
"r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]),
|
||||
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
|
||||
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
|
||||
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
|
||||
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
||||
}
|
||||
#else
|
||||
{
|
||||
__asm__ __volatile__(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
|
||||
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
|
||||
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
|
||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, "
|
||||
"%13};\n"
|
||||
: "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
|
||||
"=f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
|
||||
"=f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
|
||||
"=f"(((float*)(C_warp + (j_0_4 * 8)))[3])
|
||||
: "r"(((unsigned*)(A_shared_warp + 0))[0]),
|
||||
"r"(((unsigned*)(A_shared_warp + 0))[1]),
|
||||
"r"(((unsigned*)(A_shared_warp + 0))[2]),
|
||||
"r"(((unsigned*)(A_shared_warp + 0))[3]),
|
||||
"r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[0]),
|
||||
"r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[1]),
|
||||
"f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
|
||||
"f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
|
||||
"f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
|
||||
"f"(((float*)(C_warp + (j_0_4 * 8)))[3]));
|
||||
}
|
||||
|
||||
{
|
||||
__asm__ __volatile__(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
|
||||
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
||||
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, "
|
||||
"%13};\n"
|
||||
: "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
|
||||
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
|
||||
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
|
||||
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
||||
: "r"(((unsigned*)(A_shared_warp + 0))[0]),
|
||||
"r"(((unsigned*)(A_shared_warp + 0))[1]),
|
||||
"r"(((unsigned*)(A_shared_warp + 0))[2]),
|
||||
"r"(((unsigned*)(A_shared_warp + 0))[3]),
|
||||
"r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]),
|
||||
"r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]),
|
||||
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
|
||||
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
|
||||
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
|
||||
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
||||
}
|
||||
|
||||
#endif
|
||||
@ -261,24 +353,20 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32(
|
||||
// TODO: Shang: Hoist loop invariance.
|
||||
for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) {
|
||||
for (int local_id = 0; local_id < 8; ++local_id) {
|
||||
int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
|
||||
if (row_offset < M)
|
||||
{
|
||||
*(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
|
||||
int row_offset = (((int)blockIdx_y) / j_factors1) * 16 +
|
||||
((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
|
||||
if (row_offset < M) {
|
||||
*(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 +
|
||||
local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
__global__ void __launch_bounds__(64) dequantize_weights(
|
||||
int* __restrict__ B,
|
||||
half* __restrict__ scaling_factors,
|
||||
int* __restrict__ zeros,
|
||||
half* __restrict__ C,
|
||||
int G
|
||||
)
|
||||
{
|
||||
__global__ void __launch_bounds__(64)
|
||||
dequantize_weights(int* __restrict__ B, half* __restrict__ scaling_factors,
|
||||
int* __restrict__ zeros, half* __restrict__ C, int G) {
|
||||
int j_factors1 = 4;
|
||||
int row_stride2 = 4;
|
||||
int split_k_iters = 1;
|
||||
@ -310,14 +398,30 @@ __global__ void __launch_bounds__(64) dequantize_weights(
|
||||
|
||||
uint32_t B_loaded = *(uint32_t*)B_ptr2;
|
||||
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n"
|
||||
: "=r"(B_loaded_fp16.x)
|
||||
: "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
|
||||
: "=r"(B_loaded_fp16.x)
|
||||
: "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n"
|
||||
: "=r"(B_loaded_fp16.y)
|
||||
: "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
|
||||
: "=r"(B_loaded_fp16.y)
|
||||
: "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n"
|
||||
: "=r"(B_loaded_fp16.z)
|
||||
: "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
|
||||
: "=r"(B_loaded_fp16.z)
|
||||
: "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n"
|
||||
: "=r"(B_loaded_fp16.w)
|
||||
: "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
|
||||
: "=r"(B_loaded_fp16.w)
|
||||
: "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
|
||||
|
||||
*(uint4*)B_shared_ptr2 = B_loaded_fp16;
|
||||
|
||||
@ -329,14 +433,10 @@ __global__ void __launch_bounds__(64) dequantize_weights(
|
||||
} // namespace awq
|
||||
} // namespace vllm
|
||||
|
||||
torch::Tensor awq_dequantize(
|
||||
torch::Tensor _kernel,
|
||||
torch::Tensor awq_dequantize(torch::Tensor _kernel,
|
||||
torch::Tensor _scaling_factors,
|
||||
torch::Tensor _zeros,
|
||||
int split_k_iters,
|
||||
int thx,
|
||||
int thy)
|
||||
{
|
||||
torch::Tensor _zeros, int split_k_iters, int thx,
|
||||
int thy) {
|
||||
int in_c = _kernel.size(0);
|
||||
int qout_c = _kernel.size(1);
|
||||
int out_c = qout_c * 8;
|
||||
@ -362,12 +462,15 @@ torch::Tensor awq_dequantize(
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors));
|
||||
|
||||
auto options = torch::TensorOptions().dtype(_scaling_factors.dtype()).device(_scaling_factors.device());
|
||||
auto options = torch::TensorOptions()
|
||||
.dtype(_scaling_factors.dtype())
|
||||
.device(_scaling_factors.device());
|
||||
at::Tensor _de_kernel = torch::empty({in_c, out_c}, options);
|
||||
|
||||
auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
|
||||
auto de_kernel = reinterpret_cast<half*>(_de_kernel.data_ptr<at::Half>());
|
||||
auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
|
||||
auto scaling_factors =
|
||||
reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
|
||||
auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
|
||||
|
||||
dim3 num_blocks(x_blocks, y_blocks);
|
||||
@ -386,26 +489,26 @@ torch::Tensor awq_dequantize(
|
||||
// zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b]
|
||||
// assume that batch_size < 16 for now
|
||||
|
||||
torch::Tensor awq_gemm(
|
||||
torch::Tensor _in_feats,
|
||||
torch::Tensor _kernel,
|
||||
torch::Tensor _scaling_factors,
|
||||
torch::Tensor _zeros,
|
||||
int split_k_iters)
|
||||
{
|
||||
torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
|
||||
torch::Tensor _scaling_factors, torch::Tensor _zeros,
|
||||
int split_k_iters) {
|
||||
int num_in_feats = _in_feats.size(0);
|
||||
int num_in_channels = _in_feats.size(1);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats));
|
||||
|
||||
auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device());
|
||||
at::Tensor _out_feats = torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options);
|
||||
auto options = torch::TensorOptions()
|
||||
.dtype(_in_feats.dtype())
|
||||
.device(_in_feats.device());
|
||||
at::Tensor _out_feats =
|
||||
torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options);
|
||||
int num_out_feats = _out_feats.size(-2);
|
||||
int num_out_channels = _out_feats.size(-1);
|
||||
|
||||
auto in_feats = reinterpret_cast<half*>(_in_feats.data_ptr<at::Half>());
|
||||
auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
|
||||
auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());
|
||||
auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
|
||||
auto scaling_factors =
|
||||
reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
|
||||
auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
|
||||
int group_size = num_in_channels / _scaling_factors.size(0);
|
||||
|
||||
@ -419,28 +522,28 @@ torch::Tensor awq_gemm(
|
||||
throw std::invalid_argument("OC is not multiple of Group size");
|
||||
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
if (num_out_channels % 128 == 0)
|
||||
{
|
||||
if (num_out_channels % 128 == 0) {
|
||||
int j_factors1 = num_out_channels / 128 / 1;
|
||||
dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
|
||||
// threadIdx.x: 32
|
||||
// threadIdx.y: i_factors[2] * j_factors[2]
|
||||
dim3 threads_per_block(32, 2);
|
||||
vllm::awq::gemm_forward_4bit_cuda_m16nXk32<128><<<num_blocks, threads_per_block, 0, stream>>>(
|
||||
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels,
|
||||
num_out_channels, out_feats);
|
||||
}
|
||||
else if (num_out_channels % 64 == 0)
|
||||
{
|
||||
vllm::awq::gemm_forward_4bit_cuda_m16nXk32<128>
|
||||
<<<num_blocks, threads_per_block, 0, stream>>>(
|
||||
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros,
|
||||
num_in_feats, num_in_channels, num_out_channels, out_feats);
|
||||
} else if (num_out_channels % 64 == 0) {
|
||||
int j_factors1 = num_out_channels / 64 / 1;
|
||||
dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
|
||||
dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 *
|
||||
split_k_iters);
|
||||
|
||||
// threadIdx.x: 32
|
||||
// threadIdx.y: i_factors[2] * j_factors[2]
|
||||
dim3 threads_per_block(32, 2);
|
||||
vllm::awq::gemm_forward_4bit_cuda_m16nXk32<64><<<num_blocks, threads_per_block, 0, stream>>>(
|
||||
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels,
|
||||
num_out_channels, out_feats);
|
||||
vllm::awq::gemm_forward_4bit_cuda_m16nXk32<64>
|
||||
<<<num_blocks, threads_per_block, 0, stream>>>(
|
||||
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros,
|
||||
num_in_feats, num_in_channels, num_out_channels, out_feats);
|
||||
}
|
||||
return _out_feats.sum(0);
|
||||
}
|
||||
|
@ -43,7 +43,8 @@ void cutlass_scaled_mm_dq(torch::Tensor &c, torch::Tensor const &a,
|
||||
// Check for strides and alignment
|
||||
TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
|
||||
TORCH_CHECK(b.stride(0) == 1); // Column-major
|
||||
TORCH_CHECK(c.stride(0) % 16 == 0 && b.stride(1) % 16 == 0); // 16 Byte Alignment
|
||||
TORCH_CHECK(c.stride(0) % 16 == 0 &&
|
||||
b.stride(1) % 16 == 0); // 16 Byte Alignment
|
||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
|
||||
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
|
||||
|
@ -11,33 +11,26 @@
|
||||
|
||||
#include "hip_float8_impl.h"
|
||||
|
||||
struct alignas(1) hip_fp8
|
||||
{
|
||||
struct from_bits_t
|
||||
{
|
||||
};
|
||||
HIP_FP8_HOST_DEVICE static constexpr from_bits_t from_bits() { return from_bits_t(); }
|
||||
struct alignas(1) hip_fp8 {
|
||||
struct from_bits_t {};
|
||||
HIP_FP8_HOST_DEVICE static constexpr from_bits_t from_bits() {
|
||||
return from_bits_t();
|
||||
}
|
||||
uint8_t data;
|
||||
|
||||
hip_fp8() = default;
|
||||
HIP_FP8_HOST_DEVICE constexpr hip_fp8(const hip_fp8&) = default;
|
||||
HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v) = delete;
|
||||
explicit HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v, from_bits_t)
|
||||
: data(v)
|
||||
{
|
||||
}
|
||||
: data(v) {}
|
||||
|
||||
#ifdef __HIP__MI300__
|
||||
// NOTE: ON-DEVICE... always optimal bias
|
||||
explicit HIP_FP8_DEVICE hip_fp8(float v)
|
||||
: data(hip_fp8_impl::to_fp8_from_fp32(v))
|
||||
{
|
||||
}
|
||||
: data(hip_fp8_impl::to_fp8_from_fp32(v)) {}
|
||||
|
||||
explicit HIP_FP8_DEVICE hip_fp8(_Float16 v)
|
||||
: hip_fp8(static_cast<float>(v))
|
||||
{
|
||||
}
|
||||
: hip_fp8(static_cast<float>(v)) {}
|
||||
|
||||
// Host only implementation using s/w simulation
|
||||
explicit HIP_FP8_HOST
|
||||
@ -45,25 +38,24 @@ struct alignas(1) hip_fp8
|
||||
// both Host and DEVICE for non-MI300 using s/w simulation
|
||||
explicit HIP_FP8_HOST_DEVICE
|
||||
#endif // __HIP__MI300__
|
||||
hip_fp8(float v)
|
||||
{
|
||||
data = hip_fp8_impl::to_float8<4, 3, float, true /*negative_zero_nan*/, true /*clip*/>(v);
|
||||
hip_fp8(float v) {
|
||||
data = hip_fp8_impl::to_float8<4, 3, float, true /*negative_zero_nan*/,
|
||||
true /*clip*/>(v);
|
||||
}
|
||||
|
||||
explicit HIP_FP8_HOST_DEVICE hip_fp8(double v)
|
||||
: hip_fp8(static_cast<float>(v))
|
||||
{
|
||||
}
|
||||
: hip_fp8(static_cast<float>(v)) {}
|
||||
|
||||
#ifdef __HIP__MI300__
|
||||
// upcast using device specific intrinsic
|
||||
explicit inline HIP_FP8_DEVICE operator float() const
|
||||
{
|
||||
explicit inline HIP_FP8_DEVICE operator float() const {
|
||||
float fval;
|
||||
uint32_t i32val = static_cast<uint32_t>(data);
|
||||
|
||||
// upcast
|
||||
asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
|
||||
asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0"
|
||||
: "=v"(fval)
|
||||
: "v"(i32val));
|
||||
|
||||
return fval;
|
||||
}
|
||||
@ -73,95 +65,73 @@ struct alignas(1) hip_fp8
|
||||
explicit inline HIP_FP8_HOST_DEVICE operator float() const
|
||||
#endif // __HIP__MI300__
|
||||
{
|
||||
return hip_fp8_impl::from_float8<4, 3, float, true /*negative_zero_nan*/>(data);
|
||||
return hip_fp8_impl::from_float8<4, 3, float, true /*negative_zero_nan*/>(
|
||||
data);
|
||||
}
|
||||
};
|
||||
|
||||
namespace std
|
||||
{
|
||||
inline hip_fp8 sin(hip_fp8 a)
|
||||
{
|
||||
return hip_fp8(sinf(float(a)));
|
||||
}
|
||||
inline hip_fp8 cos(hip_fp8 a)
|
||||
{
|
||||
return hip_fp8(cosf(float(a)));
|
||||
}
|
||||
HIP_FP8_HOST_DEVICE constexpr hip_fp8 real(const hip_fp8& a)
|
||||
{
|
||||
return a;
|
||||
}
|
||||
namespace std {
|
||||
inline hip_fp8 sin(hip_fp8 a) { return hip_fp8(sinf(float(a))); }
|
||||
inline hip_fp8 cos(hip_fp8 a) { return hip_fp8(cosf(float(a))); }
|
||||
HIP_FP8_HOST_DEVICE constexpr hip_fp8 real(const hip_fp8& a) { return a; }
|
||||
} // namespace std
|
||||
|
||||
// Special operator overloading
|
||||
inline std::ostream& operator<<(std::ostream& os, const hip_fp8& f8)
|
||||
{
|
||||
inline std::ostream& operator<<(std::ostream& os, const hip_fp8& f8) {
|
||||
return os << float(f8);
|
||||
}
|
||||
|
||||
// all + operator overloading with mixed types
|
||||
// mixed types, always converts to f32, does computation in f32, and returns float
|
||||
inline HIP_FP8_HOST_DEVICE float operator+(const float fa, hip_fp8 b)
|
||||
{
|
||||
// mixed types, always converts to f32, does computation in f32, and returns
|
||||
// float
|
||||
inline HIP_FP8_HOST_DEVICE float operator+(const float fa, hip_fp8 b) {
|
||||
return (fa + float(b));
|
||||
}
|
||||
|
||||
inline HIP_FP8_HOST_DEVICE float operator+(hip_fp8 a, const float fb)
|
||||
{
|
||||
inline HIP_FP8_HOST_DEVICE float operator+(hip_fp8 a, const float fb) {
|
||||
return (float(a) + fb);
|
||||
}
|
||||
|
||||
inline HIP_FP8_HOST_DEVICE hip_fp8 operator+(hip_fp8 a, hip_fp8 b)
|
||||
{
|
||||
inline HIP_FP8_HOST_DEVICE hip_fp8 operator+(hip_fp8 a, hip_fp8 b) {
|
||||
return hip_fp8(float(a) + float(b));
|
||||
}
|
||||
|
||||
inline HIP_FP8_HOST_DEVICE hip_fp8& operator+=(hip_fp8& a, hip_fp8 b)
|
||||
{
|
||||
inline HIP_FP8_HOST_DEVICE hip_fp8& operator+=(hip_fp8& a, hip_fp8 b) {
|
||||
return a = hip_fp8(float(a) + float(b));
|
||||
}
|
||||
|
||||
// overloading multiplication, always returns float,
|
||||
inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, hip_fp8 b)
|
||||
{
|
||||
inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, hip_fp8 b) {
|
||||
return float(a) * float(b);
|
||||
}
|
||||
|
||||
inline HIP_FP8_HOST_DEVICE float operator*(float a, hip_fp8 b)
|
||||
{
|
||||
inline HIP_FP8_HOST_DEVICE float operator*(float a, hip_fp8 b) {
|
||||
return (a * float(b));
|
||||
}
|
||||
|
||||
inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, float b)
|
||||
{
|
||||
inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, float b) {
|
||||
return (float(a) * b);
|
||||
}
|
||||
|
||||
inline HIP_FP8_HOST_DEVICE float operator*(int32_t a, hip_fp8 b)
|
||||
{
|
||||
inline HIP_FP8_HOST_DEVICE float operator*(int32_t a, hip_fp8 b) {
|
||||
return ((float)a * float(b));
|
||||
}
|
||||
|
||||
inline HIP_FP8_HOST_DEVICE float operator*(double a, hip_fp8 b)
|
||||
{
|
||||
inline HIP_FP8_HOST_DEVICE float operator*(double a, hip_fp8 b) {
|
||||
return ((float)a * float(b));
|
||||
}
|
||||
|
||||
// overloading for compare
|
||||
inline HIP_FP8_HOST_DEVICE bool operator==(hip_fp8 a, hip_fp8 b)
|
||||
{
|
||||
inline HIP_FP8_HOST_DEVICE bool operator==(hip_fp8 a, hip_fp8 b) {
|
||||
return (a.data == b.data);
|
||||
}
|
||||
inline HIP_FP8_HOST_DEVICE bool operator!=(hip_fp8 a, hip_fp8 b)
|
||||
{
|
||||
inline HIP_FP8_HOST_DEVICE bool operator!=(hip_fp8 a, hip_fp8 b) {
|
||||
return (a.data != b.data);
|
||||
}
|
||||
|
||||
inline HIP_FP8_HOST_DEVICE bool operator>=(hip_fp8 a, hip_fp8 b)
|
||||
{
|
||||
inline HIP_FP8_HOST_DEVICE bool operator>=(hip_fp8 a, hip_fp8 b) {
|
||||
return static_cast<float>(a) >= static_cast<float>(b);
|
||||
}
|
||||
inline HIP_FP8_HOST_DEVICE bool operator>(hip_fp8 a, hip_fp8 b)
|
||||
{
|
||||
inline HIP_FP8_HOST_DEVICE bool operator>(hip_fp8 a, hip_fp8 b) {
|
||||
return static_cast<float>(a) > static_cast<float>(b);
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#if defined(__HIPCC__) && (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
#if defined(__HIPCC__) && \
|
||||
(defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
#define __HIP__MI300__
|
||||
#endif
|
||||
|
||||
@ -14,12 +15,10 @@
|
||||
#define HIP_FP8_DEVICE
|
||||
#endif
|
||||
|
||||
namespace hip_fp8_impl
|
||||
{
|
||||
namespace hip_fp8_impl {
|
||||
|
||||
#ifdef __HIP__MI300__
|
||||
HIP_FP8_DEVICE uint8_t to_fp8_from_fp32(float v)
|
||||
{
|
||||
HIP_FP8_DEVICE uint8_t to_fp8_from_fp32(float v) {
|
||||
uint8_t i8data;
|
||||
union {
|
||||
float fval;
|
||||
@ -30,7 +29,8 @@ HIP_FP8_DEVICE uint8_t to_fp8_from_fp32(float v)
|
||||
uint32_t ival = 0;
|
||||
val.fval = v;
|
||||
|
||||
if ((val.i32val & 0x7F800000) != 0x7F800000) { /// propagate NAN/INF, no clipping
|
||||
if ((val.i32val & 0x7F800000) !=
|
||||
0x7F800000) { /// propagate NAN/INF, no clipping
|
||||
val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0);
|
||||
}
|
||||
|
||||
@ -43,20 +43,14 @@ HIP_FP8_DEVICE uint8_t to_fp8_from_fp32(float v)
|
||||
}
|
||||
#endif // __HIP__MI300__
|
||||
|
||||
HIP_FP8_HOST inline int clz(uint32_t x)
|
||||
{
|
||||
return __builtin_clz(x);
|
||||
}
|
||||
HIP_FP8_HOST inline int clz(uint32_t x) { return __builtin_clz(x); }
|
||||
#if defined(__HIPCC__) || defined(__CUDA_ARCH__)
|
||||
HIP_FP8_DEVICE inline int clz(uint32_t x)
|
||||
{
|
||||
return __clz(x);
|
||||
}
|
||||
HIP_FP8_DEVICE inline int clz(uint32_t x) { return __clz(x); }
|
||||
#endif
|
||||
|
||||
template <int we, int wm, typename T, bool negative_zero_nan, bool clip>
|
||||
HIP_FP8_HOST_DEVICE uint8_t to_float8(T _x, bool stoch = false, uint32_t rng = 0)
|
||||
{
|
||||
HIP_FP8_HOST_DEVICE uint8_t to_float8(T _x, bool stoch = false,
|
||||
uint32_t rng = 0) {
|
||||
#ifdef __HIPCC__
|
||||
constexpr bool is_half = std::is_same<T, _Float16>::value;
|
||||
#else
|
||||
@ -130,7 +124,8 @@ HIP_FP8_HOST_DEVICE uint8_t to_float8(T _x, bool stoch = false, uint32_t rng = 0
|
||||
// For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent
|
||||
// bits
|
||||
const int f8_bias = (1 << (we - 1)) - 1 + (negative_zero_nan ? 1 : 0);
|
||||
const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal
|
||||
const int f8_denormal_act_exponent =
|
||||
1 - f8_bias; // actual exponent of f8 denormal
|
||||
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
|
||||
// f8_exponent is the converted f8 exponent with bias encoding
|
||||
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
|
||||
@ -146,7 +141,9 @@ are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
|
||||
where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8
|
||||
(NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */
|
||||
act_exponent = exponent - bias + 1;
|
||||
exponent_diff = f8_denormal_act_exponent - act_exponent; // actual exponent is exponent-bias+1 as it is denormal
|
||||
exponent_diff =
|
||||
f8_denormal_act_exponent -
|
||||
act_exponent; // actual exponent is exponent-bias+1 as it is denormal
|
||||
} else { // fp32/fp16 is normal with implicit 1
|
||||
act_exponent = exponent - bias;
|
||||
if (act_exponent <= f8_denormal_act_exponent) {
|
||||
@ -157,9 +154,9 @@ where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8
|
||||
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
|
||||
exponent_diff = f8_denormal_act_exponent - act_exponent;
|
||||
} else { // both fp32/fp16 and f8 are in normal range
|
||||
exponent_diff = 0; // exponent_diff=0 does not mean there is no difference
|
||||
// for this case,
|
||||
// act_exponent could be larger. Just that it does not need shift mantissa
|
||||
exponent_diff = 0; // exponent_diff=0 does not mean there is no
|
||||
// difference for this case, act_exponent could be
|
||||
// larger. Just that it does not need shift mantissa
|
||||
}
|
||||
mantissa += (1 << mfmt); // Add the implicit 1 into mantissa
|
||||
}
|
||||
@ -181,13 +178,16 @@ where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8
|
||||
bool implicit_one = mantissa & (1 << mfmt);
|
||||
// if there is no implicit 1, it means the f8 is denormal and need to adjust
|
||||
// to denorm exponent
|
||||
f8_exponent = (act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1);
|
||||
f8_exponent = (act_exponent + exponent_diff) /*actual f8 exponent*/ +
|
||||
f8_bias - (implicit_one ? 0 : 1);
|
||||
|
||||
// Now we have the exponent and mantissa adjusted
|
||||
uint32_t drop_mask = (1 << (mfmt - wm)) - 1;
|
||||
bool odd = mantissa & (1 << (mfmt - wm)); // if the least significant bit that
|
||||
// is not truncated is 1
|
||||
mantissa += (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & drop_mask;
|
||||
bool odd = mantissa & (1 << (mfmt - wm)); // if the least significant bit
|
||||
// that is not truncated is 1
|
||||
mantissa +=
|
||||
(stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) &
|
||||
drop_mask;
|
||||
|
||||
// Now we deal with overflow
|
||||
if (f8_exponent == 0) {
|
||||
@ -222,8 +222,7 @@ where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8
|
||||
}
|
||||
|
||||
template <int we, int wm, typename T = float, bool negative_zero_nan = true>
|
||||
inline HIP_FP8_HOST_DEVICE T from_float8(uint8_t x)
|
||||
{
|
||||
inline HIP_FP8_HOST_DEVICE T from_float8(uint8_t x) {
|
||||
#ifdef __HIPCC__
|
||||
constexpr bool is_half = std::is_same<T, _Float16>::value;
|
||||
#else
|
||||
@ -285,7 +284,8 @@ inline HIP_FP8_HOST_DEVICE T from_float8(uint8_t x)
|
||||
return reinterpret_cast<const T&>(retval);
|
||||
}
|
||||
|
||||
const int exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0);
|
||||
const int exp_low_cutoff =
|
||||
(1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0);
|
||||
|
||||
// subnormal input
|
||||
if (exponent == 0) {
|
||||
|
@ -9,29 +9,27 @@
|
||||
#include "../../../attention/dtype_float32.cuh"
|
||||
#include "../../../attention/dtype_bfloat16.cuh"
|
||||
|
||||
namespace vllm
|
||||
{
|
||||
namespace vllm {
|
||||
#ifdef USE_ROCM
|
||||
|
||||
namespace fp8 {
|
||||
#ifdef ENABLE_FP8
|
||||
|
||||
template <typename Tout, typename Tin>
|
||||
__inline__ __device__ Tout vec_conversion(const Tin& x)
|
||||
{
|
||||
__inline__ __device__ Tout vec_conversion(const Tin& x) {
|
||||
return x;
|
||||
}
|
||||
|
||||
template <typename Tout, typename Tin>
|
||||
__inline__ __device__ Tout scaled_vec_conversion(const Tin& x, const float scale)
|
||||
{
|
||||
__inline__ __device__ Tout scaled_vec_conversion(const Tin& x,
|
||||
const float scale) {
|
||||
return x;
|
||||
}
|
||||
|
||||
// fp8 -> half
|
||||
template <>
|
||||
__inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>(const uint8_t& a)
|
||||
{
|
||||
__inline__ __device__ uint16_t
|
||||
vec_conversion<uint16_t, uint8_t>(const uint8_t& a) {
|
||||
hip_fp8 f8{a, hip_fp8::from_bits()};
|
||||
__half_raw res;
|
||||
res.data = static_cast<float>(f8);
|
||||
@ -40,9 +38,10 @@ __inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>(const uint8_t&
|
||||
|
||||
// fp8x2 -> half2
|
||||
template <>
|
||||
__inline__ __device__ uint32_t vec_conversion<uint32_t, uint16_t>(const uint16_t& a)
|
||||
{
|
||||
#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
||||
__inline__ __device__ uint32_t
|
||||
vec_conversion<uint32_t, uint16_t>(const uint16_t& a) {
|
||||
#if defined(__HIP__MI300__) && \
|
||||
defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
||||
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
||||
union {
|
||||
__half2_raw h2r;
|
||||
@ -65,8 +64,7 @@ __inline__ __device__ uint32_t vec_conversion<uint32_t, uint16_t>(const uint16_t
|
||||
|
||||
// fp8x4 -> half2x2
|
||||
template <>
|
||||
__inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a)
|
||||
{
|
||||
__inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a) {
|
||||
union {
|
||||
uint2 u32x2;
|
||||
uint32_t u32[2];
|
||||
@ -78,8 +76,7 @@ __inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a)
|
||||
|
||||
// fp8x8 -> half2x4
|
||||
template <>
|
||||
__inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a)
|
||||
{
|
||||
__inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a) {
|
||||
union {
|
||||
uint4 u64x2;
|
||||
uint2 u64[2];
|
||||
@ -93,8 +90,8 @@ using __nv_bfloat16 = __hip_bfloat16;
|
||||
|
||||
// fp8 -> __nv_bfloat16
|
||||
template <>
|
||||
__inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a)
|
||||
{
|
||||
__inline__ __device__ __nv_bfloat16
|
||||
vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) {
|
||||
hip_fp8 f8{a, hip_fp8::from_bits()};
|
||||
float f{f8};
|
||||
return __float2bfloat16(f);
|
||||
@ -104,8 +101,8 @@ using __nv_bfloat162 = __hip_bfloat162;
|
||||
|
||||
// fp8x2 -> __nv_bfloat162
|
||||
template <>
|
||||
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a)
|
||||
{
|
||||
__inline__ __device__ __nv_bfloat162
|
||||
vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a) {
|
||||
__nv_bfloat162 res;
|
||||
res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a);
|
||||
res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U));
|
||||
@ -114,8 +111,8 @@ __inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(co
|
||||
|
||||
// fp8x4 -> bf16_4_t
|
||||
template <>
|
||||
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a)
|
||||
{
|
||||
__inline__ __device__ bf16_4_t
|
||||
vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a) {
|
||||
bf16_4_t res;
|
||||
res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a);
|
||||
res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U));
|
||||
@ -124,8 +121,7 @@ __inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, uint32_t>(const uint32_t
|
||||
|
||||
// fp8x8 -> bf16_8_t
|
||||
template <>
|
||||
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a)
|
||||
{
|
||||
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a) {
|
||||
bf16_4_t tmp1, tmp2;
|
||||
tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x);
|
||||
tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y);
|
||||
@ -139,17 +135,17 @@ __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a)
|
||||
|
||||
// fp8 -> float
|
||||
template <>
|
||||
__inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a)
|
||||
{
|
||||
__inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a) {
|
||||
hip_fp8 fp8{a, hip_fp8::from_bits()};
|
||||
return static_cast<float>(fp8);
|
||||
}
|
||||
|
||||
// fp8x2 -> float2
|
||||
template <>
|
||||
__inline__ __device__ float2 vec_conversion<float2, uint16_t>(const uint16_t& a)
|
||||
{
|
||||
#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
||||
__inline__ __device__ float2
|
||||
vec_conversion<float2, uint16_t>(const uint16_t& a) {
|
||||
#if defined(__HIP__MI300__) && \
|
||||
defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
||||
float2 res;
|
||||
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
||||
res.x = f2[0];
|
||||
@ -165,8 +161,8 @@ __inline__ __device__ float2 vec_conversion<float2, uint16_t>(const uint16_t& a)
|
||||
|
||||
// fp8x4 -> float4
|
||||
template <>
|
||||
__inline__ __device__ Float4_ vec_conversion<Float4_, uint32_t>(const uint32_t& a)
|
||||
{
|
||||
__inline__ __device__ Float4_
|
||||
vec_conversion<Float4_, uint32_t>(const uint32_t& a) {
|
||||
Float4_ res;
|
||||
res.x = vec_conversion<float2, uint16_t>((uint16_t)a);
|
||||
res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U));
|
||||
@ -175,8 +171,7 @@ __inline__ __device__ Float4_ vec_conversion<Float4_, uint32_t>(const uint32_t&
|
||||
|
||||
// fp8x8 -> float8
|
||||
template <>
|
||||
__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a)
|
||||
{
|
||||
__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a) {
|
||||
Float4_ tmp1, tmp2;
|
||||
tmp1 = vec_conversion<Float4_, uint32_t>(a.x);
|
||||
tmp2 = vec_conversion<Float4_, uint32_t>(a.y);
|
||||
@ -190,8 +185,8 @@ __inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a)
|
||||
|
||||
// half -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t vec_conversion<uint8_t, uint16_t>(const uint16_t& a)
|
||||
{
|
||||
__inline__ __device__ uint8_t
|
||||
vec_conversion<uint8_t, uint16_t>(const uint16_t& a) {
|
||||
__half_raw tmp;
|
||||
tmp.x = a;
|
||||
|
||||
@ -201,24 +196,23 @@ __inline__ __device__ uint8_t vec_conversion<uint8_t, uint16_t>(const uint16_t&
|
||||
|
||||
// bf16 -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a)
|
||||
{
|
||||
__inline__ __device__ uint8_t
|
||||
vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a) {
|
||||
hip_fp8 res{__bfloat162float(a)};
|
||||
return res.data;
|
||||
}
|
||||
|
||||
// float -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a)
|
||||
{
|
||||
__inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a) {
|
||||
hip_fp8 f8(a);
|
||||
return f8.data;
|
||||
}
|
||||
|
||||
// fp8x4 -> float4
|
||||
template <>
|
||||
__inline__ __device__ float4 vec_conversion<float4, uint32_t>(const uint32_t& a)
|
||||
{
|
||||
__inline__ __device__ float4
|
||||
vec_conversion<float4, uint32_t>(const uint32_t& a) {
|
||||
Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
|
||||
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
||||
return res;
|
||||
@ -226,8 +220,8 @@ __inline__ __device__ float4 vec_conversion<float4, uint32_t>(const uint32_t& a)
|
||||
|
||||
// float2 -> half2
|
||||
template <>
|
||||
__inline__ __device__ uint32_t vec_conversion<uint32_t, float2>(const float2& a)
|
||||
{
|
||||
__inline__ __device__ uint32_t
|
||||
vec_conversion<uint32_t, float2>(const float2& a) {
|
||||
union {
|
||||
half2 float16;
|
||||
uint32_t uint32;
|
||||
@ -239,8 +233,7 @@ __inline__ __device__ uint32_t vec_conversion<uint32_t, float2>(const float2& a)
|
||||
|
||||
// Float4 -> half2x2
|
||||
template <>
|
||||
__inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a)
|
||||
{
|
||||
__inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a) {
|
||||
uint2 b;
|
||||
float2 val;
|
||||
val.x = a.x.x;
|
||||
@ -255,8 +248,7 @@ __inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a)
|
||||
|
||||
// Float4 -> float4
|
||||
template <>
|
||||
__inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a)
|
||||
{
|
||||
__inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a) {
|
||||
float4 b;
|
||||
b.x = a.x.x;
|
||||
b.y = a.x.y;
|
||||
@ -267,8 +259,7 @@ __inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a)
|
||||
|
||||
// Float8 -> half2x4
|
||||
template <>
|
||||
__inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a)
|
||||
{
|
||||
__inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a) {
|
||||
uint4 b;
|
||||
b.x = vec_conversion<uint32_t, float2>(a.x);
|
||||
b.y = vec_conversion<uint32_t, float2>(a.y);
|
||||
@ -279,16 +270,16 @@ __inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a)
|
||||
|
||||
// float2 -> bfloat162
|
||||
template <>
|
||||
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(const float2& a)
|
||||
{
|
||||
__inline__ __device__ __nv_bfloat162
|
||||
vec_conversion<__nv_bfloat162, float2>(const float2& a) {
|
||||
__nv_bfloat162 b = __float22bfloat162_rn(a);
|
||||
return b;
|
||||
}
|
||||
|
||||
// Float4 -> bfloat162x2
|
||||
template <>
|
||||
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, Float4_>(const Float4_& a)
|
||||
{
|
||||
__inline__ __device__ bf16_4_t
|
||||
vec_conversion<bf16_4_t, Float4_>(const Float4_& a) {
|
||||
bf16_4_t b;
|
||||
b.x = __float22bfloat162_rn(a.x);
|
||||
b.y = __float22bfloat162_rn(a.y);
|
||||
@ -297,8 +288,8 @@ __inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, Float4_>(const Float4_&
|
||||
|
||||
// Float8 -> bfloat162x4
|
||||
template <>
|
||||
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(const Float8_& a)
|
||||
{
|
||||
__inline__ __device__ bf16_8_t
|
||||
vec_conversion<bf16_8_t, Float8_>(const Float8_& a) {
|
||||
bf16_8_t b;
|
||||
b.x = __float22bfloat162_rn(a.x);
|
||||
b.y = __float22bfloat162_rn(a.y);
|
||||
@ -307,20 +298,19 @@ __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(const Float8_&
|
||||
return b;
|
||||
}
|
||||
|
||||
/* Scaled and vectorized conversions, for data exchange between high and low
|
||||
precision domains
|
||||
|
||||
/* Scaled and vectorized conversions, for data exchange between high and low precision domains
|
||||
|
||||
Convention of the scale in API, e.g: FP8_data = Quantization( High_Precision_data / scale )
|
||||
s.t.
|
||||
Quantize(HP / scale) => FP8
|
||||
Dequant(FP8) * scale => HP
|
||||
Convention of the scale in API, e.g: FP8_data = Quantization(
|
||||
High_Precision_data / scale ) s.t. Quantize(HP / scale) => FP8 Dequant(FP8) *
|
||||
scale => HP
|
||||
|
||||
*/
|
||||
|
||||
// fp8 -> half
|
||||
template <>
|
||||
__inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, const float scale)
|
||||
{
|
||||
__inline__ __device__ uint16_t
|
||||
scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, const float scale) {
|
||||
hip_fp8 f8{a, hip_fp8::from_bits()};
|
||||
__half_raw res;
|
||||
res.data = static_cast<float>(f8) * scale;
|
||||
@ -329,9 +319,10 @@ __inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, uint8_t>(const ui
|
||||
|
||||
// fp8x2 -> half2
|
||||
template <>
|
||||
__inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>(const uint16_t& a, const float scale)
|
||||
{
|
||||
#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
||||
__inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>(
|
||||
const uint16_t& a, const float scale) {
|
||||
#if defined(__HIP__MI300__) && \
|
||||
defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
||||
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
||||
union {
|
||||
__half2_raw h2r;
|
||||
@ -346,29 +337,32 @@ __inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>(const u
|
||||
uint32_t u32;
|
||||
} tmp;
|
||||
|
||||
tmp.u16[0] = scaled_vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a), scale);
|
||||
tmp.u16[1] = scaled_vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a >> 8U), scale);
|
||||
tmp.u16[0] =
|
||||
scaled_vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a), scale);
|
||||
tmp.u16[1] = scaled_vec_conversion<uint16_t, uint8_t>(
|
||||
static_cast<uint8_t>(a >> 8U), scale);
|
||||
return tmp.u32;
|
||||
#endif
|
||||
}
|
||||
|
||||
// fp8x4 -> half2x2
|
||||
template <>
|
||||
__inline__ __device__ uint2 scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, const float scale)
|
||||
{
|
||||
__inline__ __device__ uint2
|
||||
scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, const float scale) {
|
||||
union {
|
||||
uint2 u32x2;
|
||||
uint32_t u32[2];
|
||||
} tmp;
|
||||
tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale);
|
||||
tmp.u32[1] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale);
|
||||
tmp.u32[1] =
|
||||
scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale);
|
||||
return tmp.u32x2;
|
||||
}
|
||||
|
||||
// fp8x8 -> half2x4
|
||||
template <>
|
||||
__inline__ __device__ uint4 scaled_vec_conversion<uint4, uint2>(const uint2& a, const float scale)
|
||||
{
|
||||
__inline__ __device__ uint4
|
||||
scaled_vec_conversion<uint4, uint2>(const uint2& a, const float scale) {
|
||||
union {
|
||||
uint4 u64x2;
|
||||
uint2 u64[2];
|
||||
@ -382,8 +376,9 @@ using __nv_bfloat16 = __hip_bfloat16;
|
||||
|
||||
// fp8 -> __nv_bfloat16
|
||||
template <>
|
||||
__inline__ __device__ __nv_bfloat16 scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, const float scale)
|
||||
{
|
||||
__inline__ __device__ __nv_bfloat16
|
||||
scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a,
|
||||
const float scale) {
|
||||
hip_fp8 f8{a, hip_fp8::from_bits()};
|
||||
float f{f8};
|
||||
return __float2bfloat16(f * scale);
|
||||
@ -393,28 +388,31 @@ using __nv_bfloat162 = __hip_bfloat162;
|
||||
|
||||
// fp8x2 -> __nv_bfloat162
|
||||
template <>
|
||||
__inline__ __device__ __nv_bfloat162 scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a, const float scale)
|
||||
{
|
||||
__inline__ __device__ __nv_bfloat162
|
||||
scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a,
|
||||
const float scale) {
|
||||
__nv_bfloat162 res;
|
||||
res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale);
|
||||
res.y = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale);
|
||||
res.y =
|
||||
scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale);
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x4 -> bf16_4_t
|
||||
template <>
|
||||
__inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a, const float scale)
|
||||
{
|
||||
__inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(
|
||||
const uint32_t& a, const float scale) {
|
||||
bf16_4_t res;
|
||||
res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale);
|
||||
res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), scale);
|
||||
res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U),
|
||||
scale);
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x8 -> bf16_8_t
|
||||
template <>
|
||||
__inline__ __device__ bf16_8_t scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, const float scale)
|
||||
{
|
||||
__inline__ __device__ bf16_8_t
|
||||
scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, const float scale) {
|
||||
bf16_4_t tmp1, tmp2;
|
||||
tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale);
|
||||
tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale);
|
||||
@ -428,17 +426,18 @@ __inline__ __device__ bf16_8_t scaled_vec_conversion<bf16_8_t, uint2>(const uint
|
||||
|
||||
// fp8 -> float
|
||||
template <>
|
||||
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>(const uint8_t& a, const float scale)
|
||||
{
|
||||
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
|
||||
const uint8_t& a, const float scale) {
|
||||
hip_fp8 fp8{a, hip_fp8::from_bits()};
|
||||
return static_cast<float>(fp8) * scale;
|
||||
}
|
||||
|
||||
// fp8x2 -> float2
|
||||
template <>
|
||||
__inline__ __device__ float2 scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, const float scale)
|
||||
{
|
||||
#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
||||
__inline__ __device__ float2
|
||||
scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, const float scale) {
|
||||
#if defined(__HIP__MI300__) && \
|
||||
defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
||||
float2 res;
|
||||
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
||||
res.x = f2[0] * scale;
|
||||
@ -447,15 +446,16 @@ __inline__ __device__ float2 scaled_vec_conversion<float2, uint16_t>(const uint1
|
||||
#else
|
||||
float2 res;
|
||||
res.x = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a), scale);
|
||||
res.y = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U), scale);
|
||||
res.y = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U),
|
||||
scale);
|
||||
return res;
|
||||
#endif
|
||||
}
|
||||
|
||||
// fp8x4 -> float4
|
||||
template <>
|
||||
__inline__ __device__ Float4_ scaled_vec_conversion<Float4_, uint32_t>(const uint32_t& a, const float scale)
|
||||
{
|
||||
__inline__ __device__ Float4_
|
||||
scaled_vec_conversion<Float4_, uint32_t>(const uint32_t& a, const float scale) {
|
||||
Float4_ res;
|
||||
res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale);
|
||||
res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale);
|
||||
@ -464,8 +464,8 @@ __inline__ __device__ Float4_ scaled_vec_conversion<Float4_, uint32_t>(const uin
|
||||
|
||||
// fp8x8 -> float8
|
||||
template <>
|
||||
__inline__ __device__ Float8_ scaled_vec_conversion<Float8_, uint2>(const uint2& a, const float scale)
|
||||
{
|
||||
__inline__ __device__ Float8_
|
||||
scaled_vec_conversion<Float8_, uint2>(const uint2& a, const float scale) {
|
||||
Float4_ tmp1, tmp2;
|
||||
tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale);
|
||||
tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale);
|
||||
@ -477,15 +477,14 @@ __inline__ __device__ Float8_ scaled_vec_conversion<Float8_, uint2>(const uint2&
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
/* Quantize(HP / scale) => FP8 */
|
||||
|
||||
// TODO(Hai): vectorized to add
|
||||
|
||||
// half -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, const float scale)
|
||||
{
|
||||
__inline__ __device__ uint8_t
|
||||
scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, const float scale) {
|
||||
__half_raw tmp;
|
||||
tmp.x = a;
|
||||
|
||||
@ -495,24 +494,24 @@ __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, uint16_t>(const uin
|
||||
|
||||
// bf16 -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a, const float scale)
|
||||
{
|
||||
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
|
||||
const __nv_bfloat16& a, const float scale) {
|
||||
hip_fp8 res{__bfloat162float(a) / scale};
|
||||
return res.data;
|
||||
}
|
||||
|
||||
// float -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, float>(const float& a, const float scale)
|
||||
{
|
||||
__inline__ __device__ uint8_t
|
||||
scaled_vec_conversion<uint8_t, float>(const float& a, const float scale) {
|
||||
hip_fp8 f8(a / scale);
|
||||
return f8.data;
|
||||
}
|
||||
|
||||
// fp8x4 -> float4
|
||||
template <>
|
||||
__inline__ __device__ float4 scaled_vec_conversion<float4, uint32_t>(const uint32_t& a, const float scale)
|
||||
{
|
||||
__inline__ __device__ float4
|
||||
scaled_vec_conversion<float4, uint32_t>(const uint32_t& a, const float scale) {
|
||||
Float4_ tmp = scaled_vec_conversion<Float4_, uint32_t>(a, scale);
|
||||
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
||||
return res;
|
||||
@ -539,9 +538,10 @@ __inline__ __device__ Tout scaled_convert(const Tin &x, const float scale) {
|
||||
assert(false);
|
||||
}
|
||||
|
||||
// The following macro is used to dispatch the conversion function based on the
|
||||
// data type of the key and value cache. The FN is a macro that calls a function
|
||||
// with template<typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>.
|
||||
// The following macro is used to dispatch the conversion function based on
|
||||
// the data type of the key and value cache. The FN is a macro that calls a
|
||||
// function with template<typename scalar_t, typename cache_t,
|
||||
// Fp8KVCacheDataType kv_dt>.
|
||||
#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
|
||||
if (KV_DTYPE == "auto") { \
|
||||
if (SRC_DTYPE == at::ScalarType::Float) { \
|
||||
@ -562,13 +562,14 @@ __inline__ __device__ Tout scaled_convert(const Tin &x, const float scale) {
|
||||
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
||||
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
||||
} else { \
|
||||
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
|
||||
TORCH_CHECK(false, \
|
||||
"Unsupported input type of kv cache: ", SRC_DTYPE); \
|
||||
} \
|
||||
} else { \
|
||||
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
|
||||
} \
|
||||
}
|
||||
|
||||
} // fp8
|
||||
} // namespace fp8
|
||||
#endif // USE_ROCM
|
||||
} // namespace vllm
|
||||
|
@ -11,8 +11,10 @@ namespace vllm {
|
||||
|
||||
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
|
||||
float old;
|
||||
old = (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) :
|
||||
__uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value)));
|
||||
old = (value >= 0)
|
||||
? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
|
||||
: __uint_as_float(
|
||||
atomicMin((unsigned int*)addr, __float_as_uint(value)));
|
||||
|
||||
return old;
|
||||
}
|
||||
@ -20,7 +22,8 @@ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
|
||||
#define FP8_E4M3_MAX std::numeric_limits<c10::Float8_e4m3fn>::max()
|
||||
|
||||
template <typename scalar_t>
|
||||
__device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(const scalar_t val, const float scale) {
|
||||
__device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(
|
||||
const scalar_t val, const float scale) {
|
||||
float x = static_cast<float>(val) / scale;
|
||||
float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
|
||||
return static_cast<c10::Float8_e4m3fn>(r);
|
||||
@ -33,8 +36,7 @@ __device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(const scalar
|
||||
// a value <= 0.0 and we need to wait for all thread blocks to
|
||||
// finish before consuming *scale.
|
||||
template <typename scalar_t>
|
||||
__global__ void segmented_max_reduction(
|
||||
float* __restrict__ scale,
|
||||
__global__ void segmented_max_reduction(float* __restrict__ scale,
|
||||
const scalar_t* __restrict__ input,
|
||||
int64_t num_elems) {
|
||||
__shared__ float cache[1024];
|
||||
@ -64,13 +66,13 @@ __global__ void segmented_max_reduction(
|
||||
// Finally, since cache[0] contains the maximum for this thread block,
|
||||
// atomically write the max to the target location
|
||||
if (threadIdx.x == 0) {
|
||||
atomicMaxFloat(scale, cache[0] / std::numeric_limits<c10::Float8_e4m3fn>::max());
|
||||
atomicMaxFloat(scale,
|
||||
cache[0] / std::numeric_limits<c10::Float8_e4m3fn>::max());
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void scaled_fp8_quant_kernel(
|
||||
c10::Float8_e4m3fn* __restrict__ out,
|
||||
__global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out,
|
||||
const scalar_t* __restrict__ input,
|
||||
const float* __restrict__ scale,
|
||||
int64_t num_elems) {
|
||||
@ -83,8 +85,7 @@ __global__ void scaled_fp8_quant_kernel(
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
void static_scaled_fp8_quant(
|
||||
torch::Tensor& out, // [..., d]
|
||||
void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input, // [..., d]
|
||||
torch::Tensor& scale) // [1]
|
||||
{
|
||||
@ -95,19 +96,14 @@ void static_scaled_fp8_quant(
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(),
|
||||
"scaled_fp8_quant_kernel",
|
||||
[&] {
|
||||
input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
|
||||
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<c10::Float8_e4m3fn>(),
|
||||
input.data_ptr<scalar_t>(),
|
||||
scale.data_ptr<float>(),
|
||||
num_elems);
|
||||
out.data_ptr<c10::Float8_e4m3fn>(), input.data_ptr<scalar_t>(),
|
||||
scale.data_ptr<float>(), num_elems);
|
||||
});
|
||||
}
|
||||
|
||||
void dynamic_scaled_fp8_quant(
|
||||
torch::Tensor& out, // [..., d]
|
||||
void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input, // [..., d]
|
||||
torch::Tensor& scale) // [1]
|
||||
{
|
||||
@ -118,18 +114,11 @@ void dynamic_scaled_fp8_quant(
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(),
|
||||
"scaled_fp8_quant_kernel",
|
||||
[&] {
|
||||
input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
|
||||
vllm::segmented_max_reduction<scalar_t><<<grid, block, 0, stream>>>(
|
||||
scale.data_ptr<float>(),
|
||||
input.data_ptr<scalar_t>(),
|
||||
num_elems);
|
||||
scale.data_ptr<float>(), input.data_ptr<scalar_t>(), num_elems);
|
||||
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<c10::Float8_e4m3fn>(),
|
||||
input.data_ptr<scalar_t>(),
|
||||
scale.data_ptr<float>(),
|
||||
num_elems);
|
||||
out.data_ptr<c10::Float8_e4m3fn>(), input.data_ptr<scalar_t>(),
|
||||
scale.data_ptr<float>(), num_elems);
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -406,7 +406,6 @@ template <>
|
||||
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
|
||||
const uint8_t& a, const float scale,
|
||||
const __nv_fp8_interpretation_t fp8_type) {
|
||||
|
||||
// fp8 -> half
|
||||
__half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
|
||||
uint16_t tmp = res.x;
|
||||
@ -523,9 +522,10 @@ __inline__ __device__ Tout scaled_convert(const Tin &x, const float scale) {
|
||||
assert(false);
|
||||
}
|
||||
|
||||
// The following macro is used to dispatch the conversion function based on the
|
||||
// data type of the key and value cache. The FN is a macro that calls a function
|
||||
// with template<typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>.
|
||||
// The following macro is used to dispatch the conversion function based on
|
||||
// the data type of the key and value cache. The FN is a macro that calls a
|
||||
// function with template<typename scalar_t, typename cache_t,
|
||||
// Fp8KVCacheDataType kv_dt>.
|
||||
#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
|
||||
if (KV_DTYPE == "auto") { \
|
||||
if (SRC_DTYPE == at::ScalarType::Float) { \
|
||||
@ -546,7 +546,8 @@ __inline__ __device__ Tout scaled_convert(const Tin &x, const float scale) {
|
||||
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
||||
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
||||
} else { \
|
||||
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
|
||||
TORCH_CHECK(false, \
|
||||
"Unsupported input type of kv cache: ", SRC_DTYPE); \
|
||||
} \
|
||||
} else if (KV_DTYPE == "fp8_e5m2") { \
|
||||
if (SRC_DTYPE == at::ScalarType::Float) { \
|
||||
@ -556,7 +557,8 @@ __inline__ __device__ Tout scaled_convert(const Tin &x, const float scale) {
|
||||
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
||||
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
|
||||
} else { \
|
||||
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
|
||||
TORCH_CHECK(false, \
|
||||
"Unsupported input type of kv cache: ", SRC_DTYPE); \
|
||||
} \
|
||||
} else { \
|
||||
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
|
||||
|
@ -9,40 +9,36 @@ namespace vllm {
|
||||
namespace gptq {
|
||||
// atomicAdd for half types, to support CC < 7.x
|
||||
|
||||
__device__ __forceinline__ void atomicAdd_half(half* address, half val)
|
||||
{
|
||||
unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
|
||||
__device__ __forceinline__ void atomicAdd_half(half* address, half val) {
|
||||
unsigned int* address_as_ui =
|
||||
(unsigned int*)((char*)address - ((size_t)address & 2));
|
||||
unsigned int old = *address_as_ui;
|
||||
unsigned int assumed;
|
||||
|
||||
do
|
||||
{
|
||||
do {
|
||||
assumed = old;
|
||||
__half_raw hsum;
|
||||
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
|
||||
half tmpres = __hadd(hsum, val);
|
||||
hsum = __half_raw(tmpres);
|
||||
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
|
||||
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16)
|
||||
: (old & 0xffff0000) | hsum.x;
|
||||
old = atomicCAS(address_as_ui, assumed, old);
|
||||
}
|
||||
while (assumed != old);
|
||||
} while (assumed != old);
|
||||
}
|
||||
|
||||
// atomicAdd for half2 types
|
||||
|
||||
__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
|
||||
{
|
||||
__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) {
|
||||
unsigned int* address_as_ui = (unsigned int*)address;
|
||||
unsigned int old = *address_as_ui;
|
||||
unsigned int assumed;
|
||||
do
|
||||
{
|
||||
do {
|
||||
assumed = old;
|
||||
half2 old_val = *((half2*)&old);
|
||||
half2 new_val = __hadd2(old_val, val);
|
||||
old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val));
|
||||
}
|
||||
while (assumed != old);
|
||||
} while (assumed != old);
|
||||
}
|
||||
|
||||
//
|
||||
@ -50,10 +46,14 @@ __device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
|
||||
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
|
||||
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
|
||||
|
||||
__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
|
||||
__device__ __forceinline__ void atomicAdd(half* address, half val) {
|
||||
atomicAdd_half(address, val);
|
||||
}
|
||||
|
||||
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
|
||||
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
|
||||
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) {
|
||||
atomicAdd_half2(address, val);
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
@ -1,5 +1,6 @@
|
||||
/*
|
||||
Adapted from https://github.com/turboderp/exllamav2 and https://github.com/turboderp/exllama
|
||||
Adapted from https://github.com/turboderp/exllamav2 and
|
||||
https://github.com/turboderp/exllama
|
||||
*/
|
||||
|
||||
#ifndef _matrix_view_cuh
|
||||
@ -13,24 +14,31 @@ Adapted from https://github.com/turboderp/exllamav2 and https://github.com/turbo
|
||||
namespace vllm {
|
||||
namespace gptq {
|
||||
|
||||
class MatrixView_half
|
||||
{
|
||||
class MatrixView_half {
|
||||
public:
|
||||
const half* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width)
|
||||
: data(data), height(height), width(width)
|
||||
{ }
|
||||
__device__ __forceinline__ MatrixView_half(const half* data, const int height,
|
||||
const int width)
|
||||
: data(data), height(height), width(width) {}
|
||||
|
||||
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
|
||||
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
|
||||
__device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); }
|
||||
__device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; }
|
||||
__device__ __forceinline__ half item(int row, int column) const {
|
||||
return data[row * width + column];
|
||||
}
|
||||
__device__ __forceinline__ half2 item_half2(int row, int column) const {
|
||||
return ((half2*)data)[(row * width + column) / 2];
|
||||
}
|
||||
__device__ __forceinline__ half2 item_half2half2(int row, int column) const {
|
||||
return __half2half2(data[row * width + column]);
|
||||
}
|
||||
__device__ __forceinline__ const half* item_ptr(int row, int column) const {
|
||||
return &data[row * width + column];
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item4(half (&items)[4], int row, int column) const
|
||||
{
|
||||
__device__ __forceinline__ void item4(half (&items)[4], int row,
|
||||
int column) const {
|
||||
half2* ptr = (half2*)item_ptr(row, column);
|
||||
half2 i01 = ptr[0];
|
||||
half2 i23 = ptr[1];
|
||||
@ -39,8 +47,8 @@ public:
|
||||
items[2] = __low2half(i23);
|
||||
items[3] = __high2half(i23);
|
||||
}
|
||||
__device__ __forceinline__ void item4_f(float (&items)[4], int row, int column) const
|
||||
{
|
||||
__device__ __forceinline__ void item4_f(float (&items)[4], int row,
|
||||
int column) const {
|
||||
half2* ptr = (half2*)item_ptr(row, column);
|
||||
half2 i01 = ptr[0];
|
||||
half2 i23 = ptr[1];
|
||||
@ -50,8 +58,8 @@ public:
|
||||
items[3] = __half2float(__high2half(i23));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, int column) const
|
||||
{
|
||||
__device__ __forceinline__ void item4_h2(half2 (&items)[4], int row,
|
||||
int column) const {
|
||||
half2* ptr = (half2*)item_ptr(row, column);
|
||||
half2 i01 = ptr[0];
|
||||
half2 i23 = ptr[1];
|
||||
@ -62,25 +70,34 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class MatrixView_half_rw
|
||||
{
|
||||
class MatrixView_half_rw {
|
||||
public:
|
||||
half* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width)
|
||||
: data(data), height(height), width(width)
|
||||
{ }
|
||||
__device__ __forceinline__ MatrixView_half_rw(half* data, const int height,
|
||||
const int width)
|
||||
: data(data), height(height), width(width) {}
|
||||
|
||||
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
|
||||
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
|
||||
__device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; }
|
||||
__device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; }
|
||||
__device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; }
|
||||
__device__ __forceinline__ half item(int row, int column) const {
|
||||
return data[row * width + column];
|
||||
}
|
||||
__device__ __forceinline__ half2 item_half2(int row, int column) const {
|
||||
return ((half2*)data)[(row * width + column) / 2];
|
||||
}
|
||||
__device__ __forceinline__ half* item_ptr(int row, int column) {
|
||||
return &data[row * width + column];
|
||||
}
|
||||
__device__ __forceinline__ void set(int row, int column, half value) {
|
||||
data[row * width + column] = value;
|
||||
}
|
||||
__device__ __forceinline__ void set_half2(int row, int column, half2 value) {
|
||||
((half2*)data)[(row * width + column) / 2] = value;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void set4(int row, int column, half v0, half v1, half v2, half v3)
|
||||
{
|
||||
__device__ __forceinline__ void set4(int row, int column, half v0, half v1,
|
||||
half v2, half v3) {
|
||||
half2 v01 = __halves2half2(v0, v1);
|
||||
half2 v23 = __halves2half2(v2, v3);
|
||||
half2* ptr = (half2*)item_ptr(row, column);
|
||||
@ -89,33 +106,32 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class MatrixView_q4_row
|
||||
{
|
||||
class MatrixView_q4_row {
|
||||
public:
|
||||
const uint32_t* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width)
|
||||
: data(data), height(height), width(width)
|
||||
{ }
|
||||
__device__ __forceinline__ MatrixView_q4_row(const uint32_t* data,
|
||||
const int height,
|
||||
const int width)
|
||||
: data(data), height(height), width(width) {}
|
||||
|
||||
__device__ __forceinline__ int item(int row, int column) const
|
||||
{
|
||||
__device__ __forceinline__ int item(int row, int column) const {
|
||||
int shift = (column & 0x07) * 4;
|
||||
return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const
|
||||
{
|
||||
__device__ __forceinline__ void item2(int (&items)[2], int row,
|
||||
int column) const {
|
||||
int shift = (column & 0x07) * 4;
|
||||
uint32_t d = data[row * width / 8 + column / 8] >> shift;
|
||||
items[0] = d & 0x0f;
|
||||
items[1] = (d >> 4) & 0x0f;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
|
||||
{
|
||||
__device__ __forceinline__ void item4(int (&items)[4], int row,
|
||||
int column) const {
|
||||
int shift = (column & 0x07) * 4;
|
||||
uint32_t d = data[row * width / 8 + column / 8] >> shift;
|
||||
items[0] = d & 0x0f;
|
||||
@ -125,54 +141,57 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class MatrixView_q4_column
|
||||
{
|
||||
class MatrixView_q4_column {
|
||||
public:
|
||||
const uint32_t* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width)
|
||||
: data(data), height(height), width(width)
|
||||
{ }
|
||||
__device__ __forceinline__ MatrixView_q4_column(const uint32_t* data,
|
||||
const int height,
|
||||
const int width)
|
||||
: data(data), height(height), width(width) {}
|
||||
|
||||
__device__ __forceinline__ int item(int row, int column) const
|
||||
{
|
||||
__device__ __forceinline__ int item(int row, int column) const {
|
||||
int shift = (row & 0x07) * 4;
|
||||
return (data[row / 8 * width + column] >> shift) & 0x0f;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; }
|
||||
__device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; }
|
||||
__device__ __forceinline__ uint32_t item_uint32_t(int row, int column) {
|
||||
return data[row / 8 * width + column];
|
||||
}
|
||||
__device__ __forceinline__ const uint32_t* item_uint32_ptr(int row,
|
||||
int column) {
|
||||
return &data[row / 8 * width + column];
|
||||
}
|
||||
};
|
||||
|
||||
class MatrixView_q2_row
|
||||
{
|
||||
class MatrixView_q2_row {
|
||||
public:
|
||||
const uint32_t* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_q2_row(const uint32_t* data, const int height, const int width)
|
||||
: data(data), height(height), width(width)
|
||||
{ }
|
||||
__device__ __forceinline__ MatrixView_q2_row(const uint32_t* data,
|
||||
const int height,
|
||||
const int width)
|
||||
: data(data), height(height), width(width) {}
|
||||
|
||||
__device__ __forceinline__ int item(int row, int column) const
|
||||
{
|
||||
__device__ __forceinline__ int item(int row, int column) const {
|
||||
int shift = (column & 0x0f) * 2;
|
||||
return (data[row * width / 16 + column / 16] >> shift) & 0x03;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const
|
||||
{
|
||||
__device__ __forceinline__ void item2(int (&items)[2], int row,
|
||||
int column) const {
|
||||
int shift = (column & 0x0f) * 2;
|
||||
uint32_t d = data[row * width / 16 + column / 16] >> shift;
|
||||
items[0] = d & 0x03;
|
||||
items[1] = (d >> 2) & 0x03;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
|
||||
{
|
||||
__device__ __forceinline__ void item4(int (&items)[4], int row,
|
||||
int column) const {
|
||||
int shift = (column & 0x0f) * 2;
|
||||
uint32_t d = data[row * width / 16 + column / 16] >> shift;
|
||||
items[0] = d & 0x03;
|
||||
@ -182,26 +201,27 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class MatrixView_q3_row
|
||||
{
|
||||
class MatrixView_q3_row {
|
||||
public:
|
||||
const uint32_t* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_q3_row(const uint32_t* data, const int height, const int width)
|
||||
: data(data), height(height), width(width)
|
||||
{ }
|
||||
__device__ __forceinline__ MatrixView_q3_row(const uint32_t* data,
|
||||
const int height,
|
||||
const int width)
|
||||
: data(data), height(height), width(width) {}
|
||||
|
||||
__device__ __forceinline__ int item(int row, int column) const
|
||||
{
|
||||
__device__ __forceinline__ int item(int row, int column) const {
|
||||
int z_w = column * 3 / 32;
|
||||
int z_mod = column & 0x1f;
|
||||
|
||||
if (z_mod == 10) {
|
||||
return (data[row * width * 3 / 32 + z_w] >> 30) | ((data[row * width * 3 / 32 + (z_w + 1)] << 2) & 0x4);
|
||||
return (data[row * width * 3 / 32 + z_w] >> 30) |
|
||||
((data[row * width * 3 / 32 + (z_w + 1)] << 2) & 0x4);
|
||||
} else if (z_mod == 21) {
|
||||
return (data[row * width * 3 / 32 + z_w] >> 31) | ((data[row * width * 3 / 32 + (z_w + 1)] << 1) & 0x6);
|
||||
return (data[row * width * 3 / 32 + z_w] >> 31) |
|
||||
((data[row * width * 3 / 32 + (z_w + 1)] << 1) & 0x6);
|
||||
} else if (z_mod < 10) {
|
||||
return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3)) & 0x07;
|
||||
} else if (z_mod < 21) {
|
||||
@ -211,18 +231,20 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
|
||||
{
|
||||
__device__ __forceinline__ void item4(int (&items)[4], int row,
|
||||
int column) const {
|
||||
int shift = (column & 0x1f);
|
||||
uint32_t d;
|
||||
if (shift <= 4) {
|
||||
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3);
|
||||
} else if (shift == 8) {
|
||||
d = (data[row * width / 32 * 3 + column * 3 / 32] >> 24) | ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0x0f) << 8);
|
||||
d = (data[row * width / 32 * 3 + column * 3 / 32] >> 24) |
|
||||
((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0x0f) << 8);
|
||||
} else if (shift <= 16) {
|
||||
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 32);
|
||||
} else if (shift == 20) {
|
||||
d = (data[row * width / 32 * 3 + column * 3 / 32] >> 28) | ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0xff) << 4);
|
||||
d = (data[row * width / 32 * 3 + column * 3 / 32] >> 28) |
|
||||
((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0xff) << 4);
|
||||
} else {
|
||||
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 64);
|
||||
}
|
||||
@ -233,33 +255,32 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class MatrixView_q8_row
|
||||
{
|
||||
class MatrixView_q8_row {
|
||||
public:
|
||||
const uint32_t* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_q8_row(const uint32_t* data, const int height, const int width)
|
||||
: data(data), height(height), width(width)
|
||||
{ }
|
||||
__device__ __forceinline__ MatrixView_q8_row(const uint32_t* data,
|
||||
const int height,
|
||||
const int width)
|
||||
: data(data), height(height), width(width) {}
|
||||
|
||||
__device__ __forceinline__ int item(int row, int column) const
|
||||
{
|
||||
__device__ __forceinline__ int item(int row, int column) const {
|
||||
int shift = (column & 0x03) * 8;
|
||||
return (data[row * width / 4 + column / 4] >> shift) & 0xff;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const
|
||||
{
|
||||
__device__ __forceinline__ void item2(int (&items)[2], int row,
|
||||
int column) const {
|
||||
int shift = (column & 0x03) * 8;
|
||||
uint32_t d = data[row * width / 4 + column / 4] >> shift;
|
||||
items[0] = d & 0xff;
|
||||
items[1] = (d >> 8) & 0xff;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
|
||||
{
|
||||
__device__ __forceinline__ void item4(int (&items)[4], int row,
|
||||
int column) const {
|
||||
int shift = (column & 0x03) * 2;
|
||||
uint32_t d = data[row * width / 4 + column / 4] >> shift;
|
||||
items[0] = d & 0xff;
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -14,18 +14,12 @@ namespace gptq {
|
||||
//
|
||||
// ffddbb99 77553311 eeccaa88 66442200
|
||||
|
||||
__forceinline__ __device__ void shuffle_2bit_16
|
||||
(
|
||||
uint32_t* q,
|
||||
int stride
|
||||
)
|
||||
{
|
||||
__forceinline__ __device__ void shuffle_2bit_16(uint32_t* q, int stride) {
|
||||
uint32_t qa = q[0];
|
||||
uint32_t qb = 0;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++)
|
||||
{
|
||||
for (int i = 0; i < 8; i++) {
|
||||
uint32_t qa0 = qa & 0x03;
|
||||
uint32_t qa1 = (qa & 0x0c) >> 2;
|
||||
qa >>= 4;
|
||||
@ -35,14 +29,9 @@ __forceinline__ __device__ void shuffle_2bit_16
|
||||
q[0] = qb;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_2bit_16
|
||||
(
|
||||
const uint32_t q_0,
|
||||
half2 (&dq)[8],
|
||||
int stride,
|
||||
const uint32_t zero
|
||||
)
|
||||
{
|
||||
__forceinline__ __device__ void dequant_2bit_16(const uint32_t q_0,
|
||||
half2 (&dq)[8], int stride,
|
||||
const uint32_t zero) {
|
||||
const uint32_t c0 = 0x64006400;
|
||||
const half y4_ = __float2half_rn(1.0f / 4.0f);
|
||||
const half y16_ = __float2half_rn(1.0f / 16.0f);
|
||||
|
@ -11,12 +11,7 @@ namespace gptq {
|
||||
// vjjjhhhf ffdddbbb uiiiggge eecccaaa
|
||||
// vtttrrrp ppnnnlll usssqqqo oommmkkk
|
||||
|
||||
__forceinline__ __device__ void shuffle_3bit_32
|
||||
(
|
||||
uint32_t* q,
|
||||
int stride
|
||||
)
|
||||
{
|
||||
__forceinline__ __device__ void shuffle_3bit_32(uint32_t* q, int stride) {
|
||||
uint32_t qa = q[0 * stride];
|
||||
uint32_t qb = q[1 * stride];
|
||||
uint32_t qc = q[2 * stride];
|
||||
@ -40,9 +35,27 @@ __forceinline__ __device__ void shuffle_3bit_32
|
||||
uint32_t zb = 0;
|
||||
uint32_t zc = 0;
|
||||
|
||||
for (int i = 0; i < 5; i++) { uint32_t t0 = qa & 0x07; uint32_t t1 = (qa & 0x38) >> 3; qa >>= 6; za |= (t0 << (i * 3)); za |= (t1 << (i * 3 + 16)); }
|
||||
for (int i = 0; i < 5; i++) { uint32_t t0 = qb & 0x07; uint32_t t1 = (qb & 0x38) >> 3; qb >>= 6; zb |= (t0 << (i * 3)); zb |= (t1 << (i * 3 + 16)); }
|
||||
for (int i = 0; i < 5; i++) { uint32_t t0 = qc & 0x07; uint32_t t1 = (qc & 0x38) >> 3; qc >>= 6; zc |= (t0 << (i * 3)); zc |= (t1 << (i * 3 + 16)); }
|
||||
for (int i = 0; i < 5; i++) {
|
||||
uint32_t t0 = qa & 0x07;
|
||||
uint32_t t1 = (qa & 0x38) >> 3;
|
||||
qa >>= 6;
|
||||
za |= (t0 << (i * 3));
|
||||
za |= (t1 << (i * 3 + 16));
|
||||
}
|
||||
for (int i = 0; i < 5; i++) {
|
||||
uint32_t t0 = qb & 0x07;
|
||||
uint32_t t1 = (qb & 0x38) >> 3;
|
||||
qb >>= 6;
|
||||
zb |= (t0 << (i * 3));
|
||||
zb |= (t1 << (i * 3 + 16));
|
||||
}
|
||||
for (int i = 0; i < 5; i++) {
|
||||
uint32_t t0 = qc & 0x07;
|
||||
uint32_t t1 = (qc & 0x38) >> 3;
|
||||
qc >>= 6;
|
||||
zc |= (t0 << (i * 3));
|
||||
zc |= (t1 << (i * 3 + 16));
|
||||
}
|
||||
|
||||
// za: 9997775 55333111 8886664 44222000
|
||||
// zb: jjjhhhf ffdddbbb iiiggge eecccaaa
|
||||
@ -65,16 +78,11 @@ __forceinline__ __device__ void shuffle_3bit_32
|
||||
q[2 * stride] = zc;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_3bit_32
|
||||
(
|
||||
const uint32_t q_0,
|
||||
__forceinline__ __device__ void dequant_3bit_32(const uint32_t q_0,
|
||||
const uint32_t q_1,
|
||||
const uint32_t q_2,
|
||||
half2 (&dq)[16],
|
||||
int stride,
|
||||
const uint32_t zero
|
||||
)
|
||||
{
|
||||
half2 (&dq)[16], int stride,
|
||||
const uint32_t zero) {
|
||||
const uint32_t c0 = 0x64006400;
|
||||
const half y8_ = __float2half_rn(1.0f / 8.0f);
|
||||
const half y64_ = __float2half_rn(1.0f / 64.0f);
|
||||
|
@ -13,18 +13,12 @@ namespace gptq {
|
||||
//
|
||||
// 77775555 33331111 66664444 22220000
|
||||
|
||||
__forceinline__ __device__ void shuffle_4bit_8
|
||||
(
|
||||
uint32_t* q,
|
||||
int stride
|
||||
)
|
||||
{
|
||||
__forceinline__ __device__ void shuffle_4bit_8(uint32_t* q, int stride) {
|
||||
uint32_t qa = q[0];
|
||||
uint32_t qb = 0;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++)
|
||||
{
|
||||
for (int i = 0; i < 4; i++) {
|
||||
uint32_t qa0 = qa & 0x0f;
|
||||
uint32_t qa1 = (qa & 0xf0) >> 4;
|
||||
qa >>= 8;
|
||||
@ -34,14 +28,9 @@ __forceinline__ __device__ void shuffle_4bit_8
|
||||
q[0] = qb;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8
|
||||
(
|
||||
const uint32_t q_0,
|
||||
half2 (&dq)[4],
|
||||
int stride,
|
||||
const uint32_t zero
|
||||
)
|
||||
{
|
||||
__forceinline__ __device__ void dequant_4bit_8(const uint32_t q_0,
|
||||
half2 (&dq)[4], int stride,
|
||||
const uint32_t zero) {
|
||||
const uint32_t c0 = 0x64006400;
|
||||
const half y16_ = __float2half_rn(1.0f / 16.0f);
|
||||
const half2 y16 = __halves2half2(y16_, y16_);
|
||||
@ -63,14 +52,9 @@ __forceinline__ __device__ void dequant_4bit_8
|
||||
dq[3] = __hfma2(q3.as_half2, y16, z16);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale
|
||||
(
|
||||
const uint32_t zero,
|
||||
const half scale,
|
||||
half2 (&z1z16)[2],
|
||||
half2 (&y1y16)[2]
|
||||
)
|
||||
{
|
||||
__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale(
|
||||
const uint32_t zero, const half scale, half2 (&z1z16)[2],
|
||||
half2 (&y1y16)[2]) {
|
||||
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
|
||||
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
|
||||
|
||||
@ -86,13 +70,9 @@ __forceinline__ __device__ void dequant_4bit_8_prep_zero_scale
|
||||
y1y16[1] = __hmul2(scale2, __half2half2(y16));
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8_prep_zero
|
||||
(
|
||||
const uint32_t zero,
|
||||
__forceinline__ __device__ void dequant_4bit_8_prep_zero(const uint32_t zero,
|
||||
half2 (&z1z16)[2],
|
||||
half2(&y1y16)[2]
|
||||
)
|
||||
{
|
||||
half2 (&y1y16)[2]) {
|
||||
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
|
||||
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
|
||||
|
||||
@ -106,39 +86,38 @@ __forceinline__ __device__ void dequant_4bit_8_prep_zero
|
||||
y1y16[1] = __half2half2(y16);
|
||||
}
|
||||
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8_gptq
|
||||
(
|
||||
const uint32_t q_0,
|
||||
__forceinline__ __device__ void dequant_4bit_8_gptq(const uint32_t q_0,
|
||||
half2 (&dq)[4],
|
||||
half2 (&z1z16)[2],
|
||||
half2 (&y1y16)[2],
|
||||
int stride,
|
||||
bool scaled
|
||||
)
|
||||
{
|
||||
int stride, bool scaled) {
|
||||
const uint32_t c0 = 0x64006400;
|
||||
|
||||
uint32_t qa = q_0;
|
||||
half2_uint32 q0((qa & 0x000f000f) | c0); // half2( q[0] + 1024, q[1] + 1024 )
|
||||
half2_uint32 q1((qa & 0x00f000f0) | c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 )
|
||||
half2_uint32 q0((qa & 0x000f000f) |
|
||||
c0); // half2( q[0] + 1024, q[1] + 1024 )
|
||||
half2_uint32 q1((qa & 0x00f000f0) |
|
||||
c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 )
|
||||
qa >>= 8;
|
||||
half2_uint32 q2((qa & 0x000f000f) | c0); // half2( q[4] + 1024, q[5] + 1024 )
|
||||
half2_uint32 q3((qa & 0x00f000f0) | c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 )
|
||||
half2_uint32 q2((qa & 0x000f000f) |
|
||||
c0); // half2( q[4] + 1024, q[5] + 1024 )
|
||||
half2_uint32 q3((qa & 0x00f000f0) |
|
||||
c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 )
|
||||
|
||||
if (scaled)
|
||||
{
|
||||
dq[0] = __hfma2(q0.as_half2, y1y16[0], z1z16[0]); // half2( q[0] * s - z * s, q[1] * s - z * s)
|
||||
dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] * s - z * s, q[3] * s - z * s)
|
||||
if (scaled) {
|
||||
dq[0] = __hfma2(q0.as_half2, y1y16[0],
|
||||
z1z16[0]); // half2( q[0] * s - z * s, q[1] * s - z * s)
|
||||
dq[1] = __hfma2(q1.as_half2, y1y16[1],
|
||||
z1z16[1]); // half2( q[2] * s - z * s, q[3] * s - z * s)
|
||||
dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]);
|
||||
dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]);
|
||||
}
|
||||
else
|
||||
{
|
||||
} else {
|
||||
dq[0] = __hadd2(q0.as_half2, z1z16[0]); // half2( q[0] - z, q[1] - z )
|
||||
dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] - z, q[3] - z )
|
||||
dq[1] = __hfma2(q1.as_half2, y1y16[1],
|
||||
z1z16[1]); // half2( q[2] - z, q[3] - z )
|
||||
dq[2] = __hadd2(q2.as_half2, z1z16[0]); // half2( q[4] - z, q[5] - z )
|
||||
dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); // half2( q[6] - z, q[7] - z )
|
||||
dq[3] = __hfma2(q3.as_half2, y1y16[1],
|
||||
z1z16[1]); // half2( q[6] - z, q[7] - z )
|
||||
}
|
||||
}
|
||||
} // namespace gptq
|
||||
|
@ -10,28 +10,18 @@ Copied from https://github.com/turboderp/exllamav2
|
||||
namespace vllm {
|
||||
namespace gptq {
|
||||
|
||||
__forceinline__ __device__ void shuffle_8bit_4
|
||||
(
|
||||
uint32_t* q,
|
||||
int stride
|
||||
)
|
||||
{
|
||||
}
|
||||
__forceinline__ __device__ void shuffle_8bit_4(uint32_t* q, int stride) {}
|
||||
|
||||
__forceinline__ __device__ void dequant_8bit_8
|
||||
(
|
||||
const uint32_t q_0,
|
||||
__forceinline__ __device__ void dequant_8bit_8(const uint32_t q_0,
|
||||
const uint32_t q_1,
|
||||
half2 (&dq)[4],
|
||||
int stride,
|
||||
const uint32_t zero
|
||||
)
|
||||
{
|
||||
half2 (&dq)[4], int stride,
|
||||
const uint32_t zero) {
|
||||
half dqh[8];
|
||||
for (int i = 0; i < 4; i++) dqh[i] = dq_ns(exb(q_0, i * 8, 0xff), zero);
|
||||
for (int i = 0; i < 4; i++) dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), zero);
|
||||
|
||||
for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
|
||||
for (int i = 0; i < 4; i++)
|
||||
dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
|
||||
}
|
||||
|
||||
} // namespace gptq
|
||||
|
@ -8,16 +8,14 @@ Copied from https://github.com/turboderp/exllamav2
|
||||
namespace vllm {
|
||||
namespace gptq {
|
||||
|
||||
union half2_uint32
|
||||
{
|
||||
union half2_uint32 {
|
||||
uint32_t as_uint32;
|
||||
half2 as_half2;
|
||||
__device__ half2_uint32(uint32_t val) : as_uint32(val) {}
|
||||
__device__ half2_uint32(half2 val) : as_half2(val) {}
|
||||
};
|
||||
|
||||
union half_uint16
|
||||
{
|
||||
union half_uint16 {
|
||||
uint16_t as_uint16;
|
||||
half as_half;
|
||||
__device__ half_uint16(uint16_t val) : as_uint16(val) {}
|
||||
@ -26,32 +24,30 @@ union half_uint16
|
||||
|
||||
// Max_scale premultiplied by 1/256
|
||||
|
||||
__forceinline__ __device__ half dq_scale(const int qs, const half max_scale)
|
||||
{
|
||||
__forceinline__ __device__ half dq_scale(const int qs, const half max_scale) {
|
||||
int qs_i = qs + 1;
|
||||
half qs_h = __int2half_rn(qs_i * qs_i);
|
||||
qs_h = __hmul(qs_h, max_scale);
|
||||
return qs_h;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ half dq(const int q, const int qzero, const half scale)
|
||||
{
|
||||
__forceinline__ __device__ half dq(const int q, const int qzero,
|
||||
const half scale) {
|
||||
return __hmul(__int2half_rn(q - qzero), scale);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ half dq_ns(const int q, const int qzero)
|
||||
{
|
||||
__forceinline__ __device__ half dq_ns(const int q, const int qzero) {
|
||||
// return __hsub(__int2half_rn(q), __int2half_rn(qzero));
|
||||
return __int2half_rn(q - qzero);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ int exb(const uint32_t q, const int shift, const int mask)
|
||||
{
|
||||
__forceinline__ __device__ int exb(const uint32_t q, const int shift,
|
||||
const int mask) {
|
||||
return (int)((q >> shift) & mask);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, const int shift, const int mask)
|
||||
{
|
||||
__forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0,
|
||||
const int shift, const int mask) {
|
||||
return (int)(__funnelshift_rc(q0, q1, shift) & mask);
|
||||
}
|
||||
|
||||
|
@ -22,11 +22,15 @@
|
||||
#include "gptq_marlin.cuh"
|
||||
#include "gptq_marlin_dtypes.cuh"
|
||||
|
||||
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) static_assert(\
|
||||
std::is_same<scalar_t, half>::value || std::is_same<scalar_t, nv_bfloat16>::value, \
|
||||
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
|
||||
static_assert(std::is_same<scalar_t, half>::value || \
|
||||
std::is_same<scalar_t, nv_bfloat16>::value, \
|
||||
"only float16 and bfloat16 is supported");
|
||||
|
||||
template <typename T> inline std::string str(T x) { return std::to_string(x); }
|
||||
template <typename T>
|
||||
inline std::string str(T x) {
|
||||
return std::to_string(x);
|
||||
}
|
||||
|
||||
namespace gptq_marlin {
|
||||
|
||||
@ -41,17 +45,18 @@ template <typename scalar_t, // compute dtype, half or nv_float16
|
||||
const int num_bits, // number of bits used for weights
|
||||
const int threads, // number of threads in a threadblock
|
||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||
// dimension (batchsize) of the threadblock
|
||||
// dimension (batchsize) of the
|
||||
// threadblock
|
||||
const int thread_n_blocks, // same for n dimension (output)
|
||||
const int thread_k_blocks, // same for k dimension (reduction)
|
||||
const int stages, // number of stages for the async global->shared
|
||||
// fetch pipeline
|
||||
const bool has_act_order, // whether act_order is enabled
|
||||
const int group_blocks = -1 // number of consecutive 16x16 blocks with
|
||||
// a separate quantization scale
|
||||
const int group_blocks = -1 // number of consecutive 16x16 blocks
|
||||
// with a separate quantization scale
|
||||
>
|
||||
__global__ void
|
||||
Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
||||
__global__ void Marlin(
|
||||
const int4* __restrict__ A, // fp16 input matrix of shape mxk
|
||||
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
|
||||
int4* __restrict__ C, // fp16 output buffer of shape mxn
|
||||
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
|
||||
@ -88,17 +93,19 @@ __device__ inline void mma(const typename ScalarType<scalar_t>::FragA &a_frag,
|
||||
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
if constexpr (std::is_same<scalar_t, half>::value) {
|
||||
asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]),
|
||||
"r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
|
||||
asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]),
|
||||
"r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else {
|
||||
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
|
||||
}
|
||||
@ -107,7 +114,8 @@ __device__ inline void mma(const typename ScalarType<scalar_t>::FragA &a_frag,
|
||||
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
|
||||
// memory, directly in tensor core layout.
|
||||
template <typename scalar_t>
|
||||
__device__ inline void ldsm4(typename ScalarType<scalar_t>::FragA &frag_a, const void *smem_ptr) {
|
||||
__device__ inline void ldsm4(typename ScalarType<scalar_t>::FragA& frag_a,
|
||||
const void* smem_ptr) {
|
||||
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
|
||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
|
||||
@ -118,7 +126,8 @@ __device__ inline void ldsm4(typename ScalarType<scalar_t>::FragA &frag_a, const
|
||||
// Lookup-table based 3-input logical operation; explicitly used for
|
||||
// dequantization as the compiler does not seem to automatically recognize it in
|
||||
// all cases.
|
||||
template <int lut> __device__ inline int lop3(int a, int b, int c) {
|
||||
template <int lut>
|
||||
__device__ inline int lop3(int a, int b, int c) {
|
||||
int res;
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(res)
|
||||
@ -140,8 +149,10 @@ __device__ inline uint32_t prmt(uint32_t a) {
|
||||
// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
|
||||
// values. We mostly follow the strategy in the link below, with some small
|
||||
// changes:
|
||||
// - FP16: https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287
|
||||
// - BF16: https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385
|
||||
// - FP16:
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287
|
||||
// - BF16:
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385
|
||||
template <typename scalar_t>
|
||||
__device__ inline typename ScalarType<scalar_t>::FragB dequant_4bit(int q) {
|
||||
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
|
||||
@ -170,7 +181,8 @@ __device__ inline typename ScalarType<half>::FragB dequant_4bit<half>(int q) {
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline typename ScalarType<nv_bfloat16>::FragB dequant_4bit<nv_bfloat16>(int q) {
|
||||
__device__ inline typename ScalarType<nv_bfloat16>::FragB
|
||||
dequant_4bit<nv_bfloat16>(int q) {
|
||||
static constexpr uint32_t MASK = 0x000f000f;
|
||||
static constexpr uint32_t EX = 0x43004300;
|
||||
|
||||
@ -193,10 +205,12 @@ __device__ inline typename ScalarType<nv_bfloat16>::FragB dequant_4bit<nv_bfloat
|
||||
return frag_b;
|
||||
}
|
||||
|
||||
// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or bf16
|
||||
// Reference:
|
||||
// - FP16: https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
|
||||
// - BF16: https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
|
||||
// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or
|
||||
// bf16 Reference:
|
||||
// - FP16:
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
|
||||
// - BF16:
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
|
||||
template <typename scalar_t>
|
||||
__device__ inline typename ScalarType<scalar_t>::FragB dequant_8bit(int q) {
|
||||
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
|
||||
@ -222,11 +236,13 @@ __device__ inline typename ScalarType<half>::FragB dequant_8bit<half>(int q) {
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline typename ScalarType<nv_bfloat16>::FragB dequant_8bit<nv_bfloat16>(int q) {
|
||||
__device__ inline typename ScalarType<nv_bfloat16>::FragB
|
||||
dequant_8bit<nv_bfloat16>(int q) {
|
||||
typename ScalarType<nv_bfloat16>::FragB frag_b;
|
||||
|
||||
float fp32_intermediates[4];
|
||||
uint32_t * fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
|
||||
uint32_t* fp32_intermediates_casted =
|
||||
reinterpret_cast<uint32_t*>(fp32_intermediates);
|
||||
|
||||
static constexpr uint32_t fp32_base = 0x4B000000;
|
||||
fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
|
||||
@ -240,8 +256,10 @@ __device__ inline typename ScalarType<nv_bfloat16>::FragB dequant_8bit<nv_bfloat
|
||||
fp32_intermediates[3] -= 8388736.f;
|
||||
|
||||
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(&frag_b);
|
||||
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632);
|
||||
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], fp32_intermediates_casted[3], 0x7632);
|
||||
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],
|
||||
fp32_intermediates_casted[1], 0x7632);
|
||||
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],
|
||||
fp32_intermediates_casted[3], 0x7632);
|
||||
|
||||
return frag_b;
|
||||
}
|
||||
@ -250,9 +268,11 @@ __device__ inline typename ScalarType<nv_bfloat16>::FragB dequant_8bit<nv_bfloat
|
||||
// only for grouped quantization.
|
||||
template <typename scalar_t>
|
||||
__device__ inline void scale(typename ScalarType<scalar_t>::FragB& frag_b,
|
||||
typename ScalarType<scalar_t>::FragS &frag_s, int i) {
|
||||
typename ScalarType<scalar_t>::FragS& frag_s,
|
||||
int i) {
|
||||
using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
|
||||
scalar_t2 s = ScalarType<scalar_t>::num2num2(reinterpret_cast<scalar_t *>(&frag_s)[i]);
|
||||
scalar_t2 s =
|
||||
ScalarType<scalar_t>::num2num2(reinterpret_cast<scalar_t*>(&frag_s)[i]);
|
||||
frag_b[0] = __hmul2(frag_b[0], s);
|
||||
frag_b[1] = __hmul2(frag_b[1], s);
|
||||
}
|
||||
@ -280,7 +300,8 @@ __device__ inline void scale4(typename ScalarType<scalar_t>::FragB &frag_b,
|
||||
|
||||
// Given 2 floats multiply by 2 scales (halves)
|
||||
template <typename scalar_t>
|
||||
__device__ inline void scale_float(float *c, typename ScalarType<scalar_t>::FragS &s) {
|
||||
__device__ inline void scale_float(float* c,
|
||||
typename ScalarType<scalar_t>::FragS& s) {
|
||||
scalar_t* s_ptr = reinterpret_cast<scalar_t*>(&s);
|
||||
c[0] = __fmul_rn(c[0], ScalarType<scalar_t>::num2float(s_ptr[0]));
|
||||
c[1] = __fmul_rn(c[1], ScalarType<scalar_t>::num2float(s_ptr[1]));
|
||||
@ -325,7 +346,6 @@ __global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr,
|
||||
int const* __restrict__ perm_int_ptr,
|
||||
int4* __restrict__ out_int4_ptr, int size_m,
|
||||
int size_k, int block_rows) {
|
||||
|
||||
int start_row = block_rows * blockIdx.x;
|
||||
int finish_row = start_row + block_rows;
|
||||
if (finish_row > size_m) {
|
||||
@ -341,8 +361,7 @@ __global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr,
|
||||
|
||||
int offset = row * row_stride;
|
||||
|
||||
half const *a_row_half =
|
||||
reinterpret_cast<half const *>(a_int4_ptr + offset);
|
||||
half const* a_row_half = reinterpret_cast<half const*>(a_int4_ptr + offset);
|
||||
half* out_half = reinterpret_cast<half*>(out_int4_ptr + offset);
|
||||
|
||||
int base_k = 0;
|
||||
@ -378,17 +397,18 @@ template <typename scalar_t, // compute dtype, half or nv_float16
|
||||
const int num_bits, // number of bits used for weights
|
||||
const int threads, // number of threads in a threadblock
|
||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||
// dimension (batchsize) of the threadblock
|
||||
// dimension (batchsize) of the
|
||||
// threadblock
|
||||
const int thread_n_blocks, // same for n dimension (output)
|
||||
const int thread_k_blocks, // same for k dimension (reduction)
|
||||
const int stages, // number of stages for the async global->shared
|
||||
// fetch pipeline
|
||||
const bool has_act_order, // whether act_order is enabled
|
||||
const int group_blocks = -1 // number of consecutive 16x16 blocks with
|
||||
// a separate quantization scale
|
||||
const int group_blocks = -1 // number of consecutive 16x16 blocks
|
||||
// with a separate quantization scale
|
||||
>
|
||||
__global__ void
|
||||
Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
||||
__global__ void Marlin(
|
||||
const int4* __restrict__ A, // fp16 input matrix of shape mxk
|
||||
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
|
||||
int4* __restrict__ C, // fp16 output buffer of shape mxn
|
||||
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
|
||||
@ -465,27 +485,22 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
||||
auto init_slice = [&]() {
|
||||
slice_iters =
|
||||
iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);
|
||||
if (slice_iters < 0 || slice_col_par >= n_tiles * parallel)
|
||||
slice_iters = 0;
|
||||
if (slice_iters == 0)
|
||||
return;
|
||||
if (slice_row + slice_iters > k_tiles)
|
||||
slice_iters = k_tiles - slice_row;
|
||||
if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0;
|
||||
if (slice_iters == 0) return;
|
||||
if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row;
|
||||
slice_count = 1;
|
||||
slice_idx = 0;
|
||||
int col_first = iters * div_ceil(k_tiles * slice_col_par, iters);
|
||||
if (col_first <= k_tiles * (slice_col_par + 1)) {
|
||||
int col_off = col_first - k_tiles * slice_col_par;
|
||||
slice_count = div_ceil(k_tiles - col_off, iters);
|
||||
if (col_off > 0)
|
||||
slice_count++;
|
||||
if (col_off > 0) slice_count++;
|
||||
int delta_first = iters * blockIdx.x - col_first;
|
||||
if (delta_first < 0 || (col_off == 0 && delta_first == 0))
|
||||
slice_idx = slice_count - 1;
|
||||
else {
|
||||
slice_idx = slice_count - 1 - delta_first / iters;
|
||||
if (col_off > 0)
|
||||
slice_idx--;
|
||||
if (col_off > 0) slice_idx--;
|
||||
}
|
||||
}
|
||||
if (slice_col == n_tiles) {
|
||||
@ -785,7 +800,8 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
||||
int4* sh_a_stage = sh_a + a_sh_stage * pipe;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < thread_m_blocks; i++)
|
||||
ldsm4<scalar_t>(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
|
||||
ldsm4<scalar_t>(frag_a[k % 2][i],
|
||||
&sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
|
||||
int4* sh_b_stage = sh_b + b_sh_stage * pipe;
|
||||
|
||||
#pragma unroll
|
||||
@ -906,7 +922,6 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
|
||||
int actual_k = cur_k + k_frag_offsets[i];
|
||||
|
||||
int group_id = sh_g_idx_int_ptr[actual_k];
|
||||
@ -943,8 +958,9 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
||||
|
||||
// Apply scale to frag_b0
|
||||
if constexpr (has_act_order) {
|
||||
scale4<scalar_t>(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j],
|
||||
act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 0);
|
||||
scale4<scalar_t>(frag_b0, act_frag_s[k % 2][0][j],
|
||||
act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j],
|
||||
act_frag_s[k % 2][3][j], 0);
|
||||
} else {
|
||||
if constexpr (group_blocks != -1) {
|
||||
scale<scalar_t>(frag_b0, frag_s[k % 2][j], 0);
|
||||
@ -953,8 +969,9 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
||||
|
||||
// Apply scale to frag_b1
|
||||
if constexpr (has_act_order) {
|
||||
scale4<scalar_t>(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j],
|
||||
act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 1);
|
||||
scale4<scalar_t>(frag_b1, act_frag_s[k % 2][0][j],
|
||||
act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j],
|
||||
act_frag_s[k % 2][3][j], 1);
|
||||
|
||||
} else {
|
||||
if constexpr (group_blocks != -1) {
|
||||
@ -997,8 +1014,8 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
||||
int red_sh_wr =
|
||||
red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
|
||||
if (i < red_off) {
|
||||
float *c_rd = reinterpret_cast<float *>(
|
||||
&sh[red_sh_delta * j + red_sh_rd]);
|
||||
float* c_rd =
|
||||
reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]);
|
||||
float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]);
|
||||
#pragma unroll
|
||||
for (int k = 0; k < 4; k++)
|
||||
@ -1049,16 +1066,16 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
||||
int row = (threadIdx.x % 32) / 4;
|
||||
|
||||
if (!first) {
|
||||
// Interestingly, doing direct global accesses here really seems to mess up the
|
||||
// compiler and lead to slowdowns, hence we also use async-copies even though
|
||||
// these fetches are not actually asynchronous.
|
||||
// Interestingly, doing direct global accesses here really seems to mess up
|
||||
// the compiler and lead to slowdowns, hence we also use async-copies even
|
||||
// though these fetches are not actually asynchronous.
|
||||
#pragma unroll
|
||||
for (int i = 0; i < thread_m_blocks * 4; i++) {
|
||||
cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i],
|
||||
cp_async4_pred(
|
||||
&sh[c_sh_wr + c_sh_wr_delta * i],
|
||||
&C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
|
||||
c_gl_wr_delta_i * (i % 2)],
|
||||
i < (thread_m_blocks - 1) * 4 ||
|
||||
8 * (i / 2) + row < prob_m);
|
||||
i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m);
|
||||
}
|
||||
cp_async_fence();
|
||||
cp_async_wait<0>();
|
||||
@ -1116,7 +1133,8 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
||||
// We first reorder in shared memory to guarantee the most efficient final
|
||||
// global write patterns
|
||||
auto write = [&](int idx, float c0, float c1, FragS& s) {
|
||||
scalar_t2 res = Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1));
|
||||
scalar_t2 res =
|
||||
Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1));
|
||||
|
||||
// For per-column quantization we finally apply the scale here (only for
|
||||
// 4-bit)
|
||||
@ -1286,14 +1304,18 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
||||
for (int i = 0; i < thread_m_blocks; i++) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; j++) {
|
||||
scale_float<scalar_t>(reinterpret_cast<float *>(&frag_c[i][j][0][0]),
|
||||
scale_float<scalar_t>(
|
||||
reinterpret_cast<float*>(&frag_c[i][j][0][0]),
|
||||
frag_s[j / 2][2 * (j % 2) + 0]);
|
||||
scale_float<scalar_t>(reinterpret_cast<float *>(&frag_c[i][j][0][2]),
|
||||
scale_float<scalar_t>(
|
||||
reinterpret_cast<float*>(&frag_c[i][j][0][2]),
|
||||
frag_s[j / 2][2 * (j % 2) + 0]);
|
||||
|
||||
scale_float<scalar_t>(reinterpret_cast<float *>(&frag_c[i][j][1][0]),
|
||||
scale_float<scalar_t>(
|
||||
reinterpret_cast<float*>(&frag_c[i][j][1][0]),
|
||||
frag_s[j / 2][2 * (j % 2) + 1]);
|
||||
scale_float<scalar_t>(reinterpret_cast<float *>(&frag_c[i][j][1][2]),
|
||||
scale_float<scalar_t>(
|
||||
reinterpret_cast<float*>(&frag_c[i][j][1][2]),
|
||||
frag_s[j / 2][2 * (j % 2) + 1]);
|
||||
}
|
||||
}
|
||||
@ -1320,8 +1342,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
||||
B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
|
||||
if (slice_col == 0) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < b_sh_wr_iters; i++)
|
||||
B_ptr[i] -= b_gl_stride;
|
||||
for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
|
||||
}
|
||||
|
||||
// Update slice k/n for scales loading
|
||||
@ -1341,20 +1362,21 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
||||
}
|
||||
}
|
||||
|
||||
#define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
|
||||
HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS) \
|
||||
#define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
|
||||
THREAD_K_BLOCKS, HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS) \
|
||||
else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \
|
||||
thread_n_blocks == THREAD_N_BLOCKS && \
|
||||
thread_k_blocks == THREAD_K_BLOCKS && \
|
||||
has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \
|
||||
num_threads == NUM_THREADS) { \
|
||||
cudaFuncSetAttribute( \
|
||||
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
|
||||
THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, GROUP_BLOCKS>, \
|
||||
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \
|
||||
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
|
||||
GROUP_BLOCKS>, \
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
||||
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
|
||||
THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, GROUP_BLOCKS> \
|
||||
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
|
||||
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \
|
||||
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
|
||||
GROUP_BLOCKS><<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
|
||||
A_ptr, B_ptr, C_ptr, s_ptr, g_idx_ptr, num_groups, prob_m, prob_n, \
|
||||
prob_k, locks); \
|
||||
}
|
||||
@ -1673,8 +1695,7 @@ void marlin_mm_f16i4(const void *A, const void *B, void *C, void *s,
|
||||
// Note that parallel > 1 currently only works for inputs without any
|
||||
// padding
|
||||
par = (16 * thread_m_blocks - pad) / (16 * exec_cfg.max_m_blocks);
|
||||
if (par > max_par)
|
||||
par = max_par;
|
||||
if (par > max_par) par = max_par;
|
||||
prob_m = (16 * exec_cfg.max_m_blocks) * par;
|
||||
i += exec_cfg.max_m_blocks * (par - 1);
|
||||
thread_m_blocks = exec_cfg.max_m_blocks;
|
||||
@ -1824,18 +1845,21 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
|
||||
int dev = a.get_device();
|
||||
if (a.scalar_type() == at::ScalarType::Half) {
|
||||
gptq_marlin::marlin_mm_f16i4<half>(
|
||||
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(), b_scales.data_ptr<at::Half>(),
|
||||
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::Half>(), size_m, size_n,
|
||||
size_k, workspace.data_ptr(), num_bits, has_act_order, is_k_full,
|
||||
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
||||
thread_k, thread_n, sms, gptq_marlin::max_par);
|
||||
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
|
||||
b_scales.data_ptr<at::Half>(), g_idx.data_ptr(), perm.data_ptr(),
|
||||
a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,
|
||||
workspace.data_ptr(), num_bits, has_act_order, is_k_full, num_groups,
|
||||
group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k,
|
||||
thread_n, sms, gptq_marlin::max_par);
|
||||
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
|
||||
gptq_marlin::marlin_mm_f16i4<nv_bfloat16>(
|
||||
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(), c.data_ptr<at::BFloat16>(), b_scales.data_ptr<at::BFloat16>(),
|
||||
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(), size_m, size_n,
|
||||
size_k, workspace.data_ptr(), num_bits, has_act_order, is_k_full,
|
||||
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
||||
thread_k, thread_n, sms, gptq_marlin::max_par);
|
||||
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
|
||||
c.data_ptr<at::BFloat16>(), b_scales.data_ptr<at::BFloat16>(),
|
||||
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(),
|
||||
size_m, size_n, size_k, workspace.data_ptr(), num_bits, has_act_order,
|
||||
is_k_full, num_groups, group_size, dev,
|
||||
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
|
||||
gptq_marlin::max_par);
|
||||
} else {
|
||||
TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16");
|
||||
}
|
||||
|
@ -11,12 +11,13 @@
|
||||
|
||||
namespace gptq_marlin {
|
||||
|
||||
// 8 warps are a good choice since every SM has 4 schedulers and having more than 1 warp per
|
||||
// schedule allows some more latency hiding. At the same time, we want relatively few warps to have
|
||||
// many registers per warp and small tiles.
|
||||
// 8 warps are a good choice since every SM has 4 schedulers and having more
|
||||
// than 1 warp per schedule allows some more latency hiding. At the same time,
|
||||
// we want relatively few warps to have many registers per warp and small tiles.
|
||||
static constexpr int default_threads = 256;
|
||||
|
||||
static constexpr int pipe_stages = 4; // 4 pipeline stages fit into shared memory
|
||||
static constexpr int pipe_stages =
|
||||
4; // 4 pipeline stages fit into shared memory
|
||||
|
||||
static constexpr int min_thread_n = 64;
|
||||
static constexpr int min_thread_k = 64;
|
||||
@ -38,10 +39,12 @@ constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
|
||||
// No support for async
|
||||
#else
|
||||
|
||||
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) {
|
||||
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
|
||||
bool pred = true) {
|
||||
const int BYTES = 16;
|
||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
asm volatile("{\n"
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" setp.ne.b32 p, %0, 0;\n"
|
||||
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
|
||||
@ -52,13 +55,16 @@ __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool
|
||||
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
|
||||
const int BYTES = 16;
|
||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
asm volatile("{\n"
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" cp.async.cg.shared.global [%0], [%1], %2;\n"
|
||||
"}\n" ::"r"(smem),
|
||||
"l"(glob_ptr), "n"(BYTES));
|
||||
}
|
||||
|
||||
__device__ inline void cp_async_fence() { asm volatile("cp.async.commit_group;\n" ::); }
|
||||
__device__ inline void cp_async_fence() {
|
||||
asm volatile("cp.async.commit_group;\n" ::);
|
||||
}
|
||||
|
||||
template <int n>
|
||||
__device__ inline void cp_async_wait() {
|
||||
|
@ -5,12 +5,10 @@
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_bf16.h>
|
||||
|
||||
|
||||
namespace gptq_marlin {
|
||||
|
||||
template <typename scalar_t>
|
||||
class ScalarType {
|
||||
};
|
||||
class ScalarType {};
|
||||
|
||||
template <>
|
||||
class ScalarType<half> {
|
||||
@ -26,13 +24,21 @@ public:
|
||||
using FragC = Vec<float, 4>;
|
||||
using FragS = Vec<half2, 1>;
|
||||
|
||||
static __device__ float inline num2float(const half x) { return __half2float(x); }
|
||||
static __device__ float inline num2float(const half x) {
|
||||
return __half2float(x);
|
||||
}
|
||||
|
||||
static __device__ half2 inline num2num2(const half x) { return __half2half2(x); }
|
||||
static __device__ half2 inline num2num2(const half x) {
|
||||
return __half2half2(x);
|
||||
}
|
||||
|
||||
static __device__ half2 inline nums2num2(const half x1, const half x2) { return __halves2half2(x1, x2); }
|
||||
static __device__ half2 inline nums2num2(const half x1, const half x2) {
|
||||
return __halves2half2(x1, x2);
|
||||
}
|
||||
|
||||
static __host__ __device__ half inline float2num(const float x) { return __float2half(x); }
|
||||
static __host__ __device__ half inline float2num(const float x) {
|
||||
return __float2half(x);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
@ -47,16 +53,25 @@ public:
|
||||
using FragS = Vec<nv_bfloat162, 1>;
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
static __device__ float inline num2float(const nv_bfloat16 x) { return __bfloat162float(x); }
|
||||
static __device__ float inline num2float(const nv_bfloat16 x) {
|
||||
return __bfloat162float(x);
|
||||
}
|
||||
|
||||
static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) { return __bfloat162bfloat162(x); }
|
||||
static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) {
|
||||
return __bfloat162bfloat162(x);
|
||||
}
|
||||
|
||||
static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1, const nv_bfloat16 x2) { return __halves2bfloat162(x1, x2); }
|
||||
static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1,
|
||||
const nv_bfloat16 x2) {
|
||||
return __halves2bfloat162(x1, x2);
|
||||
}
|
||||
|
||||
static __host__ __device__ nv_bfloat16 inline float2num(const float x) { return __float2bfloat16(x); }
|
||||
static __host__ __device__ nv_bfloat16 inline float2num(const float x) {
|
||||
return __float2bfloat16(x);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
}
|
||||
} // namespace gptq_marlin
|
||||
|
||||
#endif
|
||||
|
@ -12,10 +12,10 @@ static constexpr int tile_n_size = tile_k_size * 4;
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
|
||||
template <int const num_threads, int const num_bits, bool const has_perm>
|
||||
__global__ void
|
||||
marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
|
||||
uint32_t const *__restrict__ perm_ptr,
|
||||
uint32_t *__restrict__ out_ptr, int size_k, int size_n) {}
|
||||
__global__ void marlin_repack_kernel(
|
||||
uint32_t const* __restrict__ b_q_weight_ptr,
|
||||
uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,
|
||||
int size_k, int size_n) {}
|
||||
|
||||
} // namespace gptq_marlin
|
||||
|
||||
@ -30,10 +30,10 @@ torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
|
||||
#else
|
||||
|
||||
template <int const num_threads, int const num_bits, bool const has_perm>
|
||||
__global__ void
|
||||
marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
|
||||
uint32_t const *__restrict__ perm_ptr,
|
||||
uint32_t *__restrict__ out_ptr, int size_k, int size_n) {
|
||||
__global__ void marlin_repack_kernel(
|
||||
uint32_t const* __restrict__ b_q_weight_ptr,
|
||||
uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,
|
||||
int size_k, int size_n) {
|
||||
constexpr int pack_factor = 32 / num_bits;
|
||||
|
||||
int k_tiles = size_k / tile_k_size;
|
||||
@ -176,7 +176,6 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
|
||||
}
|
||||
|
||||
} else {
|
||||
|
||||
uint32_t b1_vals[tile_ints];
|
||||
uint32_t b2_vals[tile_ints];
|
||||
|
||||
@ -320,8 +319,7 @@ torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
|
||||
// Get ptrs
|
||||
uint32_t const* b_q_weight_ptr =
|
||||
reinterpret_cast<uint32_t const*>(b_q_weight.data_ptr());
|
||||
uint32_t const *perm_ptr =
|
||||
reinterpret_cast<uint32_t const *>(perm.data_ptr());
|
||||
uint32_t const* perm_ptr = reinterpret_cast<uint32_t const*>(perm.data_ptr());
|
||||
uint32_t* out_ptr = reinterpret_cast<uint32_t*>(out.data_ptr());
|
||||
|
||||
// Get dev info
|
||||
|
@ -25,7 +25,10 @@
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <typename T> inline std::string str(T x) { return std::to_string(x); }
|
||||
template <typename T>
|
||||
inline std::string str(T x) {
|
||||
return std::to_string(x);
|
||||
}
|
||||
|
||||
namespace marlin {
|
||||
|
||||
@ -38,7 +41,8 @@ constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; }
|
||||
// corresponding index accesses must be compile-time constants, which is why we
|
||||
// extensively use `#pragma unroll` throughout the kernel code to guarantee
|
||||
// this.
|
||||
template <typename T, int n> struct Vec {
|
||||
template <typename T, int n>
|
||||
struct Vec {
|
||||
T elems[n];
|
||||
__device__ T& operator[](int i) { return elems[i]; }
|
||||
};
|
||||
@ -59,7 +63,8 @@ __device__ inline void cp_async4_pred(void *smem_ptr, const void *glob_ptr,
|
||||
bool pred = true) {
|
||||
const int BYTES = 16;
|
||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
asm volatile("{\n"
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" setp.ne.b32 p, %0, 0;\n"
|
||||
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
|
||||
@ -71,9 +76,11 @@ __device__ inline void cp_async4_pred(void *smem_ptr, const void *glob_ptr,
|
||||
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
|
||||
const int BYTES = 16;
|
||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
asm volatile("{\n"
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" cp.async.cg.shared.global [%0], [%1], %2;\n"
|
||||
"}\n" :: "r"(smem), "l"(glob_ptr), "n"(BYTES));
|
||||
"}\n" ::"r"(smem),
|
||||
"l"(glob_ptr), "n"(BYTES));
|
||||
}
|
||||
|
||||
// Async copy fence.
|
||||
@ -82,7 +89,8 @@ __device__ inline void cp_async_fence() {
|
||||
}
|
||||
|
||||
// Wait until at most `n` async copy stages are still pending.
|
||||
template <int n> __device__ inline void cp_async_wait() {
|
||||
template <int n>
|
||||
__device__ inline void cp_async_wait() {
|
||||
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
|
||||
}
|
||||
|
||||
@ -93,11 +101,12 @@ __device__ inline void mma(const FragA &a_frag, const FragB &frag_b,
|
||||
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
|
||||
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]),
|
||||
"r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
}
|
||||
|
||||
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
|
||||
@ -113,7 +122,8 @@ __device__ inline void ldsm4(FragA &frag_a, const void *smem_ptr) {
|
||||
// Lookup-table based 3-input logical operation; explicitly used for
|
||||
// dequantization as the compiler does not seem to automatically recognize it in
|
||||
// all cases.
|
||||
template <int lut> __device__ inline int lop3(int a, int b, int c) {
|
||||
template <int lut>
|
||||
__device__ inline int lop3(int a, int b, int c) {
|
||||
int res;
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(res)
|
||||
@ -189,20 +199,21 @@ __device__ inline void barrier_release(int *lock, bool reset = false) {
|
||||
|
||||
template <const int threads, // number of threads in a threadblock
|
||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||
// dimension (batchsize) of the threadblock
|
||||
// dimension (batchsize) of the
|
||||
// threadblock
|
||||
const int thread_n_blocks, // same for n dimension (output)
|
||||
const int thread_k_blocks, // same for k dimension (reduction)
|
||||
const int stages, // number of stages for the async global->shared
|
||||
// fetch pipeline
|
||||
const int group_blocks = -1 // number of consecutive 16x16 blocks with
|
||||
// a separate quantization scale
|
||||
const int group_blocks = -1 // number of consecutive 16x16 blocks
|
||||
// with a separate quantization scale
|
||||
>
|
||||
__global__ void
|
||||
Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
||||
__global__ void Marlin(
|
||||
const int4* __restrict__ A, // fp16 input matrix of shape mxk
|
||||
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
|
||||
int4* __restrict__ C, // fp16 output buffer of shape mxn
|
||||
const int4
|
||||
*__restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn
|
||||
const int4* __restrict__ s, // fp16 quantization scales of shape
|
||||
// (k/groupsize)xn
|
||||
int prob_m, // batch dimension m
|
||||
int prob_n, // output dimension n
|
||||
int prob_k, // reduction dimension k
|
||||
@ -261,27 +272,22 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
||||
auto init_slice = [&]() {
|
||||
slice_iters =
|
||||
iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);
|
||||
if (slice_iters < 0 || slice_col_par >= n_tiles * parallel)
|
||||
slice_iters = 0;
|
||||
if (slice_iters == 0)
|
||||
return;
|
||||
if (slice_row + slice_iters > k_tiles)
|
||||
slice_iters = k_tiles - slice_row;
|
||||
if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0;
|
||||
if (slice_iters == 0) return;
|
||||
if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row;
|
||||
slice_count = 1;
|
||||
slice_idx = 0;
|
||||
int col_first = iters * ceildiv(k_tiles * slice_col_par, iters);
|
||||
if (col_first <= k_tiles * (slice_col_par + 1)) {
|
||||
int col_off = col_first - k_tiles * slice_col_par;
|
||||
slice_count = ceildiv(k_tiles - col_off, iters);
|
||||
if (col_off > 0)
|
||||
slice_count++;
|
||||
if (col_off > 0) slice_count++;
|
||||
int delta_first = iters * blockIdx.x - col_first;
|
||||
if (delta_first < 0 || (col_off == 0 && delta_first == 0))
|
||||
slice_idx = slice_count - 1;
|
||||
else {
|
||||
slice_idx = slice_count - 1 - delta_first / iters;
|
||||
if (col_off > 0)
|
||||
slice_idx--;
|
||||
if (col_off > 0) slice_idx--;
|
||||
}
|
||||
}
|
||||
if (slice_col == n_tiles) {
|
||||
@ -305,7 +311,8 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
||||
a_gl_stride *
|
||||
(threads / a_gl_rd_delta_o); // between subsequent accesses within a tile
|
||||
constexpr int a_sh_wr_delta =
|
||||
a_sh_stride * (threads / a_gl_rd_delta_o); // between shared memory writes
|
||||
a_sh_stride *
|
||||
(threads / a_gl_rd_delta_o); // between shared memory writes
|
||||
constexpr int a_sh_rd_delta_o =
|
||||
2 * ((threads / 32) /
|
||||
(thread_n_blocks / 4)); // between shared memory tile reads
|
||||
@ -447,8 +454,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
||||
// Only fetch scales if this tile starts a new group
|
||||
if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) {
|
||||
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
|
||||
if (s_sh_wr_pred)
|
||||
cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
|
||||
if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
|
||||
s_gl_rd += s_gl_rd_delta;
|
||||
}
|
||||
}
|
||||
@ -500,11 +506,9 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
||||
FragB frag_b0 = dequant(b_quant);
|
||||
// If there are no groups, we can just scale the final output once and can
|
||||
// avoid doing so for each weight.
|
||||
if (group_blocks != -1)
|
||||
scale(frag_b0, frag_s[k % 2][j], 0);
|
||||
if (group_blocks != -1) scale(frag_b0, frag_s[k % 2][j], 0);
|
||||
FragB frag_b1 = dequant(b_quant_shift);
|
||||
if (group_blocks != -1)
|
||||
scale(frag_b1, frag_s[k % 2][j], 1);
|
||||
if (group_blocks != -1) scale(frag_b1, frag_s[k % 2][j], 1);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < thread_m_blocks; i++) {
|
||||
mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]);
|
||||
@ -540,8 +544,8 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
||||
int red_sh_wr =
|
||||
red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
|
||||
if (i < red_off) {
|
||||
float *c_rd = reinterpret_cast<float *>(
|
||||
&sh[red_sh_delta * j + red_sh_rd]);
|
||||
float* c_rd =
|
||||
reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]);
|
||||
float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]);
|
||||
#pragma unroll
|
||||
for (int k = 0; k < 4; k++)
|
||||
@ -571,9 +575,9 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
||||
};
|
||||
|
||||
// Since multiple threadblocks may process parts of the same column slice, we
|
||||
// finally have to globally reduce over the results. As the striped partitioning
|
||||
// minimizes the number of such reductions and our outputs are usually rather
|
||||
// small, we perform this reduction serially in L2 cache.
|
||||
// finally have to globally reduce over the results. As the striped
|
||||
// partitioning minimizes the number of such reductions and our outputs are
|
||||
// usually rather small, we perform this reduction serially in L2 cache.
|
||||
auto global_reduce = [&](bool first = false, bool last = false) {
|
||||
// We are very careful here to reduce directly in the output buffer to
|
||||
// maximize L2 cache utilization in this step. To do this, we write out
|
||||
@ -592,16 +596,16 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
||||
int row = (threadIdx.x % 32) / 4;
|
||||
|
||||
if (!first) {
|
||||
// Interestingly, doing direct global accesses here really seems to mess up the
|
||||
// compiler and lead to slowdowns, hence we also use async-copies even though
|
||||
// these fetches are not actually asynchronous.
|
||||
// Interestingly, doing direct global accesses here really seems to mess up
|
||||
// the compiler and lead to slowdowns, hence we also use async-copies even
|
||||
// though these fetches are not actually asynchronous.
|
||||
#pragma unroll
|
||||
for (int i = 0; i < thread_m_blocks * 4; i++) {
|
||||
cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i],
|
||||
cp_async4_pred(
|
||||
&sh[c_sh_wr + c_sh_wr_delta * i],
|
||||
&C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
|
||||
c_gl_wr_delta_i * (i % 2)],
|
||||
i < (thread_m_blocks - 1) * 4 ||
|
||||
8 * (i / 2) + row < prob_m);
|
||||
i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m);
|
||||
}
|
||||
cp_async_fence();
|
||||
cp_async_wait<0>();
|
||||
@ -700,8 +704,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
||||
// Start global fetch and register load pipelines.
|
||||
auto start_pipes = [&]() {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < stages - 1; i++)
|
||||
fetch_to_shared(i, i, i < slice_iters);
|
||||
for (int i = 0; i < stages - 1; i++) fetch_to_shared(i, i, i < slice_iters);
|
||||
zero_accums();
|
||||
wait_for_stage();
|
||||
fetch_to_registers(0, 0);
|
||||
@ -711,9 +714,9 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
||||
|
||||
// Main loop.
|
||||
while (slice_iters) {
|
||||
// We unroll over both the global fetch and the register load pipeline to ensure
|
||||
// all shared memory accesses are static. Note that both pipelines have even
|
||||
// length meaning that the next iteration will always start at index 0.
|
||||
// We unroll over both the global fetch and the register load pipeline to
|
||||
// ensure all shared memory accesses are static. Note that both pipelines have
|
||||
// even length meaning that the next iteration will always start at index 0.
|
||||
#pragma unroll
|
||||
for (int pipe = 0; pipe < stages;) {
|
||||
#pragma unroll
|
||||
@ -728,8 +731,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
||||
matmul(k);
|
||||
}
|
||||
slice_iters--;
|
||||
if (slice_iters == 0)
|
||||
break;
|
||||
if (slice_iters == 0) break;
|
||||
}
|
||||
a_gl_rd += a_gl_rd_delta_o * stages;
|
||||
|
||||
@ -742,8 +744,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
||||
// For per-column scales, we only fetch them here in the final step before
|
||||
// write-out
|
||||
if (group_blocks == -1 && last) {
|
||||
if (s_sh_wr_pred)
|
||||
cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]);
|
||||
if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]);
|
||||
cp_async_fence();
|
||||
}
|
||||
thread_block_reduce();
|
||||
@ -775,8 +776,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
||||
B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
|
||||
if (slice_col == 0) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < b_sh_wr_iters; i++)
|
||||
B_ptr[i] -= b_gl_stride;
|
||||
for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
|
||||
}
|
||||
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
|
||||
start_pipes();
|
||||
@ -789,20 +789,21 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
||||
|
||||
template <const int threads, // number of threads in a threadblock
|
||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||
// dimension (batchsize) of the threadblock
|
||||
// dimension (batchsize) of the
|
||||
// threadblock
|
||||
const int thread_n_blocks, // same for n dimension (output)
|
||||
const int thread_k_blocks, // same for k dimension (reduction)
|
||||
const int stages, // number of stages for the async global->shared
|
||||
// fetch pipeline
|
||||
const int group_blocks = -1 // number of consecutive 16x16 blocks with
|
||||
// a separate quantization scale
|
||||
const int group_blocks = -1 // number of consecutive 16x16 blocks
|
||||
// with a separate quantization scale
|
||||
>
|
||||
__global__ void
|
||||
Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
||||
__global__ void Marlin(
|
||||
const int4* __restrict__ A, // fp16 input matrix of shape mxk
|
||||
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
|
||||
int4* __restrict__ C, // fp16 output buffer of shape mxn
|
||||
const int4
|
||||
*__restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn
|
||||
const int4* __restrict__ s, // fp16 quantization scales of shape
|
||||
// (k/groupsize)xn
|
||||
int prob_m, // batch dimension m
|
||||
int prob_n, // output dimension n
|
||||
int prob_k, // reduction dimension k
|
||||
@ -907,7 +908,6 @@ bool is_valid_config(thread_config_t const &th_config, int prob_m, int prob_n,
|
||||
}
|
||||
|
||||
thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) {
|
||||
|
||||
if (prob_m <= 16) {
|
||||
for (auto th_config : small_batch_thread_configs) {
|
||||
if (is_valid_config(th_config, prob_m, prob_n, prob_k)) {
|
||||
@ -1011,8 +1011,7 @@ void marlin_cuda(const void *A, const void *B, void *C, void *s, int prob_m,
|
||||
// Note that parallel > 1 currently only works for inputs without any
|
||||
// padding
|
||||
par = (16 * thread_m_blocks - pad) / 64;
|
||||
if (par > max_par)
|
||||
par = max_par;
|
||||
if (par > max_par) par = max_par;
|
||||
prob_m = 64 * par;
|
||||
i += 4 * (par - 1);
|
||||
thread_m_blocks = 4;
|
||||
@ -1046,7 +1045,6 @@ void marlin_cuda(const void *A, const void *B, void *C, void *s, int prob_m,
|
||||
torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_scales, torch::Tensor& workspace,
|
||||
int64_t size_m, int64_t size_n, int64_t size_k) {
|
||||
|
||||
// Verify M
|
||||
TORCH_CHECK(size_m == a.size(0),
|
||||
"Shape mismatch: a.size(0) = " + str(a.size(0)) +
|
||||
@ -1074,9 +1072,9 @@ torch::Tensor marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
|
||||
|
||||
int actual_size_n =
|
||||
(b_q_weight.size(1) / marlin::tile_size) * marlin::pack_factor_4bit;
|
||||
TORCH_CHECK(size_n == actual_size_n,
|
||||
"size_n = " + str(size_n) +
|
||||
", actual_size_n = " + str(actual_size_n));
|
||||
TORCH_CHECK(
|
||||
size_n == actual_size_n,
|
||||
"size_n = " + str(size_n) + ", actual_size_n = " + str(actual_size_n));
|
||||
|
||||
// Verify A device and strides
|
||||
TORCH_CHECK(a.device().is_cuda(), "A is not on GPU");
|
||||
|
@ -26,12 +26,14 @@ constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; }
|
||||
// corresponding index accesses must be compile-time constants, which is why we
|
||||
// extensively use `#pragma unroll` throughout the kernel code to guarantee
|
||||
// this.
|
||||
template <typename T, int n> struct Vec {
|
||||
template <typename T, int n>
|
||||
struct Vec {
|
||||
T elems[n];
|
||||
__device__ T& operator[](int i) { return elems[i]; }
|
||||
};
|
||||
|
||||
template <int M_, int N_, int K_> struct ShapeBase {
|
||||
template <int M_, int N_, int K_>
|
||||
struct ShapeBase {
|
||||
static constexpr int M = M_, N = N_, K = K_;
|
||||
};
|
||||
|
||||
|
@ -28,7 +28,8 @@ __device__ inline void cp_async4_pred_zfill(void *smem_ptr,
|
||||
const int BYTES = 16;
|
||||
int src_in_bytes = (zfill ? 0 : BYTES);
|
||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
asm volatile("{\n"
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" setp.ne.b32 p, %0, 0;\n"
|
||||
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
|
||||
@ -40,7 +41,8 @@ __device__ inline void cp_async4_pred(void *smem_ptr, const void *glob_ptr,
|
||||
bool pred = true) {
|
||||
const int BYTES = 16;
|
||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
asm volatile("{\n"
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" setp.ne.b32 p, %0, 0;\n"
|
||||
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
|
||||
@ -52,7 +54,8 @@ __device__ inline void cp_async4_pred(void *smem_ptr, const void *glob_ptr,
|
||||
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
|
||||
const int BYTES = 16;
|
||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
asm volatile("{\n"
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" cp.async.cg.shared.global [%0], [%1], %2;\n"
|
||||
"}\n" ::"r"(smem),
|
||||
"l"(glob_ptr), "n"(BYTES));
|
||||
@ -64,7 +67,8 @@ __device__ inline void cp_async_fence() {
|
||||
}
|
||||
|
||||
// Wait until at most `n` async copy stages are still pending.
|
||||
template <int n> __device__ inline void cp_async_wait() {
|
||||
template <int n>
|
||||
__device__ inline void cp_async_wait() {
|
||||
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
|
||||
}
|
||||
|
||||
|
@ -31,42 +31,47 @@ __device__ inline void mma_sp(const FragB &a_frag0, const FragB &a_frag1,
|
||||
const uint32_t* e = reinterpret_cast<const uint32_t*>(&frag_m);
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
if (psel == 0) {
|
||||
asm volatile("mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
|
||||
asm volatile(
|
||||
"mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
|
||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
|
||||
"{%12,%13,%14,%15}, %16, 0x0;\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]),
|
||||
"r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]),
|
||||
"f"(c[2]), "f"(c[3]), "r"(e[0]));
|
||||
asm volatile("mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
|
||||
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]), "r"(b[2]),
|
||||
"r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]),
|
||||
"r"(e[0]));
|
||||
asm volatile(
|
||||
"mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
|
||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
|
||||
"{%12,%13,%14,%15}, %16, 0x0;\n"
|
||||
: "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7])
|
||||
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]),
|
||||
"r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]),
|
||||
"f"(c[6]), "f"(c[7]), "r"(e[0]));
|
||||
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]), "r"(b[3]),
|
||||
"r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]), "f"(c[6]), "f"(c[7]),
|
||||
"r"(e[0]));
|
||||
} else {
|
||||
asm volatile("mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
|
||||
asm volatile(
|
||||
"mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
|
||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
|
||||
"{%12,%13,%14,%15}, %16, 0x1;\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]),
|
||||
"r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]),
|
||||
"f"(c[2]), "f"(c[3]), "r"(e[0]));
|
||||
asm volatile("mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
|
||||
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]), "r"(b[2]),
|
||||
"r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]),
|
||||
"r"(e[0]));
|
||||
asm volatile(
|
||||
"mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
|
||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
|
||||
"{%12,%13,%14,%15}, %16, 0x1;\n"
|
||||
: "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7])
|
||||
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]),
|
||||
"r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]),
|
||||
"f"(c[6]), "f"(c[7]), "r"(e[0]));
|
||||
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]), "r"(b[3]),
|
||||
"r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]), "f"(c[6]), "f"(c[7]),
|
||||
"r"(e[0]));
|
||||
}
|
||||
}
|
||||
|
||||
// Lookup-table based 3-input logical operation; explicitly used for
|
||||
// dequantization as the compiler does not seem to automatically recognize it in
|
||||
// all cases.
|
||||
template <int lut> __device__ inline int lop3(int a, int b, int c) {
|
||||
template <int lut>
|
||||
__device__ inline int lop3(int a, int b, int c) {
|
||||
int res;
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(res)
|
||||
|
@ -37,7 +37,10 @@
|
||||
|
||||
#endif
|
||||
|
||||
template <typename T> inline std::string str(T x) { return std::to_string(x); }
|
||||
template <typename T>
|
||||
inline std::string str(T x) {
|
||||
return std::to_string(x);
|
||||
}
|
||||
|
||||
namespace marlin_24 {
|
||||
|
||||
@ -57,22 +60,23 @@ static constexpr int max_par = 16;
|
||||
template <const int num_bits, // weight bits
|
||||
const int threads, // number of threads in a threadblock
|
||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||
// dimension (batchsize) of the threadblock
|
||||
// dimension (batchsize) of the
|
||||
// threadblock
|
||||
const int thread_n_blocks, // same for n dimension (output)
|
||||
const int thread_k_blocks, // same for k dimension (reduction)
|
||||
const int stages, // number of stages for the async global->shared
|
||||
// fetch pipeline
|
||||
const int group_blocks = -1 // number of consecutive 16x16 blocks with
|
||||
// a separate quantization scale
|
||||
const int group_blocks = -1 // number of consecutive 16x16 blocks
|
||||
// with a separate quantization scale
|
||||
>
|
||||
__global__ void Marlin_24(
|
||||
const int4* __restrict__ A, // fp16 input matrix of shape mxk
|
||||
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
|
||||
const int4
|
||||
*__restrict__ meta, // 2bit metadata information about 2:4 format on B
|
||||
const int4* __restrict__ meta, // 2bit metadata information about 2:4
|
||||
// format on B
|
||||
int4* __restrict__ C, // fp16 output buffer of shape mxn
|
||||
const int4
|
||||
*__restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn
|
||||
const int4* __restrict__ s, // fp16 quantization scales of shape
|
||||
// (k/groupsize)xn
|
||||
int prob_m, // batch dimension m
|
||||
int prob_n, // output dimension n
|
||||
int prob_k, // reduction dimension k
|
||||
@ -95,22 +99,23 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
|
||||
template <const int num_bits, // weight bits
|
||||
const int threads, // number of threads in a threadblock
|
||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||
// dimension (batchsize) of the threadblock
|
||||
// dimension (batchsize) of the
|
||||
// threadblock
|
||||
const int thread_n_blocks, // same for n dimension (output)
|
||||
const int thread_k_blocks, // same for k dimension (reduction)
|
||||
const int stages, // number of stages for the async global->shared
|
||||
// fetch pipeline
|
||||
const int group_blocks = -1 // number of consecutive 16x16 blocks with
|
||||
// a separate quantization scale
|
||||
const int group_blocks = -1 // number of consecutive 16x16 blocks
|
||||
// with a separate quantization scale
|
||||
>
|
||||
__global__ void Marlin_24(
|
||||
const int4* __restrict__ A, // fp16 input matrix of shape mxk
|
||||
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
|
||||
const int4
|
||||
*__restrict__ meta, // 2bit metadata information about 2:4 format on B
|
||||
const int4* __restrict__ meta, // 2bit metadata information about 2:4
|
||||
// format on B
|
||||
int4* __restrict__ C, // fp16 output buffer of shape mxn
|
||||
const int4
|
||||
*__restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn
|
||||
const int4* __restrict__ s, // fp16 quantization scales of shape
|
||||
// (k/groupsize)xn
|
||||
int prob_m, // batch dimension m
|
||||
int prob_n, // output dimension n
|
||||
int prob_k, // reduction dimension k
|
||||
@ -174,27 +179,22 @@ __global__ void Marlin_24(
|
||||
auto init_slice = [&]() {
|
||||
slice_iters =
|
||||
iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);
|
||||
if (slice_iters < 0 || slice_col_par >= n_tiles * parallel)
|
||||
slice_iters = 0;
|
||||
if (slice_iters == 0)
|
||||
return;
|
||||
if (slice_row + slice_iters > k_tiles)
|
||||
slice_iters = k_tiles - slice_row;
|
||||
if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0;
|
||||
if (slice_iters == 0) return;
|
||||
if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row;
|
||||
slice_count = 1;
|
||||
slice_idx = 0;
|
||||
int col_first = iters * ceildiv(k_tiles * slice_col_par, iters);
|
||||
if (col_first <= k_tiles * (slice_col_par + 1)) {
|
||||
int col_off = col_first - k_tiles * slice_col_par;
|
||||
slice_count = ceildiv(k_tiles - col_off, iters);
|
||||
if (col_off > 0)
|
||||
slice_count++;
|
||||
if (col_off > 0) slice_count++;
|
||||
int delta_first = iters * blockIdx.x - col_first;
|
||||
if (delta_first < 0 || (col_off == 0 && delta_first == 0))
|
||||
slice_idx = slice_count - 1;
|
||||
else {
|
||||
slice_idx = slice_count - 1 - delta_first / iters;
|
||||
if (col_off > 0)
|
||||
slice_idx--;
|
||||
if (col_off > 0) slice_idx--;
|
||||
}
|
||||
}
|
||||
if (slice_col == n_tiles) {
|
||||
@ -392,8 +392,7 @@ __global__ void Marlin_24(
|
||||
for (int i = 0; i < b_sh_wr_iters; i++) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < b_thread_vecs; j++) {
|
||||
cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j],
|
||||
B_ptr[i] + j);
|
||||
cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j);
|
||||
}
|
||||
B_ptr[i] += b_gl_rd_delta_o;
|
||||
}
|
||||
@ -401,15 +400,13 @@ __global__ void Marlin_24(
|
||||
#pragma unroll
|
||||
for (int i = 0; i < m_sh_iters; i++) {
|
||||
if (m_sh_wr_pred)
|
||||
cp_async4(&sh_meta_stage[m_sh_wr_delta * i + m_sh_wr],
|
||||
meta_ptr[i]);
|
||||
cp_async4(&sh_meta_stage[m_sh_wr_delta * i + m_sh_wr], meta_ptr[i]);
|
||||
meta_ptr[i] += m_gl_rd_delta_o;
|
||||
}
|
||||
// Only fetch scales if this tile starts a new group
|
||||
if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) {
|
||||
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
|
||||
if (s_sh_wr_pred)
|
||||
cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
|
||||
if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
|
||||
s_gl_rd += s_gl_rd_delta;
|
||||
}
|
||||
}
|
||||
@ -519,8 +516,8 @@ __global__ void Marlin_24(
|
||||
(threadIdx.x % b_sh_stride_threads);
|
||||
|
||||
// Parallel logarithmic shared memory reduction. We make sure to avoid any
|
||||
// unnecessary read or write iterations, e.g., for two warps we write only once
|
||||
// by warp 1 and read only once by warp 0.
|
||||
// unnecessary read or write iterations, e.g., for two warps we write only
|
||||
// once by warp 1 and read only once by warp 0.
|
||||
#pragma unroll
|
||||
for (int m_block = 0; m_block < thread_m_blocks; m_block++) {
|
||||
#pragma unroll
|
||||
@ -531,8 +528,8 @@ __global__ void Marlin_24(
|
||||
int red_sh_wr =
|
||||
red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
|
||||
if (i < red_off) {
|
||||
float *c_rd = reinterpret_cast<float *>(
|
||||
&sh[red_sh_delta * j + red_sh_rd]);
|
||||
float* c_rd =
|
||||
reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]);
|
||||
float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]);
|
||||
#pragma unroll
|
||||
for (int k = 0; k < 4; k++)
|
||||
@ -562,9 +559,9 @@ __global__ void Marlin_24(
|
||||
};
|
||||
|
||||
// Since multiple threadblocks may process parts of the same column slice, we
|
||||
// finally have to globally reduce over the results. As the striped partitioning
|
||||
// minimizes the number of such reductions and our outputs are usually rather
|
||||
// small, we perform this reduction serially in L2 cache.
|
||||
// finally have to globally reduce over the results. As the striped
|
||||
// partitioning minimizes the number of such reductions and our outputs are
|
||||
// usually rather small, we perform this reduction serially in L2 cache.
|
||||
auto global_reduce = [&](bool first = false, bool last = false) {
|
||||
// We are very careful here to reduce directly in the output buffer to
|
||||
// maximize L2 cache utilization in this step. To do this, we write out
|
||||
@ -584,9 +581,9 @@ __global__ void Marlin_24(
|
||||
int col = 2 * ((threadIdx.x % 32) % 4);
|
||||
|
||||
if (!first) {
|
||||
// Interestingly, doing direct global accesses here really seems to mess up the
|
||||
// compiler and lead to slowdowns, hence we also use async-copies even though
|
||||
// these fetches are not actually asynchronous.
|
||||
// Interestingly, doing direct global accesses here really seems to mess up
|
||||
// the compiler and lead to slowdowns, hence we also use async-copies even
|
||||
// though these fetches are not actually asynchronous.
|
||||
#pragma unroll
|
||||
for (int i = 0; i < thread_m_blocks * 4; i++) {
|
||||
cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i],
|
||||
@ -722,8 +719,7 @@ __global__ void Marlin_24(
|
||||
// Start global fetch and register load pipelines.
|
||||
auto start_pipes = [&]() {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < stages - 1; i++)
|
||||
fetch_to_shared(i, i, i < slice_iters);
|
||||
for (int i = 0; i < stages - 1; i++) fetch_to_shared(i, i, i < slice_iters);
|
||||
zero_accums();
|
||||
wait_for_stage();
|
||||
fetch_to_registers(0, 0);
|
||||
@ -733,9 +729,9 @@ __global__ void Marlin_24(
|
||||
|
||||
// Main loop.
|
||||
while (slice_iters) {
|
||||
// We unroll over both the global fetch and the register load pipeline to ensure
|
||||
// all shared memory accesses are static. Note that both pipelines have even
|
||||
// length meaning that the next iteration will always start at index 0.
|
||||
// We unroll over both the global fetch and the register load pipeline to
|
||||
// ensure all shared memory accesses are static. Note that both pipelines have
|
||||
// even length meaning that the next iteration will always start at index 0.
|
||||
#pragma unroll
|
||||
for (int pipe = 0; pipe < stages;) {
|
||||
fetch_to_shared((pipe + stages - 1) % stages, pipe,
|
||||
@ -747,8 +743,7 @@ __global__ void Marlin_24(
|
||||
|
||||
pipe++;
|
||||
slice_iters--;
|
||||
if (slice_iters == 0)
|
||||
break;
|
||||
if (slice_iters == 0) break;
|
||||
}
|
||||
a_gl_rd += a_gl_rd_delta_o * stages;
|
||||
|
||||
@ -762,13 +757,11 @@ __global__ void Marlin_24(
|
||||
// write-out
|
||||
if constexpr (group_blocks == -1) {
|
||||
if constexpr (num_bits == 8) {
|
||||
if (s_sh_wr_pred)
|
||||
cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]);
|
||||
if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]);
|
||||
cp_async_fence();
|
||||
} else {
|
||||
if (last) {
|
||||
if (s_sh_wr_pred)
|
||||
cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]);
|
||||
if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]);
|
||||
cp_async_fence();
|
||||
}
|
||||
}
|
||||
@ -851,11 +844,9 @@ __global__ void Marlin_24(
|
||||
meta_ptr[i] += (m_sh_stride)-m_gl_rd_delta_o * k_tiles;
|
||||
if (slice_col == 0) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < b_sh_wr_iters; i++)
|
||||
B_ptr[i] -= b_gl_stride;
|
||||
for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < m_sh_iters; i++)
|
||||
meta_ptr[i] -= m_gl_stride;
|
||||
for (int i = 0; i < m_sh_iters; i++) meta_ptr[i] -= m_gl_stride;
|
||||
}
|
||||
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
|
||||
start_pipes();
|
||||
@ -904,8 +895,8 @@ void marlin_cuda_2_4(const void *A, const void *B, const void *meta, void *C,
|
||||
|
||||
if (thread_k == -1 || thread_m == -1) {
|
||||
if (prob_n <= 16) {
|
||||
// For small batchizes, better partitioningif is slightly more important than
|
||||
// better compute utilization
|
||||
// For small batchizes, better partitioningif is slightly more important
|
||||
// than better compute utilization
|
||||
thread_k = 128;
|
||||
thread_m = 128;
|
||||
} else {
|
||||
@ -946,8 +937,7 @@ void marlin_cuda_2_4(const void *A, const void *B, const void *meta, void *C,
|
||||
// Note that parallel > 1 currently only works for inputs without any
|
||||
// padding
|
||||
par = (16 * thread_n_blocks - pad) / 64;
|
||||
if (par > max_par)
|
||||
par = max_par;
|
||||
if (par > max_par) par = max_par;
|
||||
prob_n = 64 * par;
|
||||
i += 4 * (par - 1);
|
||||
thread_n_blocks = 4;
|
||||
@ -1037,9 +1027,9 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
|
||||
" is not divisible by tile_size = " + str(marlin_24::tile_size));
|
||||
|
||||
int actual_size_n = (b_q_weight.size(1) / marlin_24::tile_size) * pack_factor;
|
||||
TORCH_CHECK(size_n == actual_size_n,
|
||||
"size_n = " + str(size_n) +
|
||||
", actual_size_n = " + str(actual_size_n));
|
||||
TORCH_CHECK(
|
||||
size_n == actual_size_n,
|
||||
"size_n = " + str(size_n) + ", actual_size_n = " + str(actual_size_n));
|
||||
|
||||
// Verify meta
|
||||
TORCH_CHECK(b_meta.size(0) == size_k / 8 / 2 / 2,
|
||||
|
@ -32,12 +32,8 @@ __global__ void NUQ4MatMulKernel(
|
||||
#else
|
||||
float2* __restrict__ mul,
|
||||
#endif
|
||||
const __half* __restrict__ lookup_table,
|
||||
int height,
|
||||
int width,
|
||||
int batch,
|
||||
int vec_height
|
||||
) {
|
||||
const __half* __restrict__ lookup_table, int height, int width, int batch,
|
||||
int vec_height) {
|
||||
|
||||
const int blockwidth2 = BLOCKWIDTH / 2;
|
||||
|
||||
@ -80,7 +76,9 @@ __global__ void NUQ4MatMulKernel(
|
||||
|
||||
__syncthreads();
|
||||
if (threadIdx.x < blockwidth2)
|
||||
blockvec[threadIdx.x] = vec[b * vec_height / 2 + (row / BLOCKHEIGHT4) * blockwidth2 + threadIdx.x];
|
||||
blockvec[threadIdx.x] =
|
||||
vec[b * vec_height / 2 + (row / BLOCKHEIGHT4) * blockwidth2 +
|
||||
threadIdx.x];
|
||||
__syncthreads();
|
||||
|
||||
while (k < blockwidth2) {
|
||||
@ -143,7 +141,8 @@ __global__ void NUQ4MatMulKernel(
|
||||
#ifndef USE_ROCM
|
||||
res = __hadd(__hadd(res2.x, res2.y), res);
|
||||
#else
|
||||
res = __hadd(__hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)), res);
|
||||
res = __hadd(__hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)),
|
||||
res);
|
||||
#endif
|
||||
|
||||
i += width;
|
||||
@ -183,22 +182,16 @@ __global__ void NUQ4MatMulKernel(
|
||||
} // namespace vllm
|
||||
|
||||
// 4-bit matvec kernel (LUT-based)
|
||||
void squeezellm_gemm(
|
||||
torch::Tensor vec,
|
||||
torch::Tensor mat,
|
||||
torch::Tensor mul,
|
||||
torch::Tensor lookup_table
|
||||
) {
|
||||
void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
||||
torch::Tensor lookup_table) {
|
||||
int height = mat.size(0);
|
||||
int width = mat.size(1);
|
||||
|
||||
int batch = vec.size(0);
|
||||
int vec_height = vec.size(1);
|
||||
|
||||
dim3 blocks(
|
||||
(height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
|
||||
(width + BLOCKWIDTH - 1) / BLOCKWIDTH
|
||||
);
|
||||
dim3 blocks((height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
|
||||
(width + BLOCKWIDTH - 1) / BLOCKWIDTH);
|
||||
dim3 threads(BLOCKWIDTH);
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
|
||||
@ -211,14 +204,12 @@ void squeezellm_gemm(
|
||||
#endif
|
||||
mat.data_ptr<int>(),
|
||||
#ifndef USE_ROCM
|
||||
(half2*) mul.data<at::Half>(),
|
||||
(__half*) lookup_table.data<at::Half>(),
|
||||
(half2*)mul.data<at::Half>(), (__half*)lookup_table.data<at::Half>(),
|
||||
#else
|
||||
(float2*)mul.data_ptr<float>(),
|
||||
(__half*)lookup_table.data_ptr<at::Half>(),
|
||||
#endif
|
||||
height, width, batch, vec_height
|
||||
);
|
||||
height, width, batch, vec_height);
|
||||
}
|
||||
|
||||
#undef BLOCKWIDTH
|
||||
|
@ -1,5 +1,6 @@
|
||||
/*
|
||||
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/reduce_kernel_utils.cuh
|
||||
* Adapted from
|
||||
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/reduce_kernel_utils.cuh
|
||||
* Copyright (c) 2023, The vLLM team.
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
@ -43,17 +44,18 @@ __inline__ __device__ T blockReduceSum(T val) {
|
||||
static_assert(maxBlockSize <= 1024);
|
||||
if constexpr (maxBlockSize > WARP_SIZE) {
|
||||
val = warpReduceSum<T>(val);
|
||||
// Calculates max number of lanes that need to participate in the last warpReduce
|
||||
// Calculates max number of lanes that need to participate in the last
|
||||
// warpReduce
|
||||
constexpr int maxActiveLanes = (maxBlockSize + WARP_SIZE - 1) / WARP_SIZE;
|
||||
static __shared__ T shared[maxActiveLanes];
|
||||
int lane = threadIdx.x % WARP_SIZE;
|
||||
int wid = threadIdx.x / WARP_SIZE;
|
||||
if (lane == 0)
|
||||
shared[wid] = val;
|
||||
if (lane == 0) shared[wid] = val;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
val = (threadIdx.x < blockDim.x / float(WARP_SIZE)) ? shared[lane] : (T)(0.0f);
|
||||
val = (threadIdx.x < blockDim.x / float(WARP_SIZE)) ? shared[lane]
|
||||
: (T)(0.0f);
|
||||
val = warpReduceSum<T, _nextPow2(maxActiveLanes)>(val);
|
||||
} else {
|
||||
// A single warpReduce is equal to blockReduce
|
||||
|
57
format.sh
57
format.sh
@ -26,6 +26,7 @@ RUFF_VERSION=$(ruff --version | awk '{print $2}')
|
||||
MYPY_VERSION=$(mypy --version | awk '{print $2}')
|
||||
CODESPELL_VERSION=$(codespell --version)
|
||||
ISORT_VERSION=$(isort --vn)
|
||||
CLANGFORMAT_VERSION=$(clang-format --version | awk '{print $3}')
|
||||
|
||||
# # params: tool name, tool version, required version
|
||||
tool_version_check() {
|
||||
@ -40,6 +41,7 @@ tool_version_check "ruff" $RUFF_VERSION "$(grep "ruff==" requirements-dev.txt |
|
||||
tool_version_check "mypy" "$MYPY_VERSION" "$(grep mypy requirements-dev.txt | cut -d'=' -f3)"
|
||||
tool_version_check "isort" "$ISORT_VERSION" "$(grep isort requirements-dev.txt | cut -d'=' -f3)"
|
||||
tool_version_check "codespell" "$CODESPELL_VERSION" "$(grep codespell requirements-dev.txt | cut -d'=' -f3)"
|
||||
tool_version_check "clang-format" "$CLANGFORMAT_VERSION" "$(grep clang-format requirements-dev.txt | cut -d'=' -f3)"
|
||||
|
||||
YAPF_FLAGS=(
|
||||
'--recursive'
|
||||
@ -179,7 +181,6 @@ lint_changed() {
|
||||
}
|
||||
|
||||
# Run Ruff
|
||||
echo 'vLLM ruff:'
|
||||
### This flag lints individual files. --files *must* be the first command line
|
||||
### arg to use this option.
|
||||
if [[ "$1" == '--files' ]]; then
|
||||
@ -192,6 +193,7 @@ else
|
||||
# Format only the files that changed in last commit.
|
||||
lint_changed
|
||||
fi
|
||||
echo 'vLLM ruff: Done'
|
||||
|
||||
# check spelling of specified files
|
||||
isort_check() {
|
||||
@ -233,6 +235,59 @@ else
|
||||
fi
|
||||
echo 'vLLM isort: Done'
|
||||
|
||||
# Clang-format section
|
||||
# Exclude some files for formatting because they are vendored
|
||||
# NOTE: Keep up to date with .github/workflows/clang-format.yml
|
||||
CLANG_FORMAT_EXCLUDES=(
|
||||
'csrc/moe/topk_softmax_kernels.cu'
|
||||
'csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu'
|
||||
'csrc/punica/bgmv/bgmv_config.h'
|
||||
'csrc/punica/bgmv/bgmv_impl.cuh'
|
||||
'csrc/punica/bgmv/vec_dtypes.cuh'
|
||||
'csrc/punica/punica_ops.cu'
|
||||
'csrc/punica/type_convert.h'
|
||||
)
|
||||
|
||||
# Format specified files with clang-format
|
||||
clang_format() {
|
||||
clang-format -i "$@"
|
||||
}
|
||||
|
||||
# Format files that differ from main branch with clang-format.
|
||||
clang_format_changed() {
|
||||
# The `if` guard ensures that the list of filenames is not empty, which
|
||||
# could cause clang-format to receive 0 positional arguments, making it hang
|
||||
# waiting for STDIN.
|
||||
#
|
||||
# `diff-filter=ACM` and $MERGEBASE is to ensure we only format files that
|
||||
# exist on both branches.
|
||||
MERGEBASE="$(git merge-base origin/main HEAD)"
|
||||
|
||||
# Get the list of changed files, excluding the specified ones
|
||||
changed_files=$(git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.h' '*.cpp' '*.cu' '*.cuh' | grep -vFf <(printf "%s\n" "${CLANG_FORMAT_EXCLUDES[@]}"))
|
||||
if [ -n "$changed_files" ]; then
|
||||
echo "$changed_files" | xargs -P 5 clang-format -i
|
||||
fi
|
||||
}
|
||||
|
||||
# Format all files with clang-format
|
||||
clang_format_all() {
|
||||
find csrc/ \( -name '*.h' -o -name '*.cpp' -o -name '*.cu' -o -name '*.cuh' \) -print \
|
||||
| grep -vFf <(printf "%s\n" "${CLANG_FORMAT_EXCLUDES[@]}") \
|
||||
| xargs clang-format -i
|
||||
}
|
||||
|
||||
# Run clang-format
|
||||
if [[ "$1" == '--files' ]]; then
|
||||
clang_format "${@:2}"
|
||||
elif [[ "$1" == '--all' ]]; then
|
||||
clang_format_all
|
||||
else
|
||||
clang_format_changed
|
||||
fi
|
||||
echo 'vLLM clang-format: Done'
|
||||
|
||||
|
||||
if ! git diff --quiet &>/dev/null; then
|
||||
echo 'Reformatted files. Please review and stage the changes.'
|
||||
echo 'Changes not staged for commit:'
|
||||
|
@ -5,6 +5,7 @@ tomli==2.0.1
|
||||
ruff==0.1.5
|
||||
codespell==2.2.6
|
||||
isort==5.13.2
|
||||
clang-format==18.1.5
|
||||
|
||||
# type checking
|
||||
mypy==1.9.0
|
||||
|
Loading…
x
Reference in New Issue
Block a user