[CI/Build] Enforce style for C++ and CUDA code with clang-format
(#4722)
This commit is contained in:
parent
9b9a10d6cb
commit
5f6d10c14c
26
.clang-format
Normal file
26
.clang-format
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
BasedOnStyle: Google
|
||||||
|
UseTab: Never
|
||||||
|
IndentWidth: 2
|
||||||
|
ColumnLimit: 80
|
||||||
|
|
||||||
|
# Force pointers to the type for C++.
|
||||||
|
DerivePointerAlignment: false
|
||||||
|
PointerAlignment: Left
|
||||||
|
|
||||||
|
# Reordering #include statements can (and currently will) introduce errors
|
||||||
|
SortIncludes: false
|
||||||
|
|
||||||
|
# Style choices
|
||||||
|
AlignConsecutiveAssignments: false
|
||||||
|
AlignConsecutiveDeclarations: false
|
||||||
|
IndentPPDirectives: BeforeHash
|
||||||
|
|
||||||
|
IncludeCategories:
|
||||||
|
- Regex: '^<'
|
||||||
|
Priority: 4
|
||||||
|
- Regex: '^"(llvm|llvm-c|clang|clang-c|mlir|mlir-c)/'
|
||||||
|
Priority: 3
|
||||||
|
- Regex: '^"(qoda|\.\.)/'
|
||||||
|
Priority: 2
|
||||||
|
- Regex: '.*'
|
||||||
|
Priority: 1
|
42
.github/workflows/clang-format.yml
vendored
Normal file
42
.github/workflows/clang-format.yml
vendored
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
name: clang-format
|
||||||
|
|
||||||
|
on:
|
||||||
|
# Trigger the workflow on push or pull request,
|
||||||
|
# but only for the main branch
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
clang-format:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python-version: ["3.11"]
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
uses: actions/setup-python@v2
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install clang-format==18.1.5
|
||||||
|
- name: Running clang-format
|
||||||
|
run: |
|
||||||
|
EXCLUDES=(
|
||||||
|
'csrc/moe/topk_softmax_kernels.cu'
|
||||||
|
'csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu'
|
||||||
|
'csrc/punica/bgmv/bgmv_config.h'
|
||||||
|
'csrc/punica/bgmv/bgmv_impl.cuh'
|
||||||
|
'csrc/punica/bgmv/vec_dtypes.cuh'
|
||||||
|
'csrc/punica/punica_ops.cu'
|
||||||
|
'csrc/punica/type_convert.h'
|
||||||
|
)
|
||||||
|
find csrc/ \( -name '*.h' -o -name '*.cpp' -o -name '*.cu' -o -name '*.cuh' \) -print \
|
||||||
|
| grep -vFf <(printf "%s\n" "${EXCLUDES[@]}") \
|
||||||
|
| xargs clang-format --dry-run --Werror
|
@ -63,31 +63,25 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
|
|||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
||||||
VLLM_DISPATCH_FLOATING_TYPES( \
|
VLLM_DISPATCH_FLOATING_TYPES( \
|
||||||
input.scalar_type(), \
|
input.scalar_type(), "act_and_mul_kernel", [&] { \
|
||||||
"act_and_mul_kernel", \
|
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>> \
|
||||||
[&] { \
|
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
|
||||||
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \
|
input.data_ptr<scalar_t>(), d); \
|
||||||
out.data_ptr<scalar_t>(), \
|
|
||||||
input.data_ptr<scalar_t>(), \
|
|
||||||
d); \
|
|
||||||
});
|
});
|
||||||
|
|
||||||
void silu_and_mul(
|
void silu_and_mul(torch::Tensor& out, // [..., d]
|
||||||
torch::Tensor& out, // [..., d]
|
|
||||||
torch::Tensor& input) // [..., 2 * d]
|
torch::Tensor& input) // [..., 2 * d]
|
||||||
{
|
{
|
||||||
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
|
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
|
||||||
}
|
}
|
||||||
|
|
||||||
void gelu_and_mul(
|
void gelu_and_mul(torch::Tensor& out, // [..., d]
|
||||||
torch::Tensor& out, // [..., d]
|
|
||||||
torch::Tensor& input) // [..., 2 * d]
|
torch::Tensor& input) // [..., 2 * d]
|
||||||
{
|
{
|
||||||
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel);
|
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel);
|
||||||
}
|
}
|
||||||
|
|
||||||
void gelu_tanh_and_mul(
|
void gelu_tanh_and_mul(torch::Tensor& out, // [..., d]
|
||||||
torch::Tensor& out, // [..., d]
|
|
||||||
torch::Tensor& input) // [..., 2 * d]
|
torch::Tensor& input) // [..., 2 * d]
|
||||||
{
|
{
|
||||||
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel);
|
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel);
|
||||||
@ -118,14 +112,10 @@ __global__ void activation_kernel(
|
|||||||
dim3 block(std::min(d, 1024)); \
|
dim3 block(std::min(d, 1024)); \
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
||||||
VLLM_DISPATCH_FLOATING_TYPES( \
|
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "activation_kernel", [&] { \
|
||||||
input.scalar_type(), \
|
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>> \
|
||||||
"activation_kernel", \
|
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
|
||||||
[&] { \
|
input.data_ptr<scalar_t>(), d); \
|
||||||
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \
|
|
||||||
out.data_ptr<scalar_t>(), \
|
|
||||||
input.data_ptr<scalar_t>(), \
|
|
||||||
d); \
|
|
||||||
});
|
});
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
@ -140,21 +130,20 @@ __device__ __forceinline__ T gelu_new_kernel(const T& x) {
|
|||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ __forceinline__ T gelu_fast_kernel(const T& x) {
|
__device__ __forceinline__ T gelu_fast_kernel(const T& x) {
|
||||||
const float f = (float)x;
|
const float f = (float)x;
|
||||||
const T t = (T) tanhf(((T) (f * 0.79788456f)) * (((T) 1.0) + (T) (0.044715f * f) * x));
|
const T t =
|
||||||
|
(T)tanhf(((T)(f * 0.79788456f)) * (((T)1.0) + (T)(0.044715f * f) * x));
|
||||||
return ((T)0.5) * x * (((T)1.0) + t);
|
return ((T)0.5) * x * (((T)1.0) + t);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
void gelu_new(
|
void gelu_new(torch::Tensor& out, // [..., d]
|
||||||
torch::Tensor& out, // [..., d]
|
|
||||||
torch::Tensor& input) // [..., d]
|
torch::Tensor& input) // [..., d]
|
||||||
{
|
{
|
||||||
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
|
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
|
||||||
}
|
}
|
||||||
|
|
||||||
void gelu_fast(
|
void gelu_fast(torch::Tensor& out, // [..., d]
|
||||||
torch::Tensor& out, // [..., d]
|
|
||||||
torch::Tensor& input) // [..., d]
|
torch::Tensor& input) // [..., d]
|
||||||
{
|
{
|
||||||
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
|
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
/*
|
/*
|
||||||
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
* Adapted from
|
||||||
|
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
||||||
* Copyright (c) 2023, The vLLM team.
|
* Copyright (c) 2023, The vLLM team.
|
||||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||||
*
|
*
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
/*
|
/*
|
||||||
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
* Adapted from
|
||||||
|
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
||||||
* Copyright (c) 2023, The vLLM team.
|
* Copyright (c) 2023, The vLLM team.
|
||||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||||
*
|
*
|
||||||
@ -82,30 +83,27 @@ inline __device__ float block_sum(float* red_smem, float sum) {
|
|||||||
|
|
||||||
// TODO(woosuk): Merge the last two dimensions of the grid.
|
// TODO(woosuk): Merge the last two dimensions of the grid.
|
||||||
// Grid: (num_heads, num_seqs, max_num_partitions).
|
// Grid: (num_heads, num_seqs, max_num_partitions).
|
||||||
template<
|
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
|
||||||
typename scalar_t,
|
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
|
||||||
typename cache_t,
|
|
||||||
int HEAD_SIZE,
|
|
||||||
int BLOCK_SIZE,
|
|
||||||
int NUM_THREADS,
|
|
||||||
vllm::Fp8KVCacheDataType KV_DTYPE,
|
|
||||||
int PARTITION_SIZE = 0> // Zero means no partitioning.
|
int PARTITION_SIZE = 0> // Zero means no partitioning.
|
||||||
__device__ void paged_attention_kernel(
|
__device__ void paged_attention_kernel(
|
||||||
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
||||||
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
|
float* __restrict__ max_logits, // [num_seqs, num_heads,
|
||||||
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
// max_num_partitions]
|
||||||
|
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
|
||||||
|
// head_size]
|
||||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||||
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
|
||||||
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
// head_size/x, block_size, x]
|
||||||
|
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
|
||||||
|
// head_size, block_size]
|
||||||
const int num_kv_heads, // [num_heads]
|
const int num_kv_heads, // [num_heads]
|
||||||
const float scale,
|
const float scale,
|
||||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||||
const int* __restrict__ seq_lens, // [num_seqs]
|
const int* __restrict__ seq_lens, // [num_seqs]
|
||||||
const int max_num_blocks_per_seq,
|
const int max_num_blocks_per_seq,
|
||||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||||
const int q_stride,
|
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
||||||
const int kv_block_stride,
|
|
||||||
const int kv_head_stride,
|
|
||||||
const float kv_scale) {
|
const float kv_scale) {
|
||||||
const int seq_idx = blockIdx.y;
|
const int seq_idx = blockIdx.y;
|
||||||
const int partition_idx = blockIdx.z;
|
const int partition_idx = blockIdx.z;
|
||||||
@ -118,22 +116,29 @@ __device__ void paged_attention_kernel(
|
|||||||
}
|
}
|
||||||
|
|
||||||
const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
|
const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
|
||||||
const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks;
|
const int num_blocks_per_partition =
|
||||||
|
USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks;
|
||||||
|
|
||||||
// [start_block_idx, end_block_idx) is the range of blocks to process.
|
// [start_block_idx, end_block_idx) is the range of blocks to process.
|
||||||
const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
|
const int start_block_idx =
|
||||||
const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks);
|
USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
|
||||||
|
const int end_block_idx =
|
||||||
|
MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks);
|
||||||
const int num_blocks = end_block_idx - start_block_idx;
|
const int num_blocks = end_block_idx - start_block_idx;
|
||||||
|
|
||||||
// [start_token_idx, end_token_idx) is the range of tokens to process.
|
// [start_token_idx, end_token_idx) is the range of tokens to process.
|
||||||
const int start_token_idx = start_block_idx * BLOCK_SIZE;
|
const int start_token_idx = start_block_idx * BLOCK_SIZE;
|
||||||
const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len);
|
const int end_token_idx =
|
||||||
|
MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len);
|
||||||
const int num_tokens = end_token_idx - start_token_idx;
|
const int num_tokens = end_token_idx - start_token_idx;
|
||||||
|
|
||||||
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
||||||
constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS
|
constexpr int NUM_THREAD_GROUPS =
|
||||||
|
NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE
|
||||||
|
// divides NUM_THREADS
|
||||||
assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
|
assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
|
||||||
constexpr int NUM_TOKENS_PER_THREAD_GROUP = DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE);
|
constexpr int NUM_TOKENS_PER_THREAD_GROUP =
|
||||||
|
DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE);
|
||||||
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
||||||
const int thread_idx = threadIdx.x;
|
const int thread_idx = threadIdx.x;
|
||||||
const int warp_idx = thread_idx / WARP_SIZE;
|
const int warp_idx = thread_idx / WARP_SIZE;
|
||||||
@ -143,13 +148,14 @@ __device__ void paged_attention_kernel(
|
|||||||
const int num_heads = gridDim.x;
|
const int num_heads = gridDim.x;
|
||||||
const int num_queries_per_kv = num_heads / num_kv_heads;
|
const int num_queries_per_kv = num_heads / num_kv_heads;
|
||||||
const int kv_head_idx = head_idx / num_queries_per_kv;
|
const int kv_head_idx = head_idx / num_queries_per_kv;
|
||||||
const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
|
const float alibi_slope =
|
||||||
|
alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
|
||||||
|
|
||||||
// A vector type to store a part of a key or a query.
|
// A vector type to store a part of a key or a query.
|
||||||
// The vector size is configured in such a way that the threads in a thread group
|
// The vector size is configured in such a way that the threads in a thread
|
||||||
// fetch or compute 16 bytes at a time.
|
// group fetch or compute 16 bytes at a time. For example, if the size of a
|
||||||
// For example, if the size of a thread group is 4 and the data type is half,
|
// thread group is 4 and the data type is half, then the vector size is 16 /
|
||||||
// then the vector size is 16 / (4 * sizeof(half)) == 2.
|
// (4 * sizeof(half)) == 2.
|
||||||
constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
|
constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
|
||||||
using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
||||||
using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
||||||
@ -163,18 +169,21 @@ __device__ void paged_attention_kernel(
|
|||||||
|
|
||||||
// Load the query to registers.
|
// Load the query to registers.
|
||||||
// Each thread in a thread group has a different part of the query.
|
// Each thread in a thread group has a different part of the query.
|
||||||
// For example, if the the thread group size is 4, then the first thread in the group
|
// For example, if the the thread group size is 4, then the first thread in
|
||||||
// has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ...
|
// the group has 0, 4, 8, ... th vectors of the query, and the second thread
|
||||||
// th vectors of the query, and so on.
|
// has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because
|
||||||
// NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous.
|
// q is split from a qkv tensor, it may not be contiguous.
|
||||||
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
||||||
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
|
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) {
|
for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD;
|
||||||
|
i += NUM_THREAD_GROUPS) {
|
||||||
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
|
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
|
||||||
q_vecs[thread_group_offset][i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
|
q_vecs[thread_group_offset][i] =
|
||||||
|
*reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
|
||||||
}
|
}
|
||||||
__syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs
|
__syncthreads(); // TODO(naed90): possible speedup if this is replaced with a
|
||||||
|
// memory wall right before we use q_vecs
|
||||||
|
|
||||||
// Memory planning.
|
// Memory planning.
|
||||||
extern __shared__ char shared_mem[];
|
extern __shared__ char shared_mem[];
|
||||||
@ -193,44 +202,50 @@ __device__ void paged_attention_kernel(
|
|||||||
// Each thread group in a warp fetches a key from the block, and computes
|
// Each thread group in a warp fetches a key from the block, and computes
|
||||||
// dot product with the query.
|
// dot product with the query.
|
||||||
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
|
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
|
||||||
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
|
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
|
||||||
// NOTE(woosuk): The block number is stored in int32. However, we cast it to int64
|
block_idx += NUM_WARPS) {
|
||||||
// because int32 can lead to overflow when this variable is multiplied by large numbers
|
// NOTE(woosuk): The block number is stored in int32. However, we cast it to
|
||||||
// (e.g., kv_block_stride).
|
// int64 because int32 can lead to overflow when this variable is multiplied
|
||||||
const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
|
// by large numbers (e.g., kv_block_stride).
|
||||||
|
const int64_t physical_block_number =
|
||||||
|
static_cast<int64_t>(block_table[block_idx]);
|
||||||
|
|
||||||
// Load a key to registers.
|
// Load a key to registers.
|
||||||
// Each thread in a thread group has a different part of the key.
|
// Each thread in a thread group has a different part of the key.
|
||||||
// For example, if the the thread group size is 4, then the first thread in the group
|
// For example, if the the thread group size is 4, then the first thread in
|
||||||
// has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th
|
// the group has 0, 4, 8, ... th vectors of the key, and the second thread
|
||||||
// vectors of the key, and so on.
|
// has 1, 5, 9, ... th vectors of the key, and so on.
|
||||||
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
|
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
|
||||||
const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
|
const int physical_block_offset =
|
||||||
|
(thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
|
||||||
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
|
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
|
||||||
K_vec k_vecs[NUM_VECS_PER_THREAD];
|
K_vec k_vecs[NUM_VECS_PER_THREAD];
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
|
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
|
||||||
const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride
|
const cache_t* k_ptr =
|
||||||
+ kv_head_idx * kv_head_stride
|
k_cache + physical_block_number * kv_block_stride +
|
||||||
+ physical_block_offset * x;
|
kv_head_idx * kv_head_stride + physical_block_offset * x;
|
||||||
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
|
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
|
||||||
const int offset1 = (vec_idx * VEC_SIZE) / x;
|
const int offset1 = (vec_idx * VEC_SIZE) / x;
|
||||||
const int offset2 = (vec_idx * VEC_SIZE) % x;
|
const int offset2 = (vec_idx * VEC_SIZE) % x;
|
||||||
|
|
||||||
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
|
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
|
||||||
k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
k_vecs[j] = *reinterpret_cast<const K_vec*>(
|
||||||
|
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
||||||
} else {
|
} else {
|
||||||
// Vector conversion from Quant_vec to K_vec.
|
// Vector conversion from Quant_vec to K_vec.
|
||||||
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
|
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
|
||||||
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
||||||
k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(k_vec_quant, kv_scale);
|
k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(
|
||||||
|
k_vec_quant, kv_scale);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compute dot product.
|
// Compute dot product.
|
||||||
// This includes a reduction across the threads in the same thread group.
|
// This includes a reduction across the threads in the same thread group.
|
||||||
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
|
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(
|
||||||
|
q_vecs[thread_group_offset], k_vecs);
|
||||||
// Add the ALiBi bias if slopes are given.
|
// Add the ALiBi bias if slopes are given.
|
||||||
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0;
|
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0;
|
||||||
|
|
||||||
@ -285,13 +300,12 @@ __device__ void paged_attention_kernel(
|
|||||||
|
|
||||||
// If partitioning is enabled, store the max logit and exp_sum.
|
// If partitioning is enabled, store the max logit and exp_sum.
|
||||||
if (USE_PARTITIONING && thread_idx == 0) {
|
if (USE_PARTITIONING && thread_idx == 0) {
|
||||||
float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions
|
float* max_logits_ptr = max_logits +
|
||||||
+ head_idx * max_num_partitions
|
seq_idx * num_heads * max_num_partitions +
|
||||||
+ partition_idx;
|
head_idx * max_num_partitions + partition_idx;
|
||||||
*max_logits_ptr = qk_max;
|
*max_logits_ptr = qk_max;
|
||||||
float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions
|
float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions +
|
||||||
+ head_idx * max_num_partitions
|
head_idx * max_num_partitions + partition_idx;
|
||||||
+ partition_idx;
|
|
||||||
*exp_sums_ptr = exp_sum;
|
*exp_sums_ptr = exp_sum;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -304,7 +318,8 @@ __device__ void paged_attention_kernel(
|
|||||||
|
|
||||||
constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
|
constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
|
||||||
constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
|
constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
|
||||||
constexpr int NUM_ROWS_PER_THREAD = DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
|
constexpr int NUM_ROWS_PER_THREAD =
|
||||||
|
DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
|
||||||
|
|
||||||
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
|
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
|
||||||
float accs[NUM_ROWS_PER_THREAD];
|
float accs[NUM_ROWS_PER_THREAD];
|
||||||
@ -315,18 +330,21 @@ __device__ void paged_attention_kernel(
|
|||||||
|
|
||||||
scalar_t zero_value;
|
scalar_t zero_value;
|
||||||
zero(zero_value);
|
zero(zero_value);
|
||||||
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
|
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
|
||||||
// NOTE(woosuk): The block number is stored in int32. However, we cast it to int64
|
block_idx += NUM_WARPS) {
|
||||||
// because int32 can lead to overflow when this variable is multiplied by large numbers
|
// NOTE(woosuk): The block number is stored in int32. However, we cast it to
|
||||||
// (e.g., kv_block_stride).
|
// int64 because int32 can lead to overflow when this variable is multiplied
|
||||||
const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
|
// by large numbers (e.g., kv_block_stride).
|
||||||
|
const int64_t physical_block_number =
|
||||||
|
static_cast<int64_t>(block_table[block_idx]);
|
||||||
const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
|
const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
|
||||||
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
|
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
|
||||||
L_vec logits_vec;
|
L_vec logits_vec;
|
||||||
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx - start_token_idx));
|
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx -
|
||||||
|
start_token_idx));
|
||||||
|
|
||||||
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride
|
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
|
||||||
+ kv_head_idx * kv_head_stride;
|
kv_head_idx * kv_head_stride;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
||||||
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
||||||
@ -337,14 +355,17 @@ __device__ void paged_attention_kernel(
|
|||||||
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
|
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
|
||||||
v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
|
v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
|
||||||
} else {
|
} else {
|
||||||
V_quant_vec v_quant_vec = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
|
V_quant_vec v_quant_vec =
|
||||||
|
*reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
|
||||||
// Vector conversion from V_quant_vec to V_vec.
|
// Vector conversion from V_quant_vec to V_vec.
|
||||||
v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec, kv_scale);
|
v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec,
|
||||||
|
kv_scale);
|
||||||
}
|
}
|
||||||
if (block_idx == num_seq_blocks - 1) {
|
if (block_idx == num_seq_blocks - 1) {
|
||||||
// NOTE(woosuk): When v_vec contains the tokens that are out of the context,
|
// NOTE(woosuk): When v_vec contains the tokens that are out of the
|
||||||
// we should explicitly zero out the values since they may contain NaNs.
|
// context, we should explicitly zero out the values since they may
|
||||||
// See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
|
// contain NaNs. See
|
||||||
|
// https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
|
||||||
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
|
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < V_VEC_SIZE; j++) {
|
for (int j = 0; j < V_VEC_SIZE; j++) {
|
||||||
@ -367,8 +388,8 @@ __device__ void paged_attention_kernel(
|
|||||||
accs[i] = acc;
|
accs[i] = acc;
|
||||||
}
|
}
|
||||||
|
|
||||||
// NOTE(woosuk): A barrier is required because the shared memory space for logits
|
// NOTE(woosuk): A barrier is required because the shared memory space for
|
||||||
// is reused for the output.
|
// logits is reused for the output.
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
// Perform reduction across warps.
|
// Perform reduction across warps.
|
||||||
@ -405,9 +426,9 @@ __device__ void paged_attention_kernel(
|
|||||||
|
|
||||||
// Write the final output.
|
// Write the final output.
|
||||||
if (warp_idx == 0) {
|
if (warp_idx == 0) {
|
||||||
scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
|
scalar_t* out_ptr =
|
||||||
+ head_idx * max_num_partitions * HEAD_SIZE
|
out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
|
||||||
+ partition_idx * HEAD_SIZE;
|
head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
||||||
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
||||||
@ -419,77 +440,73 @@ __device__ void paged_attention_kernel(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Grid: (num_heads, num_seqs, 1).
|
// Grid: (num_heads, num_seqs, 1).
|
||||||
template<
|
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
|
||||||
typename scalar_t,
|
|
||||||
typename cache_t,
|
|
||||||
int HEAD_SIZE,
|
|
||||||
int BLOCK_SIZE,
|
|
||||||
int NUM_THREADS,
|
int NUM_THREADS,
|
||||||
vllm::Fp8KVCacheDataType KV_DTYPE>
|
vllm::Fp8KVCacheDataType KV_DTYPE>
|
||||||
__global__ void paged_attention_v1_kernel(
|
__global__ void paged_attention_v1_kernel(
|
||||||
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
||||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||||
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
|
||||||
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
// head_size/x, block_size, x]
|
||||||
|
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
|
||||||
|
// head_size, block_size]
|
||||||
const int num_kv_heads, // [num_heads]
|
const int num_kv_heads, // [num_heads]
|
||||||
const float scale,
|
const float scale,
|
||||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||||
const int* __restrict__ seq_lens, // [num_seqs]
|
const int* __restrict__ seq_lens, // [num_seqs]
|
||||||
const int max_num_blocks_per_seq,
|
const int max_num_blocks_per_seq,
|
||||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||||
const int q_stride,
|
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
||||||
const int kv_block_stride,
|
|
||||||
const int kv_head_stride,
|
|
||||||
const float kv_scale) {
|
const float kv_scale) {
|
||||||
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_DTYPE>(
|
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
|
||||||
/* exp_sums */ nullptr, /* max_logits */ nullptr,
|
KV_DTYPE>(
|
||||||
out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, seq_lens,
|
/* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache,
|
||||||
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_scale);
|
v_cache, num_kv_heads, scale, block_tables, seq_lens,
|
||||||
|
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride,
|
||||||
|
kv_head_stride, kv_scale);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Grid: (num_heads, num_seqs, max_num_partitions).
|
// Grid: (num_heads, num_seqs, max_num_partitions).
|
||||||
template<
|
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
|
||||||
typename scalar_t,
|
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
|
||||||
typename cache_t,
|
|
||||||
int HEAD_SIZE,
|
|
||||||
int BLOCK_SIZE,
|
|
||||||
int NUM_THREADS,
|
|
||||||
vllm::Fp8KVCacheDataType KV_DTYPE,
|
|
||||||
int PARTITION_SIZE>
|
int PARTITION_SIZE>
|
||||||
__global__ void paged_attention_v2_kernel(
|
__global__ void paged_attention_v2_kernel(
|
||||||
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
||||||
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
|
float* __restrict__ max_logits, // [num_seqs, num_heads,
|
||||||
scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
// max_num_partitions]
|
||||||
|
scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
|
||||||
|
// max_num_partitions, head_size]
|
||||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||||
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
|
||||||
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
// head_size/x, block_size, x]
|
||||||
|
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
|
||||||
|
// head_size, block_size]
|
||||||
const int num_kv_heads, // [num_heads]
|
const int num_kv_heads, // [num_heads]
|
||||||
const float scale,
|
const float scale,
|
||||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||||
const int* __restrict__ seq_lens, // [num_seqs]
|
const int* __restrict__ seq_lens, // [num_seqs]
|
||||||
const int max_num_blocks_per_seq,
|
const int max_num_blocks_per_seq,
|
||||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||||
const int q_stride,
|
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
||||||
const int kv_block_stride,
|
|
||||||
const int kv_head_stride,
|
|
||||||
const float kv_scale) {
|
const float kv_scale) {
|
||||||
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_DTYPE, PARTITION_SIZE>(
|
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
|
||||||
|
KV_DTYPE, PARTITION_SIZE>(
|
||||||
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
|
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
|
||||||
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes,
|
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride,
|
||||||
q_stride, kv_block_stride, kv_head_stride, kv_scale);
|
kv_block_stride, kv_head_stride, kv_scale);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Grid: (num_heads, num_seqs).
|
// Grid: (num_heads, num_seqs).
|
||||||
template<
|
template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS,
|
||||||
typename scalar_t,
|
|
||||||
int HEAD_SIZE,
|
|
||||||
int NUM_THREADS,
|
|
||||||
int PARTITION_SIZE>
|
int PARTITION_SIZE>
|
||||||
__global__ void paged_attention_v2_reduce_kernel(
|
__global__ void paged_attention_v2_reduce_kernel(
|
||||||
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
||||||
const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
const float* __restrict__ exp_sums, // [num_seqs, num_heads,
|
||||||
const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
|
// max_num_partitions]
|
||||||
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
const float* __restrict__ max_logits, // [num_seqs, num_heads,
|
||||||
|
// max_num_partitions]
|
||||||
|
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
|
||||||
|
// max_num_partitions, head_size]
|
||||||
const int* __restrict__ seq_lens, // [num_seqs]
|
const int* __restrict__ seq_lens, // [num_seqs]
|
||||||
const int max_num_partitions) {
|
const int max_num_partitions) {
|
||||||
const int num_heads = gridDim.x;
|
const int num_heads = gridDim.x;
|
||||||
@ -499,9 +516,11 @@ __global__ void paged_attention_v2_reduce_kernel(
|
|||||||
const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE);
|
const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE);
|
||||||
if (num_partitions == 1) {
|
if (num_partitions == 1) {
|
||||||
// No need to reduce. Only copy tmp_out to out.
|
// No need to reduce. Only copy tmp_out to out.
|
||||||
scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
|
scalar_t* out_ptr =
|
||||||
const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
|
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
|
||||||
+ head_idx * max_num_partitions * HEAD_SIZE;
|
const scalar_t* tmp_out_ptr =
|
||||||
|
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
|
||||||
|
head_idx * max_num_partitions * HEAD_SIZE;
|
||||||
for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) {
|
for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) {
|
||||||
out_ptr[i] = tmp_out_ptr[i];
|
out_ptr[i] = tmp_out_ptr[i];
|
||||||
}
|
}
|
||||||
@ -520,8 +539,9 @@ __global__ void paged_attention_v2_reduce_kernel(
|
|||||||
|
|
||||||
// Load max logits to shared memory.
|
// Load max logits to shared memory.
|
||||||
float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
|
float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
|
||||||
const float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions
|
const float* max_logits_ptr = max_logits +
|
||||||
+ head_idx * max_num_partitions;
|
seq_idx * num_heads * max_num_partitions +
|
||||||
|
head_idx * max_num_partitions;
|
||||||
float max_logit = -FLT_MAX;
|
float max_logit = -FLT_MAX;
|
||||||
for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
|
for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
|
||||||
const float l = max_logits_ptr[i];
|
const float l = max_logits_ptr[i];
|
||||||
@ -550,9 +570,11 @@ __global__ void paged_attention_v2_reduce_kernel(
|
|||||||
max_logit = VLLM_SHFL_SYNC(max_logit, 0);
|
max_logit = VLLM_SHFL_SYNC(max_logit, 0);
|
||||||
|
|
||||||
// Load rescaled exp sums to shared memory.
|
// Load rescaled exp sums to shared memory.
|
||||||
float* shared_exp_sums = reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
|
float* shared_exp_sums =
|
||||||
const float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions
|
reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
|
||||||
+ head_idx * max_num_partitions;
|
const float* exp_sums_ptr = exp_sums +
|
||||||
|
seq_idx * num_heads * max_num_partitions +
|
||||||
|
head_idx * max_num_partitions;
|
||||||
float global_exp_sum = 0.0f;
|
float global_exp_sum = 0.0f;
|
||||||
for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
|
for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
|
||||||
float l = shared_max_logits[i];
|
float l = shared_max_logits[i];
|
||||||
@ -565,14 +587,17 @@ __global__ void paged_attention_v2_reduce_kernel(
|
|||||||
const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
|
const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
|
||||||
|
|
||||||
// Aggregate tmp_out to out.
|
// Aggregate tmp_out to out.
|
||||||
const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
|
const scalar_t* tmp_out_ptr =
|
||||||
+ head_idx * max_num_partitions * HEAD_SIZE;
|
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
|
||||||
scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
|
head_idx * max_num_partitions * HEAD_SIZE;
|
||||||
|
scalar_t* out_ptr =
|
||||||
|
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
|
for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
|
||||||
float acc = 0.0f;
|
float acc = 0.0f;
|
||||||
for (int j = 0; j < num_partitions; ++j) {
|
for (int j = 0; j < num_partitions; ++j) {
|
||||||
acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * inv_global_exp_sum;
|
acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] *
|
||||||
|
inv_global_exp_sum;
|
||||||
}
|
}
|
||||||
from_float(out_ptr[i], acc);
|
from_float(out_ptr[i], acc);
|
||||||
}
|
}
|
||||||
@ -582,44 +607,25 @@ __global__ void paged_attention_v2_reduce_kernel(
|
|||||||
|
|
||||||
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
|
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
|
||||||
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
|
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
|
||||||
((void*)vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
|
((void*)vllm::paged_attention_v1_kernel< \
|
||||||
KV_DTYPE>), shared_mem_size); \
|
T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_DTYPE>), \
|
||||||
vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
|
shared_mem_size); \
|
||||||
KV_DTYPE><<<grid, block, shared_mem_size, stream>>>( \
|
vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
|
||||||
out_ptr, \
|
NUM_THREADS, KV_DTYPE> \
|
||||||
query_ptr, \
|
<<<grid, block, shared_mem_size, stream>>>( \
|
||||||
key_cache_ptr, \
|
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \
|
||||||
value_cache_ptr, \
|
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
|
||||||
num_kv_heads, \
|
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
|
||||||
scale, \
|
|
||||||
block_tables_ptr, \
|
|
||||||
seq_lens_ptr, \
|
|
||||||
max_num_blocks_per_seq, \
|
|
||||||
alibi_slopes_ptr, \
|
|
||||||
q_stride, \
|
|
||||||
kv_block_stride, \
|
|
||||||
kv_head_stride, \
|
|
||||||
kv_scale);
|
kv_scale);
|
||||||
|
|
||||||
// TODO(woosuk): Tune NUM_THREADS.
|
// TODO(woosuk): Tune NUM_THREADS.
|
||||||
template<
|
template <typename T, typename CACHE_T, int BLOCK_SIZE,
|
||||||
typename T,
|
vllm::Fp8KVCacheDataType KV_DTYPE, int NUM_THREADS = 128>
|
||||||
typename CACHE_T,
|
|
||||||
int BLOCK_SIZE,
|
|
||||||
vllm::Fp8KVCacheDataType KV_DTYPE,
|
|
||||||
int NUM_THREADS = 128>
|
|
||||||
void paged_attention_v1_launcher(
|
void paged_attention_v1_launcher(
|
||||||
torch::Tensor& out,
|
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||||
torch::Tensor& query,
|
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
||||||
torch::Tensor& key_cache,
|
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
|
||||||
torch::Tensor& value_cache,
|
const c10::optional<torch::Tensor>& alibi_slopes, float kv_scale) {
|
||||||
int num_kv_heads,
|
|
||||||
float scale,
|
|
||||||
torch::Tensor& block_tables,
|
|
||||||
torch::Tensor& seq_lens,
|
|
||||||
int max_seq_len,
|
|
||||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
|
||||||
float kv_scale) {
|
|
||||||
int num_seqs = query.size(0);
|
int num_seqs = query.size(0);
|
||||||
int num_heads = query.size(1);
|
int num_heads = query.size(1);
|
||||||
int head_size = query.size(2);
|
int head_size = query.size(2);
|
||||||
@ -632,8 +638,9 @@ void paged_attention_v1_launcher(
|
|||||||
assert(head_size % thread_group_size == 0);
|
assert(head_size % thread_group_size == 0);
|
||||||
|
|
||||||
// NOTE: alibi_slopes is optional.
|
// NOTE: alibi_slopes is optional.
|
||||||
const float* alibi_slopes_ptr = alibi_slopes ?
|
const float* alibi_slopes_ptr =
|
||||||
reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
|
alibi_slopes
|
||||||
|
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
|
||||||
: nullptr;
|
: nullptr;
|
||||||
|
|
||||||
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
|
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
|
||||||
@ -644,7 +651,8 @@ void paged_attention_v1_launcher(
|
|||||||
int* seq_lens_ptr = seq_lens.data_ptr<int>();
|
int* seq_lens_ptr = seq_lens.data_ptr<int>();
|
||||||
|
|
||||||
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
||||||
int padded_max_seq_len = DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
|
int padded_max_seq_len =
|
||||||
|
DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
|
||||||
int logits_size = padded_max_seq_len * sizeof(float);
|
int logits_size = padded_max_seq_len * sizeof(float);
|
||||||
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
|
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
|
||||||
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
|
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
|
||||||
@ -685,17 +693,8 @@ void paged_attention_v1_launcher(
|
|||||||
|
|
||||||
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE) \
|
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE) \
|
||||||
paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE>( \
|
paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE>( \
|
||||||
out, \
|
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
|
||||||
query, \
|
seq_lens, max_seq_len, alibi_slopes, kv_scale);
|
||||||
key_cache, \
|
|
||||||
value_cache, \
|
|
||||||
num_kv_heads, \
|
|
||||||
scale, \
|
|
||||||
block_tables, \
|
|
||||||
seq_lens, \
|
|
||||||
max_seq_len, \
|
|
||||||
alibi_slopes, \
|
|
||||||
kv_scale);
|
|
||||||
|
|
||||||
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
||||||
// 1, 2, 4, 64, 128, 256.
|
// 1, 2, 4, 64, 128, 256.
|
||||||
@ -718,72 +717,43 @@ void paged_attention_v1_launcher(
|
|||||||
void paged_attention_v1(
|
void paged_attention_v1(
|
||||||
torch::Tensor& out, // [num_seqs, num_heads, head_size]
|
torch::Tensor& out, // [num_seqs, num_heads, head_size]
|
||||||
torch::Tensor& query, // [num_seqs, num_heads, head_size]
|
torch::Tensor& query, // [num_seqs, num_heads, head_size]
|
||||||
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
torch::Tensor&
|
||||||
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||||
|
torch::Tensor&
|
||||||
|
value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||||
int num_kv_heads, // [num_heads]
|
int num_kv_heads, // [num_heads]
|
||||||
float scale,
|
float scale,
|
||||||
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||||
torch::Tensor& seq_lens, // [num_seqs]
|
torch::Tensor& seq_lens, // [num_seqs]
|
||||||
int block_size,
|
int block_size, int max_seq_len,
|
||||||
int max_seq_len,
|
|
||||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||||
const std::string& kv_cache_dtype,
|
const std::string& kv_cache_dtype, float kv_scale){
|
||||||
float kv_scale) {
|
|
||||||
|
|
||||||
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, CALL_V1_LAUNCHER_BLOCK_SIZE)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
|
||||||
|
CALL_V1_LAUNCHER_BLOCK_SIZE)}
|
||||||
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
|
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
|
||||||
vllm::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
|
vllm::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
|
||||||
KV_DTYPE, PARTITION_SIZE> \
|
NUM_THREADS, KV_DTYPE, PARTITION_SIZE> \
|
||||||
<<<grid, block, shared_mem_size, stream>>>( \
|
<<<grid, block, shared_mem_size, stream>>>( \
|
||||||
exp_sums_ptr, \
|
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
|
||||||
max_logits_ptr, \
|
value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
|
||||||
tmp_out_ptr, \
|
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
|
||||||
query_ptr, \
|
kv_block_stride, kv_head_stride, kv_scale); \
|
||||||
key_cache_ptr, \
|
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \
|
||||||
value_cache_ptr, \
|
PARTITION_SIZE> \
|
||||||
num_kv_heads, \
|
|
||||||
scale, \
|
|
||||||
block_tables_ptr, \
|
|
||||||
seq_lens_ptr, \
|
|
||||||
max_num_blocks_per_seq, \
|
|
||||||
alibi_slopes_ptr, \
|
|
||||||
q_stride, \
|
|
||||||
kv_block_stride, \
|
|
||||||
kv_head_stride, \
|
|
||||||
kv_scale); \
|
|
||||||
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, PARTITION_SIZE> \
|
|
||||||
<<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
|
<<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
|
||||||
out_ptr, \
|
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
|
||||||
exp_sums_ptr, \
|
|
||||||
max_logits_ptr, \
|
|
||||||
tmp_out_ptr, \
|
|
||||||
seq_lens_ptr, \
|
|
||||||
max_num_partitions);
|
max_num_partitions);
|
||||||
|
|
||||||
template<
|
template <typename T, typename CACHE_T, int BLOCK_SIZE,
|
||||||
typename T,
|
vllm::Fp8KVCacheDataType KV_DTYPE, int NUM_THREADS = 128,
|
||||||
typename CACHE_T,
|
|
||||||
int BLOCK_SIZE,
|
|
||||||
vllm::Fp8KVCacheDataType KV_DTYPE,
|
|
||||||
int NUM_THREADS = 128,
|
|
||||||
int PARTITION_SIZE = 512>
|
int PARTITION_SIZE = 512>
|
||||||
void paged_attention_v2_launcher(
|
void paged_attention_v2_launcher(
|
||||||
torch::Tensor& out,
|
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
|
||||||
torch::Tensor& exp_sums,
|
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||||
torch::Tensor& max_logits,
|
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
||||||
torch::Tensor& tmp_out,
|
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
|
||||||
torch::Tensor& query,
|
const c10::optional<torch::Tensor>& alibi_slopes, float kv_scale) {
|
||||||
torch::Tensor& key_cache,
|
|
||||||
torch::Tensor& value_cache,
|
|
||||||
int num_kv_heads,
|
|
||||||
float scale,
|
|
||||||
torch::Tensor& block_tables,
|
|
||||||
torch::Tensor& seq_lens,
|
|
||||||
int max_seq_len,
|
|
||||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
|
||||||
float kv_scale) {
|
|
||||||
int num_seqs = query.size(0);
|
int num_seqs = query.size(0);
|
||||||
int num_heads = query.size(1);
|
int num_heads = query.size(1);
|
||||||
int head_size = query.size(2);
|
int head_size = query.size(2);
|
||||||
@ -796,8 +766,9 @@ void paged_attention_v2_launcher(
|
|||||||
assert(head_size % thread_group_size == 0);
|
assert(head_size % thread_group_size == 0);
|
||||||
|
|
||||||
// NOTE: alibi_slopes is optional.
|
// NOTE: alibi_slopes is optional.
|
||||||
const float* alibi_slopes_ptr = alibi_slopes ?
|
const float* alibi_slopes_ptr =
|
||||||
reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
|
alibi_slopes
|
||||||
|
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
|
||||||
: nullptr;
|
: nullptr;
|
||||||
|
|
||||||
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
|
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
|
||||||
@ -855,19 +826,8 @@ void paged_attention_v2_launcher(
|
|||||||
|
|
||||||
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE) \
|
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE) \
|
||||||
paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE>( \
|
paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE>( \
|
||||||
out, \
|
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
|
||||||
exp_sums, \
|
num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
|
||||||
max_logits, \
|
|
||||||
tmp_out, \
|
|
||||||
query, \
|
|
||||||
key_cache, \
|
|
||||||
value_cache, \
|
|
||||||
num_kv_heads, \
|
|
||||||
scale, \
|
|
||||||
block_tables, \
|
|
||||||
seq_lens, \
|
|
||||||
max_seq_len, \
|
|
||||||
alibi_slopes, \
|
|
||||||
kv_scale);
|
kv_scale);
|
||||||
|
|
||||||
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
||||||
@ -892,20 +852,22 @@ void paged_attention_v2(
|
|||||||
torch::Tensor& out, // [num_seqs, num_heads, head_size]
|
torch::Tensor& out, // [num_seqs, num_heads, head_size]
|
||||||
torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
||||||
torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
|
torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
|
||||||
torch::Tensor& tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
torch::Tensor&
|
||||||
|
tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
||||||
torch::Tensor& query, // [num_seqs, num_heads, head_size]
|
torch::Tensor& query, // [num_seqs, num_heads, head_size]
|
||||||
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
torch::Tensor&
|
||||||
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||||
|
torch::Tensor&
|
||||||
|
value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||||
int num_kv_heads, // [num_heads]
|
int num_kv_heads, // [num_heads]
|
||||||
float scale,
|
float scale,
|
||||||
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||||
torch::Tensor& seq_lens, // [num_seqs]
|
torch::Tensor& seq_lens, // [num_seqs]
|
||||||
int block_size,
|
int block_size, int max_seq_len,
|
||||||
int max_seq_len,
|
|
||||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||||
const std::string& kv_cache_dtype,
|
const std::string& kv_cache_dtype, float kv_scale) {
|
||||||
float kv_scale) {
|
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
|
||||||
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, CALL_V2_LAUNCHER_BLOCK_SIZE)
|
CALL_V2_LAUNCHER_BLOCK_SIZE)
|
||||||
}
|
}
|
||||||
|
|
||||||
#undef WARP_SIZE
|
#undef WARP_SIZE
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
/*
|
/*
|
||||||
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
* Adapted from
|
||||||
|
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
||||||
* Copyright (c) 2023, The vLLM team.
|
* Copyright (c) 2023, The vLLM team.
|
||||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||||
*
|
*
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
/*
|
/*
|
||||||
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
* Adapted from
|
||||||
* and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
||||||
|
* and
|
||||||
|
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
||||||
* Copyright (c) 2023, The vLLM team.
|
* Copyright (c) 2023, The vLLM team.
|
||||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||||
*
|
*
|
||||||
@ -280,7 +282,8 @@ inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Vector fused multiply-add.
|
// Vector fused multiply-add.
|
||||||
inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) {
|
inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b,
|
||||||
|
__nv_bfloat162 c) {
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||||
assert(false);
|
assert(false);
|
||||||
#else
|
#else
|
||||||
@ -288,7 +291,8 @@ inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bf
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c) {
|
inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b,
|
||||||
|
__nv_bfloat162 c) {
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||||
assert(false);
|
assert(false);
|
||||||
#else
|
#else
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
/*
|
/*
|
||||||
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
* Adapted from
|
||||||
* and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
||||||
|
* and
|
||||||
|
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
||||||
* Copyright (c) 2023, The vLLM team.
|
* Copyright (c) 2023, The vLLM team.
|
||||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||||
*
|
*
|
||||||
@ -130,7 +132,9 @@ inline __device__ uint32_t float2_to_half2(float2 f) {
|
|||||||
} tmp;
|
} tmp;
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||||
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x));
|
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n"
|
||||||
|
: "=r"(tmp.u32)
|
||||||
|
: "f"(f.y), "f"(f.x));
|
||||||
#else
|
#else
|
||||||
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
|
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
|
||||||
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
|
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
|
||||||
@ -327,9 +331,13 @@ inline __device__ Float8_ mul(uint16_t a, uint4 b) {
|
|||||||
inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) {
|
inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) {
|
||||||
uint32_t d;
|
uint32_t d;
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c));
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
|
||||||
|
: "=r"(d)
|
||||||
|
: "r"(a), "r"(b), "r"(c));
|
||||||
#else
|
#else
|
||||||
asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" : "=v"(d) : "v"(a), "v"(b), "v"(c));
|
asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n"
|
||||||
|
: "=v"(d)
|
||||||
|
: "v"(a), "v"(b), "v"(c));
|
||||||
#endif
|
#endif
|
||||||
return d;
|
return d;
|
||||||
}
|
}
|
||||||
@ -470,13 +478,9 @@ inline __device__ void from_float(uint4& dst, Float8_ src) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// From float16 to float32.
|
// From float16 to float32.
|
||||||
inline __device__ float to_float(uint16_t u) {
|
inline __device__ float to_float(uint16_t u) { return half_to_float(u); }
|
||||||
return half_to_float(u);
|
|
||||||
}
|
|
||||||
|
|
||||||
inline __device__ float2 to_float(uint32_t u) {
|
inline __device__ float2 to_float(uint32_t u) { return half2_to_float2(u); }
|
||||||
return half2_to_float2(u);
|
|
||||||
}
|
|
||||||
|
|
||||||
inline __device__ Float4_ to_float(uint2 u) {
|
inline __device__ Float4_ to_float(uint2 u) {
|
||||||
Float4_ tmp;
|
Float4_ tmp;
|
||||||
@ -495,8 +499,6 @@ inline __device__ Float8_ to_float(uint4 u) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Zero-out a variable.
|
// Zero-out a variable.
|
||||||
inline __device__ void zero(uint16_t& dst) {
|
inline __device__ void zero(uint16_t& dst) { dst = uint16_t(0); }
|
||||||
dst = uint16_t(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
/*
|
/*
|
||||||
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
* Adapted from
|
||||||
* and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
||||||
|
* and
|
||||||
|
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
||||||
* Copyright (c) 2023, The vLLM team.
|
* Copyright (c) 2023, The vLLM team.
|
||||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||||
*
|
*
|
||||||
@ -66,9 +68,7 @@ struct FloatVec<float4> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Vector addition.
|
// Vector addition.
|
||||||
inline __device__ float add(float a, float b) {
|
inline __device__ float add(float a, float b) { return a + b; }
|
||||||
return a + b;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline __device__ float2 add(float2 a, float2 b) {
|
inline __device__ float2 add(float2 a, float2 b) {
|
||||||
float2 c;
|
float2 c;
|
||||||
@ -129,9 +129,7 @@ inline __device__ float4 mul(float a, float4 b) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Vector fused multiply-add.
|
// Vector fused multiply-add.
|
||||||
inline __device__ float fma(float a, float b, float c) {
|
inline __device__ float fma(float a, float b, float c) { return a * b + c; }
|
||||||
return a * b + c;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline __device__ float2 fma(float2 a, float2 b, float2 c) {
|
inline __device__ float2 fma(float2 a, float2 b, float2 c) {
|
||||||
float2 d;
|
float2 d;
|
||||||
@ -208,9 +206,7 @@ inline __device__ float sum(Float8_ v) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Vector dot product.
|
// Vector dot product.
|
||||||
inline __device__ float dot(float a, float b) {
|
inline __device__ float dot(float a, float b) { return a * b; }
|
||||||
return a * b;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline __device__ float dot(float2 a, float2 b) {
|
inline __device__ float dot(float2 a, float2 b) {
|
||||||
float2 c = mul<float2, float2, float2>(a, b);
|
float2 c = mul<float2, float2, float2>(a, b);
|
||||||
@ -232,42 +228,24 @@ inline __device__ float dot(Float8_ a, Float8_ b) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// From float to float.
|
// From float to float.
|
||||||
inline __device__ void from_float(float& dst, float src) {
|
inline __device__ void from_float(float& dst, float src) { dst = src; }
|
||||||
dst = src;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline __device__ void from_float(float2& dst, float2 src) {
|
inline __device__ void from_float(float2& dst, float2 src) { dst = src; }
|
||||||
dst = src;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline __device__ void from_float(float4& dst, float4 src) {
|
inline __device__ void from_float(float4& dst, float4 src) { dst = src; }
|
||||||
dst = src;
|
|
||||||
}
|
|
||||||
|
|
||||||
// From float to float.
|
// From float to float.
|
||||||
inline __device__ float to_float(float u) {
|
inline __device__ float to_float(float u) { return u; }
|
||||||
return u;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline __device__ float2 to_float(float2 u) {
|
inline __device__ float2 to_float(float2 u) { return u; }
|
||||||
return u;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline __device__ float4 to_float(float4 u) {
|
inline __device__ float4 to_float(float4 u) { return u; }
|
||||||
return u;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline __device__ Float4_ to_float(Float4_ u) {
|
inline __device__ Float4_ to_float(Float4_ u) { return u; }
|
||||||
return u;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline __device__ Float8_ to_float(Float8_ u) {
|
inline __device__ Float8_ to_float(Float8_ u) { return u; }
|
||||||
return u;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Zero-out a variable.
|
// Zero-out a variable.
|
||||||
inline __device__ void zero(float& dst) {
|
inline __device__ void zero(float& dst) { dst = 0.f; }
|
||||||
dst = 0.f;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
28
csrc/cache.h
28
csrc/cache.h
@ -5,36 +5,24 @@
|
|||||||
#include <map>
|
#include <map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
void swap_blocks(
|
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
|
||||||
torch::Tensor& src,
|
|
||||||
torch::Tensor& dst,
|
|
||||||
const torch::Tensor& block_mapping);
|
const torch::Tensor& block_mapping);
|
||||||
|
|
||||||
void copy_blocks(
|
void copy_blocks(std::vector<torch::Tensor>& key_caches,
|
||||||
std::vector<torch::Tensor>& key_caches,
|
|
||||||
std::vector<torch::Tensor>& value_caches,
|
std::vector<torch::Tensor>& value_caches,
|
||||||
const torch::Tensor& block_mapping);
|
const torch::Tensor& block_mapping);
|
||||||
|
|
||||||
void reshape_and_cache(
|
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
|
||||||
torch::Tensor& key,
|
torch::Tensor& key_cache, torch::Tensor& value_cache,
|
||||||
torch::Tensor& value,
|
|
||||||
torch::Tensor& key_cache,
|
|
||||||
torch::Tensor& value_cache,
|
|
||||||
torch::Tensor& slot_mapping,
|
torch::Tensor& slot_mapping,
|
||||||
const std::string& kv_cache_dtype,
|
const std::string& kv_cache_dtype, const float kv_scale);
|
||||||
const float kv_scale);
|
|
||||||
|
|
||||||
void reshape_and_cache_flash(
|
void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
|
||||||
torch::Tensor& key,
|
|
||||||
torch::Tensor& value,
|
|
||||||
torch::Tensor& key_cache,
|
torch::Tensor& key_cache,
|
||||||
torch::Tensor& value_cache,
|
torch::Tensor& value_cache,
|
||||||
torch::Tensor& slot_mapping,
|
torch::Tensor& slot_mapping,
|
||||||
const std::string& kv_cache_dtype);
|
const std::string& kv_cache_dtype);
|
||||||
|
|
||||||
// Just for unittest
|
// Just for unittest
|
||||||
void convert_fp8(
|
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
|
||||||
torch::Tensor& dst_cache,
|
const float scale, const std::string& kv_cache_dtype);
|
||||||
torch::Tensor& src_cache,
|
|
||||||
const float scale,
|
|
||||||
const std::string& kv_cache_dtype);
|
|
||||||
|
@ -21,16 +21,13 @@
|
|||||||
typedef __hip_bfloat16 __nv_bfloat16;
|
typedef __hip_bfloat16 __nv_bfloat16;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
void swap_blocks(
|
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
|
||||||
torch::Tensor& src,
|
|
||||||
torch::Tensor& dst,
|
|
||||||
const torch::Tensor& block_mapping) {
|
const torch::Tensor& block_mapping) {
|
||||||
torch::Device src_device = src.device();
|
torch::Device src_device = src.device();
|
||||||
torch::Device dst_device = dst.device();
|
torch::Device dst_device = dst.device();
|
||||||
cudaMemcpyKind memcpy_type;
|
cudaMemcpyKind memcpy_type;
|
||||||
if (src_device.is_cuda() && dst_device.is_cuda()) {
|
if (src_device.is_cuda() && dst_device.is_cuda()) {
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(src_device.index() == dst_device.index(),
|
||||||
src_device.index() == dst_device.index(),
|
|
||||||
"src and dst must be on the same GPU");
|
"src and dst must be on the same GPU");
|
||||||
memcpy_type = cudaMemcpyDeviceToDevice;
|
memcpy_type = cudaMemcpyDeviceToDevice;
|
||||||
} else if (src_device.is_cuda() && dst_device.is_cpu()) {
|
} else if (src_device.is_cuda() && dst_device.is_cpu()) {
|
||||||
@ -50,7 +47,8 @@ void swap_blocks(
|
|||||||
char* dst_ptr = static_cast<char*>(dst.data_ptr());
|
char* dst_ptr = static_cast<char*>(dst.data_ptr());
|
||||||
|
|
||||||
const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
|
const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(src_device.is_cuda() ? src_device : dst_device);
|
const at::cuda::OptionalCUDAGuard device_guard(
|
||||||
|
src_device.is_cuda() ? src_device : dst_device);
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
// NOTE(woosuk): This can be slow if the number of blocks is large.
|
// NOTE(woosuk): This can be slow if the number of blocks is large.
|
||||||
const int64_t num_blocks = block_mapping.size(0);
|
const int64_t num_blocks = block_mapping.size(0);
|
||||||
@ -59,12 +57,8 @@ void swap_blocks(
|
|||||||
int64_t dst_block_number = block_mapping[i][1].item<int64_t>();
|
int64_t dst_block_number = block_mapping[i][1].item<int64_t>();
|
||||||
int64_t src_offset = src_block_number * block_size_in_bytes;
|
int64_t src_offset = src_block_number * block_size_in_bytes;
|
||||||
int64_t dst_offset = dst_block_number * block_size_in_bytes;
|
int64_t dst_offset = dst_block_number * block_size_in_bytes;
|
||||||
cudaMemcpyAsync(
|
cudaMemcpyAsync(dst_ptr + dst_offset, src_ptr + src_offset,
|
||||||
dst_ptr + dst_offset,
|
block_size_in_bytes, memcpy_type, stream);
|
||||||
src_ptr + src_offset,
|
|
||||||
block_size_in_bytes,
|
|
||||||
memcpy_type,
|
|
||||||
stream);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -72,8 +66,7 @@ namespace vllm {
|
|||||||
|
|
||||||
// Grid: (num_layers, num_pairs)
|
// Grid: (num_layers, num_pairs)
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
__global__ void copy_blocks_kernel(
|
__global__ void copy_blocks_kernel(int64_t* key_cache_ptrs,
|
||||||
int64_t* key_cache_ptrs,
|
|
||||||
int64_t* value_cache_ptrs,
|
int64_t* value_cache_ptrs,
|
||||||
const int64_t* __restrict__ block_mapping,
|
const int64_t* __restrict__ block_mapping,
|
||||||
const int numel_per_block) {
|
const int numel_per_block) {
|
||||||
@ -81,7 +74,8 @@ __global__ void copy_blocks_kernel(
|
|||||||
const int pair_idx = blockIdx.y;
|
const int pair_idx = blockIdx.y;
|
||||||
|
|
||||||
scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]);
|
scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]);
|
||||||
scalar_t* value_cache = reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
|
scalar_t* value_cache =
|
||||||
|
reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
|
||||||
int64_t src_block_number = block_mapping[2 * pair_idx];
|
int64_t src_block_number = block_mapping[2 * pair_idx];
|
||||||
int64_t dst_block_number = block_mapping[2 * pair_idx + 1];
|
int64_t dst_block_number = block_mapping[2 * pair_idx + 1];
|
||||||
|
|
||||||
@ -101,8 +95,7 @@ __global__ void copy_blocks_kernel(
|
|||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
void copy_blocks(
|
void copy_blocks(std::vector<torch::Tensor>& key_caches,
|
||||||
std::vector<torch::Tensor>& key_caches,
|
|
||||||
std::vector<torch::Tensor>& value_caches,
|
std::vector<torch::Tensor>& value_caches,
|
||||||
const torch::Tensor& block_mapping) {
|
const torch::Tensor& block_mapping) {
|
||||||
int num_layers = key_caches.size();
|
int num_layers = key_caches.size();
|
||||||
@ -118,8 +111,10 @@ void copy_blocks(
|
|||||||
int64_t key_cache_ptrs[num_layers];
|
int64_t key_cache_ptrs[num_layers];
|
||||||
int64_t value_cache_ptrs[num_layers];
|
int64_t value_cache_ptrs[num_layers];
|
||||||
for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
|
for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
|
||||||
key_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
|
key_cache_ptrs[layer_idx] =
|
||||||
value_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
|
reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
|
||||||
|
value_cache_ptrs[layer_idx] =
|
||||||
|
reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
|
||||||
}
|
}
|
||||||
|
|
||||||
// block_mapping is a 2D tensor with shape (num_pairs, 2).
|
// block_mapping is a 2D tensor with shape (num_pairs, 2).
|
||||||
@ -127,10 +122,12 @@ void copy_blocks(
|
|||||||
|
|
||||||
// Move the data structures to the GPU.
|
// Move the data structures to the GPU.
|
||||||
// NOTE: This synchronizes the CPU and GPU.
|
// NOTE: This synchronizes the CPU and GPU.
|
||||||
torch::Tensor key_cache_ptrs_tensor = torch::from_blob(
|
torch::Tensor key_cache_ptrs_tensor =
|
||||||
key_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
|
torch::from_blob(key_cache_ptrs, {num_layers}, torch::kInt64)
|
||||||
torch::Tensor value_cache_ptrs_tensor = torch::from_blob(
|
.to(cache_device);
|
||||||
value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
|
torch::Tensor value_cache_ptrs_tensor =
|
||||||
|
torch::from_blob(value_cache_ptrs, {num_layers}, torch::kInt64)
|
||||||
|
.to(cache_device);
|
||||||
|
|
||||||
// Launch the kernel.
|
// Launch the kernel.
|
||||||
const int numel_per_block = key_caches[0][0].numel();
|
const int numel_per_block = key_caches[0][0].numel();
|
||||||
@ -143,8 +140,7 @@ void copy_blocks(
|
|||||||
vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||||
key_cache_ptrs_tensor.data_ptr<int64_t>(),
|
key_cache_ptrs_tensor.data_ptr<int64_t>(),
|
||||||
value_cache_ptrs_tensor.data_ptr<int64_t>(),
|
value_cache_ptrs_tensor.data_ptr<int64_t>(),
|
||||||
block_mapping.data_ptr<int64_t>(),
|
block_mapping.data_ptr<int64_t>(), numel_per_block);
|
||||||
numel_per_block);
|
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -154,15 +150,13 @@ template<typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
|
|||||||
__global__ void reshape_and_cache_kernel(
|
__global__ void reshape_and_cache_kernel(
|
||||||
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
||||||
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
|
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
|
||||||
cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x,
|
||||||
cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
|
// block_size, x]
|
||||||
|
cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size,
|
||||||
|
// block_size]
|
||||||
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||||
const int key_stride,
|
const int key_stride, const int value_stride, const int num_heads,
|
||||||
const int value_stride,
|
const int head_size, const int block_size, const int x,
|
||||||
const int num_heads,
|
|
||||||
const int head_size,
|
|
||||||
const int block_size,
|
|
||||||
const int x,
|
|
||||||
const float kv_scale) {
|
const float kv_scale) {
|
||||||
const int64_t token_idx = blockIdx.x;
|
const int64_t token_idx = blockIdx.x;
|
||||||
const int64_t slot_idx = slot_mapping[token_idx];
|
const int64_t slot_idx = slot_mapping[token_idx];
|
||||||
@ -184,23 +178,24 @@ __global__ void reshape_and_cache_kernel(
|
|||||||
const int x_idx = head_offset / x;
|
const int x_idx = head_offset / x;
|
||||||
const int x_offset = head_offset % x;
|
const int x_offset = head_offset % x;
|
||||||
|
|
||||||
const int64_t tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
|
const int64_t tgt_key_idx =
|
||||||
+ head_idx * (head_size / x) * block_size * x
|
block_idx * num_heads * (head_size / x) * block_size * x +
|
||||||
+ x_idx * block_size * x
|
head_idx * (head_size / x) * block_size * x + x_idx * block_size * x +
|
||||||
+ block_offset * x
|
block_offset * x + x_offset;
|
||||||
+ x_offset;
|
const int64_t tgt_value_idx =
|
||||||
const int64_t tgt_value_idx = block_idx * num_heads * head_size * block_size
|
block_idx * num_heads * head_size * block_size +
|
||||||
+ head_idx * head_size * block_size
|
head_idx * head_size * block_size + head_offset * block_size +
|
||||||
+ head_offset * block_size
|
block_offset;
|
||||||
+ block_offset;
|
|
||||||
scalar_t tgt_key = key[src_key_idx];
|
scalar_t tgt_key = key[src_key_idx];
|
||||||
scalar_t tgt_value = value[src_value_idx];
|
scalar_t tgt_value = value[src_value_idx];
|
||||||
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
|
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
|
||||||
key_cache[tgt_key_idx] = tgt_key;
|
key_cache[tgt_key_idx] = tgt_key;
|
||||||
value_cache[tgt_value_idx] = tgt_value;
|
value_cache[tgt_value_idx] = tgt_value;
|
||||||
} else {
|
} else {
|
||||||
key_cache[tgt_key_idx] = fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, kv_scale);
|
key_cache[tgt_key_idx] =
|
||||||
value_cache[tgt_value_idx] = fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, kv_scale);
|
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, kv_scale);
|
||||||
|
value_cache[tgt_value_idx] =
|
||||||
|
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, kv_scale);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -209,15 +204,13 @@ template<typename scalar_t>
|
|||||||
__global__ void reshape_and_cache_flash_kernel(
|
__global__ void reshape_and_cache_flash_kernel(
|
||||||
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
||||||
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
|
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
|
||||||
scalar_t* __restrict__ k_cache, // [num_blocks, block_size, num_heads, head_size]
|
scalar_t* __restrict__ k_cache, // [num_blocks, block_size, num_heads,
|
||||||
scalar_t* __restrict__ v_cache, // [num_blocks, block_size, num_heads, head_size]
|
// head_size]
|
||||||
|
scalar_t* __restrict__ v_cache, // [num_blocks, block_size, num_heads,
|
||||||
|
// head_size]
|
||||||
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||||
const int block_stride,
|
const int block_stride, const int key_stride, const int value_stride,
|
||||||
const int key_stride,
|
const int num_heads, const int head_size, const int block_size) {
|
||||||
const int value_stride,
|
|
||||||
const int num_heads,
|
|
||||||
const int head_size,
|
|
||||||
const int block_size) {
|
|
||||||
const int64_t token_idx = blockIdx.x;
|
const int64_t token_idx = blockIdx.x;
|
||||||
const int64_t slot_idx = slot_mapping[token_idx];
|
const int64_t slot_idx = slot_mapping[token_idx];
|
||||||
// NOTE: slot_idx can be -1 if the token is padded
|
// NOTE: slot_idx can be -1 if the token is padded
|
||||||
@ -232,10 +225,9 @@ __global__ void reshape_and_cache_flash_kernel(
|
|||||||
const int64_t src_value_idx = token_idx * value_stride + i;
|
const int64_t src_value_idx = token_idx * value_stride + i;
|
||||||
const int head_idx = i / head_size;
|
const int head_idx = i / head_size;
|
||||||
const int head_offset = i % head_size;
|
const int head_offset = i % head_size;
|
||||||
const int64_t tgt_value_idx = block_idx * block_stride
|
const int64_t tgt_value_idx = block_idx * block_stride +
|
||||||
+ block_offset * num_heads * head_size
|
block_offset * num_heads * head_size +
|
||||||
+ head_idx * head_size
|
head_idx * head_size + head_offset;
|
||||||
+ head_offset;
|
|
||||||
k_cache[tgt_value_idx] = key[src_key_idx];
|
k_cache[tgt_value_idx] = key[src_key_idx];
|
||||||
v_cache[tgt_value_idx] = value[src_value_idx];
|
v_cache[tgt_value_idx] = value[src_value_idx];
|
||||||
}
|
}
|
||||||
@ -246,29 +238,24 @@ __global__ void reshape_and_cache_flash_kernel(
|
|||||||
// CACHE_T is the data type of key and value tensors.
|
// CACHE_T is the data type of key and value tensors.
|
||||||
// KV_DTYPE is the real data type of kv-cache.
|
// KV_DTYPE is the real data type of kv-cache.
|
||||||
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \
|
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \
|
||||||
vllm::reshape_and_cache_kernel<KV_T, CACHE_T, KV_DTYPE><<<grid, block, 0, stream>>>( \
|
vllm::reshape_and_cache_kernel<KV_T, CACHE_T, KV_DTYPE> \
|
||||||
|
<<<grid, block, 0, stream>>>( \
|
||||||
reinterpret_cast<KV_T*>(key.data_ptr()), \
|
reinterpret_cast<KV_T*>(key.data_ptr()), \
|
||||||
reinterpret_cast<KV_T*>(value.data_ptr()), \
|
reinterpret_cast<KV_T*>(value.data_ptr()), \
|
||||||
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
|
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
|
||||||
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
|
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
|
||||||
slot_mapping.data_ptr<int64_t>(), \
|
slot_mapping.data_ptr<int64_t>(), key_stride, value_stride, \
|
||||||
key_stride, \
|
num_heads, head_size, block_size, x, kv_scale);
|
||||||
value_stride, \
|
|
||||||
num_heads, \
|
|
||||||
head_size, \
|
|
||||||
block_size, \
|
|
||||||
x, \
|
|
||||||
kv_scale);
|
|
||||||
|
|
||||||
void reshape_and_cache(
|
void reshape_and_cache(
|
||||||
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
||||||
torch::Tensor& value, // [num_tokens, num_heads, head_size]
|
torch::Tensor& value, // [num_tokens, num_heads, head_size]
|
||||||
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
torch::Tensor&
|
||||||
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||||
|
torch::Tensor&
|
||||||
|
value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||||
torch::Tensor& slot_mapping, // [num_tokens]
|
torch::Tensor& slot_mapping, // [num_tokens]
|
||||||
const std::string& kv_cache_dtype,
|
const std::string& kv_cache_dtype, const float kv_scale) {
|
||||||
const float kv_scale)
|
|
||||||
{
|
|
||||||
int num_tokens = key.size(0);
|
int num_tokens = key.size(0);
|
||||||
int num_heads = key.size(1);
|
int num_heads = key.size(1);
|
||||||
int head_size = key.size(2);
|
int head_size = key.size(2);
|
||||||
@ -283,7 +270,8 @@ void reshape_and_cache(
|
|||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype, CALL_RESHAPE_AND_CACHE)
|
DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
|
||||||
|
CALL_RESHAPE_AND_CACHE)
|
||||||
}
|
}
|
||||||
|
|
||||||
void reshape_and_cache_flash(
|
void reshape_and_cache_flash(
|
||||||
@ -292,8 +280,7 @@ void reshape_and_cache_flash(
|
|||||||
torch::Tensor& k_cache, // [num_blocks, block_size, num_heads, head_size]
|
torch::Tensor& k_cache, // [num_blocks, block_size, num_heads, head_size]
|
||||||
torch::Tensor& v_cache, // [num_blocks, block_size, num_heads, head_size]
|
torch::Tensor& v_cache, // [num_blocks, block_size, num_heads, head_size]
|
||||||
torch::Tensor& slot_mapping, // [num_tokens]
|
torch::Tensor& slot_mapping, // [num_tokens]
|
||||||
const std::string& kv_cache_dtype)
|
const std::string& kv_cache_dtype) {
|
||||||
{
|
|
||||||
// FIXME: only support auto datatype, does not support fp8
|
// FIXME: only support auto datatype, does not support fp8
|
||||||
if (kv_cache_dtype != "auto") {
|
if (kv_cache_dtype != "auto") {
|
||||||
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
|
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
|
||||||
@ -313,36 +300,28 @@ void reshape_and_cache_flash(
|
|||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(
|
VLLM_DISPATCH_FLOATING_TYPES(
|
||||||
key.scalar_type(),
|
key.scalar_type(), "reshape_and_cache_flash", [&] {
|
||||||
"reshape_and_cache_flash",
|
vllm::reshape_and_cache_flash_kernel<scalar_t>
|
||||||
[&] {
|
<<<grid, block, 0, stream>>>(
|
||||||
vllm::reshape_and_cache_flash_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
|
||||||
key.data_ptr<scalar_t>(),
|
k_cache.data_ptr<scalar_t>(), v_cache.data_ptr<scalar_t>(),
|
||||||
value.data_ptr<scalar_t>(),
|
slot_mapping.data_ptr<int64_t>(), block_stride, key_stride,
|
||||||
k_cache.data_ptr<scalar_t>(),
|
value_stride, num_heads, head_size, block_size);
|
||||||
v_cache.data_ptr<scalar_t>(),
|
|
||||||
slot_mapping.data_ptr<int64_t>(),
|
|
||||||
block_stride,
|
|
||||||
key_stride,
|
|
||||||
value_stride,
|
|
||||||
num_heads,
|
|
||||||
head_size,
|
|
||||||
block_size);
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
||||||
__global__ void convert_fp8_kernel(
|
__global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache,
|
||||||
const Tin* __restrict__ src_cache,
|
|
||||||
Tout* __restrict__ dst_cache,
|
Tout* __restrict__ dst_cache,
|
||||||
const float kv_scale,
|
const float kv_scale,
|
||||||
const int64_t block_stride) {
|
const int64_t block_stride) {
|
||||||
const int64_t block_idx = blockIdx.x;
|
const int64_t block_idx = blockIdx.x;
|
||||||
for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
|
for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
|
||||||
int64_t idx = block_idx * block_stride + i;
|
int64_t idx = block_idx * block_stride + i;
|
||||||
dst_cache[idx] = fp8::scaled_convert<Tout, Tin, kv_dt>(src_cache[idx], kv_scale);
|
dst_cache[idx] =
|
||||||
|
fp8::scaled_convert<Tout, Tin, kv_dt>(src_cache[idx], kv_scale);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -351,23 +330,16 @@ __global__ void convert_fp8_kernel(
|
|||||||
#define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE) \
|
#define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE) \
|
||||||
vllm::convert_fp8_kernel<Tout, Tin, KV_DTYPE><<<grid, block, 0, stream>>>( \
|
vllm::convert_fp8_kernel<Tout, Tin, KV_DTYPE><<<grid, block, 0, stream>>>( \
|
||||||
reinterpret_cast<Tin*>(src_cache.data_ptr()), \
|
reinterpret_cast<Tin*>(src_cache.data_ptr()), \
|
||||||
reinterpret_cast<Tout*>(dst_cache.data_ptr()), \
|
reinterpret_cast<Tout*>(dst_cache.data_ptr()), kv_scale, block_stride);
|
||||||
kv_scale, \
|
|
||||||
block_stride);
|
|
||||||
|
|
||||||
// Only for testing.
|
// Only for testing.
|
||||||
void convert_fp8(
|
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
|
||||||
torch::Tensor& dst_cache,
|
const float kv_scale, const std::string& kv_cache_dtype) {
|
||||||
torch::Tensor& src_cache,
|
|
||||||
const float kv_scale,
|
|
||||||
const std::string& kv_cache_dtype)
|
|
||||||
{
|
|
||||||
torch::Device src_device = src_cache.device();
|
torch::Device src_device = src_cache.device();
|
||||||
torch::Device dst_device = dst_cache.device();
|
torch::Device dst_device = dst_cache.device();
|
||||||
TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU")
|
TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU")
|
||||||
TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU")
|
TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU")
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(src_device.index() == dst_device.index(),
|
||||||
src_device.index() == dst_device.index(),
|
|
||||||
"src and dst must be on the same GPU");
|
"src and dst must be on the same GPU");
|
||||||
at::cuda::OptionalCUDAGuard device_guard(src_device);
|
at::cuda::OptionalCUDAGuard device_guard(src_device);
|
||||||
|
|
||||||
@ -398,13 +370,15 @@ void convert_fp8(
|
|||||||
} else if (src_cache.dtype() == at::ScalarType::Half) {
|
} else if (src_cache.dtype() == at::ScalarType::Half) {
|
||||||
CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
|
CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
|
||||||
} else if (src_cache.dtype() == at::ScalarType::BFloat16) {
|
} else if (src_cache.dtype() == at::ScalarType::BFloat16) {
|
||||||
CALL_CONVERT_FP8(uint8_t, __nv_bfloat16, vllm::Fp8KVCacheDataType::kFp8E4M3);
|
CALL_CONVERT_FP8(uint8_t, __nv_bfloat16,
|
||||||
|
vllm::Fp8KVCacheDataType::kFp8E4M3);
|
||||||
} else if (dst_cache.dtype() == at::ScalarType::Float) {
|
} else if (dst_cache.dtype() == at::ScalarType::Float) {
|
||||||
CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
|
CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
|
||||||
} else if (dst_cache.dtype() == at::ScalarType::Half) {
|
} else if (dst_cache.dtype() == at::ScalarType::Half) {
|
||||||
CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
|
CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
|
||||||
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
|
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
|
||||||
CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
|
CALL_CONVERT_FP8(__nv_bfloat16, uint8_t,
|
||||||
|
vllm::Fp8KVCacheDataType::kFp8E4M3);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype);
|
TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype);
|
||||||
|
@ -81,12 +81,10 @@ void silu_and_mul(torch::Tensor &out, torch::Tensor &input) {
|
|||||||
int num_tokens = input.numel() / input.size(-1);
|
int num_tokens = input.numel() / input.size(-1);
|
||||||
int d = input.size(-1) / 2;
|
int d = input.size(-1) / 2;
|
||||||
|
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(
|
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "silu_and_mul_impl", [&] {
|
||||||
input.scalar_type(), "silu_and_mul_impl", [&] {
|
|
||||||
CPU_KERNEL_GUARD_IN(silu_and_mul_impl)
|
CPU_KERNEL_GUARD_IN(silu_and_mul_impl)
|
||||||
activation_kernel<scalar_t, silu_act, true>(num_tokens, d,
|
activation_kernel<scalar_t, silu_act, true>(
|
||||||
input.data_ptr<scalar_t>(),
|
num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
|
||||||
out.data_ptr<scalar_t>());
|
|
||||||
CPU_KERNEL_GUARD_OUT(silu_and_mul_impl)
|
CPU_KERNEL_GUARD_OUT(silu_and_mul_impl)
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@ -97,12 +95,10 @@ void gelu_and_mul(torch::Tensor &out, // [..., d]
|
|||||||
int num_tokens = input.numel() / input.size(-1);
|
int num_tokens = input.numel() / input.size(-1);
|
||||||
int d = input.size(-1) / 2;
|
int d = input.size(-1) / 2;
|
||||||
|
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(
|
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_and_mul_impl", [&] {
|
||||||
input.scalar_type(), "gelu_and_mul_impl", [&] {
|
|
||||||
CPU_KERNEL_GUARD_IN(gelu_and_mul_impl)
|
CPU_KERNEL_GUARD_IN(gelu_and_mul_impl)
|
||||||
activation_kernel<scalar_t, gelu_act, true>(num_tokens, d,
|
activation_kernel<scalar_t, gelu_act, true>(
|
||||||
input.data_ptr<scalar_t>(),
|
num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
|
||||||
out.data_ptr<scalar_t>());
|
|
||||||
CPU_KERNEL_GUARD_OUT(gelu_and_mul_impl)
|
CPU_KERNEL_GUARD_OUT(gelu_and_mul_impl)
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@ -2,7 +2,8 @@
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
template <typename scalar_t> struct KernelVecType {
|
template <typename scalar_t>
|
||||||
|
struct KernelVecType {
|
||||||
using q_load_vec_type = void;
|
using q_load_vec_type = void;
|
||||||
using q_vec_type = void;
|
using q_vec_type = void;
|
||||||
using k_load_vec_type = void;
|
using k_load_vec_type = void;
|
||||||
@ -11,7 +12,8 @@ template <typename scalar_t> struct KernelVecType {
|
|||||||
using v_load_vec_type = void;
|
using v_load_vec_type = void;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <> struct KernelVecType<float> {
|
template <>
|
||||||
|
struct KernelVecType<float> {
|
||||||
using q_load_vec_type = vec_op::FP32Vec4;
|
using q_load_vec_type = vec_op::FP32Vec4;
|
||||||
using q_vec_type = vec_op::FP32Vec16;
|
using q_vec_type = vec_op::FP32Vec16;
|
||||||
using k_load_vec_type = vec_op::FP32Vec16;
|
using k_load_vec_type = vec_op::FP32Vec16;
|
||||||
@ -21,7 +23,8 @@ template <> struct KernelVecType<float> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
#ifdef __AVX512BF16__
|
#ifdef __AVX512BF16__
|
||||||
template <> struct KernelVecType<c10::BFloat16> {
|
template <>
|
||||||
|
struct KernelVecType<c10::BFloat16> {
|
||||||
using q_load_vec_type = vec_op::BF16Vec8;
|
using q_load_vec_type = vec_op::BF16Vec8;
|
||||||
using q_vec_type = vec_op::BF16Vec32;
|
using q_vec_type = vec_op::BF16Vec32;
|
||||||
using k_load_vec_type = vec_op::BF16Vec32;
|
using k_load_vec_type = vec_op::BF16Vec32;
|
||||||
@ -30,7 +33,8 @@ template <> struct KernelVecType<c10::BFloat16> {
|
|||||||
using v_load_vec_type = vec_op::BF16Vec16;
|
using v_load_vec_type = vec_op::BF16Vec16;
|
||||||
};
|
};
|
||||||
#else
|
#else
|
||||||
template <> struct KernelVecType<c10::BFloat16> {
|
template <>
|
||||||
|
struct KernelVecType<c10::BFloat16> {
|
||||||
using q_load_vec_type = vec_op::BF16Vec8;
|
using q_load_vec_type = vec_op::BF16Vec8;
|
||||||
using q_vec_type = vec_op::FP32Vec16;
|
using q_vec_type = vec_op::FP32Vec16;
|
||||||
using k_load_vec_type = vec_op::BF16Vec16;
|
using k_load_vec_type = vec_op::BF16Vec16;
|
||||||
@ -67,9 +71,10 @@ FORCE_INLINE std::pair<T, T> reduceSoftmax(T *data, const int size,
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
FORCE_INLINE std::pair<T, T>
|
FORCE_INLINE std::pair<T, T> reduceSoftmaxAlibi(T* data, const int size,
|
||||||
reduceSoftmaxAlibi(T *data, const int size, const int capacity,
|
const int capacity,
|
||||||
const float alibi_slope, const int start_index,
|
const float alibi_slope,
|
||||||
|
const int start_index,
|
||||||
const int seq_len) {
|
const int seq_len) {
|
||||||
data[0] += alibi_slope * (start_index - seq_len + 1);
|
data[0] += alibi_slope * (start_index - seq_len + 1);
|
||||||
T max = data[0];
|
T max = data[0];
|
||||||
@ -215,16 +220,16 @@ FORCE_INLINE void reduceValueBlock(const float *prob, const scalar_t *v_block,
|
|||||||
namespace {
|
namespace {
|
||||||
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE>
|
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE>
|
||||||
struct paged_attention_v1_impl {
|
struct paged_attention_v1_impl {
|
||||||
static void
|
static void call(
|
||||||
call(scalar_t *__restrict__ out, // [num_seqs, num_heads, head_size]
|
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
||||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||||
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
|
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
|
||||||
// head_size/x, block_size, x]
|
// head_size/x, block_size, x]
|
||||||
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
|
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
|
||||||
// head_size, block_size]
|
// head_size, block_size]
|
||||||
const int num_kv_heads, const float scale,
|
const int num_kv_heads, const float scale,
|
||||||
const int
|
const int* __restrict__ block_tables, // [num_seqs,
|
||||||
*__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
// max_num_blocks_per_seq]
|
||||||
const int* __restrict__ seq_lens, // [num_seqs]
|
const int* __restrict__ seq_lens, // [num_seqs]
|
||||||
const int max_num_blocks_per_seq,
|
const int max_num_blocks_per_seq,
|
||||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||||
@ -257,8 +262,7 @@ struct paged_attention_v1_impl {
|
|||||||
const int64_t kv_head_idx = head_idx / num_queries_per_kv;
|
const int64_t kv_head_idx = head_idx / num_queries_per_kv;
|
||||||
const scalar_t* __restrict__ q_vec_ptr =
|
const scalar_t* __restrict__ q_vec_ptr =
|
||||||
q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
||||||
const int last_block_token_num =
|
const int last_block_token_num = seq_len - (block_num - 1) * BLOCK_SIZE;
|
||||||
seq_len - (block_num - 1) * BLOCK_SIZE;
|
|
||||||
float* __restrict__ thread_block_logits =
|
float* __restrict__ thread_block_logits =
|
||||||
logits + omp_get_thread_num() * max_seq_len_padded;
|
logits + omp_get_thread_num() * max_seq_len_padded;
|
||||||
|
|
||||||
@ -282,8 +286,7 @@ struct paged_attention_v1_impl {
|
|||||||
block_num * BLOCK_SIZE, alibi_slopes[head_idx], 0,
|
block_num * BLOCK_SIZE, alibi_slopes[head_idx], 0,
|
||||||
seq_len);
|
seq_len);
|
||||||
} else {
|
} else {
|
||||||
reduceSoftmax(thread_block_logits, seq_len,
|
reduceSoftmax(thread_block_logits, seq_len, block_num * BLOCK_SIZE);
|
||||||
block_num * BLOCK_SIZE);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compute value
|
// Compute value
|
||||||
@ -348,8 +351,8 @@ template <typename T, int BLOCK_SIZE>
|
|||||||
void paged_attention_v1_impl_launcher(
|
void paged_attention_v1_impl_launcher(
|
||||||
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
|
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||||
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
||||||
torch::Tensor &block_tables, torch::Tensor &seq_lens,
|
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
|
||||||
int max_seq_len, const c10::optional<torch::Tensor> &alibi_slopes) {
|
const c10::optional<torch::Tensor>& alibi_slopes) {
|
||||||
int num_seqs = query.size(0);
|
int num_seqs = query.size(0);
|
||||||
int num_heads = query.size(1);
|
int num_heads = query.size(1);
|
||||||
int head_size = query.size(2);
|
int head_size = query.size(2);
|
||||||
@ -415,9 +418,8 @@ void paged_attention_v1_impl_launcher(
|
|||||||
void paged_attention_v1(torch::Tensor& out, torch::Tensor& query,
|
void paged_attention_v1(torch::Tensor& out, torch::Tensor& query,
|
||||||
torch::Tensor& key_cache, torch::Tensor& value_cache,
|
torch::Tensor& key_cache, torch::Tensor& value_cache,
|
||||||
int num_kv_heads, float scale,
|
int num_kv_heads, float scale,
|
||||||
torch::Tensor &block_tables,
|
torch::Tensor& block_tables, torch::Tensor& seq_lens,
|
||||||
torch::Tensor &seq_lens, int block_size,
|
int block_size, int max_seq_len,
|
||||||
int max_seq_len,
|
|
||||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||||
const std::string& kv_cache_dtype, float kv_scale) {
|
const std::string& kv_cache_dtype, float kv_scale) {
|
||||||
TORCH_CHECK(kv_scale == 1.0f);
|
TORCH_CHECK(kv_scale == 1.0f);
|
||||||
@ -435,9 +437,10 @@ template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, int PARTITION_SIZE>
|
|||||||
struct paged_attention_v2_impl {
|
struct paged_attention_v2_impl {
|
||||||
static void call(
|
static void call(
|
||||||
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
||||||
float *__restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
float* __restrict__ exp_sums, // [num_seqs, num_heads,
|
||||||
float
|
// max_num_partitions]
|
||||||
*__restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
|
float* __restrict__ max_logits, // [num_seqs, num_heads,
|
||||||
|
// max_num_partitions]
|
||||||
scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
|
scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
|
||||||
// max_num_partitions, head_size]
|
// max_num_partitions, head_size]
|
||||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||||
@ -446,8 +449,8 @@ struct paged_attention_v2_impl {
|
|||||||
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
|
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
|
||||||
// head_size, block_size]
|
// head_size, block_size]
|
||||||
const int num_kv_heads, const float scale,
|
const int num_kv_heads, const float scale,
|
||||||
const int
|
const int* __restrict__ block_tables, // [num_seqs,
|
||||||
*__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
// max_num_blocks_per_seq]
|
||||||
const int* __restrict__ seq_lens, // [num_seqs]
|
const int* __restrict__ seq_lens, // [num_seqs]
|
||||||
const int max_num_blocks_per_seq,
|
const int max_num_blocks_per_seq,
|
||||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||||
@ -468,8 +471,7 @@ struct paged_attention_v2_impl {
|
|||||||
const int seq_len = seq_lens[seq_idx];
|
const int seq_len = seq_lens[seq_idx];
|
||||||
const int start_token_idx = partition_idx * PARTITION_SIZE;
|
const int start_token_idx = partition_idx * PARTITION_SIZE;
|
||||||
|
|
||||||
if (start_token_idx >= seq_len)
|
if (start_token_idx >= seq_len) continue;
|
||||||
continue;
|
|
||||||
|
|
||||||
const int partition_num =
|
const int partition_num =
|
||||||
(seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
|
(seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
|
||||||
@ -477,8 +479,7 @@ struct paged_attention_v2_impl {
|
|||||||
const int token_num =
|
const int token_num =
|
||||||
(std::min(seq_len, start_token_idx + PARTITION_SIZE) -
|
(std::min(seq_len, start_token_idx + PARTITION_SIZE) -
|
||||||
start_token_idx);
|
start_token_idx);
|
||||||
const int block_num =
|
const int block_num = (token_num + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||||
(token_num + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
|
||||||
const int last_block_token_num =
|
const int last_block_token_num =
|
||||||
token_num - (block_num - 1) * BLOCK_SIZE;
|
token_num - (block_num - 1) * BLOCK_SIZE;
|
||||||
const int* seq_block_table = block_tables +
|
const int* seq_block_table = block_tables +
|
||||||
@ -510,8 +511,8 @@ struct paged_attention_v2_impl {
|
|||||||
logits, token_num, block_num * BLOCK_SIZE,
|
logits, token_num, block_num * BLOCK_SIZE,
|
||||||
alibi_slopes[head_idx], start_token_idx, seq_len);
|
alibi_slopes[head_idx], start_token_idx, seq_len);
|
||||||
} else {
|
} else {
|
||||||
max_and_sum = reduceSoftmax(logits, token_num,
|
max_and_sum =
|
||||||
block_num * BLOCK_SIZE);
|
reduceSoftmax(logits, token_num, block_num * BLOCK_SIZE);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto&& [max_logit, exp_sum] = max_and_sum;
|
auto&& [max_logit, exp_sum] = max_and_sum;
|
||||||
@ -587,8 +588,7 @@ struct paged_attention_v2_impl {
|
|||||||
const int partition_num =
|
const int partition_num =
|
||||||
(seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
|
(seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
|
||||||
|
|
||||||
if (partition_num == 1)
|
if (partition_num == 1) continue;
|
||||||
continue;
|
|
||||||
|
|
||||||
reducePartitonSoftmax(
|
reducePartitonSoftmax(
|
||||||
max_logits + seq_idx * num_heads * max_num_partitions +
|
max_logits + seq_idx * num_heads * max_num_partitions +
|
||||||
@ -603,8 +603,8 @@ struct paged_attention_v2_impl {
|
|||||||
using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type;
|
using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type;
|
||||||
static_assert(v_load_vec_type::get_elem_num() == BLOCK_SIZE);
|
static_assert(v_load_vec_type::get_elem_num() == BLOCK_SIZE);
|
||||||
constexpr int head_elem_num_per_group =
|
constexpr int head_elem_num_per_group =
|
||||||
16; // Note: didn't align with the cacheline size, due to some HEAD_SIZE
|
16; // Note: didn't align with the cacheline size, due to some
|
||||||
// didn't align with 64 bytes
|
// HEAD_SIZE didn't align with 64 bytes
|
||||||
static_assert(HEAD_SIZE % head_elem_num_per_group == 0);
|
static_assert(HEAD_SIZE % head_elem_num_per_group == 0);
|
||||||
constexpr int head_group_num = HEAD_SIZE / head_elem_num_per_group;
|
constexpr int head_group_num = HEAD_SIZE / head_elem_num_per_group;
|
||||||
const float* __restrict__ rescale_factors = exp_sums;
|
const float* __restrict__ rescale_factors = exp_sums;
|
||||||
@ -616,8 +616,7 @@ struct paged_attention_v2_impl {
|
|||||||
const int partition_num =
|
const int partition_num =
|
||||||
(seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
|
(seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
|
||||||
|
|
||||||
if (partition_num == 1)
|
if (partition_num == 1) continue;
|
||||||
continue;
|
|
||||||
|
|
||||||
const float* __restrict__ seq_head_rescale_factors =
|
const float* __restrict__ seq_head_rescale_factors =
|
||||||
rescale_factors + seq_idx * num_heads * max_num_partitions +
|
rescale_factors + seq_idx * num_heads * max_num_partitions +
|
||||||
@ -713,8 +712,8 @@ void paged_attention_v2_impl_launcher(
|
|||||||
#define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
|
#define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
|
||||||
paged_attention_v2_impl_launcher<T, BLOCK_SIZE>( \
|
paged_attention_v2_impl_launcher<T, BLOCK_SIZE>( \
|
||||||
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
|
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
|
||||||
num_kv_heads, scale, block_tables, seq_lens, block_size, \
|
num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, \
|
||||||
max_seq_len, alibi_slopes);
|
alibi_slopes);
|
||||||
|
|
||||||
#define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
|
#define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
|
||||||
switch (block_size) { \
|
switch (block_size) { \
|
||||||
|
@ -5,17 +5,18 @@
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
void copy_blocks_cpu_impl(
|
void copy_blocks_cpu_impl(std::vector<torch::Tensor>& key_caches,
|
||||||
std::vector<torch::Tensor> &key_caches,
|
|
||||||
std::vector<torch::Tensor>& value_caches,
|
std::vector<torch::Tensor>& value_caches,
|
||||||
const torch::Tensor& mapping_pairs,
|
const torch::Tensor& mapping_pairs,
|
||||||
const int element_num_per_block, const int layer_num) {
|
const int element_num_per_block,
|
||||||
|
const int layer_num) {
|
||||||
const size_t pair_num = mapping_pairs.size(0);
|
const size_t pair_num = mapping_pairs.size(0);
|
||||||
const size_t block_bytes = sizeof(scalar_t) * element_num_per_block;
|
const size_t block_bytes = sizeof(scalar_t) * element_num_per_block;
|
||||||
#pragma omp parallel for collapse(2)
|
#pragma omp parallel for collapse(2)
|
||||||
for (int layer = 0; layer < layer_num; ++layer) {
|
for (int layer = 0; layer < layer_num; ++layer) {
|
||||||
for (size_t pair = 0; pair < pair_num; ++pair) {
|
for (size_t pair = 0; pair < pair_num; ++pair) {
|
||||||
int64_t source_offset = element_num_per_block * mapping_pairs[pair][0].item<int64_t>();
|
int64_t source_offset =
|
||||||
|
element_num_per_block * mapping_pairs[pair][0].item<int64_t>();
|
||||||
int64_t target_offset =
|
int64_t target_offset =
|
||||||
element_num_per_block * mapping_pairs[pair][1].item<int64_t>();
|
element_num_per_block * mapping_pairs[pair][1].item<int64_t>();
|
||||||
scalar_t* key_cache_ptr = key_caches[layer].data_ptr<scalar_t>();
|
scalar_t* key_cache_ptr = key_caches[layer].data_ptr<scalar_t>();
|
||||||
|
@ -87,8 +87,8 @@ void fused_add_rms_norm_impl(scalar_t *__restrict__ input,
|
|||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void rms_norm(torch::Tensor &out, torch::Tensor &input,
|
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
|
||||||
torch::Tensor &weight, float epsilon) {
|
float epsilon) {
|
||||||
int hidden_size = input.size(-1);
|
int hidden_size = input.size(-1);
|
||||||
int num_tokens = input.numel() / hidden_size;
|
int num_tokens = input.numel() / hidden_size;
|
||||||
|
|
||||||
|
@ -4,16 +4,16 @@
|
|||||||
namespace {
|
namespace {
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
void rotary_embedding_impl(
|
void rotary_embedding_impl(
|
||||||
const int64_t
|
const int64_t* __restrict__ positions, // [batch_size, seq_len] or
|
||||||
*__restrict__ positions, // [batch_size, seq_len] or [num_tokens]
|
// [num_tokens]
|
||||||
scalar_t
|
scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads,
|
||||||
*__restrict__ query, /// [batch_size, seq_len, num_heads, head_size] or
|
/// head_size] or [num_tokens, num_heads,
|
||||||
/// [num_tokens, num_heads, head_size]
|
/// head_size]
|
||||||
scalar_t
|
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
|
||||||
*__restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or
|
// head_size] or [num_tokens, num_kv_heads,
|
||||||
// [num_tokens, num_kv_heads, head_size]
|
// head_size]
|
||||||
const scalar_t
|
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
|
||||||
*__restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
|
// 2]
|
||||||
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
|
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
|
||||||
const int num_heads, const int num_kv_heads, const int head_size,
|
const int num_heads, const int num_kv_heads, const int head_size,
|
||||||
const int num_tokens) {
|
const int num_tokens) {
|
||||||
@ -94,16 +94,16 @@ void rotary_embedding_impl(
|
|||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
void rotary_embedding_gptj_impl(
|
void rotary_embedding_gptj_impl(
|
||||||
const int64_t
|
const int64_t* __restrict__ positions, // [batch_size, seq_len] or
|
||||||
*__restrict__ positions, // [batch_size, seq_len] or [num_tokens]
|
// [num_tokens]
|
||||||
scalar_t
|
scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads,
|
||||||
*__restrict__ query, /// [batch_size, seq_len, num_heads, head_size] or
|
/// head_size] or [num_tokens, num_heads,
|
||||||
/// [num_tokens, num_heads, head_size]
|
/// head_size]
|
||||||
scalar_t
|
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
|
||||||
*__restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or
|
// head_size] or [num_tokens, num_kv_heads,
|
||||||
// [num_tokens, num_kv_heads, head_size]
|
// head_size]
|
||||||
const scalar_t
|
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
|
||||||
*__restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
|
// 2]
|
||||||
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
|
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
|
||||||
const int num_heads, const int num_kv_heads, const int head_size,
|
const int num_heads, const int num_kv_heads, const int head_size,
|
||||||
const int num_tokens) {
|
const int num_tokens) {
|
||||||
|
@ -8,66 +8,37 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||||||
pybind11::module ops = m.def_submodule("ops", "vLLM custom operators");
|
pybind11::module ops = m.def_submodule("ops", "vLLM custom operators");
|
||||||
|
|
||||||
// Attention ops
|
// Attention ops
|
||||||
ops.def(
|
ops.def("paged_attention_v1", &paged_attention_v1,
|
||||||
"paged_attention_v1",
|
"Compute the attention between an input query and the cached "
|
||||||
&paged_attention_v1,
|
"keys/values using PagedAttention.");
|
||||||
"Compute the attention between an input query and the cached keys/values using PagedAttention.");
|
ops.def("paged_attention_v2", &paged_attention_v2, "PagedAttention V2.");
|
||||||
ops.def(
|
|
||||||
"paged_attention_v2",
|
|
||||||
&paged_attention_v2,
|
|
||||||
"PagedAttention V2.");
|
|
||||||
|
|
||||||
// Activation ops
|
// Activation ops
|
||||||
ops.def(
|
ops.def("silu_and_mul", &silu_and_mul, "Activation function used in SwiGLU.");
|
||||||
"silu_and_mul",
|
ops.def("gelu_and_mul", &gelu_and_mul,
|
||||||
&silu_and_mul,
|
|
||||||
"Activation function used in SwiGLU.");
|
|
||||||
ops.def(
|
|
||||||
"gelu_and_mul",
|
|
||||||
&gelu_and_mul,
|
|
||||||
"Activation function used in GeGLU with `none` approximation.");
|
"Activation function used in GeGLU with `none` approximation.");
|
||||||
ops.def(
|
ops.def("gelu_tanh_and_mul", &gelu_tanh_and_mul,
|
||||||
"gelu_tanh_and_mul",
|
|
||||||
&gelu_tanh_and_mul,
|
|
||||||
"Activation function used in GeGLU with `tanh` approximation.");
|
"Activation function used in GeGLU with `tanh` approximation.");
|
||||||
ops.def(
|
ops.def("gelu_new", &gelu_new, "GELU implementation used in GPT-2.");
|
||||||
"gelu_new",
|
ops.def("gelu_fast", &gelu_fast, "Approximate GELU implementation.");
|
||||||
&gelu_new,
|
|
||||||
"GELU implementation used in GPT-2.");
|
|
||||||
ops.def(
|
|
||||||
"gelu_fast",
|
|
||||||
&gelu_fast,
|
|
||||||
"Approximate GELU implementation.");
|
|
||||||
|
|
||||||
// Layernorm
|
// Layernorm
|
||||||
ops.def(
|
ops.def("rms_norm", &rms_norm,
|
||||||
"rms_norm",
|
|
||||||
&rms_norm,
|
|
||||||
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
|
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
|
||||||
|
|
||||||
ops.def(
|
ops.def("fused_add_rms_norm", &fused_add_rms_norm,
|
||||||
"fused_add_rms_norm",
|
|
||||||
&fused_add_rms_norm,
|
|
||||||
"In-place fused Add and RMS Normalization");
|
"In-place fused Add and RMS Normalization");
|
||||||
|
|
||||||
// Rotary embedding
|
// Rotary embedding
|
||||||
ops.def(
|
ops.def("rotary_embedding", &rotary_embedding,
|
||||||
"rotary_embedding",
|
|
||||||
&rotary_embedding,
|
|
||||||
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
|
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
|
||||||
|
|
||||||
// Cache ops
|
// Cache ops
|
||||||
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
|
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
|
||||||
cache_ops.def(
|
cache_ops.def("swap_blocks", &swap_blocks,
|
||||||
"swap_blocks",
|
|
||||||
&swap_blocks,
|
|
||||||
"Swap in (out) the cache blocks from src to dst");
|
"Swap in (out) the cache blocks from src to dst");
|
||||||
cache_ops.def(
|
cache_ops.def("copy_blocks", ©_blocks,
|
||||||
"copy_blocks",
|
|
||||||
©_blocks,
|
|
||||||
"Copy the cache blocks from src to dst");
|
"Copy the cache blocks from src to dst");
|
||||||
cache_ops.def(
|
cache_ops.def("reshape_and_cache", &reshape_and_cache,
|
||||||
"reshape_and_cache",
|
|
||||||
&reshape_and_cache,
|
|
||||||
"Reshape the key and value tensors and cache them");
|
"Reshape the key and value tensors and cache them");
|
||||||
}
|
}
|
||||||
|
@ -17,7 +17,8 @@
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask)
|
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) \
|
||||||
|
__shfl_xor_sync(uint32_t(-1), var, lane_mask)
|
||||||
#else
|
#else
|
||||||
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask)
|
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask)
|
||||||
#endif
|
#endif
|
||||||
@ -29,7 +30,8 @@
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
#define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down_sync(uint32_t(-1), var, lane_delta)
|
#define VLLM_SHFL_DOWN_SYNC(var, lane_delta) \
|
||||||
|
__shfl_down_sync(uint32_t(-1), var, lane_delta)
|
||||||
#else
|
#else
|
||||||
#define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down(var, lane_delta)
|
#define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down(var, lane_delta)
|
||||||
#endif
|
#endif
|
||||||
@ -41,4 +43,3 @@
|
|||||||
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
|
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
|
||||||
hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL)
|
hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
@ -2,9 +2,6 @@
|
|||||||
|
|
||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
|
|
||||||
int get_device_attribute(
|
int get_device_attribute(int attribute, int device_id);
|
||||||
int attribute,
|
|
||||||
int device_id);
|
|
||||||
|
|
||||||
int get_max_shared_memory_per_block_device_attribute(
|
int get_max_shared_memory_per_block_device_attribute(int device_id);
|
||||||
int device_id);
|
|
||||||
|
@ -2,25 +2,19 @@
|
|||||||
#include <hip/hip_runtime.h>
|
#include <hip/hip_runtime.h>
|
||||||
#include <hip/hip_runtime_api.h>
|
#include <hip/hip_runtime_api.h>
|
||||||
#endif
|
#endif
|
||||||
int get_device_attribute(
|
int get_device_attribute(int attribute, int device_id) {
|
||||||
int attribute,
|
|
||||||
int device_id)
|
|
||||||
{
|
|
||||||
int device, value;
|
int device, value;
|
||||||
if (device_id < 0) {
|
if (device_id < 0) {
|
||||||
cudaGetDevice(&device);
|
cudaGetDevice(&device);
|
||||||
}
|
} else {
|
||||||
else {
|
|
||||||
device = device_id;
|
device = device_id;
|
||||||
}
|
}
|
||||||
cudaDeviceGetAttribute(&value, static_cast<cudaDeviceAttr>(attribute), device);
|
cudaDeviceGetAttribute(&value, static_cast<cudaDeviceAttr>(attribute),
|
||||||
|
device);
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int get_max_shared_memory_per_block_device_attribute(int device_id) {
|
||||||
int get_max_shared_memory_per_block_device_attribute(
|
|
||||||
int device_id)
|
|
||||||
{
|
|
||||||
int attribute;
|
int attribute;
|
||||||
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
|
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
|
||||||
// cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
|
// cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
|
||||||
|
@ -80,8 +80,7 @@ void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out,
|
|||||||
}
|
}
|
||||||
case at::ScalarType::Half: {
|
case at::ScalarType::Half: {
|
||||||
fa->allreduce<half>(stream, reinterpret_cast<half*>(inp.data_ptr()),
|
fa->allreduce<half>(stream, reinterpret_cast<half*>(inp.data_ptr()),
|
||||||
reinterpret_cast<half *>(out.data_ptr()),
|
reinterpret_cast<half*>(out.data_ptr()), out.numel());
|
||||||
out.numel());
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
||||||
|
@ -137,8 +137,7 @@ DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg,
|
|||||||
// Latency = 1 p2p write
|
// Latency = 1 p2p write
|
||||||
sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1;
|
sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1;
|
||||||
// wait until we got true from all ranks
|
// wait until we got true from all ranks
|
||||||
while (!self_sg->start[blockIdx.x][threadIdx.x])
|
while (!self_sg->start[blockIdx.x][threadIdx.x]);
|
||||||
;
|
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
}
|
}
|
||||||
@ -162,8 +161,7 @@ DINLINE void end_sync(const RankSignals &sg, volatile Signal *self_sg,
|
|||||||
// Latency = 1 p2p write
|
// Latency = 1 p2p write
|
||||||
sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1;
|
sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1;
|
||||||
// wait until we got true from all ranks
|
// wait until we got true from all ranks
|
||||||
while (!self_sg->end[blockIdx.x][threadIdx.x])
|
while (!self_sg->end[blockIdx.x][threadIdx.x]);
|
||||||
;
|
|
||||||
}
|
}
|
||||||
if constexpr (!final_sync) __syncthreads();
|
if constexpr (!final_sync) __syncthreads();
|
||||||
}
|
}
|
||||||
@ -192,8 +190,7 @@ __global__ void __launch_bounds__(512, 1)
|
|||||||
// do the actual reduction
|
// do the actual reduction
|
||||||
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
||||||
idx += gridDim.x * blockDim.x) {
|
idx += gridDim.x * blockDim.x) {
|
||||||
((P *)result)[idx] =
|
((P*)result)[idx] = packed_reduce<P, ngpus, A>((const P**)&dp.ptrs[0], idx);
|
||||||
packed_reduce<P, ngpus, A>((const P **)&dp.ptrs[0], idx);
|
|
||||||
}
|
}
|
||||||
end_sync<ngpus, true>(sg, self_sg, rank);
|
end_sync<ngpus, true>(sg, self_sg, rank);
|
||||||
}
|
}
|
||||||
|
@ -12,8 +12,7 @@
|
|||||||
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
|
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
|
||||||
|
|
||||||
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
||||||
AT_DISPATCH_SWITCH( \
|
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
||||||
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
|
||||||
|
|
||||||
#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
|
#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
|
||||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||||
@ -22,8 +21,8 @@
|
|||||||
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
|
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
|
||||||
|
|
||||||
#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \
|
#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \
|
||||||
AT_DISPATCH_SWITCH( \
|
AT_DISPATCH_SWITCH(TYPE, NAME, \
|
||||||
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))
|
VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))
|
||||||
|
|
||||||
#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \
|
#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \
|
||||||
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
|
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
|
||||||
@ -33,5 +32,4 @@
|
|||||||
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
|
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
|
||||||
|
|
||||||
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
|
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
|
||||||
AT_DISPATCH_SWITCH( \
|
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
|
||||||
TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
|
|
||||||
|
@ -23,9 +23,7 @@ __global__ void rms_norm_kernel(
|
|||||||
scalar_t* __restrict__ out, // [..., hidden_size]
|
scalar_t* __restrict__ out, // [..., hidden_size]
|
||||||
const scalar_t* __restrict__ input, // [..., hidden_size]
|
const scalar_t* __restrict__ input, // [..., hidden_size]
|
||||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||||
const float epsilon,
|
const float epsilon, const int num_tokens, const int hidden_size) {
|
||||||
const int num_tokens,
|
|
||||||
const int hidden_size) {
|
|
||||||
__shared__ float s_variance;
|
__shared__ float s_variance;
|
||||||
float variance = 0.0f;
|
float variance = 0.0f;
|
||||||
|
|
||||||
@ -41,11 +39,11 @@ __global__ void rms_norm_kernel(
|
|||||||
|
|
||||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||||
float x = (float)input[blockIdx.x * hidden_size + idx];
|
float x = (float)input[blockIdx.x * hidden_size + idx];
|
||||||
out[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx];
|
out[blockIdx.x * hidden_size + idx] =
|
||||||
|
((scalar_t)(x * s_variance)) * weight[idx];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/* Converter structs for the conversion from torch types to HIP/CUDA types,
|
/* Converter structs for the conversion from torch types to HIP/CUDA types,
|
||||||
and the associated type conversions within HIP/CUDA. These helpers need
|
and the associated type conversions within HIP/CUDA. These helpers need
|
||||||
to be implemented for now because the relevant type conversion
|
to be implemented for now because the relevant type conversion
|
||||||
@ -57,7 +55,9 @@ __global__ void rms_norm_kernel(
|
|||||||
If true, the struct should be fully defined as shown in the examples below.
|
If true, the struct should be fully defined as shown in the examples below.
|
||||||
*/
|
*/
|
||||||
template <typename torch_type>
|
template <typename torch_type>
|
||||||
struct _typeConvert { static constexpr bool exists = false; };
|
struct _typeConvert {
|
||||||
|
static constexpr bool exists = false;
|
||||||
|
};
|
||||||
|
|
||||||
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
|
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
|
||||||
// CUDA < 12.0 runs into issues with packed type conversion
|
// CUDA < 12.0 runs into issues with packed type conversion
|
||||||
@ -68,9 +68,15 @@ struct _typeConvert<c10::Half> {
|
|||||||
using packed_hip_type = __half2;
|
using packed_hip_type = __half2;
|
||||||
|
|
||||||
__device__ static inline float convert(hip_type x) { return __half2float(x); }
|
__device__ static inline float convert(hip_type x) { return __half2float(x); }
|
||||||
__device__ static inline float2 convert(packed_hip_type x) { return __half22float2(x); }
|
__device__ static inline float2 convert(packed_hip_type x) {
|
||||||
__device__ static inline hip_type convert(float x) { return __float2half_rn(x); }
|
return __half22float2(x);
|
||||||
__device__ static inline packed_hip_type convert(float2 x) { return __float22half2_rn(x); }
|
}
|
||||||
|
__device__ static inline hip_type convert(float x) {
|
||||||
|
return __float2half_rn(x);
|
||||||
|
}
|
||||||
|
__device__ static inline packed_hip_type convert(float2 x) {
|
||||||
|
return __float22half2_rn(x);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||||
@ -82,13 +88,22 @@ struct _typeConvert<c10::BFloat16> {
|
|||||||
using hip_type = __nv_bfloat16;
|
using hip_type = __nv_bfloat16;
|
||||||
using packed_hip_type = __nv_bfloat162;
|
using packed_hip_type = __nv_bfloat162;
|
||||||
|
|
||||||
__device__ static inline float convert(hip_type x) { return __bfloat162float(x); }
|
__device__ static inline float convert(hip_type x) {
|
||||||
__device__ static inline float2 convert(packed_hip_type x) { return __bfloat1622float2(x); }
|
return __bfloat162float(x);
|
||||||
__device__ static inline hip_type convert(float x) { return __float2bfloat16(x); }
|
}
|
||||||
__device__ static inline packed_hip_type convert(float2 x) { return __float22bfloat162_rn(x); }
|
__device__ static inline float2 convert(packed_hip_type x) {
|
||||||
|
return __bfloat1622float2(x);
|
||||||
|
}
|
||||||
|
__device__ static inline hip_type convert(float x) {
|
||||||
|
return __float2bfloat16(x);
|
||||||
|
}
|
||||||
|
__device__ static inline packed_hip_type convert(float2 x) {
|
||||||
|
return __float22bfloat162_rn(x);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||||
#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
|
#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >=
|
||||||
|
// 12000))
|
||||||
|
|
||||||
/* Vector POD struct to generate vectorized and packed FP16/BF16 ops
|
/* Vector POD struct to generate vectorized and packed FP16/BF16 ops
|
||||||
for appropriate specializations of fused_add_rms_norm_kernel.
|
for appropriate specializations of fused_add_rms_norm_kernel.
|
||||||
@ -117,8 +132,7 @@ struct alignas(16) _f16Vec {
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < width; ++i)
|
for (int i = 0; i < width; ++i) data[i] += other.data[i];
|
||||||
data[i] += other.data[i];
|
|
||||||
}
|
}
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
@ -134,8 +148,7 @@ struct alignas(16) _f16Vec {
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < width; ++i)
|
for (int i = 0; i < width; ++i) data[i] *= other.data[i];
|
||||||
data[i] *= other.data[i];
|
|
||||||
}
|
}
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
@ -185,14 +198,12 @@ struct alignas(16) _f16Vec {
|
|||||||
packed and vectorized operations, which help with the
|
packed and vectorized operations, which help with the
|
||||||
memory latency bottleneck. */
|
memory latency bottleneck. */
|
||||||
template <typename scalar_t, int width>
|
template <typename scalar_t, int width>
|
||||||
__global__ std::enable_if_t<
|
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
|
||||||
(width > 0) && _typeConvert<scalar_t>::exists> fused_add_rms_norm_kernel(
|
fused_add_rms_norm_kernel(
|
||||||
scalar_t* __restrict__ input, // [..., hidden_size]
|
scalar_t* __restrict__ input, // [..., hidden_size]
|
||||||
scalar_t* __restrict__ residual, // [..., hidden_size]
|
scalar_t* __restrict__ residual, // [..., hidden_size]
|
||||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||||
const float epsilon,
|
const float epsilon, const int num_tokens, const int hidden_size) {
|
||||||
const int num_tokens,
|
|
||||||
const int hidden_size) {
|
|
||||||
// Sanity checks on our vector struct and type-punned pointer arithmetic
|
// Sanity checks on our vector struct and type-punned pointer arithmetic
|
||||||
static_assert(std::is_pod_v<_f16Vec<scalar_t, width>>);
|
static_assert(std::is_pod_v<_f16Vec<scalar_t, width>>);
|
||||||
static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width);
|
static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width);
|
||||||
@ -203,9 +214,12 @@ __global__ std::enable_if_t<
|
|||||||
/* These and the argument pointers are all declared `restrict` as they are
|
/* These and the argument pointers are all declared `restrict` as they are
|
||||||
not aliased in practice. Argument pointers should not be dereferenced
|
not aliased in practice. Argument pointers should not be dereferenced
|
||||||
in this kernel as that would be undefined behavior */
|
in this kernel as that would be undefined behavior */
|
||||||
auto* __restrict__ input_v = reinterpret_cast<_f16Vec<scalar_t, width>*>(input);
|
auto* __restrict__ input_v =
|
||||||
auto* __restrict__ residual_v = reinterpret_cast<_f16Vec<scalar_t, width>*>(residual);
|
reinterpret_cast<_f16Vec<scalar_t, width>*>(input);
|
||||||
auto* __restrict__ weight_v = reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight);
|
auto* __restrict__ residual_v =
|
||||||
|
reinterpret_cast<_f16Vec<scalar_t, width>*>(residual);
|
||||||
|
auto* __restrict__ weight_v =
|
||||||
|
reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight);
|
||||||
|
|
||||||
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
|
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
|
||||||
int id = blockIdx.x * vec_hidden_size + idx;
|
int id = blockIdx.x * vec_hidden_size + idx;
|
||||||
@ -218,7 +232,8 @@ __global__ std::enable_if_t<
|
|||||||
calculation of max_block_size in fused_add_rms_norm */
|
calculation of max_block_size in fused_add_rms_norm */
|
||||||
if (num_tokens < 256) {
|
if (num_tokens < 256) {
|
||||||
variance = blockReduceSum<float, 1024>(variance);
|
variance = blockReduceSum<float, 1024>(variance);
|
||||||
} else variance = blockReduceSum<float, 256>(variance);
|
} else
|
||||||
|
variance = blockReduceSum<float, 256>(variance);
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
||||||
}
|
}
|
||||||
@ -233,19 +248,16 @@ __global__ std::enable_if_t<
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/* Generic fused_add_rms_norm_kernel
|
/* Generic fused_add_rms_norm_kernel
|
||||||
The width field is not used here but necessary for other specializations.
|
The width field is not used here but necessary for other specializations.
|
||||||
*/
|
*/
|
||||||
template <typename scalar_t, int width>
|
template <typename scalar_t, int width>
|
||||||
__global__ std::enable_if_t<
|
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
|
||||||
(width == 0) || !_typeConvert<scalar_t>::exists> fused_add_rms_norm_kernel(
|
fused_add_rms_norm_kernel(
|
||||||
scalar_t* __restrict__ input, // [..., hidden_size]
|
scalar_t* __restrict__ input, // [..., hidden_size]
|
||||||
scalar_t* __restrict__ residual, // [..., hidden_size]
|
scalar_t* __restrict__ residual, // [..., hidden_size]
|
||||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||||
const float epsilon,
|
const float epsilon, const int num_tokens, const int hidden_size) {
|
||||||
const int num_tokens,
|
|
||||||
const int hidden_size) {
|
|
||||||
__shared__ float s_variance;
|
__shared__ float s_variance;
|
||||||
float variance = 0.0f;
|
float variance = 0.0f;
|
||||||
|
|
||||||
@ -260,7 +272,8 @@ __global__ std::enable_if_t<
|
|||||||
calculation of max_block_size in fused_add_rms_norm */
|
calculation of max_block_size in fused_add_rms_norm */
|
||||||
if (num_tokens < 256) {
|
if (num_tokens < 256) {
|
||||||
variance = blockReduceSum<float, 1024>(variance);
|
variance = blockReduceSum<float, 1024>(variance);
|
||||||
} else variance = blockReduceSum<float, 256>(variance);
|
} else
|
||||||
|
variance = blockReduceSum<float, 256>(variance);
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
||||||
}
|
}
|
||||||
@ -268,14 +281,14 @@ __global__ std::enable_if_t<
|
|||||||
|
|
||||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||||
float x = (float)residual[blockIdx.x * hidden_size + idx];
|
float x = (float)residual[blockIdx.x * hidden_size + idx];
|
||||||
input[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx];
|
input[blockIdx.x * hidden_size + idx] =
|
||||||
|
((scalar_t)(x * s_variance)) * weight[idx];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
void rms_norm(
|
void rms_norm(torch::Tensor& out, // [..., hidden_size]
|
||||||
torch::Tensor& out, // [..., hidden_size]
|
|
||||||
torch::Tensor& input, // [..., hidden_size]
|
torch::Tensor& input, // [..., hidden_size]
|
||||||
torch::Tensor& weight, // [hidden_size]
|
torch::Tensor& weight, // [hidden_size]
|
||||||
float epsilon) {
|
float epsilon) {
|
||||||
@ -286,37 +299,24 @@ void rms_norm(
|
|||||||
dim3 block(std::min(hidden_size, 1024));
|
dim3 block(std::min(hidden_size, 1024));
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(
|
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
|
||||||
input.scalar_type(),
|
|
||||||
"rms_norm_kernel",
|
|
||||||
[&] {
|
|
||||||
vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||||
out.data_ptr<scalar_t>(),
|
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
|
||||||
input.data_ptr<scalar_t>(),
|
weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
|
||||||
weight.data_ptr<scalar_t>(),
|
|
||||||
epsilon,
|
|
||||||
num_tokens,
|
|
||||||
hidden_size);
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
|
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
|
||||||
VLLM_DISPATCH_FLOATING_TYPES( \
|
VLLM_DISPATCH_FLOATING_TYPES( \
|
||||||
input.scalar_type(), \
|
input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
|
||||||
"fused_add_rms_norm_kernel", \
|
vllm::fused_add_rms_norm_kernel<scalar_t, width> \
|
||||||
[&] { \
|
<<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(), \
|
||||||
vllm::fused_add_rms_norm_kernel \
|
|
||||||
<scalar_t, width><<<grid, block, 0, stream>>>( \
|
|
||||||
input.data_ptr<scalar_t>(), \
|
|
||||||
residual.data_ptr<scalar_t>(), \
|
residual.data_ptr<scalar_t>(), \
|
||||||
weight.data_ptr<scalar_t>(), \
|
weight.data_ptr<scalar_t>(), epsilon, \
|
||||||
epsilon, \
|
num_tokens, hidden_size); \
|
||||||
num_tokens, \
|
|
||||||
hidden_size); \
|
|
||||||
});
|
});
|
||||||
|
|
||||||
void fused_add_rms_norm(
|
void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
|
||||||
torch::Tensor& input, // [..., hidden_size]
|
|
||||||
torch::Tensor& residual, // [..., hidden_size]
|
torch::Tensor& residual, // [..., hidden_size]
|
||||||
torch::Tensor& weight, // [hidden_size]
|
torch::Tensor& weight, // [hidden_size]
|
||||||
float epsilon) {
|
float epsilon) {
|
||||||
@ -342,8 +342,8 @@ void fused_add_rms_norm(
|
|||||||
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
|
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
|
||||||
auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr());
|
auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr());
|
||||||
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
|
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
|
||||||
bool ptrs_are_aligned = inp_ptr % 16 == 0 && res_ptr % 16 == 0 \
|
bool ptrs_are_aligned =
|
||||||
&& wt_ptr % 16 == 0;
|
inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
|
||||||
if (ptrs_are_aligned && hidden_size % 8 == 0) {
|
if (ptrs_are_aligned && hidden_size % 8 == 0) {
|
||||||
LAUNCH_FUSED_ADD_RMS_NORM(8);
|
LAUNCH_FUSED_ADD_RMS_NORM(8);
|
||||||
} else {
|
} else {
|
||||||
|
@ -3,5 +3,6 @@
|
|||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
m.def("topk_softmax", &topk_softmax, "Apply topk softmax to the gating outputs.");
|
m.def("topk_softmax", &topk_softmax,
|
||||||
|
"Apply topk softmax to the gating outputs.");
|
||||||
}
|
}
|
||||||
|
@ -2,8 +2,6 @@
|
|||||||
|
|
||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
|
|
||||||
void topk_softmax(
|
void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices,
|
||||||
torch::Tensor& topk_weights,
|
|
||||||
torch::Tensor& topk_indices,
|
|
||||||
torch::Tensor& token_expert_indices,
|
torch::Tensor& token_expert_indices,
|
||||||
torch::Tensor& gating_output);
|
torch::Tensor& gating_output);
|
||||||
|
@ -12,11 +12,12 @@
|
|||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, int32_t col) {
|
__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row,
|
||||||
|
int32_t col) {
|
||||||
// don't worry about overflow because num_experts is relatively small
|
// don't worry about overflow because num_experts is relatively small
|
||||||
return row * total_col + col;
|
return row * total_col + col;
|
||||||
}
|
}
|
||||||
}
|
} // namespace
|
||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
__global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
|
__global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
|
||||||
@ -24,15 +25,17 @@ __global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids,
|
|||||||
int32_t* expert_ids,
|
int32_t* expert_ids,
|
||||||
int32_t* total_tokens_post_pad,
|
int32_t* total_tokens_post_pad,
|
||||||
int32_t num_experts,
|
int32_t num_experts,
|
||||||
int32_t block_size,
|
int32_t block_size, size_t numel) {
|
||||||
size_t numel) {
|
|
||||||
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
|
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
|
||||||
const size_t start_idx = threadIdx.x * tokens_per_thread;
|
const size_t start_idx = threadIdx.x * tokens_per_thread;
|
||||||
|
|
||||||
extern __shared__ int32_t shared_mem[];
|
extern __shared__ int32_t shared_mem[];
|
||||||
|
|
||||||
int32_t* tokens_cnts = shared_mem; // 2d tensor with shape (num_experts + 1, num_experts)
|
int32_t* tokens_cnts =
|
||||||
int32_t* cumsum = shared_mem + (num_experts + 1) * num_experts; // 1d tensor with shape (num_experts + 1)
|
shared_mem; // 2d tensor with shape (num_experts + 1, num_experts)
|
||||||
|
int32_t* cumsum =
|
||||||
|
shared_mem + (num_experts + 1) *
|
||||||
|
num_experts; // 1d tensor with shape (num_experts + 1)
|
||||||
|
|
||||||
for (int i = 0; i < num_experts; ++i) {
|
for (int i = 0; i < num_experts; ++i) {
|
||||||
tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
|
tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
|
||||||
@ -40,8 +43,8 @@ __global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids,
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* In the first step we compute token_cnts[thread_index + 1][expert_index],
|
* In the first step we compute token_cnts[thread_index + 1][expert_index],
|
||||||
* which counts how many tokens in the token shard of thread_index are assigned
|
* which counts how many tokens in the token shard of thread_index are
|
||||||
* to expert expert_index.
|
* assigned to expert expert_index.
|
||||||
*/
|
*/
|
||||||
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
|
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
|
||||||
++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])];
|
++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])];
|
||||||
@ -52,7 +55,8 @@ __global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids,
|
|||||||
// For each expert we accumulate the token counts from the different threads.
|
// For each expert we accumulate the token counts from the different threads.
|
||||||
tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
|
tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
|
||||||
for (int i = 1; i <= blockDim.x; ++i) {
|
for (int i = 1; i <= blockDim.x; ++i) {
|
||||||
tokens_cnts[index(num_experts, i, threadIdx.x)] += tokens_cnts[index(num_experts, i-1, threadIdx.x)];
|
tokens_cnts[index(num_experts, i, threadIdx.x)] +=
|
||||||
|
tokens_cnts[index(num_experts, i - 1, threadIdx.x)];
|
||||||
}
|
}
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
@ -61,7 +65,10 @@ __global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids,
|
|||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
cumsum[0] = 0;
|
cumsum[0] = 0;
|
||||||
for (int i = 1; i <= num_experts; ++i) {
|
for (int i = 1; i <= num_experts; ++i) {
|
||||||
cumsum[i] = cumsum[i-1] + CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], block_size) * block_size;
|
cumsum[i] = cumsum[i - 1] +
|
||||||
|
CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)],
|
||||||
|
block_size) *
|
||||||
|
block_size;
|
||||||
}
|
}
|
||||||
*total_tokens_post_pad = cumsum[num_experts];
|
*total_tokens_post_pad = cumsum[num_experts];
|
||||||
}
|
}
|
||||||
@ -69,57 +76,59 @@ __global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids,
|
|||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* For each expert, each thread processes the tokens of the corresponding blocks
|
* For each expert, each thread processes the tokens of the corresponding
|
||||||
* and stores the corresponding expert_id for each block.
|
* blocks and stores the corresponding expert_id for each block.
|
||||||
*/
|
*/
|
||||||
for (int i = cumsum[threadIdx.x];i < cumsum[threadIdx.x + 1];i += block_size) {
|
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
|
||||||
|
i += block_size) {
|
||||||
expert_ids[i / block_size] = threadIdx.x;
|
expert_ids[i / block_size] = threadIdx.x;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Each thread processes a token shard, calculating the index of each token after
|
* Each thread processes a token shard, calculating the index of each token
|
||||||
* sorting by expert number. Given the example topk_ids = [0,1,2,1,2,3,0,3,4] and
|
* after sorting by expert number. Given the example topk_ids =
|
||||||
* block_size = 4, then the output would be [0, 6, *, *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *],
|
* [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *,
|
||||||
* where * represents a padding value(preset in python).
|
* *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a
|
||||||
|
* padding value(preset in python).
|
||||||
*/
|
*/
|
||||||
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
|
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
|
||||||
int32_t expert_id = topk_ids[i];
|
int32_t expert_id = topk_ids[i];
|
||||||
/** The cumsum[expert_id] stores the starting index of the tokens that the
|
/** The cumsum[expert_id] stores the starting index of the tokens that the
|
||||||
* expert with expert_id needs to process, and tokens_cnts[threadIdx.x][expert_id]
|
* expert with expert_id needs to process, and
|
||||||
* stores the indices of the tokens processed by the expert with expert_id within
|
* tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens
|
||||||
* the current thread's token shard.
|
* processed by the expert with expert_id within the current thread's token
|
||||||
|
* shard.
|
||||||
*/
|
*/
|
||||||
int32_t rank_post_pad = tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + cumsum[expert_id];
|
int32_t rank_post_pad =
|
||||||
|
tokens_cnts[index(num_experts, threadIdx.x, expert_id)] +
|
||||||
|
cumsum[expert_id];
|
||||||
sorted_token_ids[rank_post_pad] = i;
|
sorted_token_ids[rank_post_pad] = i;
|
||||||
++tokens_cnts[index(num_experts, threadIdx.x, expert_id)];
|
++tokens_cnts[index(num_experts, threadIdx.x, expert_id)];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
} // namespace vllm
|
||||||
|
|
||||||
void moe_align_block_size(
|
void moe_align_block_size(torch::Tensor topk_ids, int num_experts,
|
||||||
torch::Tensor topk_ids,
|
int block_size, torch::Tensor sorted_token_ids,
|
||||||
int num_experts,
|
|
||||||
int block_size,
|
|
||||||
torch::Tensor sorted_token_ids,
|
|
||||||
torch::Tensor experts_ids,
|
torch::Tensor experts_ids,
|
||||||
torch::Tensor num_tokens_post_pad) {
|
torch::Tensor num_tokens_post_pad) {
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
VLLM_DISPATCH_INTEGRAL_TYPES(
|
VLLM_DISPATCH_INTEGRAL_TYPES(
|
||||||
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
|
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
|
||||||
// calc needed amount of shared mem for `tokens_cnts` and `cumsum` tensors
|
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
|
||||||
const int32_t shared_mem = ((num_experts + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t);
|
// tensors
|
||||||
|
const int32_t shared_mem =
|
||||||
|
((num_experts + 1) * num_experts + (num_experts + 1)) *
|
||||||
|
sizeof(int32_t);
|
||||||
|
|
||||||
// set dynamic shared mem
|
// set dynamic shared mem
|
||||||
auto kernel = vllm::moe_align_block_size_kernel<scalar_t>;
|
auto kernel = vllm::moe_align_block_size_kernel<scalar_t>;
|
||||||
AT_CUDA_CHECK(
|
AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
|
||||||
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize((void *)kernel, shared_mem));
|
(void*)kernel, shared_mem));
|
||||||
kernel<<<1, num_experts, shared_mem, stream>>>(
|
kernel<<<1, num_experts, shared_mem, stream>>>(
|
||||||
topk_ids.data_ptr<scalar_t>(),
|
topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
|
||||||
sorted_token_ids.data_ptr<int32_t>(),
|
|
||||||
experts_ids.data_ptr<int32_t>(),
|
experts_ids.data_ptr<int32_t>(),
|
||||||
num_tokens_post_pad.data_ptr<int32_t>(),
|
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
|
||||||
num_experts,
|
|
||||||
block_size,
|
|
||||||
topk_ids.numel());
|
topk_ids.numel());
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
214
csrc/ops.h
214
csrc/ops.h
@ -2,204 +2,115 @@
|
|||||||
|
|
||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
|
|
||||||
void paged_attention_v1(
|
void paged_attention_v1(torch::Tensor& out, torch::Tensor& query,
|
||||||
torch::Tensor& out,
|
torch::Tensor& key_cache, torch::Tensor& value_cache,
|
||||||
torch::Tensor& query,
|
int num_kv_heads, float scale,
|
||||||
torch::Tensor& key_cache,
|
torch::Tensor& block_tables, torch::Tensor& seq_lens,
|
||||||
torch::Tensor& value_cache,
|
int block_size, int max_seq_len,
|
||||||
int num_kv_heads,
|
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||||
float scale,
|
const std::string& kv_cache_dtype, float kv_scale);
|
||||||
torch::Tensor& block_tables,
|
|
||||||
torch::Tensor& seq_lens,
|
void paged_attention_v2(torch::Tensor& out, torch::Tensor& exp_sums,
|
||||||
int block_size,
|
torch::Tensor& max_logits, torch::Tensor& tmp_out,
|
||||||
|
torch::Tensor& query, torch::Tensor& key_cache,
|
||||||
|
torch::Tensor& value_cache, int num_kv_heads,
|
||||||
|
float scale, torch::Tensor& block_tables,
|
||||||
|
torch::Tensor& seq_lens, int block_size,
|
||||||
int max_seq_len,
|
int max_seq_len,
|
||||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||||
const std::string& kv_cache_dtype,
|
const std::string& kv_cache_dtype, float kv_scale);
|
||||||
float kv_scale);
|
|
||||||
|
|
||||||
void paged_attention_v2(
|
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
|
||||||
torch::Tensor& out,
|
|
||||||
torch::Tensor& exp_sums,
|
|
||||||
torch::Tensor& max_logits,
|
|
||||||
torch::Tensor& tmp_out,
|
|
||||||
torch::Tensor& query,
|
|
||||||
torch::Tensor& key_cache,
|
|
||||||
torch::Tensor& value_cache,
|
|
||||||
int num_kv_heads,
|
|
||||||
float scale,
|
|
||||||
torch::Tensor& block_tables,
|
|
||||||
torch::Tensor& seq_lens,
|
|
||||||
int block_size,
|
|
||||||
int max_seq_len,
|
|
||||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
|
||||||
const std::string& kv_cache_dtype,
|
|
||||||
float kv_scale);
|
|
||||||
|
|
||||||
void rms_norm(
|
|
||||||
torch::Tensor& out,
|
|
||||||
torch::Tensor& input,
|
|
||||||
torch::Tensor& weight,
|
|
||||||
float epsilon);
|
float epsilon);
|
||||||
|
|
||||||
void fused_add_rms_norm(
|
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
|
||||||
torch::Tensor& input,
|
torch::Tensor& weight, float epsilon);
|
||||||
torch::Tensor& residual,
|
|
||||||
torch::Tensor& weight,
|
|
||||||
float epsilon);
|
|
||||||
|
|
||||||
void rotary_embedding(
|
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
|
||||||
torch::Tensor& positions,
|
torch::Tensor& key, int head_size,
|
||||||
torch::Tensor& query,
|
torch::Tensor& cos_sin_cache, bool is_neox);
|
||||||
torch::Tensor& key,
|
|
||||||
int head_size,
|
|
||||||
torch::Tensor& cos_sin_cache,
|
|
||||||
bool is_neox);
|
|
||||||
|
|
||||||
void batched_rotary_embedding(
|
void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
|
||||||
torch::Tensor& positions,
|
torch::Tensor& key, int head_size,
|
||||||
torch::Tensor& query,
|
torch::Tensor& cos_sin_cache, bool is_neox,
|
||||||
torch::Tensor& key,
|
|
||||||
int head_size,
|
|
||||||
torch::Tensor& cos_sin_cache,
|
|
||||||
bool is_neox,
|
|
||||||
int rot_dim,
|
int rot_dim,
|
||||||
torch::Tensor& cos_sin_cache_offsets);
|
torch::Tensor& cos_sin_cache_offsets);
|
||||||
|
|
||||||
void silu_and_mul(
|
void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
|
||||||
torch::Tensor& out,
|
|
||||||
torch::Tensor& input);
|
|
||||||
|
|
||||||
void gelu_and_mul(
|
void gelu_and_mul(torch::Tensor& out, torch::Tensor& input);
|
||||||
torch::Tensor& out,
|
|
||||||
torch::Tensor& input);
|
|
||||||
|
|
||||||
void gelu_tanh_and_mul(
|
void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);
|
||||||
torch::Tensor& out,
|
|
||||||
torch::Tensor& input);
|
|
||||||
|
|
||||||
void gelu_new(
|
void gelu_new(torch::Tensor& out, torch::Tensor& input);
|
||||||
torch::Tensor& out,
|
|
||||||
torch::Tensor& input);
|
|
||||||
|
|
||||||
void gelu_fast(
|
void gelu_fast(torch::Tensor& out, torch::Tensor& input);
|
||||||
torch::Tensor& out,
|
|
||||||
torch::Tensor& input);
|
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
torch::Tensor aqlm_gemm(
|
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
|
||||||
const torch::Tensor& input,
|
|
||||||
const torch::Tensor& codes,
|
|
||||||
const torch::Tensor& codebooks,
|
const torch::Tensor& codebooks,
|
||||||
const torch::Tensor& scales,
|
const torch::Tensor& scales,
|
||||||
const torch::Tensor& codebook_partition_sizes,
|
const torch::Tensor& codebook_partition_sizes,
|
||||||
const std::optional<torch::Tensor>& bias
|
const std::optional<torch::Tensor>& bias);
|
||||||
);
|
|
||||||
|
|
||||||
torch::Tensor aqlm_dequant(
|
torch::Tensor aqlm_dequant(const torch::Tensor& codes,
|
||||||
const torch::Tensor& codes,
|
|
||||||
const torch::Tensor& codebooks,
|
const torch::Tensor& codebooks,
|
||||||
const torch::Tensor& codebook_partition_sizes
|
const torch::Tensor& codebook_partition_sizes);
|
||||||
);
|
|
||||||
|
|
||||||
torch::Tensor awq_gemm(
|
torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
|
||||||
torch::Tensor _in_feats,
|
torch::Tensor _scaling_factors, torch::Tensor _zeros,
|
||||||
torch::Tensor _kernel,
|
|
||||||
torch::Tensor _scaling_factors,
|
|
||||||
torch::Tensor _zeros,
|
|
||||||
int split_k_iters);
|
int split_k_iters);
|
||||||
|
|
||||||
torch::Tensor awq_dequantize(
|
torch::Tensor awq_dequantize(torch::Tensor _kernel,
|
||||||
torch::Tensor _kernel,
|
|
||||||
torch::Tensor _scaling_factors,
|
torch::Tensor _scaling_factors,
|
||||||
torch::Tensor _zeros,
|
torch::Tensor _zeros, int split_k_iters, int thx,
|
||||||
int split_k_iters,
|
|
||||||
int thx,
|
|
||||||
int thy);
|
int thy);
|
||||||
|
|
||||||
torch::Tensor marlin_gemm(
|
torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||||
torch::Tensor& a,
|
torch::Tensor& b_scales, torch::Tensor& workspace,
|
||||||
torch::Tensor& b_q_weight,
|
int64_t size_m, int64_t size_n, int64_t size_k);
|
||||||
torch::Tensor& b_scales,
|
|
||||||
torch::Tensor& workspace,
|
|
||||||
int64_t size_m,
|
|
||||||
int64_t size_n,
|
|
||||||
int64_t size_k);
|
|
||||||
|
|
||||||
torch::Tensor gptq_marlin_24_gemm(
|
torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||||
torch::Tensor &a,
|
|
||||||
torch::Tensor &b_q_weight,
|
|
||||||
torch::Tensor& b_meta,
|
torch::Tensor& b_meta,
|
||||||
torch::Tensor& b_scales,
|
torch::Tensor& b_scales,
|
||||||
torch::Tensor &workspace,
|
torch::Tensor& workspace, int64_t num_bits,
|
||||||
int64_t num_bits,
|
int64_t size_m, int64_t size_n,
|
||||||
int64_t size_m,
|
|
||||||
int64_t size_n,
|
|
||||||
int64_t size_k);
|
int64_t size_k);
|
||||||
|
|
||||||
torch::Tensor gptq_marlin_gemm(
|
torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||||
torch::Tensor &a,
|
torch::Tensor& b_scales, torch::Tensor& g_idx,
|
||||||
torch::Tensor &b_q_weight,
|
torch::Tensor& perm, torch::Tensor& workspace,
|
||||||
torch::Tensor &b_scales,
|
int64_t num_bits, int64_t size_m, int64_t size_n,
|
||||||
torch::Tensor &g_idx,
|
int64_t size_k, bool is_k_full);
|
||||||
torch::Tensor &perm,
|
|
||||||
torch::Tensor &workspace,
|
|
||||||
int64_t num_bits,
|
|
||||||
int64_t size_m,
|
|
||||||
int64_t size_n,
|
|
||||||
int64_t size_k,
|
|
||||||
bool is_k_full);
|
|
||||||
|
|
||||||
torch::Tensor gptq_marlin_repack(
|
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
||||||
torch::Tensor &b_q_weight,
|
int64_t size_k, int64_t size_n,
|
||||||
torch::Tensor &perm,
|
|
||||||
int64_t size_k,
|
|
||||||
int64_t size_n,
|
|
||||||
int64_t num_bits);
|
int64_t num_bits);
|
||||||
|
|
||||||
int cutlass_scaled_mm_dq(
|
int cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a,
|
||||||
torch::Tensor& out,
|
torch::Tensor const& b, torch::Tensor const& a_scales,
|
||||||
torch::Tensor const &a,
|
|
||||||
torch::Tensor const &b,
|
|
||||||
torch::Tensor const &a_scales,
|
|
||||||
torch::Tensor const& b_scales);
|
torch::Tensor const& b_scales);
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
void squeezellm_gemm(
|
void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
||||||
torch::Tensor vec,
|
|
||||||
torch::Tensor mat,
|
|
||||||
torch::Tensor mul,
|
|
||||||
torch::Tensor lookup_table);
|
torch::Tensor lookup_table);
|
||||||
|
|
||||||
torch::Tensor gptq_gemm(
|
torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
|
||||||
torch::Tensor a,
|
|
||||||
torch::Tensor b_q_weight,
|
|
||||||
torch::Tensor b_gptq_qzeros,
|
torch::Tensor b_gptq_qzeros,
|
||||||
torch::Tensor b_gptq_scales,
|
torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
|
||||||
torch::Tensor b_g_idx,
|
bool use_exllama, int bit);
|
||||||
bool use_exllama,
|
|
||||||
int bit);
|
|
||||||
|
|
||||||
void gptq_shuffle(
|
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int bit);
|
||||||
torch::Tensor q_weight,
|
|
||||||
torch::Tensor q_perm,
|
|
||||||
int bit);
|
|
||||||
|
|
||||||
void static_scaled_fp8_quant(
|
void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
|
||||||
torch::Tensor& out,
|
|
||||||
torch::Tensor& input,
|
|
||||||
torch::Tensor& scale);
|
torch::Tensor& scale);
|
||||||
|
|
||||||
void dynamic_scaled_fp8_quant(
|
void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
|
||||||
torch::Tensor& out,
|
|
||||||
torch::Tensor& input,
|
|
||||||
torch::Tensor& scale);
|
torch::Tensor& scale);
|
||||||
|
|
||||||
void moe_align_block_size(
|
void moe_align_block_size(torch::Tensor topk_ids, int num_experts,
|
||||||
torch::Tensor topk_ids,
|
int block_size, torch::Tensor sorted_token_ids,
|
||||||
int num_experts,
|
|
||||||
int block_size,
|
|
||||||
torch::Tensor sorted_token_ids,
|
|
||||||
torch::Tensor experts_ids,
|
torch::Tensor experts_ids,
|
||||||
torch::Tensor num_tokens_post_pad);
|
torch::Tensor num_tokens_post_pad);
|
||||||
|
|
||||||
@ -219,7 +130,8 @@ int meta_size();
|
|||||||
void register_buffer(fptr_t _fa, torch::Tensor& t,
|
void register_buffer(fptr_t _fa, torch::Tensor& t,
|
||||||
const std::vector<std::string>& handles,
|
const std::vector<std::string>& handles,
|
||||||
const std::vector<int64_t>& offsets);
|
const std::vector<int64_t>& offsets);
|
||||||
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa);
|
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(
|
||||||
|
fptr_t _fa);
|
||||||
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
|
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
|
||||||
const std::vector<std::vector<int64_t>>& offsets);
|
const std::vector<std::vector<int64_t>>& offsets);
|
||||||
#endif
|
#endif
|
||||||
|
@ -9,12 +9,8 @@ namespace vllm {
|
|||||||
|
|
||||||
template <typename scalar_t, bool IS_NEOX>
|
template <typename scalar_t, bool IS_NEOX>
|
||||||
inline __device__ void apply_token_rotary_embedding(
|
inline __device__ void apply_token_rotary_embedding(
|
||||||
scalar_t* __restrict__ arr,
|
scalar_t* __restrict__ arr, const scalar_t* __restrict__ cos_ptr,
|
||||||
const scalar_t* __restrict__ cos_ptr,
|
const scalar_t* __restrict__ sin_ptr, int rot_offset, int embed_dim) {
|
||||||
const scalar_t* __restrict__ sin_ptr,
|
|
||||||
int rot_offset,
|
|
||||||
int embed_dim)
|
|
||||||
{
|
|
||||||
int x_index, y_index;
|
int x_index, y_index;
|
||||||
scalar_t cos, sin;
|
scalar_t cos, sin;
|
||||||
if (IS_NEOX) {
|
if (IS_NEOX) {
|
||||||
@ -39,17 +35,15 @@ inline __device__ void apply_token_rotary_embedding(
|
|||||||
|
|
||||||
template <typename scalar_t, bool IS_NEOX>
|
template <typename scalar_t, bool IS_NEOX>
|
||||||
inline __device__ void apply_rotary_embedding(
|
inline __device__ void apply_rotary_embedding(
|
||||||
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
|
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
|
||||||
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
|
// head_size] or [num_tokens, num_heads,
|
||||||
const scalar_t* cache_ptr,
|
// head_size]
|
||||||
const int head_size,
|
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
|
||||||
const int num_heads,
|
// head_size] or [num_tokens, num_kv_heads,
|
||||||
const int num_kv_heads,
|
// head_size]
|
||||||
const int rot_dim,
|
const scalar_t* cache_ptr, const int head_size, const int num_heads,
|
||||||
const int token_idx,
|
const int num_kv_heads, const int rot_dim, const int token_idx,
|
||||||
const int64_t query_stride,
|
const int64_t query_stride, const int64_t key_stride) {
|
||||||
const int64_t key_stride)
|
|
||||||
{
|
|
||||||
const int embed_dim = rot_dim / 2;
|
const int embed_dim = rot_dim / 2;
|
||||||
const scalar_t* cos_ptr = cache_ptr;
|
const scalar_t* cos_ptr = cache_ptr;
|
||||||
const scalar_t* sin_ptr = cache_ptr + embed_dim;
|
const scalar_t* sin_ptr = cache_ptr + embed_dim;
|
||||||
@ -59,8 +53,8 @@ inline __device__ void apply_rotary_embedding(
|
|||||||
const int head_idx = i / embed_dim;
|
const int head_idx = i / embed_dim;
|
||||||
const int64_t token_head = token_idx * query_stride + head_idx * head_size;
|
const int64_t token_head = token_idx * query_stride + head_idx * head_size;
|
||||||
const int rot_offset = i % embed_dim;
|
const int rot_offset = i % embed_dim;
|
||||||
apply_token_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr,
|
apply_token_rotary_embedding<scalar_t, IS_NEOX>(
|
||||||
sin_ptr, rot_offset, embed_dim);
|
query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
|
||||||
}
|
}
|
||||||
|
|
||||||
const int nk = num_kv_heads * embed_dim;
|
const int nk = num_kv_heads * embed_dim;
|
||||||
@ -68,59 +62,71 @@ inline __device__ void apply_rotary_embedding(
|
|||||||
const int head_idx = i / embed_dim;
|
const int head_idx = i / embed_dim;
|
||||||
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
|
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
|
||||||
const int rot_offset = i % embed_dim;
|
const int rot_offset = i % embed_dim;
|
||||||
apply_token_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr,
|
apply_token_rotary_embedding<scalar_t, IS_NEOX>(
|
||||||
sin_ptr, rot_offset, embed_dim);
|
key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename scalar_t, bool IS_NEOX>
|
template <typename scalar_t, bool IS_NEOX>
|
||||||
__global__ void rotary_embedding_kernel(
|
__global__ void rotary_embedding_kernel(
|
||||||
const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens]
|
const int64_t* __restrict__ positions, // [batch_size, seq_len] or
|
||||||
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
|
// [num_tokens]
|
||||||
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
|
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
|
||||||
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
|
// head_size] or [num_tokens, num_heads,
|
||||||
const int rot_dim,
|
// head_size]
|
||||||
const int64_t query_stride,
|
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
|
||||||
const int64_t key_stride,
|
// head_size] or [num_tokens, num_kv_heads,
|
||||||
const int num_heads,
|
// head_size]
|
||||||
const int num_kv_heads,
|
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
|
||||||
const int head_size) {
|
// 2]
|
||||||
|
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
|
||||||
|
const int num_heads, const int num_kv_heads, const int head_size) {
|
||||||
// Each thread block is responsible for one token.
|
// Each thread block is responsible for one token.
|
||||||
const int token_idx = blockIdx.x;
|
const int token_idx = blockIdx.x;
|
||||||
int64_t pos = positions[token_idx];
|
int64_t pos = positions[token_idx];
|
||||||
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
|
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
|
||||||
|
|
||||||
apply_rotary_embedding<scalar_t, IS_NEOX>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride);
|
apply_rotary_embedding<scalar_t, IS_NEOX>(
|
||||||
|
query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
|
||||||
|
token_idx, query_stride, key_stride);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename scalar_t, bool IS_NEOX>
|
template <typename scalar_t, bool IS_NEOX>
|
||||||
__global__ void batched_rotary_embedding_kernel(
|
__global__ void batched_rotary_embedding_kernel(
|
||||||
const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens]
|
const int64_t* __restrict__ positions, // [batch_size, seq_len] or
|
||||||
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
|
// [num_tokens]
|
||||||
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
|
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
|
||||||
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
|
// head_size] or [num_tokens, num_heads,
|
||||||
const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len] or [num_tokens]
|
// head_size]
|
||||||
const int rot_dim,
|
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
|
||||||
const int64_t query_stride,
|
// head_size] or [num_tokens, num_kv_heads,
|
||||||
const int64_t key_stride,
|
// head_size]
|
||||||
const int num_heads,
|
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
|
||||||
const int num_kv_heads,
|
// 2]
|
||||||
const int head_size) {
|
const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len]
|
||||||
|
// or [num_tokens]
|
||||||
|
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
|
||||||
|
const int num_heads, const int num_kv_heads, const int head_size) {
|
||||||
// Each thread block is responsible for one token.
|
// Each thread block is responsible for one token.
|
||||||
const int token_idx = blockIdx.x;
|
const int token_idx = blockIdx.x;
|
||||||
int64_t pos = positions[token_idx];
|
int64_t pos = positions[token_idx];
|
||||||
int64_t cos_sin_cache_offset = cos_sin_cache_offsets[token_idx];
|
int64_t cos_sin_cache_offset = cos_sin_cache_offsets[token_idx];
|
||||||
const scalar_t* cache_ptr = cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim;
|
const scalar_t* cache_ptr =
|
||||||
|
cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim;
|
||||||
|
|
||||||
apply_rotary_embedding<scalar_t, IS_NEOX>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride);
|
apply_rotary_embedding<scalar_t, IS_NEOX>(
|
||||||
|
query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
|
||||||
|
token_idx, query_stride, key_stride);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
void rotary_embedding(
|
void rotary_embedding(
|
||||||
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
|
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
|
||||||
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size]
|
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or
|
||||||
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size]
|
// [num_tokens, num_heads * head_size]
|
||||||
|
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
|
||||||
|
// [num_tokens, num_kv_heads * head_size]
|
||||||
int head_size,
|
int head_size,
|
||||||
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
||||||
bool is_neox) {
|
bool is_neox) {
|
||||||
@ -135,33 +141,18 @@ void rotary_embedding(
|
|||||||
dim3 block(std::min(num_heads * rot_dim / 2, 512));
|
dim3 block(std::min(num_heads * rot_dim / 2, 512));
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(
|
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] {
|
||||||
query.scalar_type(),
|
|
||||||
"rotary_embedding",
|
|
||||||
[&] {
|
|
||||||
if (is_neox) {
|
if (is_neox) {
|
||||||
vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
|
vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
|
||||||
positions.data_ptr<int64_t>(),
|
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
|
||||||
query.data_ptr<scalar_t>(),
|
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(), rot_dim,
|
||||||
key.data_ptr<scalar_t>(),
|
query_stride, key_stride, num_heads, num_kv_heads, head_size);
|
||||||
cos_sin_cache.data_ptr<scalar_t>(),
|
|
||||||
rot_dim,
|
|
||||||
query_stride,
|
|
||||||
key_stride,
|
|
||||||
num_heads,
|
|
||||||
num_kv_heads,
|
|
||||||
head_size);
|
|
||||||
} else {
|
} else {
|
||||||
vllm::rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>(
|
vllm::rotary_embedding_kernel<scalar_t, false>
|
||||||
positions.data_ptr<int64_t>(),
|
<<<grid, block, 0, stream>>>(
|
||||||
query.data_ptr<scalar_t>(),
|
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
|
||||||
key.data_ptr<scalar_t>(),
|
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
|
||||||
cos_sin_cache.data_ptr<scalar_t>(),
|
rot_dim, query_stride, key_stride, num_heads, num_kv_heads,
|
||||||
rot_dim,
|
|
||||||
query_stride,
|
|
||||||
key_stride,
|
|
||||||
num_heads,
|
|
||||||
num_kv_heads,
|
|
||||||
head_size);
|
head_size);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
@ -173,12 +164,13 @@ and process in batched manner.
|
|||||||
*/
|
*/
|
||||||
void batched_rotary_embedding(
|
void batched_rotary_embedding(
|
||||||
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
|
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
|
||||||
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size]
|
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or
|
||||||
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size]
|
// [num_tokens, num_heads * head_size]
|
||||||
|
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
|
||||||
|
// [num_tokens, num_kv_heads * head_size]
|
||||||
int head_size,
|
int head_size,
|
||||||
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
||||||
bool is_neox,
|
bool is_neox, int rot_dim,
|
||||||
int rot_dim,
|
|
||||||
torch::Tensor& cos_sin_cache_offsets // [num_tokens]
|
torch::Tensor& cos_sin_cache_offsets // [num_tokens]
|
||||||
) {
|
) {
|
||||||
int64_t num_tokens = cos_sin_cache_offsets.size(0);
|
int64_t num_tokens = cos_sin_cache_offsets.size(0);
|
||||||
@ -191,36 +183,21 @@ void batched_rotary_embedding(
|
|||||||
dim3 block(std::min(num_heads * rot_dim / 2, 512));
|
dim3 block(std::min(num_heads * rot_dim / 2, 512));
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(
|
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] {
|
||||||
query.scalar_type(),
|
|
||||||
"rotary_embedding",
|
|
||||||
[&] {
|
|
||||||
if (is_neox) {
|
if (is_neox) {
|
||||||
vllm::batched_rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
|
vllm::batched_rotary_embedding_kernel<scalar_t, true>
|
||||||
positions.data_ptr<int64_t>(),
|
<<<grid, block, 0, stream>>>(
|
||||||
query.data_ptr<scalar_t>(),
|
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
|
||||||
key.data_ptr<scalar_t>(),
|
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
|
||||||
cos_sin_cache.data_ptr<scalar_t>(),
|
cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
|
||||||
cos_sin_cache_offsets.data_ptr<int64_t>(),
|
key_stride, num_heads, num_kv_heads, head_size);
|
||||||
rot_dim,
|
|
||||||
query_stride,
|
|
||||||
key_stride,
|
|
||||||
num_heads,
|
|
||||||
num_kv_heads,
|
|
||||||
head_size);
|
|
||||||
} else {
|
} else {
|
||||||
vllm::batched_rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>(
|
vllm::batched_rotary_embedding_kernel<scalar_t, false>
|
||||||
positions.data_ptr<int64_t>(),
|
<<<grid, block, 0, stream>>>(
|
||||||
query.data_ptr<scalar_t>(),
|
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
|
||||||
key.data_ptr<scalar_t>(),
|
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
|
||||||
cos_sin_cache.data_ptr<scalar_t>(),
|
cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
|
||||||
cos_sin_cache_offsets.data_ptr<int64_t>(),
|
key_stride, num_heads, num_kv_heads, head_size);
|
||||||
rot_dim,
|
|
||||||
query_stride,
|
|
||||||
key_stride,
|
|
||||||
num_heads,
|
|
||||||
num_kv_heads,
|
|
||||||
head_size);
|
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
114
csrc/pybind.cpp
114
csrc/pybind.cpp
@ -8,114 +8,85 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||||||
pybind11::module ops = m.def_submodule("ops", "vLLM custom operators");
|
pybind11::module ops = m.def_submodule("ops", "vLLM custom operators");
|
||||||
|
|
||||||
// Attention ops
|
// Attention ops
|
||||||
ops.def(
|
ops.def("paged_attention_v1", &paged_attention_v1,
|
||||||
"paged_attention_v1",
|
"Compute the attention between an input query and the cached "
|
||||||
&paged_attention_v1,
|
"keys/values using PagedAttention.");
|
||||||
"Compute the attention between an input query and the cached keys/values using PagedAttention.");
|
ops.def("paged_attention_v2", &paged_attention_v2, "PagedAttention V2.");
|
||||||
ops.def(
|
|
||||||
"paged_attention_v2",
|
|
||||||
&paged_attention_v2,
|
|
||||||
"PagedAttention V2.");
|
|
||||||
|
|
||||||
// Activation ops
|
// Activation ops
|
||||||
ops.def(
|
ops.def("silu_and_mul", &silu_and_mul, "Activation function used in SwiGLU.");
|
||||||
"silu_and_mul",
|
ops.def("gelu_and_mul", &gelu_and_mul,
|
||||||
&silu_and_mul,
|
|
||||||
"Activation function used in SwiGLU.");
|
|
||||||
ops.def(
|
|
||||||
"gelu_and_mul",
|
|
||||||
&gelu_and_mul,
|
|
||||||
"Activation function used in GeGLU with `none` approximation.");
|
"Activation function used in GeGLU with `none` approximation.");
|
||||||
ops.def(
|
ops.def("gelu_tanh_and_mul", &gelu_tanh_and_mul,
|
||||||
"gelu_tanh_and_mul",
|
|
||||||
&gelu_tanh_and_mul,
|
|
||||||
"Activation function used in GeGLU with `tanh` approximation.");
|
"Activation function used in GeGLU with `tanh` approximation.");
|
||||||
ops.def(
|
ops.def("gelu_new", &gelu_new, "GELU implementation used in GPT-2.");
|
||||||
"gelu_new",
|
ops.def("gelu_fast", &gelu_fast, "Approximate GELU implementation.");
|
||||||
&gelu_new,
|
|
||||||
"GELU implementation used in GPT-2.");
|
|
||||||
ops.def(
|
|
||||||
"gelu_fast",
|
|
||||||
&gelu_fast,
|
|
||||||
"Approximate GELU implementation.");
|
|
||||||
|
|
||||||
// Layernorm
|
// Layernorm
|
||||||
ops.def(
|
ops.def("rms_norm", &rms_norm,
|
||||||
"rms_norm",
|
|
||||||
&rms_norm,
|
|
||||||
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
|
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
|
||||||
|
|
||||||
ops.def(
|
ops.def("fused_add_rms_norm", &fused_add_rms_norm,
|
||||||
"fused_add_rms_norm",
|
|
||||||
&fused_add_rms_norm,
|
|
||||||
"In-place fused Add and RMS Normalization");
|
"In-place fused Add and RMS Normalization");
|
||||||
|
|
||||||
// Rotary embedding
|
// Rotary embedding
|
||||||
ops.def(
|
ops.def("rotary_embedding", &rotary_embedding,
|
||||||
"rotary_embedding",
|
|
||||||
&rotary_embedding,
|
|
||||||
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
|
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
|
||||||
|
|
||||||
ops.def(
|
ops.def("batched_rotary_embedding", &batched_rotary_embedding,
|
||||||
"batched_rotary_embedding",
|
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key "
|
||||||
&batched_rotary_embedding,
|
"(supports multiple loras)");
|
||||||
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key (supports multiple loras)");
|
|
||||||
|
|
||||||
// Quantization ops
|
// Quantization ops
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
ops.def("aqlm_gemm", &aqlm_gemm, "Quantized GEMM for AQLM");
|
ops.def("aqlm_gemm", &aqlm_gemm, "Quantized GEMM for AQLM");
|
||||||
ops.def("aqlm_dequant", &aqlm_dequant, "Decompression method for AQLM");
|
ops.def("aqlm_dequant", &aqlm_dequant, "Decompression method for AQLM");
|
||||||
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
|
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
|
||||||
ops.def("marlin_gemm", &marlin_gemm, "Marlin (Dense) Optimized Quantized GEMM for GPTQ");
|
ops.def("marlin_gemm", &marlin_gemm,
|
||||||
ops.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm, "Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ");
|
"Marlin (Dense) Optimized Quantized GEMM for GPTQ");
|
||||||
ops.def("gptq_marlin_gemm", &gptq_marlin_gemm, "gptq_marlin Optimized Quantized GEMM for GPTQ");
|
ops.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm,
|
||||||
ops.def("gptq_marlin_repack", &gptq_marlin_repack, "gptq_marlin repack from GPTQ");
|
"Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ");
|
||||||
|
ops.def("gptq_marlin_gemm", &gptq_marlin_gemm,
|
||||||
|
"gptq_marlin Optimized Quantized GEMM for GPTQ");
|
||||||
|
ops.def("gptq_marlin_repack", &gptq_marlin_repack,
|
||||||
|
"gptq_marlin repack from GPTQ");
|
||||||
ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
|
ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
|
||||||
ops.def("cutlass_scaled_mm_dq", &cutlass_scaled_mm_dq, "CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column quantization.");
|
ops.def("cutlass_scaled_mm_dq", &cutlass_scaled_mm_dq,
|
||||||
|
"CUTLASS w8a8 GEMM, supporting symmetric per-tensor or "
|
||||||
|
"per-row/column quantization.");
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
|
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
|
||||||
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
|
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
|
||||||
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
|
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
|
||||||
ops.def("static_scaled_fp8_quant", &static_scaled_fp8_quant, "Compute FP8 quantized tensor for given scaling factor");
|
ops.def("static_scaled_fp8_quant", &static_scaled_fp8_quant,
|
||||||
ops.def("dynamic_scaled_fp8_quant", &dynamic_scaled_fp8_quant, "Compute FP8 quantized tensor and scaling factor");
|
"Compute FP8 quantized tensor for given scaling factor");
|
||||||
ops.def(
|
ops.def("dynamic_scaled_fp8_quant", &dynamic_scaled_fp8_quant,
|
||||||
"moe_align_block_size",
|
"Compute FP8 quantized tensor and scaling factor");
|
||||||
&moe_align_block_size,
|
ops.def("moe_align_block_size", &moe_align_block_size,
|
||||||
"Aligning the number of tokens to be processed by each expert such that it is divisible by the block size.");
|
"Aligning the number of tokens to be processed by each expert such "
|
||||||
|
"that it is divisible by the block size.");
|
||||||
|
|
||||||
// Cache ops
|
// Cache ops
|
||||||
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
|
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
|
||||||
cache_ops.def(
|
cache_ops.def("swap_blocks", &swap_blocks,
|
||||||
"swap_blocks",
|
|
||||||
&swap_blocks,
|
|
||||||
"Swap in (out) the cache blocks from src to dst");
|
"Swap in (out) the cache blocks from src to dst");
|
||||||
cache_ops.def(
|
cache_ops.def("copy_blocks", ©_blocks,
|
||||||
"copy_blocks",
|
|
||||||
©_blocks,
|
|
||||||
"Copy the cache blocks from src to dst");
|
"Copy the cache blocks from src to dst");
|
||||||
cache_ops.def(
|
cache_ops.def("reshape_and_cache", &reshape_and_cache,
|
||||||
"reshape_and_cache",
|
|
||||||
&reshape_and_cache,
|
|
||||||
"Reshape the key and value tensors and cache them");
|
"Reshape the key and value tensors and cache them");
|
||||||
cache_ops.def(
|
cache_ops.def("reshape_and_cache_flash", &reshape_and_cache_flash,
|
||||||
"reshape_and_cache_flash",
|
|
||||||
&reshape_and_cache_flash,
|
|
||||||
"Reshape the key and value tensors and cache them");
|
"Reshape the key and value tensors and cache them");
|
||||||
cache_ops.def(
|
cache_ops.def("convert_fp8", &convert_fp8,
|
||||||
"convert_fp8",
|
|
||||||
&convert_fp8,
|
|
||||||
"Convert the key and value cache to fp8 data type");
|
"Convert the key and value cache to fp8 data type");
|
||||||
|
|
||||||
// Cuda utils
|
// Cuda utils
|
||||||
pybind11::module cuda_utils = m.def_submodule("cuda_utils", "vLLM cuda utils");
|
pybind11::module cuda_utils =
|
||||||
cuda_utils.def(
|
m.def_submodule("cuda_utils", "vLLM cuda utils");
|
||||||
"get_device_attribute",
|
cuda_utils.def("get_device_attribute", &get_device_attribute,
|
||||||
&get_device_attribute,
|
|
||||||
"Gets the specified device attribute.");
|
"Gets the specified device attribute.");
|
||||||
|
|
||||||
cuda_utils.def(
|
cuda_utils.def("get_max_shared_memory_per_block_device_attribute",
|
||||||
"get_max_shared_memory_per_block_device_attribute",
|
|
||||||
&get_max_shared_memory_per_block_device_attribute,
|
&get_max_shared_memory_per_block_device_attribute,
|
||||||
"Gets the maximum shared memory per block device attribute.");
|
"Gets the maximum shared memory per block device attribute.");
|
||||||
|
|
||||||
@ -134,5 +105,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||||||
custom_ar.def("register_graph_buffers", ®ister_graph_buffers,
|
custom_ar.def("register_graph_buffers", ®ister_graph_buffers,
|
||||||
"register_graph_buffers");
|
"register_graph_buffers");
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -25,30 +25,26 @@
|
|||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
|
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
namespace aqlm {
|
namespace aqlm {
|
||||||
|
|
||||||
__global__ void Code1x16MatVec(
|
__global__ void Code1x16MatVec(
|
||||||
const int4* __restrict__ A,
|
const int4* __restrict__ A, const int4* __restrict__ B,
|
||||||
const int4* __restrict__ B,
|
int4* __restrict__ C, const int4* __restrict__ codebook, const int prob_m,
|
||||||
int4* __restrict__ C,
|
|
||||||
const int4* __restrict__ codebook,
|
|
||||||
const int prob_m,
|
|
||||||
const int prob_k,
|
const int prob_k,
|
||||||
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long.
|
const int4 codebook_a_sizes, // cumulative sizes of A spanning each
|
||||||
|
// codebook, at most 3 long.
|
||||||
const int codebook_stride // as int4.
|
const int codebook_stride // as int4.
|
||||||
) {
|
) {
|
||||||
int a_gl_stride = prob_k / 8 / 8;
|
int a_gl_stride = prob_k / 8 / 8;
|
||||||
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
|
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
|
||||||
bool pred = a_gl_rd < prob_m;
|
bool pred = a_gl_rd < prob_m;
|
||||||
|
|
||||||
if (pred)
|
if (pred) {
|
||||||
{
|
// advance to the correct codebook, this easy because we only multiply one
|
||||||
// advance to the correct codebook, this easy because we only multiply one column of the codebook.
|
// column of the codebook.
|
||||||
auto codebook_size = &codebook_a_sizes.x;
|
auto codebook_size = &codebook_a_sizes.x;
|
||||||
while (a_gl_rd >= *codebook_size)
|
while (a_gl_rd >= *codebook_size) {
|
||||||
{
|
|
||||||
codebook += codebook_stride;
|
codebook += codebook_stride;
|
||||||
++codebook_size;
|
++codebook_size;
|
||||||
}
|
}
|
||||||
@ -67,8 +63,7 @@ __global__ void Code1x16MatVec(
|
|||||||
// We pad shared memory to avoid bank conflicts during reads
|
// We pad shared memory to avoid bank conflicts during reads
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) {
|
for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) {
|
||||||
if (b_gl_rd + i < prob_k / 8)
|
if (b_gl_rd + i < prob_k / 8) sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i];
|
||||||
sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i];
|
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
b_gl_rd += 32 * 8;
|
b_gl_rd += 32 * 8;
|
||||||
@ -79,19 +74,16 @@ __global__ void Code1x16MatVec(
|
|||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < 8; i++) {
|
for (int i = 0; i < 8; i++) {
|
||||||
uint32_t dec[4];
|
uint32_t dec[4];
|
||||||
// We bypass the L1 cache to avoid massive amounts of memory streaming that doesn't
|
// We bypass the L1 cache to avoid massive amounts of memory streaming
|
||||||
// actually help us; this brings > 2x speedup.
|
// that doesn't actually help us; this brings > 2x speedup.
|
||||||
asm volatile (
|
asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
|
||||||
"ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
|
|
||||||
: "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3])
|
: "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3])
|
||||||
: "l"((void*) &codebook[enc[i]])
|
: "l"((void*)&codebook[enc[i]]));
|
||||||
);
|
|
||||||
half2* a = reinterpret_cast<half2*>(&dec);
|
half2* a = reinterpret_cast<half2*>(&dec);
|
||||||
half2* b = reinterpret_cast<half2*>(&sh_b[b_sh_rd]);
|
half2* b = reinterpret_cast<half2*>(&sh_b[b_sh_rd]);
|
||||||
half2 res2 = {};
|
half2 res2 = {};
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < 4; j++)
|
for (int j = 0; j < 4; j++) res2 = __hfma2(a[j], b[j], res2);
|
||||||
res2 = __hfma2(a[j], b[j], res2);
|
|
||||||
res += __half2float(res2.x) + __half2float(res2.y);
|
res += __half2float(res2.x) + __half2float(res2.y);
|
||||||
b_sh_rd++;
|
b_sh_rd++;
|
||||||
}
|
}
|
||||||
@ -101,21 +93,18 @@ __global__ void Code1x16MatVec(
|
|||||||
|
|
||||||
if (pred) {
|
if (pred) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 16; i > 0; i /= 2)
|
for (int i = 16; i > 0; i /= 2) res += __shfl_down_sync(0xffffffff, res, i);
|
||||||
res += __shfl_down_sync(0xffffffff, res, i);
|
|
||||||
if (threadIdx.x % 32 == 0)
|
if (threadIdx.x % 32 == 0)
|
||||||
reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res);
|
reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
__global__ void Code2x8MatVec(
|
__global__ void Code2x8MatVec(
|
||||||
const int4* __restrict__ A,
|
const int4* __restrict__ A, const int4* __restrict__ B,
|
||||||
const int4* __restrict__ B,
|
int4* __restrict__ C, const int4* __restrict__ codebook, int prob_m,
|
||||||
int4* __restrict__ C,
|
|
||||||
const int4* __restrict__ codebook,
|
|
||||||
int prob_m,
|
|
||||||
int prob_k,
|
int prob_k,
|
||||||
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long.
|
const int4 codebook_a_sizes, // cumulative sizes of A spanning each
|
||||||
|
// codebook, at most 3 long.
|
||||||
const int codebook_stride // as int4.
|
const int codebook_stride // as int4.
|
||||||
|
|
||||||
) {
|
) {
|
||||||
@ -123,12 +112,11 @@ __global__ void Code2x8MatVec(
|
|||||||
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
|
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
|
||||||
bool pred = a_gl_rd < prob_m;
|
bool pred = a_gl_rd < prob_m;
|
||||||
|
|
||||||
if (pred)
|
if (pred) {
|
||||||
{
|
// advance to the correct codebook, this easy because we only multiply one
|
||||||
// advance to the correct codebook, this easy because we only multiply one column of the codebook.
|
// column of the codebook.
|
||||||
auto codebook_size = &codebook_a_sizes.x;
|
auto codebook_size = &codebook_a_sizes.x;
|
||||||
while (a_gl_rd >= *codebook_size)
|
while (a_gl_rd >= *codebook_size) {
|
||||||
{
|
|
||||||
codebook += codebook_stride;
|
codebook += codebook_stride;
|
||||||
++codebook_size;
|
++codebook_size;
|
||||||
}
|
}
|
||||||
@ -149,8 +137,7 @@ __global__ void Code2x8MatVec(
|
|||||||
for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) {
|
for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) {
|
||||||
int4 dec = codebook[i];
|
int4 dec = codebook[i];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < 8; j++)
|
for (int j = 0; j < 8; j++) sh_code[8 * i + (j + lane) % 8] = dec;
|
||||||
sh_code[8 * i + (j + lane) % 8] = dec;
|
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
@ -161,8 +148,7 @@ __global__ void Code2x8MatVec(
|
|||||||
// We pad shared memory to avoid bank conflicts during reads
|
// We pad shared memory to avoid bank conflicts during reads
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) {
|
for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) {
|
||||||
if (b_gl_rd + i < prob_k / 8)
|
if (b_gl_rd + i < prob_k / 8) sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i];
|
||||||
sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i];
|
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
b_gl_rd += 32 * 8;
|
b_gl_rd += 32 * 8;
|
||||||
@ -172,8 +158,10 @@ __global__ void Code2x8MatVec(
|
|||||||
const uint8_t* enc = reinterpret_cast<const uint8_t*>(&A[a_gl_rd]);
|
const uint8_t* enc = reinterpret_cast<const uint8_t*>(&A[a_gl_rd]);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < 8; i++) {
|
for (int i = 0; i < 8; i++) {
|
||||||
half2* a0 = reinterpret_cast<half2*>(&sh_code0[8 * enc[2 * i + 0] + lane]);
|
half2* a0 =
|
||||||
half2* a1 = reinterpret_cast<half2*>(&sh_code1[8 * enc[2 * i + 1] + lane]);
|
reinterpret_cast<half2*>(&sh_code0[8 * enc[2 * i + 0] + lane]);
|
||||||
|
half2* a1 =
|
||||||
|
reinterpret_cast<half2*>(&sh_code1[8 * enc[2 * i + 1] + lane]);
|
||||||
half2* b = reinterpret_cast<half2*>(&sh_b[b_sh_rd]);
|
half2* b = reinterpret_cast<half2*>(&sh_b[b_sh_rd]);
|
||||||
half2 res2 = {};
|
half2 res2 = {};
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
@ -188,33 +176,28 @@ __global__ void Code2x8MatVec(
|
|||||||
|
|
||||||
if (pred) {
|
if (pred) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 16; i > 0; i /= 2)
|
for (int i = 16; i > 0; i /= 2) res += __shfl_down_sync(0xffffffff, res, i);
|
||||||
res += __shfl_down_sync(0xffffffff, res, i);
|
|
||||||
if (threadIdx.x % 32 == 0)
|
if (threadIdx.x % 32 == 0)
|
||||||
reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res);
|
reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
__global__ void Code1x16Dequant(
|
__global__ void Code1x16Dequant(
|
||||||
const int4* __restrict__ A,
|
const int4* __restrict__ A, int4* __restrict__ C,
|
||||||
int4* __restrict__ C,
|
const int4* __restrict__ codebook, int prob_m, int prob_k,
|
||||||
const int4* __restrict__ codebook,
|
const int4 codebook_a_sizes, // cumulative sizes of A spanning each
|
||||||
int prob_m,
|
// codebook, at most 3 long, sums to m.
|
||||||
int prob_k,
|
|
||||||
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long, sums to m.
|
|
||||||
const int codebook_stride // as int4
|
const int codebook_stride // as int4
|
||||||
) {
|
) {
|
||||||
int a_gl_stride = prob_k / 8 / 8;
|
int a_gl_stride = prob_k / 8 / 8;
|
||||||
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
|
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
|
||||||
bool pred = a_gl_rd < prob_m;
|
bool pred = a_gl_rd < prob_m;
|
||||||
|
|
||||||
if (pred)
|
if (pred) {
|
||||||
{
|
// advance to the correct codebook, this easy because we only multiply one
|
||||||
// advance to the correct codebook, this easy because we only multiply one column of the codebook.
|
// column of the codebook.
|
||||||
auto codebook_size = &codebook_a_sizes.x;
|
auto codebook_size = &codebook_a_sizes.x;
|
||||||
while (a_gl_rd >= *codebook_size)
|
while (a_gl_rd >= *codebook_size) {
|
||||||
{
|
|
||||||
codebook += codebook_stride;
|
codebook += codebook_stride;
|
||||||
++codebook_size;
|
++codebook_size;
|
||||||
}
|
}
|
||||||
@ -235,13 +218,11 @@ __global__ void Code1x16Dequant(
|
|||||||
for (int i = 0; i < 8; i++) {
|
for (int i = 0; i < 8; i++) {
|
||||||
int4 chunk;
|
int4 chunk;
|
||||||
auto dec = reinterpret_cast<uint32_t*>(&chunk);
|
auto dec = reinterpret_cast<uint32_t*>(&chunk);
|
||||||
// We bypass the L1 cache to avoid massive amounts of memory streaming that doesn't
|
// We bypass the L1 cache to avoid massive amounts of memory streaming
|
||||||
// actually help us; this brings > 2x speedup.
|
// that doesn't actually help us; this brings > 2x speedup.
|
||||||
asm volatile (
|
asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
|
||||||
"ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
|
|
||||||
: "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3])
|
: "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3])
|
||||||
: "l"((void*) &codebook[enc[i]])
|
: "l"((void*)&codebook[enc[i]]));
|
||||||
);
|
|
||||||
|
|
||||||
C[a_gl_rd * 8 + i] = chunk;
|
C[a_gl_rd * 8 + i] = chunk;
|
||||||
}
|
}
|
||||||
@ -250,26 +231,23 @@ __global__ void Code1x16Dequant(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
__global__ void Code2x8Dequant(
|
__global__ void Code2x8Dequant(
|
||||||
const int4* __restrict__ A,
|
const int4* __restrict__ A, int4* __restrict__ C,
|
||||||
int4* __restrict__ C,
|
const int4* __restrict__ codebook, int prob_m, int prob_k,
|
||||||
const int4* __restrict__ codebook,
|
const int4
|
||||||
int prob_m,
|
codebook_a_sizes, // cumulative sizes of A spanning each codebook, at
|
||||||
int prob_k,
|
// most 3 long, corresponds to cols.
|
||||||
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long, corresponds to cols.
|
|
||||||
const int codebook_stride // as int4
|
const int codebook_stride // as int4
|
||||||
) {
|
) {
|
||||||
int a_gl_stride = prob_k / 8 / 8;
|
int a_gl_stride = prob_k / 8 / 8;
|
||||||
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
|
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
|
||||||
bool pred = a_gl_rd < prob_m;
|
bool pred = a_gl_rd < prob_m;
|
||||||
|
|
||||||
if (pred)
|
if (pred) {
|
||||||
{
|
// advance to the correct codebook, this easy because we only multiply one
|
||||||
// advance to the correct codebook, this easy because we only multiply one column of the codebook.
|
// column of the codebook.
|
||||||
auto codebook_size = &codebook_a_sizes.x;
|
auto codebook_size = &codebook_a_sizes.x;
|
||||||
while (a_gl_rd >= *codebook_size)
|
while (a_gl_rd >= *codebook_size) {
|
||||||
{
|
|
||||||
codebook += codebook_stride;
|
codebook += codebook_stride;
|
||||||
++codebook_size;
|
++codebook_size;
|
||||||
}
|
}
|
||||||
@ -291,8 +269,7 @@ __global__ void Code2x8Dequant(
|
|||||||
for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) {
|
for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) {
|
||||||
int4 dec = codebook[i];
|
int4 dec = codebook[i];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < 8; j++)
|
for (int j = 0; j < 8; j++) sh_code[8 * i + (j + lane) % 8] = dec;
|
||||||
sh_code[8 * i + (j + lane) % 8] = dec;
|
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
@ -305,8 +282,10 @@ __global__ void Code2x8Dequant(
|
|||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < 8; i++) {
|
for (int i = 0; i < 8; i++) {
|
||||||
int4 chunk;
|
int4 chunk;
|
||||||
half2* a0 = reinterpret_cast<half2*>(&sh_code0[8 * enc[2 * i + 0] + lane]);
|
half2* a0 =
|
||||||
half2* a1 = reinterpret_cast<half2*>(&sh_code1[8 * enc[2 * i + 1] + lane]);
|
reinterpret_cast<half2*>(&sh_code0[8 * enc[2 * i + 0] + lane]);
|
||||||
|
half2* a1 =
|
||||||
|
reinterpret_cast<half2*>(&sh_code1[8 * enc[2 * i + 1] + lane]);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < 4; j++)
|
for (int j = 0; j < 4; j++)
|
||||||
reinterpret_cast<half2*>(&chunk)[j] = __hadd2(a0[j], a1[j]);
|
reinterpret_cast<half2*>(&chunk)[j] = __hadd2(a0[j], a1[j]);
|
||||||
@ -317,22 +296,15 @@ __global__ void Code2x8Dequant(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inline int ceildiv(int a, int b) {
|
inline int ceildiv(int a, int b) { return (a + b - 1) / b; }
|
||||||
return (a + b - 1) / b;
|
|
||||||
}
|
|
||||||
|
|
||||||
const int THREAD_M = 16;
|
const int THREAD_M = 16;
|
||||||
|
|
||||||
void code1x16_matvec_cuda(
|
void code1x16_matvec_cuda(const void* __restrict__ A,
|
||||||
const void* __restrict__ A,
|
const void* __restrict__ B, void* __restrict__ C,
|
||||||
const void* __restrict__ B,
|
const void* __restrict__ codebook, int prob_m,
|
||||||
void* __restrict__ C,
|
int prob_k, const int4 codebook_a_sizes,
|
||||||
const void* __restrict__ codebook,
|
const int codebook_stride) {
|
||||||
int prob_m,
|
|
||||||
int prob_k,
|
|
||||||
const int4 codebook_a_sizes,
|
|
||||||
const int codebook_stride
|
|
||||||
) {
|
|
||||||
int sms;
|
int sms;
|
||||||
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
|
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
|
||||||
int waves = 0;
|
int waves = 0;
|
||||||
@ -346,27 +318,15 @@ void code1x16_matvec_cuda(
|
|||||||
int threads = 32 * thread_m;
|
int threads = 32 * thread_m;
|
||||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||||
Code1x16MatVec<<<blocks, threads, 16 * 32 * 9, stream>>>(
|
Code1x16MatVec<<<blocks, threads, 16 * 32 * 9, stream>>>(
|
||||||
(const int4*) A,
|
(const int4*)A, (const int4*)B, (int4*)C, (const int4*)codebook, prob_m,
|
||||||
(const int4*) B,
|
prob_k, codebook_a_sizes, codebook_stride);
|
||||||
(int4*) C,
|
|
||||||
(const int4*) codebook,
|
|
||||||
prob_m,
|
|
||||||
prob_k,
|
|
||||||
codebook_a_sizes,
|
|
||||||
codebook_stride
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void code2x8_matvec_cuda(
|
void code2x8_matvec_cuda(const void* __restrict__ A, const void* __restrict__ B,
|
||||||
const void* __restrict__ A,
|
|
||||||
const void* __restrict__ B,
|
|
||||||
void* __restrict__ C,
|
void* __restrict__ C,
|
||||||
const void* __restrict__ codebook,
|
const void* __restrict__ codebook, int prob_m,
|
||||||
int prob_m,
|
int prob_k, const int4 codebook_a_sizes,
|
||||||
int prob_k,
|
const int codebook_stride) {
|
||||||
const int4 codebook_a_sizes,
|
|
||||||
const int codebook_stride
|
|
||||||
) {
|
|
||||||
int sms;
|
int sms;
|
||||||
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
|
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
|
||||||
int waves = 0;
|
int waves = 0;
|
||||||
@ -379,29 +339,19 @@ void code2x8_matvec_cuda(
|
|||||||
int blocks = ceildiv(prob_m, thread_m);
|
int blocks = ceildiv(prob_m, thread_m);
|
||||||
int threads = 32 * thread_m;
|
int threads = 32 * thread_m;
|
||||||
int shared = 16 * (2 * 256 * 8 + 32 * 9);
|
int shared = 16 * (2 * 256 * 8 + 32 * 9);
|
||||||
cudaFuncSetAttribute(
|
cudaFuncSetAttribute(Code2x8MatVec,
|
||||||
Code2x8MatVec, cudaFuncAttributeMaxDynamicSharedMemorySize, shared
|
cudaFuncAttributeMaxDynamicSharedMemorySize, shared);
|
||||||
);
|
|
||||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||||
Code2x8MatVec<<<blocks, threads, shared, stream>>>(
|
Code2x8MatVec<<<blocks, threads, shared, stream>>>(
|
||||||
(const int4*) A,
|
(const int4*)A, (const int4*)B, (int4*)C, (const int4*)codebook, prob_m,
|
||||||
(const int4*) B,
|
prob_k, codebook_a_sizes, codebook_stride);
|
||||||
(int4*) C,
|
|
||||||
(const int4*) codebook,
|
|
||||||
prob_m,
|
|
||||||
prob_k,
|
|
||||||
codebook_a_sizes,
|
|
||||||
codebook_stride
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void code1x16_dequant_cuda(
|
void code1x16_dequant_cuda(
|
||||||
const void* __restrict__ A,
|
const void* __restrict__ A, void* __restrict__ C,
|
||||||
void* __restrict__ C,
|
const void* __restrict__ codebook, int prob_m, int prob_k,
|
||||||
const void* __restrict__ codebook,
|
const int4 codebook_a_sizes, // cumulative sizes of A spanning each
|
||||||
int prob_m,
|
// codebook, at most 3 long.
|
||||||
int prob_k,
|
|
||||||
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long.
|
|
||||||
const int codebook_stride // as int4.
|
const int codebook_stride // as int4.
|
||||||
) {
|
) {
|
||||||
int sms;
|
int sms;
|
||||||
@ -417,24 +367,20 @@ void code1x16_dequant_cuda(
|
|||||||
int threads = 32 * thread_m;
|
int threads = 32 * thread_m;
|
||||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||||
Code1x16Dequant<<<blocks, threads, 0, stream>>>(
|
Code1x16Dequant<<<blocks, threads, 0, stream>>>(
|
||||||
(const int4*) A,
|
(const int4*)A, (int4*)C, (const int4*)codebook, prob_m, prob_k,
|
||||||
(int4*) C,
|
codebook_a_sizes, // cumulative sizes of A spanning each codebook, at
|
||||||
(const int4*) codebook,
|
// most 3 long.
|
||||||
prob_m,
|
|
||||||
prob_k,
|
|
||||||
codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long.
|
|
||||||
codebook_stride // as int4.
|
codebook_stride // as int4.
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Dequantizes the code and codebook into weights.
|
// Dequantizes the code and codebook into weights.
|
||||||
void code2x8_dequant_cuda(
|
void code2x8_dequant_cuda(
|
||||||
const void* __restrict__ A,
|
const void* __restrict__ A, void* __restrict__ C,
|
||||||
void* __restrict__ C,
|
const void* __restrict__ codebook, int prob_m, int prob_k,
|
||||||
const void* __restrict__ codebook,
|
const int4
|
||||||
int prob_m,
|
codebook_a_sizes, // cumulative sizes of A spanning each codebook, at
|
||||||
int prob_k,
|
// most 3 long, corresponds to cols.
|
||||||
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long, corresponds to cols.
|
|
||||||
const int codebook_stride // as int4
|
const int codebook_stride // as int4
|
||||||
) {
|
) {
|
||||||
int sms;
|
int sms;
|
||||||
@ -451,50 +397,33 @@ void code2x8_dequant_cuda(
|
|||||||
int shared = 16 * (2 * 256 * 8 + 32 * 9);
|
int shared = 16 * (2 * 256 * 8 + 32 * 9);
|
||||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||||
|
|
||||||
cudaFuncSetAttribute(
|
cudaFuncSetAttribute(Code2x8Dequant,
|
||||||
Code2x8Dequant, cudaFuncAttributeMaxDynamicSharedMemorySize, shared
|
cudaFuncAttributeMaxDynamicSharedMemorySize, shared);
|
||||||
);
|
|
||||||
Code2x8Dequant<<<blocks, threads, shared, stream>>>(
|
Code2x8Dequant<<<blocks, threads, shared, stream>>>(
|
||||||
(const int4*) A,
|
(const int4*)A, (int4*)C, (const int4*)codebook, prob_m, prob_k,
|
||||||
(int4*) C,
|
codebook_a_sizes, codebook_stride);
|
||||||
(const int4*) codebook,
|
|
||||||
prob_m,
|
|
||||||
prob_k,
|
|
||||||
codebook_a_sizes,
|
|
||||||
codebook_stride
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int codebook_stride(const torch::Tensor& codebooks)
|
int codebook_stride(const torch::Tensor& codebooks) {
|
||||||
{
|
|
||||||
return codebooks.stride(0) * codebooks.element_size() / sizeof(int4);
|
return codebooks.stride(0) * codebooks.element_size() / sizeof(int4);
|
||||||
}
|
}
|
||||||
|
|
||||||
void code1x16_matvec(
|
void code1x16_matvec(
|
||||||
const torch::Tensor& A,
|
const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& C,
|
||||||
const torch::Tensor& B,
|
|
||||||
torch::Tensor& C,
|
|
||||||
const torch::Tensor& codebook,
|
const torch::Tensor& codebook,
|
||||||
const int4 codebook_a_sizes // cumulative sizes of A spanning each codebook, at most 3 long.
|
const int4 codebook_a_sizes // cumulative sizes of A spanning each
|
||||||
|
// codebook, at most 3 long.
|
||||||
) {
|
) {
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
|
||||||
int prob_m = C.size(0);
|
int prob_m = C.size(0);
|
||||||
int prob_k = B.size(0);
|
int prob_k = B.size(0);
|
||||||
|
|
||||||
code1x16_matvec_cuda(
|
code1x16_matvec_cuda(A.data_ptr(), B.data_ptr(), C.data_ptr(),
|
||||||
A.data_ptr(),
|
codebook.data_ptr(), prob_m, prob_k, codebook_a_sizes,
|
||||||
B.data_ptr(),
|
codebook_stride(codebook));
|
||||||
C.data_ptr(),
|
|
||||||
codebook.data_ptr(),
|
|
||||||
prob_m,
|
|
||||||
prob_k,
|
|
||||||
codebook_a_sizes,
|
|
||||||
codebook_stride(codebook)
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::Tensor code1x16_matmat(
|
torch::Tensor code1x16_matmat(const torch::Tensor& input,
|
||||||
const torch::Tensor& input,
|
|
||||||
const torch::Tensor& codes,
|
const torch::Tensor& codes,
|
||||||
const torch::Tensor& codebooks,
|
const torch::Tensor& codebooks,
|
||||||
const torch::Tensor& scales,
|
const torch::Tensor& scales,
|
||||||
@ -503,22 +432,15 @@ torch::Tensor code1x16_matmat(
|
|||||||
auto input_sizes = input.sizes();
|
auto input_sizes = input.sizes();
|
||||||
auto out_features = codes.size(0) * codebooks.size(2);
|
auto out_features = codes.size(0) * codebooks.size(2);
|
||||||
auto flat_input = input.reshape({-1, input.size(-1)});
|
auto flat_input = input.reshape({-1, input.size(-1)});
|
||||||
auto flat_output = torch::empty({flat_input.size(0), out_features},
|
auto flat_output = torch::empty(
|
||||||
torch::TensorOptions()
|
{flat_input.size(0), out_features},
|
||||||
.dtype(input.dtype())
|
torch::TensorOptions().dtype(input.dtype()).device(input.device()));
|
||||||
.device(input.device())
|
|
||||||
);
|
|
||||||
|
|
||||||
for (int i = 0; i < flat_input.size(0); ++i) {
|
for (int i = 0; i < flat_input.size(0); ++i) {
|
||||||
auto input_vec = flat_input.index({i});
|
auto input_vec = flat_input.index({i});
|
||||||
auto output_vec = flat_output.index({i});
|
auto output_vec = flat_output.index({i});
|
||||||
code1x16_matvec(
|
code1x16_matvec(codes.squeeze(2), input_vec, output_vec, codebooks,
|
||||||
codes.squeeze(2),
|
codebook_a_sizes);
|
||||||
input_vec,
|
|
||||||
output_vec,
|
|
||||||
codebooks,
|
|
||||||
codebook_a_sizes
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
flat_output *= scales.flatten().unsqueeze(0);
|
flat_output *= scales.flatten().unsqueeze(0);
|
||||||
|
|
||||||
@ -533,55 +455,35 @@ torch::Tensor code1x16_matmat(
|
|||||||
return output;
|
return output;
|
||||||
}
|
}
|
||||||
|
|
||||||
void code2x8_matvec(
|
void code2x8_matvec(const torch::Tensor& A, const torch::Tensor& B,
|
||||||
const torch::Tensor& A,
|
torch::Tensor& C, const torch::Tensor& codebook,
|
||||||
const torch::Tensor& B,
|
const int4 codebook_a_sizes) {
|
||||||
torch::Tensor& C,
|
|
||||||
const torch::Tensor& codebook,
|
|
||||||
const int4 codebook_a_sizes
|
|
||||||
) {
|
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
|
||||||
int prob_m = C.size(0);
|
int prob_m = C.size(0);
|
||||||
int prob_k = B.size(0);
|
int prob_k = B.size(0);
|
||||||
code2x8_matvec_cuda(
|
code2x8_matvec_cuda(A.data_ptr(), B.data_ptr(), C.data_ptr(),
|
||||||
A.data_ptr(),
|
codebook.data_ptr(), prob_m, prob_k, codebook_a_sizes,
|
||||||
B.data_ptr(),
|
2 * codebook_stride(codebook));
|
||||||
C.data_ptr(),
|
|
||||||
codebook.data_ptr(),
|
|
||||||
prob_m,
|
|
||||||
prob_k,
|
|
||||||
codebook_a_sizes,
|
|
||||||
2 * codebook_stride(codebook)
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::Tensor code2x8_matmat(
|
torch::Tensor code2x8_matmat(const torch::Tensor& input,
|
||||||
const torch::Tensor& input,
|
|
||||||
const torch::Tensor& codes,
|
const torch::Tensor& codes,
|
||||||
const torch::Tensor& codebooks,
|
const torch::Tensor& codebooks,
|
||||||
const torch::Tensor& scales,
|
const torch::Tensor& scales,
|
||||||
const int4 codebook_a_sizes,
|
const int4 codebook_a_sizes,
|
||||||
const std::optional<torch::Tensor>& bias
|
const std::optional<torch::Tensor>& bias) {
|
||||||
) {
|
|
||||||
auto input_sizes = input.sizes();
|
auto input_sizes = input.sizes();
|
||||||
auto out_features = codes.size(0) * codebooks.size(2);
|
auto out_features = codes.size(0) * codebooks.size(2);
|
||||||
auto flat_input = input.reshape({-1, input.size(-1)});
|
auto flat_input = input.reshape({-1, input.size(-1)});
|
||||||
auto flat_output = torch::empty({flat_input.size(0), out_features},
|
auto flat_output = torch::empty(
|
||||||
torch::TensorOptions()
|
{flat_input.size(0), out_features},
|
||||||
.dtype(input.dtype())
|
torch::TensorOptions().dtype(input.dtype()).device(input.device()));
|
||||||
.device(input.device())
|
|
||||||
);
|
|
||||||
|
|
||||||
for (int i = 0; i < flat_input.size(0); ++i) {
|
for (int i = 0; i < flat_input.size(0); ++i) {
|
||||||
auto input_vec = flat_input.index({i});
|
auto input_vec = flat_input.index({i});
|
||||||
auto output_vec = flat_output.index({i});
|
auto output_vec = flat_output.index({i});
|
||||||
code2x8_matvec(
|
code2x8_matvec(codes.squeeze(2), input_vec, output_vec, codebooks,
|
||||||
codes.squeeze(2),
|
codebook_a_sizes);
|
||||||
input_vec,
|
|
||||||
output_vec,
|
|
||||||
codebooks,
|
|
||||||
codebook_a_sizes
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
flat_output *= scales.flatten().unsqueeze(0);
|
flat_output *= scales.flatten().unsqueeze(0);
|
||||||
if (bias.has_value()) {
|
if (bias.has_value()) {
|
||||||
@ -596,21 +498,18 @@ torch::Tensor code2x8_matmat(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Accumulate the partition sizes.
|
// Accumulate the partition sizes.
|
||||||
int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes)
|
int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes) {
|
||||||
{
|
|
||||||
int4 cumulative_sizes;
|
int4 cumulative_sizes;
|
||||||
auto cumulative_size = &cumulative_sizes.x;
|
auto cumulative_size = &cumulative_sizes.x;
|
||||||
int i = 0;
|
int i = 0;
|
||||||
int last = 0;
|
int last = 0;
|
||||||
assert(codebook_partition_sizes.size(0) <= 4);
|
assert(codebook_partition_sizes.size(0) <= 4);
|
||||||
for (; i < codebook_partition_sizes.size(0); ++i, ++cumulative_size)
|
for (; i < codebook_partition_sizes.size(0); ++i, ++cumulative_size) {
|
||||||
{
|
|
||||||
*cumulative_size = codebook_partition_sizes[i].item<int>() + last;
|
*cumulative_size = codebook_partition_sizes[i].item<int>() + last;
|
||||||
last = *cumulative_size;
|
last = *cumulative_size;
|
||||||
}
|
}
|
||||||
// fill in the rest with unreachable.
|
// fill in the rest with unreachable.
|
||||||
for (; i < 4; ++i, ++cumulative_size)
|
for (; i < 4; ++i, ++cumulative_size) {
|
||||||
{
|
|
||||||
*cumulative_size = last * 10;
|
*cumulative_size = last * 10;
|
||||||
}
|
}
|
||||||
return cumulative_sizes;
|
return cumulative_sizes;
|
||||||
@ -619,41 +518,36 @@ int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes)
|
|||||||
} // namespace aqlm
|
} // namespace aqlm
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
|
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
|
||||||
torch::Tensor aqlm_gemm(
|
|
||||||
const torch::Tensor& input,
|
|
||||||
const torch::Tensor& codes,
|
|
||||||
const torch::Tensor& codebooks,
|
const torch::Tensor& codebooks,
|
||||||
const torch::Tensor& scales,
|
const torch::Tensor& scales,
|
||||||
const torch::Tensor& codebook_partition_sizes,
|
const torch::Tensor& codebook_partition_sizes,
|
||||||
const std::optional<torch::Tensor>& bias
|
const std::optional<torch::Tensor>& bias) {
|
||||||
)
|
int4 cumulative_sizes =
|
||||||
{
|
vllm::aqlm::accumulate_sizes(codebook_partition_sizes);
|
||||||
int4 cumulative_sizes = vllm::aqlm::accumulate_sizes(codebook_partition_sizes);
|
|
||||||
|
|
||||||
int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0);
|
int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0);
|
||||||
int const entries = codebooks.size(1);
|
int const entries = codebooks.size(1);
|
||||||
|
|
||||||
if (nbooks == 1 && entries == (1 << 16))
|
if (nbooks == 1 && entries == (1 << 16)) {
|
||||||
{
|
return vllm::aqlm::code1x16_matmat(input, codes, codebooks, scales,
|
||||||
return vllm::aqlm::code1x16_matmat(input, codes, codebooks, scales, cumulative_sizes, bias);
|
cumulative_sizes, bias);
|
||||||
}
|
}
|
||||||
if (nbooks == 2 && entries == (1 << 8))
|
if (nbooks == 2 && entries == (1 << 8)) {
|
||||||
{
|
return vllm::aqlm::code2x8_matmat(input, codes, codebooks, scales,
|
||||||
return vllm::aqlm::code2x8_matmat(input, codes, codebooks, scales, cumulative_sizes, bias);
|
cumulative_sizes, bias);
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, " entries is not currently supported.")
|
TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries,
|
||||||
|
" entries is not currently supported.")
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::Tensor aqlm_dequant(
|
torch::Tensor aqlm_dequant(const torch::Tensor& codes,
|
||||||
const torch::Tensor& codes,
|
|
||||||
const torch::Tensor& codebooks,
|
const torch::Tensor& codebooks,
|
||||||
const torch::Tensor& codebook_partition_sizes
|
const torch::Tensor& codebook_partition_sizes) {
|
||||||
)
|
int4 cumulative_sizes =
|
||||||
{
|
vllm::aqlm::accumulate_sizes(codebook_partition_sizes);
|
||||||
int4 cumulative_sizes = vllm::aqlm::accumulate_sizes(codebook_partition_sizes);
|
|
||||||
|
|
||||||
int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0);
|
int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0);
|
||||||
int const entries = codebooks.size(1);
|
int const entries = codebooks.size(1);
|
||||||
@ -670,43 +564,35 @@ torch::Tensor aqlm_dequant(
|
|||||||
auto weights = torch::empty({out_features, in_features},
|
auto weights = torch::empty({out_features, in_features},
|
||||||
torch::TensorOptions()
|
torch::TensorOptions()
|
||||||
.dtype(codebooks.dtype())
|
.dtype(codebooks.dtype())
|
||||||
.device(codebooks.device())
|
.device(codebooks.device()));
|
||||||
);
|
|
||||||
|
|
||||||
if (nbooks == 1 && entries == (1 << 16))
|
if (nbooks == 1 && entries == (1 << 16)) {
|
||||||
{
|
vllm::aqlm::code1x16_dequant_cuda(codes.data_ptr(), weights.data_ptr(),
|
||||||
vllm::aqlm::code1x16_dequant_cuda(
|
codebooks.data_ptr(), out_features,
|
||||||
codes.data_ptr(),
|
in_features, cumulative_sizes,
|
||||||
weights.data_ptr(),
|
|
||||||
codebooks.data_ptr(),
|
|
||||||
out_features,
|
|
||||||
in_features,
|
|
||||||
cumulative_sizes,
|
|
||||||
vllm::aqlm::codebook_stride(codebooks));
|
vllm::aqlm::codebook_stride(codebooks));
|
||||||
|
|
||||||
// if you wanted to flip to scaling the weights, (though it's 30%-ish slower and not consistent with gemv implementation.)
|
// if you wanted to flip to scaling the weights, (though it's 30%-ish slower
|
||||||
// weights *= scales.index({"...", 0, 0});
|
// and not consistent with gemv implementation.) weights *=
|
||||||
|
// scales.index({"...", 0, 0});
|
||||||
|
|
||||||
return weights;
|
return weights;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (nbooks == 2 && entries == (1 << 8))
|
if (nbooks == 2 && entries == (1 << 8)) {
|
||||||
{
|
vllm::aqlm::code2x8_dequant_cuda(codes.data_ptr(), weights.data_ptr(),
|
||||||
vllm::aqlm::code2x8_dequant_cuda(
|
codebooks.data_ptr(), out_features,
|
||||||
codes.data_ptr(),
|
in_features, cumulative_sizes,
|
||||||
weights.data_ptr(),
|
|
||||||
codebooks.data_ptr(),
|
|
||||||
out_features,
|
|
||||||
in_features,
|
|
||||||
cumulative_sizes,
|
|
||||||
vllm::aqlm::codebook_stride(codebooks));
|
vllm::aqlm::codebook_stride(codebooks));
|
||||||
|
|
||||||
// if you wanted to flip to scaling the weights, (though it's 30%-ish slower and not consistent with gemv implementation)
|
// if you wanted to flip to scaling the weights, (though it's 30%-ish slower
|
||||||
// weights *= scales.index({"...", 0, 0});
|
// and not consistent with gemv implementation) weights *=
|
||||||
|
// scales.index({"...", 0, 0});
|
||||||
|
|
||||||
return weights;
|
return weights;
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, " entries is not currently supported.")
|
TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries,
|
||||||
|
" entries is not currently supported.")
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
/*
|
/*
|
||||||
Adapted from https://github.com/mit-han-lab/llm-awq
|
Adapted from https://github.com/mit-han-lab/llm-awq
|
||||||
Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
Modified from NVIDIA FasterTransformer:
|
||||||
|
https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
||||||
@article{lin2023awq,
|
@article{lin2023awq,
|
||||||
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
|
title={AWQ: Activation-aware Weight Quantization for LLM Compression and
|
||||||
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
|
Acceleration}, author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang,
|
||||||
journal={arXiv},
|
Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023}
|
||||||
year={2023}
|
|
||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
|
|
||||||
@ -14,8 +14,7 @@ Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransfor
|
|||||||
namespace vllm {
|
namespace vllm {
|
||||||
namespace awq {
|
namespace awq {
|
||||||
|
|
||||||
__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
|
__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) {
|
||||||
{
|
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
|
||||||
assert(false);
|
assert(false);
|
||||||
#else
|
#else
|
||||||
@ -30,33 +29,40 @@ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
|
|||||||
static constexpr uint32_t TOP_MASK = 0x00f000f0;
|
static constexpr uint32_t TOP_MASK = 0x00f000f0;
|
||||||
static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
|
static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
|
||||||
|
|
||||||
// Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing
|
// Note that the entire sequence only requires 1 shift instruction. This is
|
||||||
// format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions.
|
// thanks to the register packing format and the fact that we force our
|
||||||
// In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and
|
// integers to be unsigned, and account for this in the fp16 subtractions. In
|
||||||
// elt_67 to fp16 without having to shift them to the bottom bits before hand.
|
// addition, I exploit the fact that sub and fma have the same throughput in
|
||||||
|
// order to convert elt_23 and elt_67 to fp16 without having to shift them to
|
||||||
|
// the bottom bits before hand.
|
||||||
|
|
||||||
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue
|
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW
|
||||||
// immediately before required.
|
// dependency if we issue immediately before required.
|
||||||
const uint32_t top_i4s = i4s >> 8;
|
const uint32_t top_i4s = i4s >> 8;
|
||||||
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
|
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
|
||||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||||
: "=r"(h[0])
|
: "=r"(h[0])
|
||||||
: "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
: "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM),
|
||||||
|
"n"(immLut));
|
||||||
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
|
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
|
||||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||||
: "=r"(h[1])
|
: "=r"(h[1])
|
||||||
: "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
: "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM),
|
||||||
|
"n"(immLut));
|
||||||
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
|
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
|
||||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||||
: "=r"(h[2])
|
: "=r"(h[2])
|
||||||
: "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
: "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM),
|
||||||
|
"n"(immLut));
|
||||||
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
|
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
|
||||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||||
: "=r"(h[3])
|
: "=r"(h[3])
|
||||||
: "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
: "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM),
|
||||||
|
"n"(immLut));
|
||||||
|
|
||||||
// I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the
|
// I use inline PTX below because I am not sure if the compiler will emit
|
||||||
// half2 ctor. In this case, I chose performance reliability over code readability.
|
// float2half instructions if I use the half2 ctor. In this case, I chose
|
||||||
|
// performance reliability over code readability.
|
||||||
|
|
||||||
// This is the half2 {1032, 1032} represented as an integer.
|
// This is the half2 {1032, 1032} represented as an integer.
|
||||||
// static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
|
// static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
|
||||||
@ -71,13 +77,21 @@ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
|
|||||||
|
|
||||||
// Finally, we construct the output numbers.
|
// Finally, we construct the output numbers.
|
||||||
// Convert elt_01
|
// Convert elt_01
|
||||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
|
asm volatile("sub.f16x2 %0, %1, %2;\n"
|
||||||
|
: "=r"(h[0])
|
||||||
|
: "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
|
||||||
// Convert elt_23
|
// Convert elt_23
|
||||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
|
||||||
|
: "=r"(h[1])
|
||||||
|
: "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
|
||||||
// Convert elt_45
|
// Convert elt_45
|
||||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
|
asm volatile("sub.f16x2 %0, %1, %2;\n"
|
||||||
|
: "=r"(h[2])
|
||||||
|
: "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
|
||||||
// Convert elt_67
|
// Convert elt_67
|
||||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
|
||||||
|
: "=r"(h[3])
|
||||||
|
: "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
#endif
|
#endif
|
||||||
|
@ -1,14 +1,12 @@
|
|||||||
/*
|
/*
|
||||||
Adapted from https://github.com/mit-han-lab/llm-awq
|
Adapted from https://github.com/mit-han-lab/llm-awq
|
||||||
@article{lin2023awq,
|
@article{lin2023awq,
|
||||||
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
|
title={AWQ: Activation-aware Weight Quantization for LLM Compression and
|
||||||
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
|
Acceleration}, author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang,
|
||||||
journal={arXiv},
|
Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023}
|
||||||
year={2023}
|
|
||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
|
||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
|
||||||
@ -20,26 +18,20 @@ namespace vllm {
|
|||||||
namespace awq {
|
namespace awq {
|
||||||
|
|
||||||
// Pack two half values.
|
// Pack two half values.
|
||||||
static inline __device__ __host__ unsigned
|
static inline __device__ __host__ unsigned __pack_half2(const half x,
|
||||||
__pack_half2(const half x, const half y) {
|
const half y) {
|
||||||
unsigned v0 = *((unsigned short*)&x);
|
unsigned v0 = *((unsigned short*)&x);
|
||||||
unsigned v1 = *((unsigned short*)&y);
|
unsigned v1 = *((unsigned short*)&y);
|
||||||
return (v1 << 16) | v0;
|
return (v1 << 16) | v0;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int N>
|
template <int N>
|
||||||
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32(
|
__global__ void __launch_bounds__(64)
|
||||||
int G,
|
gemm_forward_4bit_cuda_m16nXk32(int G, int split_k_iters,
|
||||||
int split_k_iters,
|
half* __restrict__ A, int* __restrict__ B,
|
||||||
half* __restrict__ A,
|
|
||||||
int* __restrict__ B,
|
|
||||||
half* __restrict__ scaling_factors,
|
half* __restrict__ scaling_factors,
|
||||||
int* __restrict__ zeros,
|
int* __restrict__ zeros, int M, int IC,
|
||||||
int M,
|
int OC, half* __restrict__ C) {
|
||||||
int IC,
|
|
||||||
int OC,
|
|
||||||
half* __restrict__ C)
|
|
||||||
{
|
|
||||||
// Only support matrix n = 64 or 128
|
// Only support matrix n = 64 or 128
|
||||||
assert(N == 64 || N == 128);
|
assert(N == 64 || N == 128);
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
|
||||||
@ -70,43 +62,46 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32(
|
|||||||
static constexpr int row_stride = 2 * 32 * 8 / N;
|
static constexpr int row_stride = 2 * 32 * 8 / N;
|
||||||
bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < N;
|
bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < N;
|
||||||
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
|
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
|
||||||
bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id
|
bool ld_A_flag =
|
||||||
|
(blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp +
|
||||||
|
threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id
|
||||||
// bool wb_C_flag = (threadIdx.x / 4) < M;
|
// bool wb_C_flag = (threadIdx.x / 4) < M;
|
||||||
|
|
||||||
half* A_ptr = A
|
half* A_ptr =
|
||||||
+ (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC
|
A +
|
||||||
+ (((int)threadIdx.x) % (32 / 8)) * 8;
|
(((int)blockIdx_y) / j_factors1 * 16 +
|
||||||
|
(((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) *
|
||||||
|
IC +
|
||||||
|
(((int)threadIdx.x) % (32 / 8)) * 8;
|
||||||
|
|
||||||
int* B_ptr = B
|
int* B_ptr = B + ((int)threadIdx.y) * (OC / 8) * (256 / N) +
|
||||||
+ ((int)threadIdx.y) * (OC / 8) * (256 / N)
|
(((int)threadIdx.x) / (N / 8)) * (OC / 8) +
|
||||||
+ (((int)threadIdx.x) / (N / 8)) * (OC / 8)
|
(((int)blockIdx_y) % j_factors1) * (N / 8) +
|
||||||
+ (((int)blockIdx_y) % j_factors1) * (N / 8)
|
(((int)threadIdx.x) % (N / 8)) * 1;
|
||||||
+ (((int)threadIdx.x) % (N / 8)) * 1;
|
|
||||||
// Why * 1 in the above line?
|
// Why * 1 in the above line?
|
||||||
|
|
||||||
half* A_shared_ptr = A_shared
|
half* A_shared_ptr = A_shared +
|
||||||
+ ((int)threadIdx.y) * row_stride_warp * (32 + 8)
|
((int)threadIdx.y) * row_stride_warp * (32 + 8) +
|
||||||
+ (((int)threadIdx.x) / (32 / 8)) * (32 + 8)
|
(((int)threadIdx.x) / (32 / 8)) * (32 + 8) +
|
||||||
+ (((int)threadIdx.x) % (32 / 8) ) * 8;
|
(((int)threadIdx.x) % (32 / 8)) * 8;
|
||||||
|
|
||||||
half* B_shared_ptr = B_shared
|
half* B_shared_ptr = B_shared +
|
||||||
+ ((int)threadIdx.y) * (row_stride / 2) * (N + 8)
|
((int)threadIdx.y) * (row_stride / 2) * (N + 8) +
|
||||||
+ (((int)threadIdx.x) / (N / 8)) * (N + 8)
|
(((int)threadIdx.x) / (N / 8)) * (N + 8) +
|
||||||
+ (((int)threadIdx.x) % (N / 8)) * 8;
|
(((int)threadIdx.x) % (N / 8)) * 8;
|
||||||
|
|
||||||
int* zeros_ptr = zeros
|
int* zeros_ptr = zeros + (((int)blockIdx_y) % j_factors1) * (N / 8) +
|
||||||
+ (((int)blockIdx_y) % j_factors1) * (N / 8)
|
((int)threadIdx.x) % (N / 8);
|
||||||
+ ((int)threadIdx.x) % (N / 8);
|
|
||||||
|
|
||||||
half* scaling_factors_ptr = scaling_factors
|
half* scaling_factors_ptr = scaling_factors +
|
||||||
+ (((int)blockIdx_y) % j_factors1) * N
|
(((int)blockIdx_y) % j_factors1) * N +
|
||||||
+ (((int)threadIdx.x) % (N / 8)) * 8;
|
(((int)threadIdx.x) % (N / 8)) * 8;
|
||||||
|
|
||||||
half* C_ptr = C
|
half* C_ptr =
|
||||||
+ static_cast<long long>(blockIdx_z) * M * OC // blockIdz.x -> split_k dim
|
C +
|
||||||
+ (((int)blockIdx_y) % j_factors1) * N
|
static_cast<long long>(blockIdx_z) * M * OC // blockIdz.x -> split_k dim
|
||||||
+ ((int)threadIdx.y) * (N / 2)
|
+ (((int)blockIdx_y) % j_factors1) * N + ((int)threadIdx.y) * (N / 2) +
|
||||||
+ (((int)threadIdx.x) % 4) * 2;
|
(((int)threadIdx.x) % 4) * 2;
|
||||||
|
|
||||||
// preload s.f. and zeros
|
// preload s.f. and zeros
|
||||||
int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
|
int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
|
||||||
@ -115,57 +110,83 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32(
|
|||||||
int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
|
int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
|
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
|
||||||
if (ld_A_flag)
|
if (ld_A_flag) {
|
||||||
{
|
|
||||||
*(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
|
*(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
|
||||||
}
|
} else {
|
||||||
else
|
|
||||||
{
|
|
||||||
*(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
|
*(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
// for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
|
// for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
|
||||||
uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
|
uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
|
||||||
uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
|
uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
|
||||||
uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
|
uint4 B_loaded_scale =
|
||||||
|
*(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
|
||||||
/*
|
/*
|
||||||
if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
|
if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 &&
|
||||||
printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
|
threadIdx.y == 0){ printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x,
|
||||||
|
B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x,
|
||||||
|
B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
|
||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
// uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
|
// uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
|
||||||
int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
|
int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
|
||||||
|
|
||||||
for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < N / 16; ++ax0_ax1_fused_0) {
|
for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < N / 16; ++ax0_ax1_fused_0) {
|
||||||
|
|
||||||
// B: 32 x 136 (128+8) float16
|
// B: 32 x 136 (128+8) float16
|
||||||
// each warp: 32 x 4
|
// each warp: 32 x 4
|
||||||
// each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
|
// each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus
|
||||||
// *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
|
// zero -> WB UINT4
|
||||||
// row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
|
// *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) *
|
||||||
uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
|
// 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15)
|
||||||
|
// * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 *
|
||||||
|
// 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) *
|
||||||
|
// 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) *
|
||||||
|
// 8))); row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
|
||||||
|
uint32_t B_loaded =
|
||||||
|
*(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
|
||||||
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
|
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
|
||||||
//uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8);
|
// uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N /
|
||||||
|
// 8)) * 8);
|
||||||
|
|
||||||
// uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8);
|
// uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x
|
||||||
|
// % (cta_N / 8)) * 8);
|
||||||
// - zero and * scale
|
// - zero and * scale
|
||||||
// TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale.
|
// TODO (Haotian): can save 4 assembly instructions if sormulate as deq =
|
||||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
|
// q * scale - zero * scale.
|
||||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
|
asm volatile("sub.f16x2 %0, %1, %2;\n"
|
||||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
|
: "=r"(B_loaded_fp16.x)
|
||||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
|
: "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
|
||||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
|
||||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
|
: "=r"(B_loaded_fp16.x)
|
||||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
|
: "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
|
||||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
|
asm volatile("sub.f16x2 %0, %1, %2;\n"
|
||||||
|
: "=r"(B_loaded_fp16.y)
|
||||||
|
: "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
|
||||||
|
: "=r"(B_loaded_fp16.y)
|
||||||
|
: "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n"
|
||||||
|
: "=r"(B_loaded_fp16.z)
|
||||||
|
: "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
|
||||||
|
: "=r"(B_loaded_fp16.z)
|
||||||
|
: "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n"
|
||||||
|
: "=r"(B_loaded_fp16.w)
|
||||||
|
: "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
|
||||||
|
: "=r"(B_loaded_fp16.w)
|
||||||
|
: "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
|
||||||
/*
|
/*
|
||||||
if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){
|
if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 ==
|
||||||
printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
|
0 && threadIdx.x == 17 && threadIdx.y == 0){ printf("[x] %X %X %X %X\n",
|
||||||
|
B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
|
||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
|
|
||||||
// write back
|
// write back
|
||||||
*(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) = B_loaded_fp16;
|
*(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) =
|
||||||
|
B_loaded_fp16;
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
@ -173,34 +194,43 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32(
|
|||||||
{
|
{
|
||||||
unsigned int addr;
|
unsigned int addr;
|
||||||
__asm__ __volatile__(
|
__asm__ __volatile__(
|
||||||
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
|
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, "
|
||||||
|
"addr; }\n"
|
||||||
: "=r"(addr)
|
: "=r"(addr)
|
||||||
: "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))
|
: "l"((void*)((&(A_shared[(k_0_1 * 16)])) +
|
||||||
);
|
(((((int)threadIdx.x) & 15) * 40) +
|
||||||
|
((((int)threadIdx.x) >> 4) * 8)))));
|
||||||
|
|
||||||
__asm__ __volatile__(
|
__asm__ __volatile__(
|
||||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
|
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
|
||||||
"{%0, %1, %2, %3}, [%4];\n"
|
"{%0, %1, %2, %3}, [%4];\n"
|
||||||
: "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3])
|
: "=r"(((unsigned*)(A_shared_warp + 0))[0]),
|
||||||
: "r"(addr)
|
"=r"(((unsigned*)(A_shared_warp + 0))[1]),
|
||||||
);
|
"=r"(((unsigned*)(A_shared_warp + 0))[2]),
|
||||||
|
"=r"(((unsigned*)(A_shared_warp + 0))[3])
|
||||||
|
: "r"(addr));
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int ax1_0 = 0; ax1_0 < N / 32; ++ax1_0) {
|
for (int ax1_0 = 0; ax1_0 < N / 32; ++ax1_0) {
|
||||||
{
|
{
|
||||||
unsigned int addr;
|
unsigned int addr;
|
||||||
__asm__ __volatile__(
|
__asm__ __volatile__(
|
||||||
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
|
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, "
|
||||||
|
"addr; }\n"
|
||||||
: "=r"(addr)
|
: "=r"(addr)
|
||||||
: "l"((void *)((&(B_shared[(((k_0_1 * (N * 16 + 128)) + (((int)threadIdx.y) * (N / 2))) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * (N + 8)) + ((((int)threadIdx.x) >> 4) * 8))))
|
: "l"((void*)((&(B_shared[(((k_0_1 * (N * 16 + 128)) +
|
||||||
);
|
(((int)threadIdx.y) * (N / 2))) +
|
||||||
|
(ax1_0 * 16))])) +
|
||||||
|
(((((int)threadIdx.x) & 15) * (N + 8)) +
|
||||||
|
((((int)threadIdx.x) >> 4) * 8)))));
|
||||||
__asm__ __volatile__(
|
__asm__ __volatile__(
|
||||||
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
|
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
|
||||||
"{%0, %1, %2, %3}, [%4];\n"
|
"{%0, %1, %2, %3}, [%4];\n"
|
||||||
: "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3])
|
: "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[0]),
|
||||||
: "r"(addr)
|
"=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[1]),
|
||||||
);
|
"=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[2]),
|
||||||
|
"=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[3])
|
||||||
|
: "r"(addr));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (int j_0_4 = 0; j_0_4 < N / 32; ++j_0_4) {
|
for (int j_0_4 = 0; j_0_4 < N / 32; ++j_0_4) {
|
||||||
@ -209,48 +239,110 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32(
|
|||||||
__asm__ __volatile__(
|
__asm__ __volatile__(
|
||||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
||||||
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
||||||
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
|
: "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
|
||||||
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
|
"=f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
|
||||||
|
"=f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
|
||||||
|
"=f"(((float*)(C_warp + (j_0_4 * 8)))[3])
|
||||||
|
: "r"(((unsigned*)(A_shared_warp + 0))[0]),
|
||||||
|
"r"(((unsigned*)(A_shared_warp + 0))[1]),
|
||||||
|
"r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[0]),
|
||||||
|
"f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
|
||||||
|
"f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
|
||||||
|
"f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
|
||||||
|
"f"(((float*)(C_warp + (j_0_4 * 8)))[3]));
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
__asm__ __volatile__(
|
__asm__ __volatile__(
|
||||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
||||||
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
||||||
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
: "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
|
||||||
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
|
||||||
|
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
|
||||||
|
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
||||||
|
: "r"(((unsigned*)(A_shared_warp + 0))[0]),
|
||||||
|
"r"(((unsigned*)(A_shared_warp + 0))[1]),
|
||||||
|
"r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]),
|
||||||
|
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
|
||||||
|
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
|
||||||
|
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
|
||||||
|
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
__asm__ __volatile__(
|
__asm__ __volatile__(
|
||||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
||||||
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
||||||
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
|
: "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
|
||||||
: "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
|
"=f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
|
||||||
|
"=f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
|
||||||
|
"=f"(((float*)(C_warp + (j_0_4 * 8)))[3])
|
||||||
|
: "r"(((unsigned*)(A_shared_warp + 0))[2]),
|
||||||
|
"r"(((unsigned*)(A_shared_warp + 0))[3]),
|
||||||
|
"r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[1]),
|
||||||
|
"f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
|
||||||
|
"f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
|
||||||
|
"f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
|
||||||
|
"f"(((float*)(C_warp + (j_0_4 * 8)))[3]));
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
__asm__ __volatile__(
|
__asm__ __volatile__(
|
||||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
||||||
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
||||||
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
: "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
|
||||||
: "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
|
||||||
|
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
|
||||||
|
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
||||||
|
: "r"(((unsigned*)(A_shared_warp + 0))[2]),
|
||||||
|
"r"(((unsigned*)(A_shared_warp + 0))[3]),
|
||||||
|
"r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]),
|
||||||
|
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
|
||||||
|
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
|
||||||
|
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
|
||||||
|
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
{
|
{
|
||||||
__asm__ __volatile__(
|
__asm__ __volatile__(
|
||||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
|
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
|
||||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
|
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, "
|
||||||
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
|
"%13};\n"
|
||||||
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
|
: "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
|
||||||
|
"=f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
|
||||||
|
"=f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
|
||||||
|
"=f"(((float*)(C_warp + (j_0_4 * 8)))[3])
|
||||||
|
: "r"(((unsigned*)(A_shared_warp + 0))[0]),
|
||||||
|
"r"(((unsigned*)(A_shared_warp + 0))[1]),
|
||||||
|
"r"(((unsigned*)(A_shared_warp + 0))[2]),
|
||||||
|
"r"(((unsigned*)(A_shared_warp + 0))[3]),
|
||||||
|
"r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[0]),
|
||||||
|
"r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[1]),
|
||||||
|
"f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
|
||||||
|
"f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
|
||||||
|
"f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
|
||||||
|
"f"(((float*)(C_warp + (j_0_4 * 8)))[3]));
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
__asm__ __volatile__(
|
__asm__ __volatile__(
|
||||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
|
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
|
||||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
|
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, "
|
||||||
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
"%13};\n"
|
||||||
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
: "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
|
||||||
|
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
|
||||||
|
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
|
||||||
|
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
||||||
|
: "r"(((unsigned*)(A_shared_warp + 0))[0]),
|
||||||
|
"r"(((unsigned*)(A_shared_warp + 0))[1]),
|
||||||
|
"r"(((unsigned*)(A_shared_warp + 0))[2]),
|
||||||
|
"r"(((unsigned*)(A_shared_warp + 0))[3]),
|
||||||
|
"r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]),
|
||||||
|
"r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]),
|
||||||
|
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
|
||||||
|
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
|
||||||
|
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
|
||||||
|
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
@ -261,24 +353,20 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32(
|
|||||||
// TODO: Shang: Hoist loop invariance.
|
// TODO: Shang: Hoist loop invariance.
|
||||||
for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) {
|
for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) {
|
||||||
for (int local_id = 0; local_id < 8; ++local_id) {
|
for (int local_id = 0; local_id < 8; ++local_id) {
|
||||||
int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
|
int row_offset = (((int)blockIdx_y) / j_factors1) * 16 +
|
||||||
if (row_offset < M)
|
((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
|
||||||
{
|
if (row_offset < M) {
|
||||||
*(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
|
*(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 +
|
||||||
|
local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
__global__ void __launch_bounds__(64) dequantize_weights(
|
__global__ void __launch_bounds__(64)
|
||||||
int* __restrict__ B,
|
dequantize_weights(int* __restrict__ B, half* __restrict__ scaling_factors,
|
||||||
half* __restrict__ scaling_factors,
|
int* __restrict__ zeros, half* __restrict__ C, int G) {
|
||||||
int* __restrict__ zeros,
|
|
||||||
half* __restrict__ C,
|
|
||||||
int G
|
|
||||||
)
|
|
||||||
{
|
|
||||||
int j_factors1 = 4;
|
int j_factors1 = 4;
|
||||||
int row_stride2 = 4;
|
int row_stride2 = 4;
|
||||||
int split_k_iters = 1;
|
int split_k_iters = 1;
|
||||||
@ -310,14 +398,30 @@ __global__ void __launch_bounds__(64) dequantize_weights(
|
|||||||
|
|
||||||
uint32_t B_loaded = *(uint32_t*)B_ptr2;
|
uint32_t B_loaded = *(uint32_t*)B_ptr2;
|
||||||
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
|
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
|
||||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
|
asm volatile("sub.f16x2 %0, %1, %2;\n"
|
||||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
|
: "=r"(B_loaded_fp16.x)
|
||||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
|
: "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
|
||||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
|
||||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
|
: "=r"(B_loaded_fp16.x)
|
||||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
|
: "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
|
||||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
|
asm volatile("sub.f16x2 %0, %1, %2;\n"
|
||||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
|
: "=r"(B_loaded_fp16.y)
|
||||||
|
: "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
|
||||||
|
: "=r"(B_loaded_fp16.y)
|
||||||
|
: "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n"
|
||||||
|
: "=r"(B_loaded_fp16.z)
|
||||||
|
: "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
|
||||||
|
: "=r"(B_loaded_fp16.z)
|
||||||
|
: "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n"
|
||||||
|
: "=r"(B_loaded_fp16.w)
|
||||||
|
: "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
|
||||||
|
: "=r"(B_loaded_fp16.w)
|
||||||
|
: "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
|
||||||
|
|
||||||
*(uint4*)B_shared_ptr2 = B_loaded_fp16;
|
*(uint4*)B_shared_ptr2 = B_loaded_fp16;
|
||||||
|
|
||||||
@ -329,14 +433,10 @@ __global__ void __launch_bounds__(64) dequantize_weights(
|
|||||||
} // namespace awq
|
} // namespace awq
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
torch::Tensor awq_dequantize(
|
torch::Tensor awq_dequantize(torch::Tensor _kernel,
|
||||||
torch::Tensor _kernel,
|
|
||||||
torch::Tensor _scaling_factors,
|
torch::Tensor _scaling_factors,
|
||||||
torch::Tensor _zeros,
|
torch::Tensor _zeros, int split_k_iters, int thx,
|
||||||
int split_k_iters,
|
int thy) {
|
||||||
int thx,
|
|
||||||
int thy)
|
|
||||||
{
|
|
||||||
int in_c = _kernel.size(0);
|
int in_c = _kernel.size(0);
|
||||||
int qout_c = _kernel.size(1);
|
int qout_c = _kernel.size(1);
|
||||||
int out_c = qout_c * 8;
|
int out_c = qout_c * 8;
|
||||||
@ -362,12 +462,15 @@ torch::Tensor awq_dequantize(
|
|||||||
|
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors));
|
||||||
|
|
||||||
auto options = torch::TensorOptions().dtype(_scaling_factors.dtype()).device(_scaling_factors.device());
|
auto options = torch::TensorOptions()
|
||||||
|
.dtype(_scaling_factors.dtype())
|
||||||
|
.device(_scaling_factors.device());
|
||||||
at::Tensor _de_kernel = torch::empty({in_c, out_c}, options);
|
at::Tensor _de_kernel = torch::empty({in_c, out_c}, options);
|
||||||
|
|
||||||
auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
|
auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
|
||||||
auto de_kernel = reinterpret_cast<half*>(_de_kernel.data_ptr<at::Half>());
|
auto de_kernel = reinterpret_cast<half*>(_de_kernel.data_ptr<at::Half>());
|
||||||
auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
|
auto scaling_factors =
|
||||||
|
reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
|
||||||
auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
|
auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
|
||||||
|
|
||||||
dim3 num_blocks(x_blocks, y_blocks);
|
dim3 num_blocks(x_blocks, y_blocks);
|
||||||
@ -386,26 +489,26 @@ torch::Tensor awq_dequantize(
|
|||||||
// zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b]
|
// zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b]
|
||||||
// assume that batch_size < 16 for now
|
// assume that batch_size < 16 for now
|
||||||
|
|
||||||
torch::Tensor awq_gemm(
|
torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
|
||||||
torch::Tensor _in_feats,
|
torch::Tensor _scaling_factors, torch::Tensor _zeros,
|
||||||
torch::Tensor _kernel,
|
int split_k_iters) {
|
||||||
torch::Tensor _scaling_factors,
|
|
||||||
torch::Tensor _zeros,
|
|
||||||
int split_k_iters)
|
|
||||||
{
|
|
||||||
int num_in_feats = _in_feats.size(0);
|
int num_in_feats = _in_feats.size(0);
|
||||||
int num_in_channels = _in_feats.size(1);
|
int num_in_channels = _in_feats.size(1);
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats));
|
||||||
|
|
||||||
auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device());
|
auto options = torch::TensorOptions()
|
||||||
at::Tensor _out_feats = torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options);
|
.dtype(_in_feats.dtype())
|
||||||
|
.device(_in_feats.device());
|
||||||
|
at::Tensor _out_feats =
|
||||||
|
torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options);
|
||||||
int num_out_feats = _out_feats.size(-2);
|
int num_out_feats = _out_feats.size(-2);
|
||||||
int num_out_channels = _out_feats.size(-1);
|
int num_out_channels = _out_feats.size(-1);
|
||||||
|
|
||||||
auto in_feats = reinterpret_cast<half*>(_in_feats.data_ptr<at::Half>());
|
auto in_feats = reinterpret_cast<half*>(_in_feats.data_ptr<at::Half>());
|
||||||
auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
|
auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
|
||||||
auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());
|
auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());
|
||||||
auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
|
auto scaling_factors =
|
||||||
|
reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
|
||||||
auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
|
auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
|
||||||
int group_size = num_in_channels / _scaling_factors.size(0);
|
int group_size = num_in_channels / _scaling_factors.size(0);
|
||||||
|
|
||||||
@ -419,28 +522,28 @@ torch::Tensor awq_gemm(
|
|||||||
throw std::invalid_argument("OC is not multiple of Group size");
|
throw std::invalid_argument("OC is not multiple of Group size");
|
||||||
|
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
if (num_out_channels % 128 == 0)
|
if (num_out_channels % 128 == 0) {
|
||||||
{
|
|
||||||
int j_factors1 = num_out_channels / 128 / 1;
|
int j_factors1 = num_out_channels / 128 / 1;
|
||||||
dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
|
dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
|
||||||
// threadIdx.x: 32
|
// threadIdx.x: 32
|
||||||
// threadIdx.y: i_factors[2] * j_factors[2]
|
// threadIdx.y: i_factors[2] * j_factors[2]
|
||||||
dim3 threads_per_block(32, 2);
|
dim3 threads_per_block(32, 2);
|
||||||
vllm::awq::gemm_forward_4bit_cuda_m16nXk32<128><<<num_blocks, threads_per_block, 0, stream>>>(
|
vllm::awq::gemm_forward_4bit_cuda_m16nXk32<128>
|
||||||
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels,
|
<<<num_blocks, threads_per_block, 0, stream>>>(
|
||||||
num_out_channels, out_feats);
|
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros,
|
||||||
}
|
num_in_feats, num_in_channels, num_out_channels, out_feats);
|
||||||
else if (num_out_channels % 64 == 0)
|
} else if (num_out_channels % 64 == 0) {
|
||||||
{
|
|
||||||
int j_factors1 = num_out_channels / 64 / 1;
|
int j_factors1 = num_out_channels / 64 / 1;
|
||||||
dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
|
dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 *
|
||||||
|
split_k_iters);
|
||||||
|
|
||||||
// threadIdx.x: 32
|
// threadIdx.x: 32
|
||||||
// threadIdx.y: i_factors[2] * j_factors[2]
|
// threadIdx.y: i_factors[2] * j_factors[2]
|
||||||
dim3 threads_per_block(32, 2);
|
dim3 threads_per_block(32, 2);
|
||||||
vllm::awq::gemm_forward_4bit_cuda_m16nXk32<64><<<num_blocks, threads_per_block, 0, stream>>>(
|
vllm::awq::gemm_forward_4bit_cuda_m16nXk32<64>
|
||||||
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels,
|
<<<num_blocks, threads_per_block, 0, stream>>>(
|
||||||
num_out_channels, out_feats);
|
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros,
|
||||||
|
num_in_feats, num_in_channels, num_out_channels, out_feats);
|
||||||
}
|
}
|
||||||
return _out_feats.sum(0);
|
return _out_feats.sum(0);
|
||||||
}
|
}
|
||||||
|
@ -43,7 +43,8 @@ void cutlass_scaled_mm_dq(torch::Tensor &c, torch::Tensor const &a,
|
|||||||
// Check for strides and alignment
|
// Check for strides and alignment
|
||||||
TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
|
TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
|
||||||
TORCH_CHECK(b.stride(0) == 1); // Column-major
|
TORCH_CHECK(b.stride(0) == 1); // Column-major
|
||||||
TORCH_CHECK(c.stride(0) % 16 == 0 && b.stride(1) % 16 == 0); // 16 Byte Alignment
|
TORCH_CHECK(c.stride(0) % 16 == 0 &&
|
||||||
|
b.stride(1) % 16 == 0); // 16 Byte Alignment
|
||||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||||
|
|
||||||
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
|
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
|
||||||
|
@ -11,33 +11,26 @@
|
|||||||
|
|
||||||
#include "hip_float8_impl.h"
|
#include "hip_float8_impl.h"
|
||||||
|
|
||||||
struct alignas(1) hip_fp8
|
struct alignas(1) hip_fp8 {
|
||||||
{
|
struct from_bits_t {};
|
||||||
struct from_bits_t
|
HIP_FP8_HOST_DEVICE static constexpr from_bits_t from_bits() {
|
||||||
{
|
return from_bits_t();
|
||||||
};
|
}
|
||||||
HIP_FP8_HOST_DEVICE static constexpr from_bits_t from_bits() { return from_bits_t(); }
|
|
||||||
uint8_t data;
|
uint8_t data;
|
||||||
|
|
||||||
hip_fp8() = default;
|
hip_fp8() = default;
|
||||||
HIP_FP8_HOST_DEVICE constexpr hip_fp8(const hip_fp8&) = default;
|
HIP_FP8_HOST_DEVICE constexpr hip_fp8(const hip_fp8&) = default;
|
||||||
HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v) = delete;
|
HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v) = delete;
|
||||||
explicit HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v, from_bits_t)
|
explicit HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v, from_bits_t)
|
||||||
: data(v)
|
: data(v) {}
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
#ifdef __HIP__MI300__
|
#ifdef __HIP__MI300__
|
||||||
// NOTE: ON-DEVICE... always optimal bias
|
// NOTE: ON-DEVICE... always optimal bias
|
||||||
explicit HIP_FP8_DEVICE hip_fp8(float v)
|
explicit HIP_FP8_DEVICE hip_fp8(float v)
|
||||||
: data(hip_fp8_impl::to_fp8_from_fp32(v))
|
: data(hip_fp8_impl::to_fp8_from_fp32(v)) {}
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
explicit HIP_FP8_DEVICE hip_fp8(_Float16 v)
|
explicit HIP_FP8_DEVICE hip_fp8(_Float16 v)
|
||||||
: hip_fp8(static_cast<float>(v))
|
: hip_fp8(static_cast<float>(v)) {}
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
// Host only implementation using s/w simulation
|
// Host only implementation using s/w simulation
|
||||||
explicit HIP_FP8_HOST
|
explicit HIP_FP8_HOST
|
||||||
@ -45,25 +38,24 @@ struct alignas(1) hip_fp8
|
|||||||
// both Host and DEVICE for non-MI300 using s/w simulation
|
// both Host and DEVICE for non-MI300 using s/w simulation
|
||||||
explicit HIP_FP8_HOST_DEVICE
|
explicit HIP_FP8_HOST_DEVICE
|
||||||
#endif // __HIP__MI300__
|
#endif // __HIP__MI300__
|
||||||
hip_fp8(float v)
|
hip_fp8(float v) {
|
||||||
{
|
data = hip_fp8_impl::to_float8<4, 3, float, true /*negative_zero_nan*/,
|
||||||
data = hip_fp8_impl::to_float8<4, 3, float, true /*negative_zero_nan*/, true /*clip*/>(v);
|
true /*clip*/>(v);
|
||||||
}
|
}
|
||||||
|
|
||||||
explicit HIP_FP8_HOST_DEVICE hip_fp8(double v)
|
explicit HIP_FP8_HOST_DEVICE hip_fp8(double v)
|
||||||
: hip_fp8(static_cast<float>(v))
|
: hip_fp8(static_cast<float>(v)) {}
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
#ifdef __HIP__MI300__
|
#ifdef __HIP__MI300__
|
||||||
// upcast using device specific intrinsic
|
// upcast using device specific intrinsic
|
||||||
explicit inline HIP_FP8_DEVICE operator float() const
|
explicit inline HIP_FP8_DEVICE operator float() const {
|
||||||
{
|
|
||||||
float fval;
|
float fval;
|
||||||
uint32_t i32val = static_cast<uint32_t>(data);
|
uint32_t i32val = static_cast<uint32_t>(data);
|
||||||
|
|
||||||
// upcast
|
// upcast
|
||||||
asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
|
asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0"
|
||||||
|
: "=v"(fval)
|
||||||
|
: "v"(i32val));
|
||||||
|
|
||||||
return fval;
|
return fval;
|
||||||
}
|
}
|
||||||
@ -73,95 +65,73 @@ struct alignas(1) hip_fp8
|
|||||||
explicit inline HIP_FP8_HOST_DEVICE operator float() const
|
explicit inline HIP_FP8_HOST_DEVICE operator float() const
|
||||||
#endif // __HIP__MI300__
|
#endif // __HIP__MI300__
|
||||||
{
|
{
|
||||||
return hip_fp8_impl::from_float8<4, 3, float, true /*negative_zero_nan*/>(data);
|
return hip_fp8_impl::from_float8<4, 3, float, true /*negative_zero_nan*/>(
|
||||||
|
data);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
namespace std
|
namespace std {
|
||||||
{
|
inline hip_fp8 sin(hip_fp8 a) { return hip_fp8(sinf(float(a))); }
|
||||||
inline hip_fp8 sin(hip_fp8 a)
|
inline hip_fp8 cos(hip_fp8 a) { return hip_fp8(cosf(float(a))); }
|
||||||
{
|
HIP_FP8_HOST_DEVICE constexpr hip_fp8 real(const hip_fp8& a) { return a; }
|
||||||
return hip_fp8(sinf(float(a)));
|
|
||||||
}
|
|
||||||
inline hip_fp8 cos(hip_fp8 a)
|
|
||||||
{
|
|
||||||
return hip_fp8(cosf(float(a)));
|
|
||||||
}
|
|
||||||
HIP_FP8_HOST_DEVICE constexpr hip_fp8 real(const hip_fp8& a)
|
|
||||||
{
|
|
||||||
return a;
|
|
||||||
}
|
|
||||||
} // namespace std
|
} // namespace std
|
||||||
|
|
||||||
// Special operator overloading
|
// Special operator overloading
|
||||||
inline std::ostream& operator<<(std::ostream& os, const hip_fp8& f8)
|
inline std::ostream& operator<<(std::ostream& os, const hip_fp8& f8) {
|
||||||
{
|
|
||||||
return os << float(f8);
|
return os << float(f8);
|
||||||
}
|
}
|
||||||
|
|
||||||
// all + operator overloading with mixed types
|
// all + operator overloading with mixed types
|
||||||
// mixed types, always converts to f32, does computation in f32, and returns float
|
// mixed types, always converts to f32, does computation in f32, and returns
|
||||||
inline HIP_FP8_HOST_DEVICE float operator+(const float fa, hip_fp8 b)
|
// float
|
||||||
{
|
inline HIP_FP8_HOST_DEVICE float operator+(const float fa, hip_fp8 b) {
|
||||||
return (fa + float(b));
|
return (fa + float(b));
|
||||||
}
|
}
|
||||||
|
|
||||||
inline HIP_FP8_HOST_DEVICE float operator+(hip_fp8 a, const float fb)
|
inline HIP_FP8_HOST_DEVICE float operator+(hip_fp8 a, const float fb) {
|
||||||
{
|
|
||||||
return (float(a) + fb);
|
return (float(a) + fb);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline HIP_FP8_HOST_DEVICE hip_fp8 operator+(hip_fp8 a, hip_fp8 b)
|
inline HIP_FP8_HOST_DEVICE hip_fp8 operator+(hip_fp8 a, hip_fp8 b) {
|
||||||
{
|
|
||||||
return hip_fp8(float(a) + float(b));
|
return hip_fp8(float(a) + float(b));
|
||||||
}
|
}
|
||||||
|
|
||||||
inline HIP_FP8_HOST_DEVICE hip_fp8& operator+=(hip_fp8& a, hip_fp8 b)
|
inline HIP_FP8_HOST_DEVICE hip_fp8& operator+=(hip_fp8& a, hip_fp8 b) {
|
||||||
{
|
|
||||||
return a = hip_fp8(float(a) + float(b));
|
return a = hip_fp8(float(a) + float(b));
|
||||||
}
|
}
|
||||||
|
|
||||||
// overloading multiplication, always returns float,
|
// overloading multiplication, always returns float,
|
||||||
inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, hip_fp8 b)
|
inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, hip_fp8 b) {
|
||||||
{
|
|
||||||
return float(a) * float(b);
|
return float(a) * float(b);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline HIP_FP8_HOST_DEVICE float operator*(float a, hip_fp8 b)
|
inline HIP_FP8_HOST_DEVICE float operator*(float a, hip_fp8 b) {
|
||||||
{
|
|
||||||
return (a * float(b));
|
return (a * float(b));
|
||||||
}
|
}
|
||||||
|
|
||||||
inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, float b)
|
inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, float b) {
|
||||||
{
|
|
||||||
return (float(a) * b);
|
return (float(a) * b);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline HIP_FP8_HOST_DEVICE float operator*(int32_t a, hip_fp8 b)
|
inline HIP_FP8_HOST_DEVICE float operator*(int32_t a, hip_fp8 b) {
|
||||||
{
|
|
||||||
return ((float)a * float(b));
|
return ((float)a * float(b));
|
||||||
}
|
}
|
||||||
|
|
||||||
inline HIP_FP8_HOST_DEVICE float operator*(double a, hip_fp8 b)
|
inline HIP_FP8_HOST_DEVICE float operator*(double a, hip_fp8 b) {
|
||||||
{
|
|
||||||
return ((float)a * float(b));
|
return ((float)a * float(b));
|
||||||
}
|
}
|
||||||
|
|
||||||
// overloading for compare
|
// overloading for compare
|
||||||
inline HIP_FP8_HOST_DEVICE bool operator==(hip_fp8 a, hip_fp8 b)
|
inline HIP_FP8_HOST_DEVICE bool operator==(hip_fp8 a, hip_fp8 b) {
|
||||||
{
|
|
||||||
return (a.data == b.data);
|
return (a.data == b.data);
|
||||||
}
|
}
|
||||||
inline HIP_FP8_HOST_DEVICE bool operator!=(hip_fp8 a, hip_fp8 b)
|
inline HIP_FP8_HOST_DEVICE bool operator!=(hip_fp8 a, hip_fp8 b) {
|
||||||
{
|
|
||||||
return (a.data != b.data);
|
return (a.data != b.data);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline HIP_FP8_HOST_DEVICE bool operator>=(hip_fp8 a, hip_fp8 b)
|
inline HIP_FP8_HOST_DEVICE bool operator>=(hip_fp8 a, hip_fp8 b) {
|
||||||
{
|
|
||||||
return static_cast<float>(a) >= static_cast<float>(b);
|
return static_cast<float>(a) >= static_cast<float>(b);
|
||||||
}
|
}
|
||||||
inline HIP_FP8_HOST_DEVICE bool operator>(hip_fp8 a, hip_fp8 b)
|
inline HIP_FP8_HOST_DEVICE bool operator>(hip_fp8 a, hip_fp8 b) {
|
||||||
{
|
|
||||||
return static_cast<float>(a) > static_cast<float>(b);
|
return static_cast<float>(a) > static_cast<float>(b);
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#if defined(__HIPCC__) && (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
#if defined(__HIPCC__) && \
|
||||||
|
(defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||||
#define __HIP__MI300__
|
#define __HIP__MI300__
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@ -14,12 +15,10 @@
|
|||||||
#define HIP_FP8_DEVICE
|
#define HIP_FP8_DEVICE
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace hip_fp8_impl
|
namespace hip_fp8_impl {
|
||||||
{
|
|
||||||
|
|
||||||
#ifdef __HIP__MI300__
|
#ifdef __HIP__MI300__
|
||||||
HIP_FP8_DEVICE uint8_t to_fp8_from_fp32(float v)
|
HIP_FP8_DEVICE uint8_t to_fp8_from_fp32(float v) {
|
||||||
{
|
|
||||||
uint8_t i8data;
|
uint8_t i8data;
|
||||||
union {
|
union {
|
||||||
float fval;
|
float fval;
|
||||||
@ -30,7 +29,8 @@ HIP_FP8_DEVICE uint8_t to_fp8_from_fp32(float v)
|
|||||||
uint32_t ival = 0;
|
uint32_t ival = 0;
|
||||||
val.fval = v;
|
val.fval = v;
|
||||||
|
|
||||||
if ((val.i32val & 0x7F800000) != 0x7F800000) { /// propagate NAN/INF, no clipping
|
if ((val.i32val & 0x7F800000) !=
|
||||||
|
0x7F800000) { /// propagate NAN/INF, no clipping
|
||||||
val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0);
|
val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -43,20 +43,14 @@ HIP_FP8_DEVICE uint8_t to_fp8_from_fp32(float v)
|
|||||||
}
|
}
|
||||||
#endif // __HIP__MI300__
|
#endif // __HIP__MI300__
|
||||||
|
|
||||||
HIP_FP8_HOST inline int clz(uint32_t x)
|
HIP_FP8_HOST inline int clz(uint32_t x) { return __builtin_clz(x); }
|
||||||
{
|
|
||||||
return __builtin_clz(x);
|
|
||||||
}
|
|
||||||
#if defined(__HIPCC__) || defined(__CUDA_ARCH__)
|
#if defined(__HIPCC__) || defined(__CUDA_ARCH__)
|
||||||
HIP_FP8_DEVICE inline int clz(uint32_t x)
|
HIP_FP8_DEVICE inline int clz(uint32_t x) { return __clz(x); }
|
||||||
{
|
|
||||||
return __clz(x);
|
|
||||||
}
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
template <int we, int wm, typename T, bool negative_zero_nan, bool clip>
|
template <int we, int wm, typename T, bool negative_zero_nan, bool clip>
|
||||||
HIP_FP8_HOST_DEVICE uint8_t to_float8(T _x, bool stoch = false, uint32_t rng = 0)
|
HIP_FP8_HOST_DEVICE uint8_t to_float8(T _x, bool stoch = false,
|
||||||
{
|
uint32_t rng = 0) {
|
||||||
#ifdef __HIPCC__
|
#ifdef __HIPCC__
|
||||||
constexpr bool is_half = std::is_same<T, _Float16>::value;
|
constexpr bool is_half = std::is_same<T, _Float16>::value;
|
||||||
#else
|
#else
|
||||||
@ -130,7 +124,8 @@ HIP_FP8_HOST_DEVICE uint8_t to_float8(T _x, bool stoch = false, uint32_t rng = 0
|
|||||||
// For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent
|
// For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent
|
||||||
// bits
|
// bits
|
||||||
const int f8_bias = (1 << (we - 1)) - 1 + (negative_zero_nan ? 1 : 0);
|
const int f8_bias = (1 << (we - 1)) - 1 + (negative_zero_nan ? 1 : 0);
|
||||||
const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal
|
const int f8_denormal_act_exponent =
|
||||||
|
1 - f8_bias; // actual exponent of f8 denormal
|
||||||
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
|
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
|
||||||
// f8_exponent is the converted f8 exponent with bias encoding
|
// f8_exponent is the converted f8 exponent with bias encoding
|
||||||
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
|
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
|
||||||
@ -146,7 +141,9 @@ are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
|
|||||||
where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8
|
where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8
|
||||||
(NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */
|
(NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */
|
||||||
act_exponent = exponent - bias + 1;
|
act_exponent = exponent - bias + 1;
|
||||||
exponent_diff = f8_denormal_act_exponent - act_exponent; // actual exponent is exponent-bias+1 as it is denormal
|
exponent_diff =
|
||||||
|
f8_denormal_act_exponent -
|
||||||
|
act_exponent; // actual exponent is exponent-bias+1 as it is denormal
|
||||||
} else { // fp32/fp16 is normal with implicit 1
|
} else { // fp32/fp16 is normal with implicit 1
|
||||||
act_exponent = exponent - bias;
|
act_exponent = exponent - bias;
|
||||||
if (act_exponent <= f8_denormal_act_exponent) {
|
if (act_exponent <= f8_denormal_act_exponent) {
|
||||||
@ -157,9 +154,9 @@ where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8
|
|||||||
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
|
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
|
||||||
exponent_diff = f8_denormal_act_exponent - act_exponent;
|
exponent_diff = f8_denormal_act_exponent - act_exponent;
|
||||||
} else { // both fp32/fp16 and f8 are in normal range
|
} else { // both fp32/fp16 and f8 are in normal range
|
||||||
exponent_diff = 0; // exponent_diff=0 does not mean there is no difference
|
exponent_diff = 0; // exponent_diff=0 does not mean there is no
|
||||||
// for this case,
|
// difference for this case, act_exponent could be
|
||||||
// act_exponent could be larger. Just that it does not need shift mantissa
|
// larger. Just that it does not need shift mantissa
|
||||||
}
|
}
|
||||||
mantissa += (1 << mfmt); // Add the implicit 1 into mantissa
|
mantissa += (1 << mfmt); // Add the implicit 1 into mantissa
|
||||||
}
|
}
|
||||||
@ -181,13 +178,16 @@ where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8
|
|||||||
bool implicit_one = mantissa & (1 << mfmt);
|
bool implicit_one = mantissa & (1 << mfmt);
|
||||||
// if there is no implicit 1, it means the f8 is denormal and need to adjust
|
// if there is no implicit 1, it means the f8 is denormal and need to adjust
|
||||||
// to denorm exponent
|
// to denorm exponent
|
||||||
f8_exponent = (act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1);
|
f8_exponent = (act_exponent + exponent_diff) /*actual f8 exponent*/ +
|
||||||
|
f8_bias - (implicit_one ? 0 : 1);
|
||||||
|
|
||||||
// Now we have the exponent and mantissa adjusted
|
// Now we have the exponent and mantissa adjusted
|
||||||
uint32_t drop_mask = (1 << (mfmt - wm)) - 1;
|
uint32_t drop_mask = (1 << (mfmt - wm)) - 1;
|
||||||
bool odd = mantissa & (1 << (mfmt - wm)); // if the least significant bit that
|
bool odd = mantissa & (1 << (mfmt - wm)); // if the least significant bit
|
||||||
// is not truncated is 1
|
// that is not truncated is 1
|
||||||
mantissa += (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & drop_mask;
|
mantissa +=
|
||||||
|
(stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) &
|
||||||
|
drop_mask;
|
||||||
|
|
||||||
// Now we deal with overflow
|
// Now we deal with overflow
|
||||||
if (f8_exponent == 0) {
|
if (f8_exponent == 0) {
|
||||||
@ -222,8 +222,7 @@ where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <int we, int wm, typename T = float, bool negative_zero_nan = true>
|
template <int we, int wm, typename T = float, bool negative_zero_nan = true>
|
||||||
inline HIP_FP8_HOST_DEVICE T from_float8(uint8_t x)
|
inline HIP_FP8_HOST_DEVICE T from_float8(uint8_t x) {
|
||||||
{
|
|
||||||
#ifdef __HIPCC__
|
#ifdef __HIPCC__
|
||||||
constexpr bool is_half = std::is_same<T, _Float16>::value;
|
constexpr bool is_half = std::is_same<T, _Float16>::value;
|
||||||
#else
|
#else
|
||||||
@ -285,7 +284,8 @@ inline HIP_FP8_HOST_DEVICE T from_float8(uint8_t x)
|
|||||||
return reinterpret_cast<const T&>(retval);
|
return reinterpret_cast<const T&>(retval);
|
||||||
}
|
}
|
||||||
|
|
||||||
const int exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0);
|
const int exp_low_cutoff =
|
||||||
|
(1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0);
|
||||||
|
|
||||||
// subnormal input
|
// subnormal input
|
||||||
if (exponent == 0) {
|
if (exponent == 0) {
|
||||||
|
@ -9,29 +9,27 @@
|
|||||||
#include "../../../attention/dtype_float32.cuh"
|
#include "../../../attention/dtype_float32.cuh"
|
||||||
#include "../../../attention/dtype_bfloat16.cuh"
|
#include "../../../attention/dtype_bfloat16.cuh"
|
||||||
|
|
||||||
namespace vllm
|
namespace vllm {
|
||||||
{
|
|
||||||
#ifdef USE_ROCM
|
#ifdef USE_ROCM
|
||||||
|
|
||||||
namespace fp8 {
|
namespace fp8 {
|
||||||
#ifdef ENABLE_FP8
|
#ifdef ENABLE_FP8
|
||||||
|
|
||||||
template <typename Tout, typename Tin>
|
template <typename Tout, typename Tin>
|
||||||
__inline__ __device__ Tout vec_conversion(const Tin& x)
|
__inline__ __device__ Tout vec_conversion(const Tin& x) {
|
||||||
{
|
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Tout, typename Tin>
|
template <typename Tout, typename Tin>
|
||||||
__inline__ __device__ Tout scaled_vec_conversion(const Tin& x, const float scale)
|
__inline__ __device__ Tout scaled_vec_conversion(const Tin& x,
|
||||||
{
|
const float scale) {
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
// fp8 -> half
|
// fp8 -> half
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>(const uint8_t& a)
|
__inline__ __device__ uint16_t
|
||||||
{
|
vec_conversion<uint16_t, uint8_t>(const uint8_t& a) {
|
||||||
hip_fp8 f8{a, hip_fp8::from_bits()};
|
hip_fp8 f8{a, hip_fp8::from_bits()};
|
||||||
__half_raw res;
|
__half_raw res;
|
||||||
res.data = static_cast<float>(f8);
|
res.data = static_cast<float>(f8);
|
||||||
@ -40,9 +38,10 @@ __inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>(const uint8_t&
|
|||||||
|
|
||||||
// fp8x2 -> half2
|
// fp8x2 -> half2
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ uint32_t vec_conversion<uint32_t, uint16_t>(const uint16_t& a)
|
__inline__ __device__ uint32_t
|
||||||
{
|
vec_conversion<uint32_t, uint16_t>(const uint16_t& a) {
|
||||||
#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
#if defined(__HIP__MI300__) && \
|
||||||
|
defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
||||||
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
||||||
union {
|
union {
|
||||||
__half2_raw h2r;
|
__half2_raw h2r;
|
||||||
@ -65,8 +64,7 @@ __inline__ __device__ uint32_t vec_conversion<uint32_t, uint16_t>(const uint16_t
|
|||||||
|
|
||||||
// fp8x4 -> half2x2
|
// fp8x4 -> half2x2
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a)
|
__inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a) {
|
||||||
{
|
|
||||||
union {
|
union {
|
||||||
uint2 u32x2;
|
uint2 u32x2;
|
||||||
uint32_t u32[2];
|
uint32_t u32[2];
|
||||||
@ -78,8 +76,7 @@ __inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a)
|
|||||||
|
|
||||||
// fp8x8 -> half2x4
|
// fp8x8 -> half2x4
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a)
|
__inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a) {
|
||||||
{
|
|
||||||
union {
|
union {
|
||||||
uint4 u64x2;
|
uint4 u64x2;
|
||||||
uint2 u64[2];
|
uint2 u64[2];
|
||||||
@ -93,8 +90,8 @@ using __nv_bfloat16 = __hip_bfloat16;
|
|||||||
|
|
||||||
// fp8 -> __nv_bfloat16
|
// fp8 -> __nv_bfloat16
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a)
|
__inline__ __device__ __nv_bfloat16
|
||||||
{
|
vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) {
|
||||||
hip_fp8 f8{a, hip_fp8::from_bits()};
|
hip_fp8 f8{a, hip_fp8::from_bits()};
|
||||||
float f{f8};
|
float f{f8};
|
||||||
return __float2bfloat16(f);
|
return __float2bfloat16(f);
|
||||||
@ -104,8 +101,8 @@ using __nv_bfloat162 = __hip_bfloat162;
|
|||||||
|
|
||||||
// fp8x2 -> __nv_bfloat162
|
// fp8x2 -> __nv_bfloat162
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a)
|
__inline__ __device__ __nv_bfloat162
|
||||||
{
|
vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a) {
|
||||||
__nv_bfloat162 res;
|
__nv_bfloat162 res;
|
||||||
res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a);
|
res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a);
|
||||||
res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U));
|
res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U));
|
||||||
@ -114,8 +111,8 @@ __inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(co
|
|||||||
|
|
||||||
// fp8x4 -> bf16_4_t
|
// fp8x4 -> bf16_4_t
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a)
|
__inline__ __device__ bf16_4_t
|
||||||
{
|
vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a) {
|
||||||
bf16_4_t res;
|
bf16_4_t res;
|
||||||
res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a);
|
res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a);
|
||||||
res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U));
|
res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U));
|
||||||
@ -124,8 +121,7 @@ __inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, uint32_t>(const uint32_t
|
|||||||
|
|
||||||
// fp8x8 -> bf16_8_t
|
// fp8x8 -> bf16_8_t
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a)
|
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a) {
|
||||||
{
|
|
||||||
bf16_4_t tmp1, tmp2;
|
bf16_4_t tmp1, tmp2;
|
||||||
tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x);
|
tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x);
|
||||||
tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y);
|
tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y);
|
||||||
@ -139,17 +135,17 @@ __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a)
|
|||||||
|
|
||||||
// fp8 -> float
|
// fp8 -> float
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a)
|
__inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a) {
|
||||||
{
|
|
||||||
hip_fp8 fp8{a, hip_fp8::from_bits()};
|
hip_fp8 fp8{a, hip_fp8::from_bits()};
|
||||||
return static_cast<float>(fp8);
|
return static_cast<float>(fp8);
|
||||||
}
|
}
|
||||||
|
|
||||||
// fp8x2 -> float2
|
// fp8x2 -> float2
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ float2 vec_conversion<float2, uint16_t>(const uint16_t& a)
|
__inline__ __device__ float2
|
||||||
{
|
vec_conversion<float2, uint16_t>(const uint16_t& a) {
|
||||||
#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
#if defined(__HIP__MI300__) && \
|
||||||
|
defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
||||||
float2 res;
|
float2 res;
|
||||||
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
||||||
res.x = f2[0];
|
res.x = f2[0];
|
||||||
@ -165,8 +161,8 @@ __inline__ __device__ float2 vec_conversion<float2, uint16_t>(const uint16_t& a)
|
|||||||
|
|
||||||
// fp8x4 -> float4
|
// fp8x4 -> float4
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ Float4_ vec_conversion<Float4_, uint32_t>(const uint32_t& a)
|
__inline__ __device__ Float4_
|
||||||
{
|
vec_conversion<Float4_, uint32_t>(const uint32_t& a) {
|
||||||
Float4_ res;
|
Float4_ res;
|
||||||
res.x = vec_conversion<float2, uint16_t>((uint16_t)a);
|
res.x = vec_conversion<float2, uint16_t>((uint16_t)a);
|
||||||
res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U));
|
res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U));
|
||||||
@ -175,8 +171,7 @@ __inline__ __device__ Float4_ vec_conversion<Float4_, uint32_t>(const uint32_t&
|
|||||||
|
|
||||||
// fp8x8 -> float8
|
// fp8x8 -> float8
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a)
|
__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a) {
|
||||||
{
|
|
||||||
Float4_ tmp1, tmp2;
|
Float4_ tmp1, tmp2;
|
||||||
tmp1 = vec_conversion<Float4_, uint32_t>(a.x);
|
tmp1 = vec_conversion<Float4_, uint32_t>(a.x);
|
||||||
tmp2 = vec_conversion<Float4_, uint32_t>(a.y);
|
tmp2 = vec_conversion<Float4_, uint32_t>(a.y);
|
||||||
@ -190,8 +185,8 @@ __inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a)
|
|||||||
|
|
||||||
// half -> fp8
|
// half -> fp8
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ uint8_t vec_conversion<uint8_t, uint16_t>(const uint16_t& a)
|
__inline__ __device__ uint8_t
|
||||||
{
|
vec_conversion<uint8_t, uint16_t>(const uint16_t& a) {
|
||||||
__half_raw tmp;
|
__half_raw tmp;
|
||||||
tmp.x = a;
|
tmp.x = a;
|
||||||
|
|
||||||
@ -201,24 +196,23 @@ __inline__ __device__ uint8_t vec_conversion<uint8_t, uint16_t>(const uint16_t&
|
|||||||
|
|
||||||
// bf16 -> fp8
|
// bf16 -> fp8
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ uint8_t vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a)
|
__inline__ __device__ uint8_t
|
||||||
{
|
vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a) {
|
||||||
hip_fp8 res{__bfloat162float(a)};
|
hip_fp8 res{__bfloat162float(a)};
|
||||||
return res.data;
|
return res.data;
|
||||||
}
|
}
|
||||||
|
|
||||||
// float -> fp8
|
// float -> fp8
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a)
|
__inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a) {
|
||||||
{
|
|
||||||
hip_fp8 f8(a);
|
hip_fp8 f8(a);
|
||||||
return f8.data;
|
return f8.data;
|
||||||
}
|
}
|
||||||
|
|
||||||
// fp8x4 -> float4
|
// fp8x4 -> float4
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ float4 vec_conversion<float4, uint32_t>(const uint32_t& a)
|
__inline__ __device__ float4
|
||||||
{
|
vec_conversion<float4, uint32_t>(const uint32_t& a) {
|
||||||
Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
|
Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
|
||||||
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
||||||
return res;
|
return res;
|
||||||
@ -226,8 +220,8 @@ __inline__ __device__ float4 vec_conversion<float4, uint32_t>(const uint32_t& a)
|
|||||||
|
|
||||||
// float2 -> half2
|
// float2 -> half2
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ uint32_t vec_conversion<uint32_t, float2>(const float2& a)
|
__inline__ __device__ uint32_t
|
||||||
{
|
vec_conversion<uint32_t, float2>(const float2& a) {
|
||||||
union {
|
union {
|
||||||
half2 float16;
|
half2 float16;
|
||||||
uint32_t uint32;
|
uint32_t uint32;
|
||||||
@ -239,8 +233,7 @@ __inline__ __device__ uint32_t vec_conversion<uint32_t, float2>(const float2& a)
|
|||||||
|
|
||||||
// Float4 -> half2x2
|
// Float4 -> half2x2
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a)
|
__inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a) {
|
||||||
{
|
|
||||||
uint2 b;
|
uint2 b;
|
||||||
float2 val;
|
float2 val;
|
||||||
val.x = a.x.x;
|
val.x = a.x.x;
|
||||||
@ -255,8 +248,7 @@ __inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a)
|
|||||||
|
|
||||||
// Float4 -> float4
|
// Float4 -> float4
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a)
|
__inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a) {
|
||||||
{
|
|
||||||
float4 b;
|
float4 b;
|
||||||
b.x = a.x.x;
|
b.x = a.x.x;
|
||||||
b.y = a.x.y;
|
b.y = a.x.y;
|
||||||
@ -267,8 +259,7 @@ __inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a)
|
|||||||
|
|
||||||
// Float8 -> half2x4
|
// Float8 -> half2x4
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a)
|
__inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a) {
|
||||||
{
|
|
||||||
uint4 b;
|
uint4 b;
|
||||||
b.x = vec_conversion<uint32_t, float2>(a.x);
|
b.x = vec_conversion<uint32_t, float2>(a.x);
|
||||||
b.y = vec_conversion<uint32_t, float2>(a.y);
|
b.y = vec_conversion<uint32_t, float2>(a.y);
|
||||||
@ -279,16 +270,16 @@ __inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a)
|
|||||||
|
|
||||||
// float2 -> bfloat162
|
// float2 -> bfloat162
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(const float2& a)
|
__inline__ __device__ __nv_bfloat162
|
||||||
{
|
vec_conversion<__nv_bfloat162, float2>(const float2& a) {
|
||||||
__nv_bfloat162 b = __float22bfloat162_rn(a);
|
__nv_bfloat162 b = __float22bfloat162_rn(a);
|
||||||
return b;
|
return b;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Float4 -> bfloat162x2
|
// Float4 -> bfloat162x2
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, Float4_>(const Float4_& a)
|
__inline__ __device__ bf16_4_t
|
||||||
{
|
vec_conversion<bf16_4_t, Float4_>(const Float4_& a) {
|
||||||
bf16_4_t b;
|
bf16_4_t b;
|
||||||
b.x = __float22bfloat162_rn(a.x);
|
b.x = __float22bfloat162_rn(a.x);
|
||||||
b.y = __float22bfloat162_rn(a.y);
|
b.y = __float22bfloat162_rn(a.y);
|
||||||
@ -297,8 +288,8 @@ __inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, Float4_>(const Float4_&
|
|||||||
|
|
||||||
// Float8 -> bfloat162x4
|
// Float8 -> bfloat162x4
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(const Float8_& a)
|
__inline__ __device__ bf16_8_t
|
||||||
{
|
vec_conversion<bf16_8_t, Float8_>(const Float8_& a) {
|
||||||
bf16_8_t b;
|
bf16_8_t b;
|
||||||
b.x = __float22bfloat162_rn(a.x);
|
b.x = __float22bfloat162_rn(a.x);
|
||||||
b.y = __float22bfloat162_rn(a.y);
|
b.y = __float22bfloat162_rn(a.y);
|
||||||
@ -307,20 +298,19 @@ __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(const Float8_&
|
|||||||
return b;
|
return b;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* Scaled and vectorized conversions, for data exchange between high and low
|
||||||
|
precision domains
|
||||||
|
|
||||||
/* Scaled and vectorized conversions, for data exchange between high and low precision domains
|
Convention of the scale in API, e.g: FP8_data = Quantization(
|
||||||
|
High_Precision_data / scale ) s.t. Quantize(HP / scale) => FP8 Dequant(FP8) *
|
||||||
Convention of the scale in API, e.g: FP8_data = Quantization( High_Precision_data / scale )
|
scale => HP
|
||||||
s.t.
|
|
||||||
Quantize(HP / scale) => FP8
|
|
||||||
Dequant(FP8) * scale => HP
|
|
||||||
|
|
||||||
*/
|
*/
|
||||||
|
|
||||||
// fp8 -> half
|
// fp8 -> half
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, const float scale)
|
__inline__ __device__ uint16_t
|
||||||
{
|
scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, const float scale) {
|
||||||
hip_fp8 f8{a, hip_fp8::from_bits()};
|
hip_fp8 f8{a, hip_fp8::from_bits()};
|
||||||
__half_raw res;
|
__half_raw res;
|
||||||
res.data = static_cast<float>(f8) * scale;
|
res.data = static_cast<float>(f8) * scale;
|
||||||
@ -329,9 +319,10 @@ __inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, uint8_t>(const ui
|
|||||||
|
|
||||||
// fp8x2 -> half2
|
// fp8x2 -> half2
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>(const uint16_t& a, const float scale)
|
__inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>(
|
||||||
{
|
const uint16_t& a, const float scale) {
|
||||||
#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
#if defined(__HIP__MI300__) && \
|
||||||
|
defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
||||||
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
||||||
union {
|
union {
|
||||||
__half2_raw h2r;
|
__half2_raw h2r;
|
||||||
@ -346,29 +337,32 @@ __inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>(const u
|
|||||||
uint32_t u32;
|
uint32_t u32;
|
||||||
} tmp;
|
} tmp;
|
||||||
|
|
||||||
tmp.u16[0] = scaled_vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a), scale);
|
tmp.u16[0] =
|
||||||
tmp.u16[1] = scaled_vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a >> 8U), scale);
|
scaled_vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a), scale);
|
||||||
|
tmp.u16[1] = scaled_vec_conversion<uint16_t, uint8_t>(
|
||||||
|
static_cast<uint8_t>(a >> 8U), scale);
|
||||||
return tmp.u32;
|
return tmp.u32;
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
// fp8x4 -> half2x2
|
// fp8x4 -> half2x2
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ uint2 scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, const float scale)
|
__inline__ __device__ uint2
|
||||||
{
|
scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, const float scale) {
|
||||||
union {
|
union {
|
||||||
uint2 u32x2;
|
uint2 u32x2;
|
||||||
uint32_t u32[2];
|
uint32_t u32[2];
|
||||||
} tmp;
|
} tmp;
|
||||||
tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale);
|
tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale);
|
||||||
tmp.u32[1] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale);
|
tmp.u32[1] =
|
||||||
|
scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale);
|
||||||
return tmp.u32x2;
|
return tmp.u32x2;
|
||||||
}
|
}
|
||||||
|
|
||||||
// fp8x8 -> half2x4
|
// fp8x8 -> half2x4
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ uint4 scaled_vec_conversion<uint4, uint2>(const uint2& a, const float scale)
|
__inline__ __device__ uint4
|
||||||
{
|
scaled_vec_conversion<uint4, uint2>(const uint2& a, const float scale) {
|
||||||
union {
|
union {
|
||||||
uint4 u64x2;
|
uint4 u64x2;
|
||||||
uint2 u64[2];
|
uint2 u64[2];
|
||||||
@ -382,8 +376,9 @@ using __nv_bfloat16 = __hip_bfloat16;
|
|||||||
|
|
||||||
// fp8 -> __nv_bfloat16
|
// fp8 -> __nv_bfloat16
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ __nv_bfloat16 scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, const float scale)
|
__inline__ __device__ __nv_bfloat16
|
||||||
{
|
scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a,
|
||||||
|
const float scale) {
|
||||||
hip_fp8 f8{a, hip_fp8::from_bits()};
|
hip_fp8 f8{a, hip_fp8::from_bits()};
|
||||||
float f{f8};
|
float f{f8};
|
||||||
return __float2bfloat16(f * scale);
|
return __float2bfloat16(f * scale);
|
||||||
@ -393,28 +388,31 @@ using __nv_bfloat162 = __hip_bfloat162;
|
|||||||
|
|
||||||
// fp8x2 -> __nv_bfloat162
|
// fp8x2 -> __nv_bfloat162
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ __nv_bfloat162 scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a, const float scale)
|
__inline__ __device__ __nv_bfloat162
|
||||||
{
|
scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a,
|
||||||
|
const float scale) {
|
||||||
__nv_bfloat162 res;
|
__nv_bfloat162 res;
|
||||||
res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale);
|
res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale);
|
||||||
res.y = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale);
|
res.y =
|
||||||
|
scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale);
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
// fp8x4 -> bf16_4_t
|
// fp8x4 -> bf16_4_t
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a, const float scale)
|
__inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(
|
||||||
{
|
const uint32_t& a, const float scale) {
|
||||||
bf16_4_t res;
|
bf16_4_t res;
|
||||||
res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale);
|
res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale);
|
||||||
res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), scale);
|
res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U),
|
||||||
|
scale);
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
// fp8x8 -> bf16_8_t
|
// fp8x8 -> bf16_8_t
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ bf16_8_t scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, const float scale)
|
__inline__ __device__ bf16_8_t
|
||||||
{
|
scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, const float scale) {
|
||||||
bf16_4_t tmp1, tmp2;
|
bf16_4_t tmp1, tmp2;
|
||||||
tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale);
|
tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale);
|
||||||
tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale);
|
tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale);
|
||||||
@ -428,17 +426,18 @@ __inline__ __device__ bf16_8_t scaled_vec_conversion<bf16_8_t, uint2>(const uint
|
|||||||
|
|
||||||
// fp8 -> float
|
// fp8 -> float
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>(const uint8_t& a, const float scale)
|
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
|
||||||
{
|
const uint8_t& a, const float scale) {
|
||||||
hip_fp8 fp8{a, hip_fp8::from_bits()};
|
hip_fp8 fp8{a, hip_fp8::from_bits()};
|
||||||
return static_cast<float>(fp8) * scale;
|
return static_cast<float>(fp8) * scale;
|
||||||
}
|
}
|
||||||
|
|
||||||
// fp8x2 -> float2
|
// fp8x2 -> float2
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ float2 scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, const float scale)
|
__inline__ __device__ float2
|
||||||
{
|
scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, const float scale) {
|
||||||
#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
#if defined(__HIP__MI300__) && \
|
||||||
|
defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
||||||
float2 res;
|
float2 res;
|
||||||
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
||||||
res.x = f2[0] * scale;
|
res.x = f2[0] * scale;
|
||||||
@ -447,15 +446,16 @@ __inline__ __device__ float2 scaled_vec_conversion<float2, uint16_t>(const uint1
|
|||||||
#else
|
#else
|
||||||
float2 res;
|
float2 res;
|
||||||
res.x = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a), scale);
|
res.x = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a), scale);
|
||||||
res.y = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U), scale);
|
res.y = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U),
|
||||||
|
scale);
|
||||||
return res;
|
return res;
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
// fp8x4 -> float4
|
// fp8x4 -> float4
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ Float4_ scaled_vec_conversion<Float4_, uint32_t>(const uint32_t& a, const float scale)
|
__inline__ __device__ Float4_
|
||||||
{
|
scaled_vec_conversion<Float4_, uint32_t>(const uint32_t& a, const float scale) {
|
||||||
Float4_ res;
|
Float4_ res;
|
||||||
res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale);
|
res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale);
|
||||||
res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale);
|
res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale);
|
||||||
@ -464,8 +464,8 @@ __inline__ __device__ Float4_ scaled_vec_conversion<Float4_, uint32_t>(const uin
|
|||||||
|
|
||||||
// fp8x8 -> float8
|
// fp8x8 -> float8
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ Float8_ scaled_vec_conversion<Float8_, uint2>(const uint2& a, const float scale)
|
__inline__ __device__ Float8_
|
||||||
{
|
scaled_vec_conversion<Float8_, uint2>(const uint2& a, const float scale) {
|
||||||
Float4_ tmp1, tmp2;
|
Float4_ tmp1, tmp2;
|
||||||
tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale);
|
tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale);
|
||||||
tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale);
|
tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale);
|
||||||
@ -477,15 +477,14 @@ __inline__ __device__ Float8_ scaled_vec_conversion<Float8_, uint2>(const uint2&
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/* Quantize(HP / scale) => FP8 */
|
/* Quantize(HP / scale) => FP8 */
|
||||||
|
|
||||||
// TODO(Hai): vectorized to add
|
// TODO(Hai): vectorized to add
|
||||||
|
|
||||||
// half -> fp8
|
// half -> fp8
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, const float scale)
|
__inline__ __device__ uint8_t
|
||||||
{
|
scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, const float scale) {
|
||||||
__half_raw tmp;
|
__half_raw tmp;
|
||||||
tmp.x = a;
|
tmp.x = a;
|
||||||
|
|
||||||
@ -495,24 +494,24 @@ __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, uint16_t>(const uin
|
|||||||
|
|
||||||
// bf16 -> fp8
|
// bf16 -> fp8
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a, const float scale)
|
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
|
||||||
{
|
const __nv_bfloat16& a, const float scale) {
|
||||||
hip_fp8 res{__bfloat162float(a) / scale};
|
hip_fp8 res{__bfloat162float(a) / scale};
|
||||||
return res.data;
|
return res.data;
|
||||||
}
|
}
|
||||||
|
|
||||||
// float -> fp8
|
// float -> fp8
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, float>(const float& a, const float scale)
|
__inline__ __device__ uint8_t
|
||||||
{
|
scaled_vec_conversion<uint8_t, float>(const float& a, const float scale) {
|
||||||
hip_fp8 f8(a / scale);
|
hip_fp8 f8(a / scale);
|
||||||
return f8.data;
|
return f8.data;
|
||||||
}
|
}
|
||||||
|
|
||||||
// fp8x4 -> float4
|
// fp8x4 -> float4
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ float4 scaled_vec_conversion<float4, uint32_t>(const uint32_t& a, const float scale)
|
__inline__ __device__ float4
|
||||||
{
|
scaled_vec_conversion<float4, uint32_t>(const uint32_t& a, const float scale) {
|
||||||
Float4_ tmp = scaled_vec_conversion<Float4_, uint32_t>(a, scale);
|
Float4_ tmp = scaled_vec_conversion<Float4_, uint32_t>(a, scale);
|
||||||
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
||||||
return res;
|
return res;
|
||||||
@ -539,9 +538,10 @@ __inline__ __device__ Tout scaled_convert(const Tin &x, const float scale) {
|
|||||||
assert(false);
|
assert(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
// The following macro is used to dispatch the conversion function based on the
|
// The following macro is used to dispatch the conversion function based on
|
||||||
// data type of the key and value cache. The FN is a macro that calls a function
|
// the data type of the key and value cache. The FN is a macro that calls a
|
||||||
// with template<typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>.
|
// function with template<typename scalar_t, typename cache_t,
|
||||||
|
// Fp8KVCacheDataType kv_dt>.
|
||||||
#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
|
#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
|
||||||
if (KV_DTYPE == "auto") { \
|
if (KV_DTYPE == "auto") { \
|
||||||
if (SRC_DTYPE == at::ScalarType::Float) { \
|
if (SRC_DTYPE == at::ScalarType::Float) { \
|
||||||
@ -562,13 +562,14 @@ __inline__ __device__ Tout scaled_convert(const Tin &x, const float scale) {
|
|||||||
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
||||||
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
||||||
} else { \
|
} else { \
|
||||||
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
|
TORCH_CHECK(false, \
|
||||||
|
"Unsupported input type of kv cache: ", SRC_DTYPE); \
|
||||||
} \
|
} \
|
||||||
} else { \
|
} else { \
|
||||||
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
|
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
|
||||||
} \
|
} \
|
||||||
}
|
}
|
||||||
|
|
||||||
} // fp8
|
} // namespace fp8
|
||||||
#endif // USE_ROCM
|
#endif // USE_ROCM
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
@ -11,8 +11,10 @@ namespace vllm {
|
|||||||
|
|
||||||
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
|
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
|
||||||
float old;
|
float old;
|
||||||
old = (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) :
|
old = (value >= 0)
|
||||||
__uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value)));
|
? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
|
||||||
|
: __uint_as_float(
|
||||||
|
atomicMin((unsigned int*)addr, __float_as_uint(value)));
|
||||||
|
|
||||||
return old;
|
return old;
|
||||||
}
|
}
|
||||||
@ -20,7 +22,8 @@ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
|
|||||||
#define FP8_E4M3_MAX std::numeric_limits<c10::Float8_e4m3fn>::max()
|
#define FP8_E4M3_MAX std::numeric_limits<c10::Float8_e4m3fn>::max()
|
||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
__device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(const scalar_t val, const float scale) {
|
__device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(
|
||||||
|
const scalar_t val, const float scale) {
|
||||||
float x = static_cast<float>(val) / scale;
|
float x = static_cast<float>(val) / scale;
|
||||||
float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
|
float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
|
||||||
return static_cast<c10::Float8_e4m3fn>(r);
|
return static_cast<c10::Float8_e4m3fn>(r);
|
||||||
@ -33,8 +36,7 @@ __device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(const scalar
|
|||||||
// a value <= 0.0 and we need to wait for all thread blocks to
|
// a value <= 0.0 and we need to wait for all thread blocks to
|
||||||
// finish before consuming *scale.
|
// finish before consuming *scale.
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
__global__ void segmented_max_reduction(
|
__global__ void segmented_max_reduction(float* __restrict__ scale,
|
||||||
float* __restrict__ scale,
|
|
||||||
const scalar_t* __restrict__ input,
|
const scalar_t* __restrict__ input,
|
||||||
int64_t num_elems) {
|
int64_t num_elems) {
|
||||||
__shared__ float cache[1024];
|
__shared__ float cache[1024];
|
||||||
@ -64,13 +66,13 @@ __global__ void segmented_max_reduction(
|
|||||||
// Finally, since cache[0] contains the maximum for this thread block,
|
// Finally, since cache[0] contains the maximum for this thread block,
|
||||||
// atomically write the max to the target location
|
// atomically write the max to the target location
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
atomicMaxFloat(scale, cache[0] / std::numeric_limits<c10::Float8_e4m3fn>::max());
|
atomicMaxFloat(scale,
|
||||||
|
cache[0] / std::numeric_limits<c10::Float8_e4m3fn>::max());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
__global__ void scaled_fp8_quant_kernel(
|
__global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out,
|
||||||
c10::Float8_e4m3fn* __restrict__ out,
|
|
||||||
const scalar_t* __restrict__ input,
|
const scalar_t* __restrict__ input,
|
||||||
const float* __restrict__ scale,
|
const float* __restrict__ scale,
|
||||||
int64_t num_elems) {
|
int64_t num_elems) {
|
||||||
@ -83,8 +85,7 @@ __global__ void scaled_fp8_quant_kernel(
|
|||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
void static_scaled_fp8_quant(
|
void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
||||||
torch::Tensor& out, // [..., d]
|
|
||||||
torch::Tensor& input, // [..., d]
|
torch::Tensor& input, // [..., d]
|
||||||
torch::Tensor& scale) // [1]
|
torch::Tensor& scale) // [1]
|
||||||
{
|
{
|
||||||
@ -95,19 +96,14 @@ void static_scaled_fp8_quant(
|
|||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(
|
VLLM_DISPATCH_FLOATING_TYPES(
|
||||||
input.scalar_type(),
|
input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
|
||||||
"scaled_fp8_quant_kernel",
|
|
||||||
[&] {
|
|
||||||
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||||
out.data_ptr<c10::Float8_e4m3fn>(),
|
out.data_ptr<c10::Float8_e4m3fn>(), input.data_ptr<scalar_t>(),
|
||||||
input.data_ptr<scalar_t>(),
|
scale.data_ptr<float>(), num_elems);
|
||||||
scale.data_ptr<float>(),
|
|
||||||
num_elems);
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void dynamic_scaled_fp8_quant(
|
void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
||||||
torch::Tensor& out, // [..., d]
|
|
||||||
torch::Tensor& input, // [..., d]
|
torch::Tensor& input, // [..., d]
|
||||||
torch::Tensor& scale) // [1]
|
torch::Tensor& scale) // [1]
|
||||||
{
|
{
|
||||||
@ -118,18 +114,11 @@ void dynamic_scaled_fp8_quant(
|
|||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(
|
VLLM_DISPATCH_FLOATING_TYPES(
|
||||||
input.scalar_type(),
|
input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
|
||||||
"scaled_fp8_quant_kernel",
|
|
||||||
[&] {
|
|
||||||
vllm::segmented_max_reduction<scalar_t><<<grid, block, 0, stream>>>(
|
vllm::segmented_max_reduction<scalar_t><<<grid, block, 0, stream>>>(
|
||||||
scale.data_ptr<float>(),
|
scale.data_ptr<float>(), input.data_ptr<scalar_t>(), num_elems);
|
||||||
input.data_ptr<scalar_t>(),
|
|
||||||
num_elems);
|
|
||||||
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||||
out.data_ptr<c10::Float8_e4m3fn>(),
|
out.data_ptr<c10::Float8_e4m3fn>(), input.data_ptr<scalar_t>(),
|
||||||
input.data_ptr<scalar_t>(),
|
scale.data_ptr<float>(), num_elems);
|
||||||
scale.data_ptr<float>(),
|
|
||||||
num_elems);
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -406,7 +406,6 @@ template <>
|
|||||||
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
|
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
|
||||||
const uint8_t& a, const float scale,
|
const uint8_t& a, const float scale,
|
||||||
const __nv_fp8_interpretation_t fp8_type) {
|
const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
|
||||||
// fp8 -> half
|
// fp8 -> half
|
||||||
__half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
|
__half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
|
||||||
uint16_t tmp = res.x;
|
uint16_t tmp = res.x;
|
||||||
@ -523,9 +522,10 @@ __inline__ __device__ Tout scaled_convert(const Tin &x, const float scale) {
|
|||||||
assert(false);
|
assert(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
// The following macro is used to dispatch the conversion function based on the
|
// The following macro is used to dispatch the conversion function based on
|
||||||
// data type of the key and value cache. The FN is a macro that calls a function
|
// the data type of the key and value cache. The FN is a macro that calls a
|
||||||
// with template<typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>.
|
// function with template<typename scalar_t, typename cache_t,
|
||||||
|
// Fp8KVCacheDataType kv_dt>.
|
||||||
#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
|
#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
|
||||||
if (KV_DTYPE == "auto") { \
|
if (KV_DTYPE == "auto") { \
|
||||||
if (SRC_DTYPE == at::ScalarType::Float) { \
|
if (SRC_DTYPE == at::ScalarType::Float) { \
|
||||||
@ -546,7 +546,8 @@ __inline__ __device__ Tout scaled_convert(const Tin &x, const float scale) {
|
|||||||
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
||||||
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
||||||
} else { \
|
} else { \
|
||||||
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
|
TORCH_CHECK(false, \
|
||||||
|
"Unsupported input type of kv cache: ", SRC_DTYPE); \
|
||||||
} \
|
} \
|
||||||
} else if (KV_DTYPE == "fp8_e5m2") { \
|
} else if (KV_DTYPE == "fp8_e5m2") { \
|
||||||
if (SRC_DTYPE == at::ScalarType::Float) { \
|
if (SRC_DTYPE == at::ScalarType::Float) { \
|
||||||
@ -556,7 +557,8 @@ __inline__ __device__ Tout scaled_convert(const Tin &x, const float scale) {
|
|||||||
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
||||||
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
|
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
|
||||||
} else { \
|
} else { \
|
||||||
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
|
TORCH_CHECK(false, \
|
||||||
|
"Unsupported input type of kv cache: ", SRC_DTYPE); \
|
||||||
} \
|
} \
|
||||||
} else { \
|
} else { \
|
||||||
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
|
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
|
||||||
|
@ -9,40 +9,36 @@ namespace vllm {
|
|||||||
namespace gptq {
|
namespace gptq {
|
||||||
// atomicAdd for half types, to support CC < 7.x
|
// atomicAdd for half types, to support CC < 7.x
|
||||||
|
|
||||||
__device__ __forceinline__ void atomicAdd_half(half* address, half val)
|
__device__ __forceinline__ void atomicAdd_half(half* address, half val) {
|
||||||
{
|
unsigned int* address_as_ui =
|
||||||
unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
|
(unsigned int*)((char*)address - ((size_t)address & 2));
|
||||||
unsigned int old = *address_as_ui;
|
unsigned int old = *address_as_ui;
|
||||||
unsigned int assumed;
|
unsigned int assumed;
|
||||||
|
|
||||||
do
|
do {
|
||||||
{
|
|
||||||
assumed = old;
|
assumed = old;
|
||||||
__half_raw hsum;
|
__half_raw hsum;
|
||||||
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
|
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
|
||||||
half tmpres = __hadd(hsum, val);
|
half tmpres = __hadd(hsum, val);
|
||||||
hsum = __half_raw(tmpres);
|
hsum = __half_raw(tmpres);
|
||||||
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
|
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16)
|
||||||
|
: (old & 0xffff0000) | hsum.x;
|
||||||
old = atomicCAS(address_as_ui, assumed, old);
|
old = atomicCAS(address_as_ui, assumed, old);
|
||||||
}
|
} while (assumed != old);
|
||||||
while (assumed != old);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// atomicAdd for half2 types
|
// atomicAdd for half2 types
|
||||||
|
|
||||||
__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
|
__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) {
|
||||||
{
|
|
||||||
unsigned int* address_as_ui = (unsigned int*)address;
|
unsigned int* address_as_ui = (unsigned int*)address;
|
||||||
unsigned int old = *address_as_ui;
|
unsigned int old = *address_as_ui;
|
||||||
unsigned int assumed;
|
unsigned int assumed;
|
||||||
do
|
do {
|
||||||
{
|
|
||||||
assumed = old;
|
assumed = old;
|
||||||
half2 old_val = *((half2*)&old);
|
half2 old_val = *((half2*)&old);
|
||||||
half2 new_val = __hadd2(old_val, val);
|
half2 new_val = __hadd2(old_val, val);
|
||||||
old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val));
|
old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val));
|
||||||
}
|
} while (assumed != old);
|
||||||
while (assumed != old);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
@ -50,10 +46,14 @@ __device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
|
|||||||
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
|
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
|
||||||
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
|
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
|
||||||
|
|
||||||
__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
|
__device__ __forceinline__ void atomicAdd(half* address, half val) {
|
||||||
|
atomicAdd_half(address, val);
|
||||||
|
}
|
||||||
|
|
||||||
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
|
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
|
||||||
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
|
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) {
|
||||||
|
atomicAdd_half2(address, val);
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
/*
|
/*
|
||||||
Adapted from https://github.com/turboderp/exllamav2 and https://github.com/turboderp/exllama
|
Adapted from https://github.com/turboderp/exllamav2 and
|
||||||
|
https://github.com/turboderp/exllama
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef _matrix_view_cuh
|
#ifndef _matrix_view_cuh
|
||||||
@ -13,24 +14,31 @@ Adapted from https://github.com/turboderp/exllamav2 and https://github.com/turbo
|
|||||||
namespace vllm {
|
namespace vllm {
|
||||||
namespace gptq {
|
namespace gptq {
|
||||||
|
|
||||||
class MatrixView_half
|
class MatrixView_half {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
const half* data;
|
const half* data;
|
||||||
const int height;
|
const int height;
|
||||||
const int width;
|
const int width;
|
||||||
|
|
||||||
__device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width)
|
__device__ __forceinline__ MatrixView_half(const half* data, const int height,
|
||||||
: data(data), height(height), width(width)
|
const int width)
|
||||||
{ }
|
: data(data), height(height), width(width) {}
|
||||||
|
|
||||||
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
|
__device__ __forceinline__ half item(int row, int column) const {
|
||||||
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
|
return data[row * width + column];
|
||||||
__device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); }
|
}
|
||||||
__device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; }
|
__device__ __forceinline__ half2 item_half2(int row, int column) const {
|
||||||
|
return ((half2*)data)[(row * width + column) / 2];
|
||||||
|
}
|
||||||
|
__device__ __forceinline__ half2 item_half2half2(int row, int column) const {
|
||||||
|
return __half2half2(data[row * width + column]);
|
||||||
|
}
|
||||||
|
__device__ __forceinline__ const half* item_ptr(int row, int column) const {
|
||||||
|
return &data[row * width + column];
|
||||||
|
}
|
||||||
|
|
||||||
__device__ __forceinline__ void item4(half (&items)[4], int row, int column) const
|
__device__ __forceinline__ void item4(half (&items)[4], int row,
|
||||||
{
|
int column) const {
|
||||||
half2* ptr = (half2*)item_ptr(row, column);
|
half2* ptr = (half2*)item_ptr(row, column);
|
||||||
half2 i01 = ptr[0];
|
half2 i01 = ptr[0];
|
||||||
half2 i23 = ptr[1];
|
half2 i23 = ptr[1];
|
||||||
@ -39,8 +47,8 @@ public:
|
|||||||
items[2] = __low2half(i23);
|
items[2] = __low2half(i23);
|
||||||
items[3] = __high2half(i23);
|
items[3] = __high2half(i23);
|
||||||
}
|
}
|
||||||
__device__ __forceinline__ void item4_f(float (&items)[4], int row, int column) const
|
__device__ __forceinline__ void item4_f(float (&items)[4], int row,
|
||||||
{
|
int column) const {
|
||||||
half2* ptr = (half2*)item_ptr(row, column);
|
half2* ptr = (half2*)item_ptr(row, column);
|
||||||
half2 i01 = ptr[0];
|
half2 i01 = ptr[0];
|
||||||
half2 i23 = ptr[1];
|
half2 i23 = ptr[1];
|
||||||
@ -50,8 +58,8 @@ public:
|
|||||||
items[3] = __half2float(__high2half(i23));
|
items[3] = __half2float(__high2half(i23));
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, int column) const
|
__device__ __forceinline__ void item4_h2(half2 (&items)[4], int row,
|
||||||
{
|
int column) const {
|
||||||
half2* ptr = (half2*)item_ptr(row, column);
|
half2* ptr = (half2*)item_ptr(row, column);
|
||||||
half2 i01 = ptr[0];
|
half2 i01 = ptr[0];
|
||||||
half2 i23 = ptr[1];
|
half2 i23 = ptr[1];
|
||||||
@ -62,25 +70,34 @@ public:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class MatrixView_half_rw
|
class MatrixView_half_rw {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
half* data;
|
half* data;
|
||||||
const int height;
|
const int height;
|
||||||
const int width;
|
const int width;
|
||||||
|
|
||||||
__device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width)
|
__device__ __forceinline__ MatrixView_half_rw(half* data, const int height,
|
||||||
: data(data), height(height), width(width)
|
const int width)
|
||||||
{ }
|
: data(data), height(height), width(width) {}
|
||||||
|
|
||||||
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
|
__device__ __forceinline__ half item(int row, int column) const {
|
||||||
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
|
return data[row * width + column];
|
||||||
__device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; }
|
}
|
||||||
__device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; }
|
__device__ __forceinline__ half2 item_half2(int row, int column) const {
|
||||||
__device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; }
|
return ((half2*)data)[(row * width + column) / 2];
|
||||||
|
}
|
||||||
|
__device__ __forceinline__ half* item_ptr(int row, int column) {
|
||||||
|
return &data[row * width + column];
|
||||||
|
}
|
||||||
|
__device__ __forceinline__ void set(int row, int column, half value) {
|
||||||
|
data[row * width + column] = value;
|
||||||
|
}
|
||||||
|
__device__ __forceinline__ void set_half2(int row, int column, half2 value) {
|
||||||
|
((half2*)data)[(row * width + column) / 2] = value;
|
||||||
|
}
|
||||||
|
|
||||||
__device__ __forceinline__ void set4(int row, int column, half v0, half v1, half v2, half v3)
|
__device__ __forceinline__ void set4(int row, int column, half v0, half v1,
|
||||||
{
|
half v2, half v3) {
|
||||||
half2 v01 = __halves2half2(v0, v1);
|
half2 v01 = __halves2half2(v0, v1);
|
||||||
half2 v23 = __halves2half2(v2, v3);
|
half2 v23 = __halves2half2(v2, v3);
|
||||||
half2* ptr = (half2*)item_ptr(row, column);
|
half2* ptr = (half2*)item_ptr(row, column);
|
||||||
@ -89,33 +106,32 @@ public:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class MatrixView_q4_row
|
class MatrixView_q4_row {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
const uint32_t* data;
|
const uint32_t* data;
|
||||||
const int height;
|
const int height;
|
||||||
const int width;
|
const int width;
|
||||||
|
|
||||||
__device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width)
|
__device__ __forceinline__ MatrixView_q4_row(const uint32_t* data,
|
||||||
: data(data), height(height), width(width)
|
const int height,
|
||||||
{ }
|
const int width)
|
||||||
|
: data(data), height(height), width(width) {}
|
||||||
|
|
||||||
__device__ __forceinline__ int item(int row, int column) const
|
__device__ __forceinline__ int item(int row, int column) const {
|
||||||
{
|
|
||||||
int shift = (column & 0x07) * 4;
|
int shift = (column & 0x07) * 4;
|
||||||
return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
|
return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const
|
__device__ __forceinline__ void item2(int (&items)[2], int row,
|
||||||
{
|
int column) const {
|
||||||
int shift = (column & 0x07) * 4;
|
int shift = (column & 0x07) * 4;
|
||||||
uint32_t d = data[row * width / 8 + column / 8] >> shift;
|
uint32_t d = data[row * width / 8 + column / 8] >> shift;
|
||||||
items[0] = d & 0x0f;
|
items[0] = d & 0x0f;
|
||||||
items[1] = (d >> 4) & 0x0f;
|
items[1] = (d >> 4) & 0x0f;
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
|
__device__ __forceinline__ void item4(int (&items)[4], int row,
|
||||||
{
|
int column) const {
|
||||||
int shift = (column & 0x07) * 4;
|
int shift = (column & 0x07) * 4;
|
||||||
uint32_t d = data[row * width / 8 + column / 8] >> shift;
|
uint32_t d = data[row * width / 8 + column / 8] >> shift;
|
||||||
items[0] = d & 0x0f;
|
items[0] = d & 0x0f;
|
||||||
@ -125,54 +141,57 @@ public:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class MatrixView_q4_column
|
class MatrixView_q4_column {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
const uint32_t* data;
|
const uint32_t* data;
|
||||||
const int height;
|
const int height;
|
||||||
const int width;
|
const int width;
|
||||||
|
|
||||||
__device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width)
|
__device__ __forceinline__ MatrixView_q4_column(const uint32_t* data,
|
||||||
: data(data), height(height), width(width)
|
const int height,
|
||||||
{ }
|
const int width)
|
||||||
|
: data(data), height(height), width(width) {}
|
||||||
|
|
||||||
__device__ __forceinline__ int item(int row, int column) const
|
__device__ __forceinline__ int item(int row, int column) const {
|
||||||
{
|
|
||||||
int shift = (row & 0x07) * 4;
|
int shift = (row & 0x07) * 4;
|
||||||
return (data[row / 8 * width + column] >> shift) & 0x0f;
|
return (data[row / 8 * width + column] >> shift) & 0x0f;
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; }
|
__device__ __forceinline__ uint32_t item_uint32_t(int row, int column) {
|
||||||
__device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; }
|
return data[row / 8 * width + column];
|
||||||
|
}
|
||||||
|
__device__ __forceinline__ const uint32_t* item_uint32_ptr(int row,
|
||||||
|
int column) {
|
||||||
|
return &data[row / 8 * width + column];
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class MatrixView_q2_row
|
class MatrixView_q2_row {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
const uint32_t* data;
|
const uint32_t* data;
|
||||||
const int height;
|
const int height;
|
||||||
const int width;
|
const int width;
|
||||||
|
|
||||||
__device__ __forceinline__ MatrixView_q2_row(const uint32_t* data, const int height, const int width)
|
__device__ __forceinline__ MatrixView_q2_row(const uint32_t* data,
|
||||||
: data(data), height(height), width(width)
|
const int height,
|
||||||
{ }
|
const int width)
|
||||||
|
: data(data), height(height), width(width) {}
|
||||||
|
|
||||||
__device__ __forceinline__ int item(int row, int column) const
|
__device__ __forceinline__ int item(int row, int column) const {
|
||||||
{
|
|
||||||
int shift = (column & 0x0f) * 2;
|
int shift = (column & 0x0f) * 2;
|
||||||
return (data[row * width / 16 + column / 16] >> shift) & 0x03;
|
return (data[row * width / 16 + column / 16] >> shift) & 0x03;
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const
|
__device__ __forceinline__ void item2(int (&items)[2], int row,
|
||||||
{
|
int column) const {
|
||||||
int shift = (column & 0x0f) * 2;
|
int shift = (column & 0x0f) * 2;
|
||||||
uint32_t d = data[row * width / 16 + column / 16] >> shift;
|
uint32_t d = data[row * width / 16 + column / 16] >> shift;
|
||||||
items[0] = d & 0x03;
|
items[0] = d & 0x03;
|
||||||
items[1] = (d >> 2) & 0x03;
|
items[1] = (d >> 2) & 0x03;
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
|
__device__ __forceinline__ void item4(int (&items)[4], int row,
|
||||||
{
|
int column) const {
|
||||||
int shift = (column & 0x0f) * 2;
|
int shift = (column & 0x0f) * 2;
|
||||||
uint32_t d = data[row * width / 16 + column / 16] >> shift;
|
uint32_t d = data[row * width / 16 + column / 16] >> shift;
|
||||||
items[0] = d & 0x03;
|
items[0] = d & 0x03;
|
||||||
@ -182,26 +201,27 @@ public:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class MatrixView_q3_row
|
class MatrixView_q3_row {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
const uint32_t* data;
|
const uint32_t* data;
|
||||||
const int height;
|
const int height;
|
||||||
const int width;
|
const int width;
|
||||||
|
|
||||||
__device__ __forceinline__ MatrixView_q3_row(const uint32_t* data, const int height, const int width)
|
__device__ __forceinline__ MatrixView_q3_row(const uint32_t* data,
|
||||||
: data(data), height(height), width(width)
|
const int height,
|
||||||
{ }
|
const int width)
|
||||||
|
: data(data), height(height), width(width) {}
|
||||||
|
|
||||||
__device__ __forceinline__ int item(int row, int column) const
|
__device__ __forceinline__ int item(int row, int column) const {
|
||||||
{
|
|
||||||
int z_w = column * 3 / 32;
|
int z_w = column * 3 / 32;
|
||||||
int z_mod = column & 0x1f;
|
int z_mod = column & 0x1f;
|
||||||
|
|
||||||
if (z_mod == 10) {
|
if (z_mod == 10) {
|
||||||
return (data[row * width * 3 / 32 + z_w] >> 30) | ((data[row * width * 3 / 32 + (z_w + 1)] << 2) & 0x4);
|
return (data[row * width * 3 / 32 + z_w] >> 30) |
|
||||||
|
((data[row * width * 3 / 32 + (z_w + 1)] << 2) & 0x4);
|
||||||
} else if (z_mod == 21) {
|
} else if (z_mod == 21) {
|
||||||
return (data[row * width * 3 / 32 + z_w] >> 31) | ((data[row * width * 3 / 32 + (z_w + 1)] << 1) & 0x6);
|
return (data[row * width * 3 / 32 + z_w] >> 31) |
|
||||||
|
((data[row * width * 3 / 32 + (z_w + 1)] << 1) & 0x6);
|
||||||
} else if (z_mod < 10) {
|
} else if (z_mod < 10) {
|
||||||
return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3)) & 0x07;
|
return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3)) & 0x07;
|
||||||
} else if (z_mod < 21) {
|
} else if (z_mod < 21) {
|
||||||
@ -211,18 +231,20 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
|
__device__ __forceinline__ void item4(int (&items)[4], int row,
|
||||||
{
|
int column) const {
|
||||||
int shift = (column & 0x1f);
|
int shift = (column & 0x1f);
|
||||||
uint32_t d;
|
uint32_t d;
|
||||||
if (shift <= 4) {
|
if (shift <= 4) {
|
||||||
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3);
|
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3);
|
||||||
} else if (shift == 8) {
|
} else if (shift == 8) {
|
||||||
d = (data[row * width / 32 * 3 + column * 3 / 32] >> 24) | ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0x0f) << 8);
|
d = (data[row * width / 32 * 3 + column * 3 / 32] >> 24) |
|
||||||
|
((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0x0f) << 8);
|
||||||
} else if (shift <= 16) {
|
} else if (shift <= 16) {
|
||||||
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 32);
|
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 32);
|
||||||
} else if (shift == 20) {
|
} else if (shift == 20) {
|
||||||
d = (data[row * width / 32 * 3 + column * 3 / 32] >> 28) | ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0xff) << 4);
|
d = (data[row * width / 32 * 3 + column * 3 / 32] >> 28) |
|
||||||
|
((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0xff) << 4);
|
||||||
} else {
|
} else {
|
||||||
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 64);
|
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 64);
|
||||||
}
|
}
|
||||||
@ -233,33 +255,32 @@ public:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class MatrixView_q8_row
|
class MatrixView_q8_row {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
const uint32_t* data;
|
const uint32_t* data;
|
||||||
const int height;
|
const int height;
|
||||||
const int width;
|
const int width;
|
||||||
|
|
||||||
__device__ __forceinline__ MatrixView_q8_row(const uint32_t* data, const int height, const int width)
|
__device__ __forceinline__ MatrixView_q8_row(const uint32_t* data,
|
||||||
: data(data), height(height), width(width)
|
const int height,
|
||||||
{ }
|
const int width)
|
||||||
|
: data(data), height(height), width(width) {}
|
||||||
|
|
||||||
__device__ __forceinline__ int item(int row, int column) const
|
__device__ __forceinline__ int item(int row, int column) const {
|
||||||
{
|
|
||||||
int shift = (column & 0x03) * 8;
|
int shift = (column & 0x03) * 8;
|
||||||
return (data[row * width / 4 + column / 4] >> shift) & 0xff;
|
return (data[row * width / 4 + column / 4] >> shift) & 0xff;
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const
|
__device__ __forceinline__ void item2(int (&items)[2], int row,
|
||||||
{
|
int column) const {
|
||||||
int shift = (column & 0x03) * 8;
|
int shift = (column & 0x03) * 8;
|
||||||
uint32_t d = data[row * width / 4 + column / 4] >> shift;
|
uint32_t d = data[row * width / 4 + column / 4] >> shift;
|
||||||
items[0] = d & 0xff;
|
items[0] = d & 0xff;
|
||||||
items[1] = (d >> 8) & 0xff;
|
items[1] = (d >> 8) & 0xff;
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
|
__device__ __forceinline__ void item4(int (&items)[4], int row,
|
||||||
{
|
int column) const {
|
||||||
int shift = (column & 0x03) * 2;
|
int shift = (column & 0x03) * 2;
|
||||||
uint32_t d = data[row * width / 4 + column / 4] >> shift;
|
uint32_t d = data[row * width / 4 + column / 4] >> shift;
|
||||||
items[0] = d & 0xff;
|
items[0] = d & 0xff;
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -14,18 +14,12 @@ namespace gptq {
|
|||||||
//
|
//
|
||||||
// ffddbb99 77553311 eeccaa88 66442200
|
// ffddbb99 77553311 eeccaa88 66442200
|
||||||
|
|
||||||
__forceinline__ __device__ void shuffle_2bit_16
|
__forceinline__ __device__ void shuffle_2bit_16(uint32_t* q, int stride) {
|
||||||
(
|
|
||||||
uint32_t* q,
|
|
||||||
int stride
|
|
||||||
)
|
|
||||||
{
|
|
||||||
uint32_t qa = q[0];
|
uint32_t qa = q[0];
|
||||||
uint32_t qb = 0;
|
uint32_t qb = 0;
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < 8; i++)
|
for (int i = 0; i < 8; i++) {
|
||||||
{
|
|
||||||
uint32_t qa0 = qa & 0x03;
|
uint32_t qa0 = qa & 0x03;
|
||||||
uint32_t qa1 = (qa & 0x0c) >> 2;
|
uint32_t qa1 = (qa & 0x0c) >> 2;
|
||||||
qa >>= 4;
|
qa >>= 4;
|
||||||
@ -35,14 +29,9 @@ __forceinline__ __device__ void shuffle_2bit_16
|
|||||||
q[0] = qb;
|
q[0] = qb;
|
||||||
}
|
}
|
||||||
|
|
||||||
__forceinline__ __device__ void dequant_2bit_16
|
__forceinline__ __device__ void dequant_2bit_16(const uint32_t q_0,
|
||||||
(
|
half2 (&dq)[8], int stride,
|
||||||
const uint32_t q_0,
|
const uint32_t zero) {
|
||||||
half2 (&dq)[8],
|
|
||||||
int stride,
|
|
||||||
const uint32_t zero
|
|
||||||
)
|
|
||||||
{
|
|
||||||
const uint32_t c0 = 0x64006400;
|
const uint32_t c0 = 0x64006400;
|
||||||
const half y4_ = __float2half_rn(1.0f / 4.0f);
|
const half y4_ = __float2half_rn(1.0f / 4.0f);
|
||||||
const half y16_ = __float2half_rn(1.0f / 16.0f);
|
const half y16_ = __float2half_rn(1.0f / 16.0f);
|
||||||
|
@ -11,12 +11,7 @@ namespace gptq {
|
|||||||
// vjjjhhhf ffdddbbb uiiiggge eecccaaa
|
// vjjjhhhf ffdddbbb uiiiggge eecccaaa
|
||||||
// vtttrrrp ppnnnlll usssqqqo oommmkkk
|
// vtttrrrp ppnnnlll usssqqqo oommmkkk
|
||||||
|
|
||||||
__forceinline__ __device__ void shuffle_3bit_32
|
__forceinline__ __device__ void shuffle_3bit_32(uint32_t* q, int stride) {
|
||||||
(
|
|
||||||
uint32_t* q,
|
|
||||||
int stride
|
|
||||||
)
|
|
||||||
{
|
|
||||||
uint32_t qa = q[0 * stride];
|
uint32_t qa = q[0 * stride];
|
||||||
uint32_t qb = q[1 * stride];
|
uint32_t qb = q[1 * stride];
|
||||||
uint32_t qc = q[2 * stride];
|
uint32_t qc = q[2 * stride];
|
||||||
@ -40,9 +35,27 @@ __forceinline__ __device__ void shuffle_3bit_32
|
|||||||
uint32_t zb = 0;
|
uint32_t zb = 0;
|
||||||
uint32_t zc = 0;
|
uint32_t zc = 0;
|
||||||
|
|
||||||
for (int i = 0; i < 5; i++) { uint32_t t0 = qa & 0x07; uint32_t t1 = (qa & 0x38) >> 3; qa >>= 6; za |= (t0 << (i * 3)); za |= (t1 << (i * 3 + 16)); }
|
for (int i = 0; i < 5; i++) {
|
||||||
for (int i = 0; i < 5; i++) { uint32_t t0 = qb & 0x07; uint32_t t1 = (qb & 0x38) >> 3; qb >>= 6; zb |= (t0 << (i * 3)); zb |= (t1 << (i * 3 + 16)); }
|
uint32_t t0 = qa & 0x07;
|
||||||
for (int i = 0; i < 5; i++) { uint32_t t0 = qc & 0x07; uint32_t t1 = (qc & 0x38) >> 3; qc >>= 6; zc |= (t0 << (i * 3)); zc |= (t1 << (i * 3 + 16)); }
|
uint32_t t1 = (qa & 0x38) >> 3;
|
||||||
|
qa >>= 6;
|
||||||
|
za |= (t0 << (i * 3));
|
||||||
|
za |= (t1 << (i * 3 + 16));
|
||||||
|
}
|
||||||
|
for (int i = 0; i < 5; i++) {
|
||||||
|
uint32_t t0 = qb & 0x07;
|
||||||
|
uint32_t t1 = (qb & 0x38) >> 3;
|
||||||
|
qb >>= 6;
|
||||||
|
zb |= (t0 << (i * 3));
|
||||||
|
zb |= (t1 << (i * 3 + 16));
|
||||||
|
}
|
||||||
|
for (int i = 0; i < 5; i++) {
|
||||||
|
uint32_t t0 = qc & 0x07;
|
||||||
|
uint32_t t1 = (qc & 0x38) >> 3;
|
||||||
|
qc >>= 6;
|
||||||
|
zc |= (t0 << (i * 3));
|
||||||
|
zc |= (t1 << (i * 3 + 16));
|
||||||
|
}
|
||||||
|
|
||||||
// za: 9997775 55333111 8886664 44222000
|
// za: 9997775 55333111 8886664 44222000
|
||||||
// zb: jjjhhhf ffdddbbb iiiggge eecccaaa
|
// zb: jjjhhhf ffdddbbb iiiggge eecccaaa
|
||||||
@ -65,16 +78,11 @@ __forceinline__ __device__ void shuffle_3bit_32
|
|||||||
q[2 * stride] = zc;
|
q[2 * stride] = zc;
|
||||||
}
|
}
|
||||||
|
|
||||||
__forceinline__ __device__ void dequant_3bit_32
|
__forceinline__ __device__ void dequant_3bit_32(const uint32_t q_0,
|
||||||
(
|
|
||||||
const uint32_t q_0,
|
|
||||||
const uint32_t q_1,
|
const uint32_t q_1,
|
||||||
const uint32_t q_2,
|
const uint32_t q_2,
|
||||||
half2 (&dq)[16],
|
half2 (&dq)[16], int stride,
|
||||||
int stride,
|
const uint32_t zero) {
|
||||||
const uint32_t zero
|
|
||||||
)
|
|
||||||
{
|
|
||||||
const uint32_t c0 = 0x64006400;
|
const uint32_t c0 = 0x64006400;
|
||||||
const half y8_ = __float2half_rn(1.0f / 8.0f);
|
const half y8_ = __float2half_rn(1.0f / 8.0f);
|
||||||
const half y64_ = __float2half_rn(1.0f / 64.0f);
|
const half y64_ = __float2half_rn(1.0f / 64.0f);
|
||||||
|
@ -13,18 +13,12 @@ namespace gptq {
|
|||||||
//
|
//
|
||||||
// 77775555 33331111 66664444 22220000
|
// 77775555 33331111 66664444 22220000
|
||||||
|
|
||||||
__forceinline__ __device__ void shuffle_4bit_8
|
__forceinline__ __device__ void shuffle_4bit_8(uint32_t* q, int stride) {
|
||||||
(
|
|
||||||
uint32_t* q,
|
|
||||||
int stride
|
|
||||||
)
|
|
||||||
{
|
|
||||||
uint32_t qa = q[0];
|
uint32_t qa = q[0];
|
||||||
uint32_t qb = 0;
|
uint32_t qb = 0;
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < 4; i++)
|
for (int i = 0; i < 4; i++) {
|
||||||
{
|
|
||||||
uint32_t qa0 = qa & 0x0f;
|
uint32_t qa0 = qa & 0x0f;
|
||||||
uint32_t qa1 = (qa & 0xf0) >> 4;
|
uint32_t qa1 = (qa & 0xf0) >> 4;
|
||||||
qa >>= 8;
|
qa >>= 8;
|
||||||
@ -34,14 +28,9 @@ __forceinline__ __device__ void shuffle_4bit_8
|
|||||||
q[0] = qb;
|
q[0] = qb;
|
||||||
}
|
}
|
||||||
|
|
||||||
__forceinline__ __device__ void dequant_4bit_8
|
__forceinline__ __device__ void dequant_4bit_8(const uint32_t q_0,
|
||||||
(
|
half2 (&dq)[4], int stride,
|
||||||
const uint32_t q_0,
|
const uint32_t zero) {
|
||||||
half2 (&dq)[4],
|
|
||||||
int stride,
|
|
||||||
const uint32_t zero
|
|
||||||
)
|
|
||||||
{
|
|
||||||
const uint32_t c0 = 0x64006400;
|
const uint32_t c0 = 0x64006400;
|
||||||
const half y16_ = __float2half_rn(1.0f / 16.0f);
|
const half y16_ = __float2half_rn(1.0f / 16.0f);
|
||||||
const half2 y16 = __halves2half2(y16_, y16_);
|
const half2 y16 = __halves2half2(y16_, y16_);
|
||||||
@ -63,14 +52,9 @@ __forceinline__ __device__ void dequant_4bit_8
|
|||||||
dq[3] = __hfma2(q3.as_half2, y16, z16);
|
dq[3] = __hfma2(q3.as_half2, y16, z16);
|
||||||
}
|
}
|
||||||
|
|
||||||
__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale
|
__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale(
|
||||||
(
|
const uint32_t zero, const half scale, half2 (&z1z16)[2],
|
||||||
const uint32_t zero,
|
half2 (&y1y16)[2]) {
|
||||||
const half scale,
|
|
||||||
half2 (&z1z16)[2],
|
|
||||||
half2 (&y1y16)[2]
|
|
||||||
)
|
|
||||||
{
|
|
||||||
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
|
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
|
||||||
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
|
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
|
||||||
|
|
||||||
@ -86,13 +70,9 @@ __forceinline__ __device__ void dequant_4bit_8_prep_zero_scale
|
|||||||
y1y16[1] = __hmul2(scale2, __half2half2(y16));
|
y1y16[1] = __hmul2(scale2, __half2half2(y16));
|
||||||
}
|
}
|
||||||
|
|
||||||
__forceinline__ __device__ void dequant_4bit_8_prep_zero
|
__forceinline__ __device__ void dequant_4bit_8_prep_zero(const uint32_t zero,
|
||||||
(
|
|
||||||
const uint32_t zero,
|
|
||||||
half2 (&z1z16)[2],
|
half2 (&z1z16)[2],
|
||||||
half2(&y1y16)[2]
|
half2 (&y1y16)[2]) {
|
||||||
)
|
|
||||||
{
|
|
||||||
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
|
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
|
||||||
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
|
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
|
||||||
|
|
||||||
@ -106,39 +86,38 @@ __forceinline__ __device__ void dequant_4bit_8_prep_zero
|
|||||||
y1y16[1] = __half2half2(y16);
|
y1y16[1] = __half2half2(y16);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__forceinline__ __device__ void dequant_4bit_8_gptq(const uint32_t q_0,
|
||||||
__forceinline__ __device__ void dequant_4bit_8_gptq
|
|
||||||
(
|
|
||||||
const uint32_t q_0,
|
|
||||||
half2 (&dq)[4],
|
half2 (&dq)[4],
|
||||||
half2 (&z1z16)[2],
|
half2 (&z1z16)[2],
|
||||||
half2 (&y1y16)[2],
|
half2 (&y1y16)[2],
|
||||||
int stride,
|
int stride, bool scaled) {
|
||||||
bool scaled
|
|
||||||
)
|
|
||||||
{
|
|
||||||
const uint32_t c0 = 0x64006400;
|
const uint32_t c0 = 0x64006400;
|
||||||
|
|
||||||
uint32_t qa = q_0;
|
uint32_t qa = q_0;
|
||||||
half2_uint32 q0((qa & 0x000f000f) | c0); // half2( q[0] + 1024, q[1] + 1024 )
|
half2_uint32 q0((qa & 0x000f000f) |
|
||||||
half2_uint32 q1((qa & 0x00f000f0) | c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 )
|
c0); // half2( q[0] + 1024, q[1] + 1024 )
|
||||||
|
half2_uint32 q1((qa & 0x00f000f0) |
|
||||||
|
c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 )
|
||||||
qa >>= 8;
|
qa >>= 8;
|
||||||
half2_uint32 q2((qa & 0x000f000f) | c0); // half2( q[4] + 1024, q[5] + 1024 )
|
half2_uint32 q2((qa & 0x000f000f) |
|
||||||
half2_uint32 q3((qa & 0x00f000f0) | c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 )
|
c0); // half2( q[4] + 1024, q[5] + 1024 )
|
||||||
|
half2_uint32 q3((qa & 0x00f000f0) |
|
||||||
|
c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 )
|
||||||
|
|
||||||
if (scaled)
|
if (scaled) {
|
||||||
{
|
dq[0] = __hfma2(q0.as_half2, y1y16[0],
|
||||||
dq[0] = __hfma2(q0.as_half2, y1y16[0], z1z16[0]); // half2( q[0] * s - z * s, q[1] * s - z * s)
|
z1z16[0]); // half2( q[0] * s - z * s, q[1] * s - z * s)
|
||||||
dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] * s - z * s, q[3] * s - z * s)
|
dq[1] = __hfma2(q1.as_half2, y1y16[1],
|
||||||
|
z1z16[1]); // half2( q[2] * s - z * s, q[3] * s - z * s)
|
||||||
dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]);
|
dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]);
|
||||||
dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]);
|
dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]);
|
||||||
}
|
} else {
|
||||||
else
|
|
||||||
{
|
|
||||||
dq[0] = __hadd2(q0.as_half2, z1z16[0]); // half2( q[0] - z, q[1] - z )
|
dq[0] = __hadd2(q0.as_half2, z1z16[0]); // half2( q[0] - z, q[1] - z )
|
||||||
dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] - z, q[3] - z )
|
dq[1] = __hfma2(q1.as_half2, y1y16[1],
|
||||||
|
z1z16[1]); // half2( q[2] - z, q[3] - z )
|
||||||
dq[2] = __hadd2(q2.as_half2, z1z16[0]); // half2( q[4] - z, q[5] - z )
|
dq[2] = __hadd2(q2.as_half2, z1z16[0]); // half2( q[4] - z, q[5] - z )
|
||||||
dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); // half2( q[6] - z, q[7] - z )
|
dq[3] = __hfma2(q3.as_half2, y1y16[1],
|
||||||
|
z1z16[1]); // half2( q[6] - z, q[7] - z )
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} // namespace gptq
|
} // namespace gptq
|
||||||
|
@ -10,28 +10,18 @@ Copied from https://github.com/turboderp/exllamav2
|
|||||||
namespace vllm {
|
namespace vllm {
|
||||||
namespace gptq {
|
namespace gptq {
|
||||||
|
|
||||||
__forceinline__ __device__ void shuffle_8bit_4
|
__forceinline__ __device__ void shuffle_8bit_4(uint32_t* q, int stride) {}
|
||||||
(
|
|
||||||
uint32_t* q,
|
|
||||||
int stride
|
|
||||||
)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __device__ void dequant_8bit_8
|
__forceinline__ __device__ void dequant_8bit_8(const uint32_t q_0,
|
||||||
(
|
|
||||||
const uint32_t q_0,
|
|
||||||
const uint32_t q_1,
|
const uint32_t q_1,
|
||||||
half2 (&dq)[4],
|
half2 (&dq)[4], int stride,
|
||||||
int stride,
|
const uint32_t zero) {
|
||||||
const uint32_t zero
|
|
||||||
)
|
|
||||||
{
|
|
||||||
half dqh[8];
|
half dqh[8];
|
||||||
for (int i = 0; i < 4; i++) dqh[i] = dq_ns(exb(q_0, i * 8, 0xff), zero);
|
for (int i = 0; i < 4; i++) dqh[i] = dq_ns(exb(q_0, i * 8, 0xff), zero);
|
||||||
for (int i = 0; i < 4; i++) dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), zero);
|
for (int i = 0; i < 4; i++) dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), zero);
|
||||||
|
|
||||||
for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
|
for (int i = 0; i < 4; i++)
|
||||||
|
dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gptq
|
} // namespace gptq
|
||||||
|
@ -8,16 +8,14 @@ Copied from https://github.com/turboderp/exllamav2
|
|||||||
namespace vllm {
|
namespace vllm {
|
||||||
namespace gptq {
|
namespace gptq {
|
||||||
|
|
||||||
union half2_uint32
|
union half2_uint32 {
|
||||||
{
|
|
||||||
uint32_t as_uint32;
|
uint32_t as_uint32;
|
||||||
half2 as_half2;
|
half2 as_half2;
|
||||||
__device__ half2_uint32(uint32_t val) : as_uint32(val) {}
|
__device__ half2_uint32(uint32_t val) : as_uint32(val) {}
|
||||||
__device__ half2_uint32(half2 val) : as_half2(val) {}
|
__device__ half2_uint32(half2 val) : as_half2(val) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
union half_uint16
|
union half_uint16 {
|
||||||
{
|
|
||||||
uint16_t as_uint16;
|
uint16_t as_uint16;
|
||||||
half as_half;
|
half as_half;
|
||||||
__device__ half_uint16(uint16_t val) : as_uint16(val) {}
|
__device__ half_uint16(uint16_t val) : as_uint16(val) {}
|
||||||
@ -26,32 +24,30 @@ union half_uint16
|
|||||||
|
|
||||||
// Max_scale premultiplied by 1/256
|
// Max_scale premultiplied by 1/256
|
||||||
|
|
||||||
__forceinline__ __device__ half dq_scale(const int qs, const half max_scale)
|
__forceinline__ __device__ half dq_scale(const int qs, const half max_scale) {
|
||||||
{
|
|
||||||
int qs_i = qs + 1;
|
int qs_i = qs + 1;
|
||||||
half qs_h = __int2half_rn(qs_i * qs_i);
|
half qs_h = __int2half_rn(qs_i * qs_i);
|
||||||
qs_h = __hmul(qs_h, max_scale);
|
qs_h = __hmul(qs_h, max_scale);
|
||||||
return qs_h;
|
return qs_h;
|
||||||
}
|
}
|
||||||
|
|
||||||
__forceinline__ __device__ half dq(const int q, const int qzero, const half scale)
|
__forceinline__ __device__ half dq(const int q, const int qzero,
|
||||||
{
|
const half scale) {
|
||||||
return __hmul(__int2half_rn(q - qzero), scale);
|
return __hmul(__int2half_rn(q - qzero), scale);
|
||||||
}
|
}
|
||||||
|
|
||||||
__forceinline__ __device__ half dq_ns(const int q, const int qzero)
|
__forceinline__ __device__ half dq_ns(const int q, const int qzero) {
|
||||||
{
|
|
||||||
// return __hsub(__int2half_rn(q), __int2half_rn(qzero));
|
// return __hsub(__int2half_rn(q), __int2half_rn(qzero));
|
||||||
return __int2half_rn(q - qzero);
|
return __int2half_rn(q - qzero);
|
||||||
}
|
}
|
||||||
|
|
||||||
__forceinline__ __device__ int exb(const uint32_t q, const int shift, const int mask)
|
__forceinline__ __device__ int exb(const uint32_t q, const int shift,
|
||||||
{
|
const int mask) {
|
||||||
return (int)((q >> shift) & mask);
|
return (int)((q >> shift) & mask);
|
||||||
}
|
}
|
||||||
|
|
||||||
__forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, const int shift, const int mask)
|
__forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0,
|
||||||
{
|
const int shift, const int mask) {
|
||||||
return (int)(__funnelshift_rc(q0, q1, shift) & mask);
|
return (int)(__funnelshift_rc(q0, q1, shift) & mask);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -22,11 +22,15 @@
|
|||||||
#include "gptq_marlin.cuh"
|
#include "gptq_marlin.cuh"
|
||||||
#include "gptq_marlin_dtypes.cuh"
|
#include "gptq_marlin_dtypes.cuh"
|
||||||
|
|
||||||
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) static_assert(\
|
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
|
||||||
std::is_same<scalar_t, half>::value || std::is_same<scalar_t, nv_bfloat16>::value, \
|
static_assert(std::is_same<scalar_t, half>::value || \
|
||||||
|
std::is_same<scalar_t, nv_bfloat16>::value, \
|
||||||
"only float16 and bfloat16 is supported");
|
"only float16 and bfloat16 is supported");
|
||||||
|
|
||||||
template <typename T> inline std::string str(T x) { return std::to_string(x); }
|
template <typename T>
|
||||||
|
inline std::string str(T x) {
|
||||||
|
return std::to_string(x);
|
||||||
|
}
|
||||||
|
|
||||||
namespace gptq_marlin {
|
namespace gptq_marlin {
|
||||||
|
|
||||||
@ -41,17 +45,18 @@ template <typename scalar_t, // compute dtype, half or nv_float16
|
|||||||
const int num_bits, // number of bits used for weights
|
const int num_bits, // number of bits used for weights
|
||||||
const int threads, // number of threads in a threadblock
|
const int threads, // number of threads in a threadblock
|
||||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||||
// dimension (batchsize) of the threadblock
|
// dimension (batchsize) of the
|
||||||
|
// threadblock
|
||||||
const int thread_n_blocks, // same for n dimension (output)
|
const int thread_n_blocks, // same for n dimension (output)
|
||||||
const int thread_k_blocks, // same for k dimension (reduction)
|
const int thread_k_blocks, // same for k dimension (reduction)
|
||||||
const int stages, // number of stages for the async global->shared
|
const int stages, // number of stages for the async global->shared
|
||||||
// fetch pipeline
|
// fetch pipeline
|
||||||
const bool has_act_order, // whether act_order is enabled
|
const bool has_act_order, // whether act_order is enabled
|
||||||
const int group_blocks = -1 // number of consecutive 16x16 blocks with
|
const int group_blocks = -1 // number of consecutive 16x16 blocks
|
||||||
// a separate quantization scale
|
// with a separate quantization scale
|
||||||
>
|
>
|
||||||
__global__ void
|
__global__ void Marlin(
|
||||||
Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
const int4* __restrict__ A, // fp16 input matrix of shape mxk
|
||||||
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
|
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
|
||||||
int4* __restrict__ C, // fp16 output buffer of shape mxn
|
int4* __restrict__ C, // fp16 output buffer of shape mxn
|
||||||
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
|
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
|
||||||
@ -88,17 +93,19 @@ __device__ inline void mma(const typename ScalarType<scalar_t>::FragA &a_frag,
|
|||||||
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
|
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
|
||||||
float* c = reinterpret_cast<float*>(&frag_c);
|
float* c = reinterpret_cast<float*>(&frag_c);
|
||||||
if constexpr (std::is_same<scalar_t, half>::value) {
|
if constexpr (std::is_same<scalar_t, half>::value) {
|
||||||
asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
|
asm volatile(
|
||||||
|
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
|
||||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]),
|
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||||
"r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||||
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
|
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
|
||||||
asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
|
asm volatile(
|
||||||
|
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
|
||||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]),
|
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||||
"r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||||
} else {
|
} else {
|
||||||
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
|
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
|
||||||
}
|
}
|
||||||
@ -107,7 +114,8 @@ __device__ inline void mma(const typename ScalarType<scalar_t>::FragA &a_frag,
|
|||||||
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
|
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
|
||||||
// memory, directly in tensor core layout.
|
// memory, directly in tensor core layout.
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
__device__ inline void ldsm4(typename ScalarType<scalar_t>::FragA &frag_a, const void *smem_ptr) {
|
__device__ inline void ldsm4(typename ScalarType<scalar_t>::FragA& frag_a,
|
||||||
|
const void* smem_ptr) {
|
||||||
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
|
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
|
||||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||||
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
|
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
|
||||||
@ -118,7 +126,8 @@ __device__ inline void ldsm4(typename ScalarType<scalar_t>::FragA &frag_a, const
|
|||||||
// Lookup-table based 3-input logical operation; explicitly used for
|
// Lookup-table based 3-input logical operation; explicitly used for
|
||||||
// dequantization as the compiler does not seem to automatically recognize it in
|
// dequantization as the compiler does not seem to automatically recognize it in
|
||||||
// all cases.
|
// all cases.
|
||||||
template <int lut> __device__ inline int lop3(int a, int b, int c) {
|
template <int lut>
|
||||||
|
__device__ inline int lop3(int a, int b, int c) {
|
||||||
int res;
|
int res;
|
||||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||||
: "=r"(res)
|
: "=r"(res)
|
||||||
@ -140,8 +149,10 @@ __device__ inline uint32_t prmt(uint32_t a) {
|
|||||||
// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
|
// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
|
||||||
// values. We mostly follow the strategy in the link below, with some small
|
// values. We mostly follow the strategy in the link below, with some small
|
||||||
// changes:
|
// changes:
|
||||||
// - FP16: https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287
|
// - FP16:
|
||||||
// - BF16: https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385
|
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287
|
||||||
|
// - BF16:
|
||||||
|
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
__device__ inline typename ScalarType<scalar_t>::FragB dequant_4bit(int q) {
|
__device__ inline typename ScalarType<scalar_t>::FragB dequant_4bit(int q) {
|
||||||
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
|
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
|
||||||
@ -170,7 +181,8 @@ __device__ inline typename ScalarType<half>::FragB dequant_4bit<half>(int q) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
__device__ inline typename ScalarType<nv_bfloat16>::FragB dequant_4bit<nv_bfloat16>(int q) {
|
__device__ inline typename ScalarType<nv_bfloat16>::FragB
|
||||||
|
dequant_4bit<nv_bfloat16>(int q) {
|
||||||
static constexpr uint32_t MASK = 0x000f000f;
|
static constexpr uint32_t MASK = 0x000f000f;
|
||||||
static constexpr uint32_t EX = 0x43004300;
|
static constexpr uint32_t EX = 0x43004300;
|
||||||
|
|
||||||
@ -193,10 +205,12 @@ __device__ inline typename ScalarType<nv_bfloat16>::FragB dequant_4bit<nv_bfloat
|
|||||||
return frag_b;
|
return frag_b;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or bf16
|
// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or
|
||||||
// Reference:
|
// bf16 Reference:
|
||||||
// - FP16: https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
|
// - FP16:
|
||||||
// - BF16: https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
|
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
|
||||||
|
// - BF16:
|
||||||
|
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
__device__ inline typename ScalarType<scalar_t>::FragB dequant_8bit(int q) {
|
__device__ inline typename ScalarType<scalar_t>::FragB dequant_8bit(int q) {
|
||||||
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
|
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
|
||||||
@ -222,11 +236,13 @@ __device__ inline typename ScalarType<half>::FragB dequant_8bit<half>(int q) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
__device__ inline typename ScalarType<nv_bfloat16>::FragB dequant_8bit<nv_bfloat16>(int q) {
|
__device__ inline typename ScalarType<nv_bfloat16>::FragB
|
||||||
|
dequant_8bit<nv_bfloat16>(int q) {
|
||||||
typename ScalarType<nv_bfloat16>::FragB frag_b;
|
typename ScalarType<nv_bfloat16>::FragB frag_b;
|
||||||
|
|
||||||
float fp32_intermediates[4];
|
float fp32_intermediates[4];
|
||||||
uint32_t * fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
|
uint32_t* fp32_intermediates_casted =
|
||||||
|
reinterpret_cast<uint32_t*>(fp32_intermediates);
|
||||||
|
|
||||||
static constexpr uint32_t fp32_base = 0x4B000000;
|
static constexpr uint32_t fp32_base = 0x4B000000;
|
||||||
fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
|
fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
|
||||||
@ -240,8 +256,10 @@ __device__ inline typename ScalarType<nv_bfloat16>::FragB dequant_8bit<nv_bfloat
|
|||||||
fp32_intermediates[3] -= 8388736.f;
|
fp32_intermediates[3] -= 8388736.f;
|
||||||
|
|
||||||
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(&frag_b);
|
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(&frag_b);
|
||||||
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632);
|
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],
|
||||||
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], fp32_intermediates_casted[3], 0x7632);
|
fp32_intermediates_casted[1], 0x7632);
|
||||||
|
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],
|
||||||
|
fp32_intermediates_casted[3], 0x7632);
|
||||||
|
|
||||||
return frag_b;
|
return frag_b;
|
||||||
}
|
}
|
||||||
@ -250,9 +268,11 @@ __device__ inline typename ScalarType<nv_bfloat16>::FragB dequant_8bit<nv_bfloat
|
|||||||
// only for grouped quantization.
|
// only for grouped quantization.
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
__device__ inline void scale(typename ScalarType<scalar_t>::FragB& frag_b,
|
__device__ inline void scale(typename ScalarType<scalar_t>::FragB& frag_b,
|
||||||
typename ScalarType<scalar_t>::FragS &frag_s, int i) {
|
typename ScalarType<scalar_t>::FragS& frag_s,
|
||||||
|
int i) {
|
||||||
using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
|
using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
|
||||||
scalar_t2 s = ScalarType<scalar_t>::num2num2(reinterpret_cast<scalar_t *>(&frag_s)[i]);
|
scalar_t2 s =
|
||||||
|
ScalarType<scalar_t>::num2num2(reinterpret_cast<scalar_t*>(&frag_s)[i]);
|
||||||
frag_b[0] = __hmul2(frag_b[0], s);
|
frag_b[0] = __hmul2(frag_b[0], s);
|
||||||
frag_b[1] = __hmul2(frag_b[1], s);
|
frag_b[1] = __hmul2(frag_b[1], s);
|
||||||
}
|
}
|
||||||
@ -280,7 +300,8 @@ __device__ inline void scale4(typename ScalarType<scalar_t>::FragB &frag_b,
|
|||||||
|
|
||||||
// Given 2 floats multiply by 2 scales (halves)
|
// Given 2 floats multiply by 2 scales (halves)
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
__device__ inline void scale_float(float *c, typename ScalarType<scalar_t>::FragS &s) {
|
__device__ inline void scale_float(float* c,
|
||||||
|
typename ScalarType<scalar_t>::FragS& s) {
|
||||||
scalar_t* s_ptr = reinterpret_cast<scalar_t*>(&s);
|
scalar_t* s_ptr = reinterpret_cast<scalar_t*>(&s);
|
||||||
c[0] = __fmul_rn(c[0], ScalarType<scalar_t>::num2float(s_ptr[0]));
|
c[0] = __fmul_rn(c[0], ScalarType<scalar_t>::num2float(s_ptr[0]));
|
||||||
c[1] = __fmul_rn(c[1], ScalarType<scalar_t>::num2float(s_ptr[1]));
|
c[1] = __fmul_rn(c[1], ScalarType<scalar_t>::num2float(s_ptr[1]));
|
||||||
@ -325,7 +346,6 @@ __global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr,
|
|||||||
int const* __restrict__ perm_int_ptr,
|
int const* __restrict__ perm_int_ptr,
|
||||||
int4* __restrict__ out_int4_ptr, int size_m,
|
int4* __restrict__ out_int4_ptr, int size_m,
|
||||||
int size_k, int block_rows) {
|
int size_k, int block_rows) {
|
||||||
|
|
||||||
int start_row = block_rows * blockIdx.x;
|
int start_row = block_rows * blockIdx.x;
|
||||||
int finish_row = start_row + block_rows;
|
int finish_row = start_row + block_rows;
|
||||||
if (finish_row > size_m) {
|
if (finish_row > size_m) {
|
||||||
@ -341,8 +361,7 @@ __global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr,
|
|||||||
|
|
||||||
int offset = row * row_stride;
|
int offset = row * row_stride;
|
||||||
|
|
||||||
half const *a_row_half =
|
half const* a_row_half = reinterpret_cast<half const*>(a_int4_ptr + offset);
|
||||||
reinterpret_cast<half const *>(a_int4_ptr + offset);
|
|
||||||
half* out_half = reinterpret_cast<half*>(out_int4_ptr + offset);
|
half* out_half = reinterpret_cast<half*>(out_int4_ptr + offset);
|
||||||
|
|
||||||
int base_k = 0;
|
int base_k = 0;
|
||||||
@ -378,17 +397,18 @@ template <typename scalar_t, // compute dtype, half or nv_float16
|
|||||||
const int num_bits, // number of bits used for weights
|
const int num_bits, // number of bits used for weights
|
||||||
const int threads, // number of threads in a threadblock
|
const int threads, // number of threads in a threadblock
|
||||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||||
// dimension (batchsize) of the threadblock
|
// dimension (batchsize) of the
|
||||||
|
// threadblock
|
||||||
const int thread_n_blocks, // same for n dimension (output)
|
const int thread_n_blocks, // same for n dimension (output)
|
||||||
const int thread_k_blocks, // same for k dimension (reduction)
|
const int thread_k_blocks, // same for k dimension (reduction)
|
||||||
const int stages, // number of stages for the async global->shared
|
const int stages, // number of stages for the async global->shared
|
||||||
// fetch pipeline
|
// fetch pipeline
|
||||||
const bool has_act_order, // whether act_order is enabled
|
const bool has_act_order, // whether act_order is enabled
|
||||||
const int group_blocks = -1 // number of consecutive 16x16 blocks with
|
const int group_blocks = -1 // number of consecutive 16x16 blocks
|
||||||
// a separate quantization scale
|
// with a separate quantization scale
|
||||||
>
|
>
|
||||||
__global__ void
|
__global__ void Marlin(
|
||||||
Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
const int4* __restrict__ A, // fp16 input matrix of shape mxk
|
||||||
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
|
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
|
||||||
int4* __restrict__ C, // fp16 output buffer of shape mxn
|
int4* __restrict__ C, // fp16 output buffer of shape mxn
|
||||||
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
|
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
|
||||||
@ -465,27 +485,22 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|||||||
auto init_slice = [&]() {
|
auto init_slice = [&]() {
|
||||||
slice_iters =
|
slice_iters =
|
||||||
iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);
|
iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);
|
||||||
if (slice_iters < 0 || slice_col_par >= n_tiles * parallel)
|
if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0;
|
||||||
slice_iters = 0;
|
if (slice_iters == 0) return;
|
||||||
if (slice_iters == 0)
|
if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row;
|
||||||
return;
|
|
||||||
if (slice_row + slice_iters > k_tiles)
|
|
||||||
slice_iters = k_tiles - slice_row;
|
|
||||||
slice_count = 1;
|
slice_count = 1;
|
||||||
slice_idx = 0;
|
slice_idx = 0;
|
||||||
int col_first = iters * div_ceil(k_tiles * slice_col_par, iters);
|
int col_first = iters * div_ceil(k_tiles * slice_col_par, iters);
|
||||||
if (col_first <= k_tiles * (slice_col_par + 1)) {
|
if (col_first <= k_tiles * (slice_col_par + 1)) {
|
||||||
int col_off = col_first - k_tiles * slice_col_par;
|
int col_off = col_first - k_tiles * slice_col_par;
|
||||||
slice_count = div_ceil(k_tiles - col_off, iters);
|
slice_count = div_ceil(k_tiles - col_off, iters);
|
||||||
if (col_off > 0)
|
if (col_off > 0) slice_count++;
|
||||||
slice_count++;
|
|
||||||
int delta_first = iters * blockIdx.x - col_first;
|
int delta_first = iters * blockIdx.x - col_first;
|
||||||
if (delta_first < 0 || (col_off == 0 && delta_first == 0))
|
if (delta_first < 0 || (col_off == 0 && delta_first == 0))
|
||||||
slice_idx = slice_count - 1;
|
slice_idx = slice_count - 1;
|
||||||
else {
|
else {
|
||||||
slice_idx = slice_count - 1 - delta_first / iters;
|
slice_idx = slice_count - 1 - delta_first / iters;
|
||||||
if (col_off > 0)
|
if (col_off > 0) slice_idx--;
|
||||||
slice_idx--;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (slice_col == n_tiles) {
|
if (slice_col == n_tiles) {
|
||||||
@ -785,7 +800,8 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|||||||
int4* sh_a_stage = sh_a + a_sh_stage * pipe;
|
int4* sh_a_stage = sh_a + a_sh_stage * pipe;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < thread_m_blocks; i++)
|
for (int i = 0; i < thread_m_blocks; i++)
|
||||||
ldsm4<scalar_t>(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
|
ldsm4<scalar_t>(frag_a[k % 2][i],
|
||||||
|
&sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
|
||||||
int4* sh_b_stage = sh_b + b_sh_stage * pipe;
|
int4* sh_b_stage = sh_b + b_sh_stage * pipe;
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
@ -906,7 +922,6 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < 4; i++) {
|
for (int i = 0; i < 4; i++) {
|
||||||
|
|
||||||
int actual_k = cur_k + k_frag_offsets[i];
|
int actual_k = cur_k + k_frag_offsets[i];
|
||||||
|
|
||||||
int group_id = sh_g_idx_int_ptr[actual_k];
|
int group_id = sh_g_idx_int_ptr[actual_k];
|
||||||
@ -943,8 +958,9 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|||||||
|
|
||||||
// Apply scale to frag_b0
|
// Apply scale to frag_b0
|
||||||
if constexpr (has_act_order) {
|
if constexpr (has_act_order) {
|
||||||
scale4<scalar_t>(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j],
|
scale4<scalar_t>(frag_b0, act_frag_s[k % 2][0][j],
|
||||||
act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 0);
|
act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j],
|
||||||
|
act_frag_s[k % 2][3][j], 0);
|
||||||
} else {
|
} else {
|
||||||
if constexpr (group_blocks != -1) {
|
if constexpr (group_blocks != -1) {
|
||||||
scale<scalar_t>(frag_b0, frag_s[k % 2][j], 0);
|
scale<scalar_t>(frag_b0, frag_s[k % 2][j], 0);
|
||||||
@ -953,8 +969,9 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|||||||
|
|
||||||
// Apply scale to frag_b1
|
// Apply scale to frag_b1
|
||||||
if constexpr (has_act_order) {
|
if constexpr (has_act_order) {
|
||||||
scale4<scalar_t>(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j],
|
scale4<scalar_t>(frag_b1, act_frag_s[k % 2][0][j],
|
||||||
act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 1);
|
act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j],
|
||||||
|
act_frag_s[k % 2][3][j], 1);
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
if constexpr (group_blocks != -1) {
|
if constexpr (group_blocks != -1) {
|
||||||
@ -997,8 +1014,8 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|||||||
int red_sh_wr =
|
int red_sh_wr =
|
||||||
red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
|
red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
|
||||||
if (i < red_off) {
|
if (i < red_off) {
|
||||||
float *c_rd = reinterpret_cast<float *>(
|
float* c_rd =
|
||||||
&sh[red_sh_delta * j + red_sh_rd]);
|
reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]);
|
||||||
float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]);
|
float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int k = 0; k < 4; k++)
|
for (int k = 0; k < 4; k++)
|
||||||
@ -1049,16 +1066,16 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|||||||
int row = (threadIdx.x % 32) / 4;
|
int row = (threadIdx.x % 32) / 4;
|
||||||
|
|
||||||
if (!first) {
|
if (!first) {
|
||||||
// Interestingly, doing direct global accesses here really seems to mess up the
|
// Interestingly, doing direct global accesses here really seems to mess up
|
||||||
// compiler and lead to slowdowns, hence we also use async-copies even though
|
// the compiler and lead to slowdowns, hence we also use async-copies even
|
||||||
// these fetches are not actually asynchronous.
|
// though these fetches are not actually asynchronous.
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < thread_m_blocks * 4; i++) {
|
for (int i = 0; i < thread_m_blocks * 4; i++) {
|
||||||
cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i],
|
cp_async4_pred(
|
||||||
|
&sh[c_sh_wr + c_sh_wr_delta * i],
|
||||||
&C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
|
&C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
|
||||||
c_gl_wr_delta_i * (i % 2)],
|
c_gl_wr_delta_i * (i % 2)],
|
||||||
i < (thread_m_blocks - 1) * 4 ||
|
i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m);
|
||||||
8 * (i / 2) + row < prob_m);
|
|
||||||
}
|
}
|
||||||
cp_async_fence();
|
cp_async_fence();
|
||||||
cp_async_wait<0>();
|
cp_async_wait<0>();
|
||||||
@ -1116,7 +1133,8 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|||||||
// We first reorder in shared memory to guarantee the most efficient final
|
// We first reorder in shared memory to guarantee the most efficient final
|
||||||
// global write patterns
|
// global write patterns
|
||||||
auto write = [&](int idx, float c0, float c1, FragS& s) {
|
auto write = [&](int idx, float c0, float c1, FragS& s) {
|
||||||
scalar_t2 res = Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1));
|
scalar_t2 res =
|
||||||
|
Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1));
|
||||||
|
|
||||||
// For per-column quantization we finally apply the scale here (only for
|
// For per-column quantization we finally apply the scale here (only for
|
||||||
// 4-bit)
|
// 4-bit)
|
||||||
@ -1286,14 +1304,18 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|||||||
for (int i = 0; i < thread_m_blocks; i++) {
|
for (int i = 0; i < thread_m_blocks; i++) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < 4; j++) {
|
for (int j = 0; j < 4; j++) {
|
||||||
scale_float<scalar_t>(reinterpret_cast<float *>(&frag_c[i][j][0][0]),
|
scale_float<scalar_t>(
|
||||||
|
reinterpret_cast<float*>(&frag_c[i][j][0][0]),
|
||||||
frag_s[j / 2][2 * (j % 2) + 0]);
|
frag_s[j / 2][2 * (j % 2) + 0]);
|
||||||
scale_float<scalar_t>(reinterpret_cast<float *>(&frag_c[i][j][0][2]),
|
scale_float<scalar_t>(
|
||||||
|
reinterpret_cast<float*>(&frag_c[i][j][0][2]),
|
||||||
frag_s[j / 2][2 * (j % 2) + 0]);
|
frag_s[j / 2][2 * (j % 2) + 0]);
|
||||||
|
|
||||||
scale_float<scalar_t>(reinterpret_cast<float *>(&frag_c[i][j][1][0]),
|
scale_float<scalar_t>(
|
||||||
|
reinterpret_cast<float*>(&frag_c[i][j][1][0]),
|
||||||
frag_s[j / 2][2 * (j % 2) + 1]);
|
frag_s[j / 2][2 * (j % 2) + 1]);
|
||||||
scale_float<scalar_t>(reinterpret_cast<float *>(&frag_c[i][j][1][2]),
|
scale_float<scalar_t>(
|
||||||
|
reinterpret_cast<float*>(&frag_c[i][j][1][2]),
|
||||||
frag_s[j / 2][2 * (j % 2) + 1]);
|
frag_s[j / 2][2 * (j % 2) + 1]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1320,8 +1342,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|||||||
B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
|
B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
|
||||||
if (slice_col == 0) {
|
if (slice_col == 0) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < b_sh_wr_iters; i++)
|
for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
|
||||||
B_ptr[i] -= b_gl_stride;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update slice k/n for scales loading
|
// Update slice k/n for scales loading
|
||||||
@ -1341,20 +1362,21 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
|
#define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
|
||||||
HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS) \
|
THREAD_K_BLOCKS, HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS) \
|
||||||
else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \
|
else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \
|
||||||
thread_n_blocks == THREAD_N_BLOCKS && \
|
thread_n_blocks == THREAD_N_BLOCKS && \
|
||||||
thread_k_blocks == THREAD_K_BLOCKS && \
|
thread_k_blocks == THREAD_K_BLOCKS && \
|
||||||
has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \
|
has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \
|
||||||
num_threads == NUM_THREADS) { \
|
num_threads == NUM_THREADS) { \
|
||||||
cudaFuncSetAttribute( \
|
cudaFuncSetAttribute( \
|
||||||
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
|
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \
|
||||||
THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, GROUP_BLOCKS>, \
|
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
|
||||||
|
GROUP_BLOCKS>, \
|
||||||
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
||||||
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
|
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \
|
||||||
THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, GROUP_BLOCKS> \
|
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
|
||||||
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
|
GROUP_BLOCKS><<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
|
||||||
A_ptr, B_ptr, C_ptr, s_ptr, g_idx_ptr, num_groups, prob_m, prob_n, \
|
A_ptr, B_ptr, C_ptr, s_ptr, g_idx_ptr, num_groups, prob_m, prob_n, \
|
||||||
prob_k, locks); \
|
prob_k, locks); \
|
||||||
}
|
}
|
||||||
@ -1673,8 +1695,7 @@ void marlin_mm_f16i4(const void *A, const void *B, void *C, void *s,
|
|||||||
// Note that parallel > 1 currently only works for inputs without any
|
// Note that parallel > 1 currently only works for inputs without any
|
||||||
// padding
|
// padding
|
||||||
par = (16 * thread_m_blocks - pad) / (16 * exec_cfg.max_m_blocks);
|
par = (16 * thread_m_blocks - pad) / (16 * exec_cfg.max_m_blocks);
|
||||||
if (par > max_par)
|
if (par > max_par) par = max_par;
|
||||||
par = max_par;
|
|
||||||
prob_m = (16 * exec_cfg.max_m_blocks) * par;
|
prob_m = (16 * exec_cfg.max_m_blocks) * par;
|
||||||
i += exec_cfg.max_m_blocks * (par - 1);
|
i += exec_cfg.max_m_blocks * (par - 1);
|
||||||
thread_m_blocks = exec_cfg.max_m_blocks;
|
thread_m_blocks = exec_cfg.max_m_blocks;
|
||||||
@ -1824,18 +1845,21 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
|
|||||||
int dev = a.get_device();
|
int dev = a.get_device();
|
||||||
if (a.scalar_type() == at::ScalarType::Half) {
|
if (a.scalar_type() == at::ScalarType::Half) {
|
||||||
gptq_marlin::marlin_mm_f16i4<half>(
|
gptq_marlin::marlin_mm_f16i4<half>(
|
||||||
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(), b_scales.data_ptr<at::Half>(),
|
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
|
||||||
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::Half>(), size_m, size_n,
|
b_scales.data_ptr<at::Half>(), g_idx.data_ptr(), perm.data_ptr(),
|
||||||
size_k, workspace.data_ptr(), num_bits, has_act_order, is_k_full,
|
a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,
|
||||||
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
workspace.data_ptr(), num_bits, has_act_order, is_k_full, num_groups,
|
||||||
thread_k, thread_n, sms, gptq_marlin::max_par);
|
group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k,
|
||||||
|
thread_n, sms, gptq_marlin::max_par);
|
||||||
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
|
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
|
||||||
gptq_marlin::marlin_mm_f16i4<nv_bfloat16>(
|
gptq_marlin::marlin_mm_f16i4<nv_bfloat16>(
|
||||||
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(), c.data_ptr<at::BFloat16>(), b_scales.data_ptr<at::BFloat16>(),
|
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
|
||||||
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(), size_m, size_n,
|
c.data_ptr<at::BFloat16>(), b_scales.data_ptr<at::BFloat16>(),
|
||||||
size_k, workspace.data_ptr(), num_bits, has_act_order, is_k_full,
|
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(),
|
||||||
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
size_m, size_n, size_k, workspace.data_ptr(), num_bits, has_act_order,
|
||||||
thread_k, thread_n, sms, gptq_marlin::max_par);
|
is_k_full, num_groups, group_size, dev,
|
||||||
|
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
|
||||||
|
gptq_marlin::max_par);
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16");
|
TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16");
|
||||||
}
|
}
|
||||||
|
@ -11,12 +11,13 @@
|
|||||||
|
|
||||||
namespace gptq_marlin {
|
namespace gptq_marlin {
|
||||||
|
|
||||||
// 8 warps are a good choice since every SM has 4 schedulers and having more than 1 warp per
|
// 8 warps are a good choice since every SM has 4 schedulers and having more
|
||||||
// schedule allows some more latency hiding. At the same time, we want relatively few warps to have
|
// than 1 warp per schedule allows some more latency hiding. At the same time,
|
||||||
// many registers per warp and small tiles.
|
// we want relatively few warps to have many registers per warp and small tiles.
|
||||||
static constexpr int default_threads = 256;
|
static constexpr int default_threads = 256;
|
||||||
|
|
||||||
static constexpr int pipe_stages = 4; // 4 pipeline stages fit into shared memory
|
static constexpr int pipe_stages =
|
||||||
|
4; // 4 pipeline stages fit into shared memory
|
||||||
|
|
||||||
static constexpr int min_thread_n = 64;
|
static constexpr int min_thread_n = 64;
|
||||||
static constexpr int min_thread_k = 64;
|
static constexpr int min_thread_k = 64;
|
||||||
@ -38,10 +39,12 @@ constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
|
|||||||
// No support for async
|
// No support for async
|
||||||
#else
|
#else
|
||||||
|
|
||||||
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) {
|
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
|
||||||
|
bool pred = true) {
|
||||||
const int BYTES = 16;
|
const int BYTES = 16;
|
||||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||||
asm volatile("{\n"
|
asm volatile(
|
||||||
|
"{\n"
|
||||||
" .reg .pred p;\n"
|
" .reg .pred p;\n"
|
||||||
" setp.ne.b32 p, %0, 0;\n"
|
" setp.ne.b32 p, %0, 0;\n"
|
||||||
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
|
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
|
||||||
@ -52,13 +55,16 @@ __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool
|
|||||||
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
|
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
|
||||||
const int BYTES = 16;
|
const int BYTES = 16;
|
||||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||||
asm volatile("{\n"
|
asm volatile(
|
||||||
|
"{\n"
|
||||||
" cp.async.cg.shared.global [%0], [%1], %2;\n"
|
" cp.async.cg.shared.global [%0], [%1], %2;\n"
|
||||||
"}\n" ::"r"(smem),
|
"}\n" ::"r"(smem),
|
||||||
"l"(glob_ptr), "n"(BYTES));
|
"l"(glob_ptr), "n"(BYTES));
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ inline void cp_async_fence() { asm volatile("cp.async.commit_group;\n" ::); }
|
__device__ inline void cp_async_fence() {
|
||||||
|
asm volatile("cp.async.commit_group;\n" ::);
|
||||||
|
}
|
||||||
|
|
||||||
template <int n>
|
template <int n>
|
||||||
__device__ inline void cp_async_wait() {
|
__device__ inline void cp_async_wait() {
|
||||||
|
@ -5,12 +5,10 @@
|
|||||||
#include <cuda_fp16.h>
|
#include <cuda_fp16.h>
|
||||||
#include <cuda_bf16.h>
|
#include <cuda_bf16.h>
|
||||||
|
|
||||||
|
|
||||||
namespace gptq_marlin {
|
namespace gptq_marlin {
|
||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
class ScalarType {
|
class ScalarType {};
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
class ScalarType<half> {
|
class ScalarType<half> {
|
||||||
@ -26,13 +24,21 @@ public:
|
|||||||
using FragC = Vec<float, 4>;
|
using FragC = Vec<float, 4>;
|
||||||
using FragS = Vec<half2, 1>;
|
using FragS = Vec<half2, 1>;
|
||||||
|
|
||||||
static __device__ float inline num2float(const half x) { return __half2float(x); }
|
static __device__ float inline num2float(const half x) {
|
||||||
|
return __half2float(x);
|
||||||
|
}
|
||||||
|
|
||||||
static __device__ half2 inline num2num2(const half x) { return __half2half2(x); }
|
static __device__ half2 inline num2num2(const half x) {
|
||||||
|
return __half2half2(x);
|
||||||
|
}
|
||||||
|
|
||||||
static __device__ half2 inline nums2num2(const half x1, const half x2) { return __halves2half2(x1, x2); }
|
static __device__ half2 inline nums2num2(const half x1, const half x2) {
|
||||||
|
return __halves2half2(x1, x2);
|
||||||
|
}
|
||||||
|
|
||||||
static __host__ __device__ half inline float2num(const float x) { return __float2half(x); }
|
static __host__ __device__ half inline float2num(const float x) {
|
||||||
|
return __float2half(x);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
@ -47,16 +53,25 @@ public:
|
|||||||
using FragS = Vec<nv_bfloat162, 1>;
|
using FragS = Vec<nv_bfloat162, 1>;
|
||||||
|
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||||
static __device__ float inline num2float(const nv_bfloat16 x) { return __bfloat162float(x); }
|
static __device__ float inline num2float(const nv_bfloat16 x) {
|
||||||
|
return __bfloat162float(x);
|
||||||
|
}
|
||||||
|
|
||||||
static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) { return __bfloat162bfloat162(x); }
|
static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) {
|
||||||
|
return __bfloat162bfloat162(x);
|
||||||
|
}
|
||||||
|
|
||||||
static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1, const nv_bfloat16 x2) { return __halves2bfloat162(x1, x2); }
|
static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1,
|
||||||
|
const nv_bfloat16 x2) {
|
||||||
|
return __halves2bfloat162(x1, x2);
|
||||||
|
}
|
||||||
|
|
||||||
static __host__ __device__ nv_bfloat16 inline float2num(const float x) { return __float2bfloat16(x); }
|
static __host__ __device__ nv_bfloat16 inline float2num(const float x) {
|
||||||
|
return __float2bfloat16(x);
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
} // namespace gptq_marlin
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
@ -12,10 +12,10 @@ static constexpr int tile_n_size = tile_k_size * 4;
|
|||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||||
|
|
||||||
template <int const num_threads, int const num_bits, bool const has_perm>
|
template <int const num_threads, int const num_bits, bool const has_perm>
|
||||||
__global__ void
|
__global__ void marlin_repack_kernel(
|
||||||
marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
|
uint32_t const* __restrict__ b_q_weight_ptr,
|
||||||
uint32_t const *__restrict__ perm_ptr,
|
uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,
|
||||||
uint32_t *__restrict__ out_ptr, int size_k, int size_n) {}
|
int size_k, int size_n) {}
|
||||||
|
|
||||||
} // namespace gptq_marlin
|
} // namespace gptq_marlin
|
||||||
|
|
||||||
@ -30,10 +30,10 @@ torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
|
|||||||
#else
|
#else
|
||||||
|
|
||||||
template <int const num_threads, int const num_bits, bool const has_perm>
|
template <int const num_threads, int const num_bits, bool const has_perm>
|
||||||
__global__ void
|
__global__ void marlin_repack_kernel(
|
||||||
marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
|
uint32_t const* __restrict__ b_q_weight_ptr,
|
||||||
uint32_t const *__restrict__ perm_ptr,
|
uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,
|
||||||
uint32_t *__restrict__ out_ptr, int size_k, int size_n) {
|
int size_k, int size_n) {
|
||||||
constexpr int pack_factor = 32 / num_bits;
|
constexpr int pack_factor = 32 / num_bits;
|
||||||
|
|
||||||
int k_tiles = size_k / tile_k_size;
|
int k_tiles = size_k / tile_k_size;
|
||||||
@ -176,7 +176,6 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
|
|||||||
}
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
|
|
||||||
uint32_t b1_vals[tile_ints];
|
uint32_t b1_vals[tile_ints];
|
||||||
uint32_t b2_vals[tile_ints];
|
uint32_t b2_vals[tile_ints];
|
||||||
|
|
||||||
@ -320,8 +319,7 @@ torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
|
|||||||
// Get ptrs
|
// Get ptrs
|
||||||
uint32_t const* b_q_weight_ptr =
|
uint32_t const* b_q_weight_ptr =
|
||||||
reinterpret_cast<uint32_t const*>(b_q_weight.data_ptr());
|
reinterpret_cast<uint32_t const*>(b_q_weight.data_ptr());
|
||||||
uint32_t const *perm_ptr =
|
uint32_t const* perm_ptr = reinterpret_cast<uint32_t const*>(perm.data_ptr());
|
||||||
reinterpret_cast<uint32_t const *>(perm.data_ptr());
|
|
||||||
uint32_t* out_ptr = reinterpret_cast<uint32_t*>(out.data_ptr());
|
uint32_t* out_ptr = reinterpret_cast<uint32_t*>(out.data_ptr());
|
||||||
|
|
||||||
// Get dev info
|
// Get dev info
|
||||||
|
@ -25,7 +25,10 @@
|
|||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
|
||||||
template <typename T> inline std::string str(T x) { return std::to_string(x); }
|
template <typename T>
|
||||||
|
inline std::string str(T x) {
|
||||||
|
return std::to_string(x);
|
||||||
|
}
|
||||||
|
|
||||||
namespace marlin {
|
namespace marlin {
|
||||||
|
|
||||||
@ -38,7 +41,8 @@ constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; }
|
|||||||
// corresponding index accesses must be compile-time constants, which is why we
|
// corresponding index accesses must be compile-time constants, which is why we
|
||||||
// extensively use `#pragma unroll` throughout the kernel code to guarantee
|
// extensively use `#pragma unroll` throughout the kernel code to guarantee
|
||||||
// this.
|
// this.
|
||||||
template <typename T, int n> struct Vec {
|
template <typename T, int n>
|
||||||
|
struct Vec {
|
||||||
T elems[n];
|
T elems[n];
|
||||||
__device__ T& operator[](int i) { return elems[i]; }
|
__device__ T& operator[](int i) { return elems[i]; }
|
||||||
};
|
};
|
||||||
@ -59,7 +63,8 @@ __device__ inline void cp_async4_pred(void *smem_ptr, const void *glob_ptr,
|
|||||||
bool pred = true) {
|
bool pred = true) {
|
||||||
const int BYTES = 16;
|
const int BYTES = 16;
|
||||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||||
asm volatile("{\n"
|
asm volatile(
|
||||||
|
"{\n"
|
||||||
" .reg .pred p;\n"
|
" .reg .pred p;\n"
|
||||||
" setp.ne.b32 p, %0, 0;\n"
|
" setp.ne.b32 p, %0, 0;\n"
|
||||||
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
|
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
|
||||||
@ -71,9 +76,11 @@ __device__ inline void cp_async4_pred(void *smem_ptr, const void *glob_ptr,
|
|||||||
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
|
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
|
||||||
const int BYTES = 16;
|
const int BYTES = 16;
|
||||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||||
asm volatile("{\n"
|
asm volatile(
|
||||||
|
"{\n"
|
||||||
" cp.async.cg.shared.global [%0], [%1], %2;\n"
|
" cp.async.cg.shared.global [%0], [%1], %2;\n"
|
||||||
"}\n" :: "r"(smem), "l"(glob_ptr), "n"(BYTES));
|
"}\n" ::"r"(smem),
|
||||||
|
"l"(glob_ptr), "n"(BYTES));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Async copy fence.
|
// Async copy fence.
|
||||||
@ -82,7 +89,8 @@ __device__ inline void cp_async_fence() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Wait until at most `n` async copy stages are still pending.
|
// Wait until at most `n` async copy stages are still pending.
|
||||||
template <int n> __device__ inline void cp_async_wait() {
|
template <int n>
|
||||||
|
__device__ inline void cp_async_wait() {
|
||||||
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
|
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -93,11 +101,12 @@ __device__ inline void mma(const FragA &a_frag, const FragB &frag_b,
|
|||||||
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
|
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
|
||||||
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
|
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
|
||||||
float* c = reinterpret_cast<float*>(&frag_c);
|
float* c = reinterpret_cast<float*>(&frag_c);
|
||||||
asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
|
asm volatile(
|
||||||
|
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
|
||||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]),
|
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||||
"r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
|
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
|
||||||
@ -113,7 +122,8 @@ __device__ inline void ldsm4(FragA &frag_a, const void *smem_ptr) {
|
|||||||
// Lookup-table based 3-input logical operation; explicitly used for
|
// Lookup-table based 3-input logical operation; explicitly used for
|
||||||
// dequantization as the compiler does not seem to automatically recognize it in
|
// dequantization as the compiler does not seem to automatically recognize it in
|
||||||
// all cases.
|
// all cases.
|
||||||
template <int lut> __device__ inline int lop3(int a, int b, int c) {
|
template <int lut>
|
||||||
|
__device__ inline int lop3(int a, int b, int c) {
|
||||||
int res;
|
int res;
|
||||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||||
: "=r"(res)
|
: "=r"(res)
|
||||||
@ -189,20 +199,21 @@ __device__ inline void barrier_release(int *lock, bool reset = false) {
|
|||||||
|
|
||||||
template <const int threads, // number of threads in a threadblock
|
template <const int threads, // number of threads in a threadblock
|
||||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||||
// dimension (batchsize) of the threadblock
|
// dimension (batchsize) of the
|
||||||
|
// threadblock
|
||||||
const int thread_n_blocks, // same for n dimension (output)
|
const int thread_n_blocks, // same for n dimension (output)
|
||||||
const int thread_k_blocks, // same for k dimension (reduction)
|
const int thread_k_blocks, // same for k dimension (reduction)
|
||||||
const int stages, // number of stages for the async global->shared
|
const int stages, // number of stages for the async global->shared
|
||||||
// fetch pipeline
|
// fetch pipeline
|
||||||
const int group_blocks = -1 // number of consecutive 16x16 blocks with
|
const int group_blocks = -1 // number of consecutive 16x16 blocks
|
||||||
// a separate quantization scale
|
// with a separate quantization scale
|
||||||
>
|
>
|
||||||
__global__ void
|
__global__ void Marlin(
|
||||||
Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
const int4* __restrict__ A, // fp16 input matrix of shape mxk
|
||||||
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
|
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
|
||||||
int4* __restrict__ C, // fp16 output buffer of shape mxn
|
int4* __restrict__ C, // fp16 output buffer of shape mxn
|
||||||
const int4
|
const int4* __restrict__ s, // fp16 quantization scales of shape
|
||||||
*__restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn
|
// (k/groupsize)xn
|
||||||
int prob_m, // batch dimension m
|
int prob_m, // batch dimension m
|
||||||
int prob_n, // output dimension n
|
int prob_n, // output dimension n
|
||||||
int prob_k, // reduction dimension k
|
int prob_k, // reduction dimension k
|
||||||
@ -261,27 +272,22 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|||||||
auto init_slice = [&]() {
|
auto init_slice = [&]() {
|
||||||
slice_iters =
|
slice_iters =
|
||||||
iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);
|
iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);
|
||||||
if (slice_iters < 0 || slice_col_par >= n_tiles * parallel)
|
if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0;
|
||||||
slice_iters = 0;
|
if (slice_iters == 0) return;
|
||||||
if (slice_iters == 0)
|
if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row;
|
||||||
return;
|
|
||||||
if (slice_row + slice_iters > k_tiles)
|
|
||||||
slice_iters = k_tiles - slice_row;
|
|
||||||
slice_count = 1;
|
slice_count = 1;
|
||||||
slice_idx = 0;
|
slice_idx = 0;
|
||||||
int col_first = iters * ceildiv(k_tiles * slice_col_par, iters);
|
int col_first = iters * ceildiv(k_tiles * slice_col_par, iters);
|
||||||
if (col_first <= k_tiles * (slice_col_par + 1)) {
|
if (col_first <= k_tiles * (slice_col_par + 1)) {
|
||||||
int col_off = col_first - k_tiles * slice_col_par;
|
int col_off = col_first - k_tiles * slice_col_par;
|
||||||
slice_count = ceildiv(k_tiles - col_off, iters);
|
slice_count = ceildiv(k_tiles - col_off, iters);
|
||||||
if (col_off > 0)
|
if (col_off > 0) slice_count++;
|
||||||
slice_count++;
|
|
||||||
int delta_first = iters * blockIdx.x - col_first;
|
int delta_first = iters * blockIdx.x - col_first;
|
||||||
if (delta_first < 0 || (col_off == 0 && delta_first == 0))
|
if (delta_first < 0 || (col_off == 0 && delta_first == 0))
|
||||||
slice_idx = slice_count - 1;
|
slice_idx = slice_count - 1;
|
||||||
else {
|
else {
|
||||||
slice_idx = slice_count - 1 - delta_first / iters;
|
slice_idx = slice_count - 1 - delta_first / iters;
|
||||||
if (col_off > 0)
|
if (col_off > 0) slice_idx--;
|
||||||
slice_idx--;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (slice_col == n_tiles) {
|
if (slice_col == n_tiles) {
|
||||||
@ -305,7 +311,8 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|||||||
a_gl_stride *
|
a_gl_stride *
|
||||||
(threads / a_gl_rd_delta_o); // between subsequent accesses within a tile
|
(threads / a_gl_rd_delta_o); // between subsequent accesses within a tile
|
||||||
constexpr int a_sh_wr_delta =
|
constexpr int a_sh_wr_delta =
|
||||||
a_sh_stride * (threads / a_gl_rd_delta_o); // between shared memory writes
|
a_sh_stride *
|
||||||
|
(threads / a_gl_rd_delta_o); // between shared memory writes
|
||||||
constexpr int a_sh_rd_delta_o =
|
constexpr int a_sh_rd_delta_o =
|
||||||
2 * ((threads / 32) /
|
2 * ((threads / 32) /
|
||||||
(thread_n_blocks / 4)); // between shared memory tile reads
|
(thread_n_blocks / 4)); // between shared memory tile reads
|
||||||
@ -447,8 +454,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|||||||
// Only fetch scales if this tile starts a new group
|
// Only fetch scales if this tile starts a new group
|
||||||
if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) {
|
if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) {
|
||||||
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
|
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
|
||||||
if (s_sh_wr_pred)
|
if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
|
||||||
cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
|
|
||||||
s_gl_rd += s_gl_rd_delta;
|
s_gl_rd += s_gl_rd_delta;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -500,11 +506,9 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|||||||
FragB frag_b0 = dequant(b_quant);
|
FragB frag_b0 = dequant(b_quant);
|
||||||
// If there are no groups, we can just scale the final output once and can
|
// If there are no groups, we can just scale the final output once and can
|
||||||
// avoid doing so for each weight.
|
// avoid doing so for each weight.
|
||||||
if (group_blocks != -1)
|
if (group_blocks != -1) scale(frag_b0, frag_s[k % 2][j], 0);
|
||||||
scale(frag_b0, frag_s[k % 2][j], 0);
|
|
||||||
FragB frag_b1 = dequant(b_quant_shift);
|
FragB frag_b1 = dequant(b_quant_shift);
|
||||||
if (group_blocks != -1)
|
if (group_blocks != -1) scale(frag_b1, frag_s[k % 2][j], 1);
|
||||||
scale(frag_b1, frag_s[k % 2][j], 1);
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < thread_m_blocks; i++) {
|
for (int i = 0; i < thread_m_blocks; i++) {
|
||||||
mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]);
|
mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]);
|
||||||
@ -540,8 +544,8 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|||||||
int red_sh_wr =
|
int red_sh_wr =
|
||||||
red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
|
red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
|
||||||
if (i < red_off) {
|
if (i < red_off) {
|
||||||
float *c_rd = reinterpret_cast<float *>(
|
float* c_rd =
|
||||||
&sh[red_sh_delta * j + red_sh_rd]);
|
reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]);
|
||||||
float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]);
|
float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int k = 0; k < 4; k++)
|
for (int k = 0; k < 4; k++)
|
||||||
@ -571,9 +575,9 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Since multiple threadblocks may process parts of the same column slice, we
|
// Since multiple threadblocks may process parts of the same column slice, we
|
||||||
// finally have to globally reduce over the results. As the striped partitioning
|
// finally have to globally reduce over the results. As the striped
|
||||||
// minimizes the number of such reductions and our outputs are usually rather
|
// partitioning minimizes the number of such reductions and our outputs are
|
||||||
// small, we perform this reduction serially in L2 cache.
|
// usually rather small, we perform this reduction serially in L2 cache.
|
||||||
auto global_reduce = [&](bool first = false, bool last = false) {
|
auto global_reduce = [&](bool first = false, bool last = false) {
|
||||||
// We are very careful here to reduce directly in the output buffer to
|
// We are very careful here to reduce directly in the output buffer to
|
||||||
// maximize L2 cache utilization in this step. To do this, we write out
|
// maximize L2 cache utilization in this step. To do this, we write out
|
||||||
@ -592,16 +596,16 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|||||||
int row = (threadIdx.x % 32) / 4;
|
int row = (threadIdx.x % 32) / 4;
|
||||||
|
|
||||||
if (!first) {
|
if (!first) {
|
||||||
// Interestingly, doing direct global accesses here really seems to mess up the
|
// Interestingly, doing direct global accesses here really seems to mess up
|
||||||
// compiler and lead to slowdowns, hence we also use async-copies even though
|
// the compiler and lead to slowdowns, hence we also use async-copies even
|
||||||
// these fetches are not actually asynchronous.
|
// though these fetches are not actually asynchronous.
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < thread_m_blocks * 4; i++) {
|
for (int i = 0; i < thread_m_blocks * 4; i++) {
|
||||||
cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i],
|
cp_async4_pred(
|
||||||
|
&sh[c_sh_wr + c_sh_wr_delta * i],
|
||||||
&C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
|
&C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
|
||||||
c_gl_wr_delta_i * (i % 2)],
|
c_gl_wr_delta_i * (i % 2)],
|
||||||
i < (thread_m_blocks - 1) * 4 ||
|
i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m);
|
||||||
8 * (i / 2) + row < prob_m);
|
|
||||||
}
|
}
|
||||||
cp_async_fence();
|
cp_async_fence();
|
||||||
cp_async_wait<0>();
|
cp_async_wait<0>();
|
||||||
@ -700,8 +704,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|||||||
// Start global fetch and register load pipelines.
|
// Start global fetch and register load pipelines.
|
||||||
auto start_pipes = [&]() {
|
auto start_pipes = [&]() {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < stages - 1; i++)
|
for (int i = 0; i < stages - 1; i++) fetch_to_shared(i, i, i < slice_iters);
|
||||||
fetch_to_shared(i, i, i < slice_iters);
|
|
||||||
zero_accums();
|
zero_accums();
|
||||||
wait_for_stage();
|
wait_for_stage();
|
||||||
fetch_to_registers(0, 0);
|
fetch_to_registers(0, 0);
|
||||||
@ -711,9 +714,9 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|||||||
|
|
||||||
// Main loop.
|
// Main loop.
|
||||||
while (slice_iters) {
|
while (slice_iters) {
|
||||||
// We unroll over both the global fetch and the register load pipeline to ensure
|
// We unroll over both the global fetch and the register load pipeline to
|
||||||
// all shared memory accesses are static. Note that both pipelines have even
|
// ensure all shared memory accesses are static. Note that both pipelines have
|
||||||
// length meaning that the next iteration will always start at index 0.
|
// even length meaning that the next iteration will always start at index 0.
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int pipe = 0; pipe < stages;) {
|
for (int pipe = 0; pipe < stages;) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
@ -728,8 +731,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|||||||
matmul(k);
|
matmul(k);
|
||||||
}
|
}
|
||||||
slice_iters--;
|
slice_iters--;
|
||||||
if (slice_iters == 0)
|
if (slice_iters == 0) break;
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
a_gl_rd += a_gl_rd_delta_o * stages;
|
a_gl_rd += a_gl_rd_delta_o * stages;
|
||||||
|
|
||||||
@ -742,8 +744,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|||||||
// For per-column scales, we only fetch them here in the final step before
|
// For per-column scales, we only fetch them here in the final step before
|
||||||
// write-out
|
// write-out
|
||||||
if (group_blocks == -1 && last) {
|
if (group_blocks == -1 && last) {
|
||||||
if (s_sh_wr_pred)
|
if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]);
|
||||||
cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]);
|
|
||||||
cp_async_fence();
|
cp_async_fence();
|
||||||
}
|
}
|
||||||
thread_block_reduce();
|
thread_block_reduce();
|
||||||
@ -775,8 +776,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|||||||
B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
|
B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
|
||||||
if (slice_col == 0) {
|
if (slice_col == 0) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < b_sh_wr_iters; i++)
|
for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
|
||||||
B_ptr[i] -= b_gl_stride;
|
|
||||||
}
|
}
|
||||||
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
|
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
|
||||||
start_pipes();
|
start_pipes();
|
||||||
@ -789,20 +789,21 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|||||||
|
|
||||||
template <const int threads, // number of threads in a threadblock
|
template <const int threads, // number of threads in a threadblock
|
||||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||||
// dimension (batchsize) of the threadblock
|
// dimension (batchsize) of the
|
||||||
|
// threadblock
|
||||||
const int thread_n_blocks, // same for n dimension (output)
|
const int thread_n_blocks, // same for n dimension (output)
|
||||||
const int thread_k_blocks, // same for k dimension (reduction)
|
const int thread_k_blocks, // same for k dimension (reduction)
|
||||||
const int stages, // number of stages for the async global->shared
|
const int stages, // number of stages for the async global->shared
|
||||||
// fetch pipeline
|
// fetch pipeline
|
||||||
const int group_blocks = -1 // number of consecutive 16x16 blocks with
|
const int group_blocks = -1 // number of consecutive 16x16 blocks
|
||||||
// a separate quantization scale
|
// with a separate quantization scale
|
||||||
>
|
>
|
||||||
__global__ void
|
__global__ void Marlin(
|
||||||
Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
const int4* __restrict__ A, // fp16 input matrix of shape mxk
|
||||||
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
|
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
|
||||||
int4* __restrict__ C, // fp16 output buffer of shape mxn
|
int4* __restrict__ C, // fp16 output buffer of shape mxn
|
||||||
const int4
|
const int4* __restrict__ s, // fp16 quantization scales of shape
|
||||||
*__restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn
|
// (k/groupsize)xn
|
||||||
int prob_m, // batch dimension m
|
int prob_m, // batch dimension m
|
||||||
int prob_n, // output dimension n
|
int prob_n, // output dimension n
|
||||||
int prob_k, // reduction dimension k
|
int prob_k, // reduction dimension k
|
||||||
@ -907,7 +908,6 @@ bool is_valid_config(thread_config_t const &th_config, int prob_m, int prob_n,
|
|||||||
}
|
}
|
||||||
|
|
||||||
thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) {
|
thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) {
|
||||||
|
|
||||||
if (prob_m <= 16) {
|
if (prob_m <= 16) {
|
||||||
for (auto th_config : small_batch_thread_configs) {
|
for (auto th_config : small_batch_thread_configs) {
|
||||||
if (is_valid_config(th_config, prob_m, prob_n, prob_k)) {
|
if (is_valid_config(th_config, prob_m, prob_n, prob_k)) {
|
||||||
@ -1011,8 +1011,7 @@ void marlin_cuda(const void *A, const void *B, void *C, void *s, int prob_m,
|
|||||||
// Note that parallel > 1 currently only works for inputs without any
|
// Note that parallel > 1 currently only works for inputs without any
|
||||||
// padding
|
// padding
|
||||||
par = (16 * thread_m_blocks - pad) / 64;
|
par = (16 * thread_m_blocks - pad) / 64;
|
||||||
if (par > max_par)
|
if (par > max_par) par = max_par;
|
||||||
par = max_par;
|
|
||||||
prob_m = 64 * par;
|
prob_m = 64 * par;
|
||||||
i += 4 * (par - 1);
|
i += 4 * (par - 1);
|
||||||
thread_m_blocks = 4;
|
thread_m_blocks = 4;
|
||||||
@ -1046,7 +1045,6 @@ void marlin_cuda(const void *A, const void *B, void *C, void *s, int prob_m,
|
|||||||
torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||||
torch::Tensor& b_scales, torch::Tensor& workspace,
|
torch::Tensor& b_scales, torch::Tensor& workspace,
|
||||||
int64_t size_m, int64_t size_n, int64_t size_k) {
|
int64_t size_m, int64_t size_n, int64_t size_k) {
|
||||||
|
|
||||||
// Verify M
|
// Verify M
|
||||||
TORCH_CHECK(size_m == a.size(0),
|
TORCH_CHECK(size_m == a.size(0),
|
||||||
"Shape mismatch: a.size(0) = " + str(a.size(0)) +
|
"Shape mismatch: a.size(0) = " + str(a.size(0)) +
|
||||||
@ -1074,9 +1072,9 @@ torch::Tensor marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
|
|||||||
|
|
||||||
int actual_size_n =
|
int actual_size_n =
|
||||||
(b_q_weight.size(1) / marlin::tile_size) * marlin::pack_factor_4bit;
|
(b_q_weight.size(1) / marlin::tile_size) * marlin::pack_factor_4bit;
|
||||||
TORCH_CHECK(size_n == actual_size_n,
|
TORCH_CHECK(
|
||||||
"size_n = " + str(size_n) +
|
size_n == actual_size_n,
|
||||||
", actual_size_n = " + str(actual_size_n));
|
"size_n = " + str(size_n) + ", actual_size_n = " + str(actual_size_n));
|
||||||
|
|
||||||
// Verify A device and strides
|
// Verify A device and strides
|
||||||
TORCH_CHECK(a.device().is_cuda(), "A is not on GPU");
|
TORCH_CHECK(a.device().is_cuda(), "A is not on GPU");
|
||||||
|
@ -26,12 +26,14 @@ constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; }
|
|||||||
// corresponding index accesses must be compile-time constants, which is why we
|
// corresponding index accesses must be compile-time constants, which is why we
|
||||||
// extensively use `#pragma unroll` throughout the kernel code to guarantee
|
// extensively use `#pragma unroll` throughout the kernel code to guarantee
|
||||||
// this.
|
// this.
|
||||||
template <typename T, int n> struct Vec {
|
template <typename T, int n>
|
||||||
|
struct Vec {
|
||||||
T elems[n];
|
T elems[n];
|
||||||
__device__ T& operator[](int i) { return elems[i]; }
|
__device__ T& operator[](int i) { return elems[i]; }
|
||||||
};
|
};
|
||||||
|
|
||||||
template <int M_, int N_, int K_> struct ShapeBase {
|
template <int M_, int N_, int K_>
|
||||||
|
struct ShapeBase {
|
||||||
static constexpr int M = M_, N = N_, K = K_;
|
static constexpr int M = M_, N = N_, K = K_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -28,7 +28,8 @@ __device__ inline void cp_async4_pred_zfill(void *smem_ptr,
|
|||||||
const int BYTES = 16;
|
const int BYTES = 16;
|
||||||
int src_in_bytes = (zfill ? 0 : BYTES);
|
int src_in_bytes = (zfill ? 0 : BYTES);
|
||||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||||
asm volatile("{\n"
|
asm volatile(
|
||||||
|
"{\n"
|
||||||
" .reg .pred p;\n"
|
" .reg .pred p;\n"
|
||||||
" setp.ne.b32 p, %0, 0;\n"
|
" setp.ne.b32 p, %0, 0;\n"
|
||||||
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
|
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
|
||||||
@ -40,7 +41,8 @@ __device__ inline void cp_async4_pred(void *smem_ptr, const void *glob_ptr,
|
|||||||
bool pred = true) {
|
bool pred = true) {
|
||||||
const int BYTES = 16;
|
const int BYTES = 16;
|
||||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||||
asm volatile("{\n"
|
asm volatile(
|
||||||
|
"{\n"
|
||||||
" .reg .pred p;\n"
|
" .reg .pred p;\n"
|
||||||
" setp.ne.b32 p, %0, 0;\n"
|
" setp.ne.b32 p, %0, 0;\n"
|
||||||
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
|
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
|
||||||
@ -52,7 +54,8 @@ __device__ inline void cp_async4_pred(void *smem_ptr, const void *glob_ptr,
|
|||||||
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
|
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
|
||||||
const int BYTES = 16;
|
const int BYTES = 16;
|
||||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||||
asm volatile("{\n"
|
asm volatile(
|
||||||
|
"{\n"
|
||||||
" cp.async.cg.shared.global [%0], [%1], %2;\n"
|
" cp.async.cg.shared.global [%0], [%1], %2;\n"
|
||||||
"}\n" ::"r"(smem),
|
"}\n" ::"r"(smem),
|
||||||
"l"(glob_ptr), "n"(BYTES));
|
"l"(glob_ptr), "n"(BYTES));
|
||||||
@ -64,7 +67,8 @@ __device__ inline void cp_async_fence() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Wait until at most `n` async copy stages are still pending.
|
// Wait until at most `n` async copy stages are still pending.
|
||||||
template <int n> __device__ inline void cp_async_wait() {
|
template <int n>
|
||||||
|
__device__ inline void cp_async_wait() {
|
||||||
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
|
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -31,42 +31,47 @@ __device__ inline void mma_sp(const FragB &a_frag0, const FragB &a_frag1,
|
|||||||
const uint32_t* e = reinterpret_cast<const uint32_t*>(&frag_m);
|
const uint32_t* e = reinterpret_cast<const uint32_t*>(&frag_m);
|
||||||
float* c = reinterpret_cast<float*>(&frag_c);
|
float* c = reinterpret_cast<float*>(&frag_c);
|
||||||
if (psel == 0) {
|
if (psel == 0) {
|
||||||
asm volatile("mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
|
asm volatile(
|
||||||
|
"mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
|
||||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
|
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
|
||||||
"{%12,%13,%14,%15}, %16, 0x0;\n"
|
"{%12,%13,%14,%15}, %16, 0x0;\n"
|
||||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||||
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]),
|
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]), "r"(b[2]),
|
||||||
"r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]),
|
"r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]),
|
||||||
"f"(c[2]), "f"(c[3]), "r"(e[0]));
|
"r"(e[0]));
|
||||||
asm volatile("mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
|
asm volatile(
|
||||||
|
"mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
|
||||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
|
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
|
||||||
"{%12,%13,%14,%15}, %16, 0x0;\n"
|
"{%12,%13,%14,%15}, %16, 0x0;\n"
|
||||||
: "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7])
|
: "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7])
|
||||||
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]),
|
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]), "r"(b[3]),
|
||||||
"r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]),
|
"r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]), "f"(c[6]), "f"(c[7]),
|
||||||
"f"(c[6]), "f"(c[7]), "r"(e[0]));
|
"r"(e[0]));
|
||||||
} else {
|
} else {
|
||||||
asm volatile("mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
|
asm volatile(
|
||||||
|
"mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
|
||||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
|
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
|
||||||
"{%12,%13,%14,%15}, %16, 0x1;\n"
|
"{%12,%13,%14,%15}, %16, 0x1;\n"
|
||||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||||
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]),
|
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]), "r"(b[2]),
|
||||||
"r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]),
|
"r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]),
|
||||||
"f"(c[2]), "f"(c[3]), "r"(e[0]));
|
"r"(e[0]));
|
||||||
asm volatile("mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
|
asm volatile(
|
||||||
|
"mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
|
||||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
|
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
|
||||||
"{%12,%13,%14,%15}, %16, 0x1;\n"
|
"{%12,%13,%14,%15}, %16, 0x1;\n"
|
||||||
: "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7])
|
: "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7])
|
||||||
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]),
|
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]), "r"(b[3]),
|
||||||
"r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]),
|
"r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]), "f"(c[6]), "f"(c[7]),
|
||||||
"f"(c[6]), "f"(c[7]), "r"(e[0]));
|
"r"(e[0]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Lookup-table based 3-input logical operation; explicitly used for
|
// Lookup-table based 3-input logical operation; explicitly used for
|
||||||
// dequantization as the compiler does not seem to automatically recognize it in
|
// dequantization as the compiler does not seem to automatically recognize it in
|
||||||
// all cases.
|
// all cases.
|
||||||
template <int lut> __device__ inline int lop3(int a, int b, int c) {
|
template <int lut>
|
||||||
|
__device__ inline int lop3(int a, int b, int c) {
|
||||||
int res;
|
int res;
|
||||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||||
: "=r"(res)
|
: "=r"(res)
|
||||||
|
@ -37,7 +37,10 @@
|
|||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
template <typename T> inline std::string str(T x) { return std::to_string(x); }
|
template <typename T>
|
||||||
|
inline std::string str(T x) {
|
||||||
|
return std::to_string(x);
|
||||||
|
}
|
||||||
|
|
||||||
namespace marlin_24 {
|
namespace marlin_24 {
|
||||||
|
|
||||||
@ -57,22 +60,23 @@ static constexpr int max_par = 16;
|
|||||||
template <const int num_bits, // weight bits
|
template <const int num_bits, // weight bits
|
||||||
const int threads, // number of threads in a threadblock
|
const int threads, // number of threads in a threadblock
|
||||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||||
// dimension (batchsize) of the threadblock
|
// dimension (batchsize) of the
|
||||||
|
// threadblock
|
||||||
const int thread_n_blocks, // same for n dimension (output)
|
const int thread_n_blocks, // same for n dimension (output)
|
||||||
const int thread_k_blocks, // same for k dimension (reduction)
|
const int thread_k_blocks, // same for k dimension (reduction)
|
||||||
const int stages, // number of stages for the async global->shared
|
const int stages, // number of stages for the async global->shared
|
||||||
// fetch pipeline
|
// fetch pipeline
|
||||||
const int group_blocks = -1 // number of consecutive 16x16 blocks with
|
const int group_blocks = -1 // number of consecutive 16x16 blocks
|
||||||
// a separate quantization scale
|
// with a separate quantization scale
|
||||||
>
|
>
|
||||||
__global__ void Marlin_24(
|
__global__ void Marlin_24(
|
||||||
const int4* __restrict__ A, // fp16 input matrix of shape mxk
|
const int4* __restrict__ A, // fp16 input matrix of shape mxk
|
||||||
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
|
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
|
||||||
const int4
|
const int4* __restrict__ meta, // 2bit metadata information about 2:4
|
||||||
*__restrict__ meta, // 2bit metadata information about 2:4 format on B
|
// format on B
|
||||||
int4* __restrict__ C, // fp16 output buffer of shape mxn
|
int4* __restrict__ C, // fp16 output buffer of shape mxn
|
||||||
const int4
|
const int4* __restrict__ s, // fp16 quantization scales of shape
|
||||||
*__restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn
|
// (k/groupsize)xn
|
||||||
int prob_m, // batch dimension m
|
int prob_m, // batch dimension m
|
||||||
int prob_n, // output dimension n
|
int prob_n, // output dimension n
|
||||||
int prob_k, // reduction dimension k
|
int prob_k, // reduction dimension k
|
||||||
@ -95,22 +99,23 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
|
|||||||
template <const int num_bits, // weight bits
|
template <const int num_bits, // weight bits
|
||||||
const int threads, // number of threads in a threadblock
|
const int threads, // number of threads in a threadblock
|
||||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||||
// dimension (batchsize) of the threadblock
|
// dimension (batchsize) of the
|
||||||
|
// threadblock
|
||||||
const int thread_n_blocks, // same for n dimension (output)
|
const int thread_n_blocks, // same for n dimension (output)
|
||||||
const int thread_k_blocks, // same for k dimension (reduction)
|
const int thread_k_blocks, // same for k dimension (reduction)
|
||||||
const int stages, // number of stages for the async global->shared
|
const int stages, // number of stages for the async global->shared
|
||||||
// fetch pipeline
|
// fetch pipeline
|
||||||
const int group_blocks = -1 // number of consecutive 16x16 blocks with
|
const int group_blocks = -1 // number of consecutive 16x16 blocks
|
||||||
// a separate quantization scale
|
// with a separate quantization scale
|
||||||
>
|
>
|
||||||
__global__ void Marlin_24(
|
__global__ void Marlin_24(
|
||||||
const int4* __restrict__ A, // fp16 input matrix of shape mxk
|
const int4* __restrict__ A, // fp16 input matrix of shape mxk
|
||||||
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
|
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
|
||||||
const int4
|
const int4* __restrict__ meta, // 2bit metadata information about 2:4
|
||||||
*__restrict__ meta, // 2bit metadata information about 2:4 format on B
|
// format on B
|
||||||
int4* __restrict__ C, // fp16 output buffer of shape mxn
|
int4* __restrict__ C, // fp16 output buffer of shape mxn
|
||||||
const int4
|
const int4* __restrict__ s, // fp16 quantization scales of shape
|
||||||
*__restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn
|
// (k/groupsize)xn
|
||||||
int prob_m, // batch dimension m
|
int prob_m, // batch dimension m
|
||||||
int prob_n, // output dimension n
|
int prob_n, // output dimension n
|
||||||
int prob_k, // reduction dimension k
|
int prob_k, // reduction dimension k
|
||||||
@ -174,27 +179,22 @@ __global__ void Marlin_24(
|
|||||||
auto init_slice = [&]() {
|
auto init_slice = [&]() {
|
||||||
slice_iters =
|
slice_iters =
|
||||||
iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);
|
iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);
|
||||||
if (slice_iters < 0 || slice_col_par >= n_tiles * parallel)
|
if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0;
|
||||||
slice_iters = 0;
|
if (slice_iters == 0) return;
|
||||||
if (slice_iters == 0)
|
if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row;
|
||||||
return;
|
|
||||||
if (slice_row + slice_iters > k_tiles)
|
|
||||||
slice_iters = k_tiles - slice_row;
|
|
||||||
slice_count = 1;
|
slice_count = 1;
|
||||||
slice_idx = 0;
|
slice_idx = 0;
|
||||||
int col_first = iters * ceildiv(k_tiles * slice_col_par, iters);
|
int col_first = iters * ceildiv(k_tiles * slice_col_par, iters);
|
||||||
if (col_first <= k_tiles * (slice_col_par + 1)) {
|
if (col_first <= k_tiles * (slice_col_par + 1)) {
|
||||||
int col_off = col_first - k_tiles * slice_col_par;
|
int col_off = col_first - k_tiles * slice_col_par;
|
||||||
slice_count = ceildiv(k_tiles - col_off, iters);
|
slice_count = ceildiv(k_tiles - col_off, iters);
|
||||||
if (col_off > 0)
|
if (col_off > 0) slice_count++;
|
||||||
slice_count++;
|
|
||||||
int delta_first = iters * blockIdx.x - col_first;
|
int delta_first = iters * blockIdx.x - col_first;
|
||||||
if (delta_first < 0 || (col_off == 0 && delta_first == 0))
|
if (delta_first < 0 || (col_off == 0 && delta_first == 0))
|
||||||
slice_idx = slice_count - 1;
|
slice_idx = slice_count - 1;
|
||||||
else {
|
else {
|
||||||
slice_idx = slice_count - 1 - delta_first / iters;
|
slice_idx = slice_count - 1 - delta_first / iters;
|
||||||
if (col_off > 0)
|
if (col_off > 0) slice_idx--;
|
||||||
slice_idx--;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (slice_col == n_tiles) {
|
if (slice_col == n_tiles) {
|
||||||
@ -392,8 +392,7 @@ __global__ void Marlin_24(
|
|||||||
for (int i = 0; i < b_sh_wr_iters; i++) {
|
for (int i = 0; i < b_sh_wr_iters; i++) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < b_thread_vecs; j++) {
|
for (int j = 0; j < b_thread_vecs; j++) {
|
||||||
cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j],
|
cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j);
|
||||||
B_ptr[i] + j);
|
|
||||||
}
|
}
|
||||||
B_ptr[i] += b_gl_rd_delta_o;
|
B_ptr[i] += b_gl_rd_delta_o;
|
||||||
}
|
}
|
||||||
@ -401,15 +400,13 @@ __global__ void Marlin_24(
|
|||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < m_sh_iters; i++) {
|
for (int i = 0; i < m_sh_iters; i++) {
|
||||||
if (m_sh_wr_pred)
|
if (m_sh_wr_pred)
|
||||||
cp_async4(&sh_meta_stage[m_sh_wr_delta * i + m_sh_wr],
|
cp_async4(&sh_meta_stage[m_sh_wr_delta * i + m_sh_wr], meta_ptr[i]);
|
||||||
meta_ptr[i]);
|
|
||||||
meta_ptr[i] += m_gl_rd_delta_o;
|
meta_ptr[i] += m_gl_rd_delta_o;
|
||||||
}
|
}
|
||||||
// Only fetch scales if this tile starts a new group
|
// Only fetch scales if this tile starts a new group
|
||||||
if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) {
|
if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) {
|
||||||
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
|
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
|
||||||
if (s_sh_wr_pred)
|
if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
|
||||||
cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
|
|
||||||
s_gl_rd += s_gl_rd_delta;
|
s_gl_rd += s_gl_rd_delta;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -519,8 +516,8 @@ __global__ void Marlin_24(
|
|||||||
(threadIdx.x % b_sh_stride_threads);
|
(threadIdx.x % b_sh_stride_threads);
|
||||||
|
|
||||||
// Parallel logarithmic shared memory reduction. We make sure to avoid any
|
// Parallel logarithmic shared memory reduction. We make sure to avoid any
|
||||||
// unnecessary read or write iterations, e.g., for two warps we write only once
|
// unnecessary read or write iterations, e.g., for two warps we write only
|
||||||
// by warp 1 and read only once by warp 0.
|
// once by warp 1 and read only once by warp 0.
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int m_block = 0; m_block < thread_m_blocks; m_block++) {
|
for (int m_block = 0; m_block < thread_m_blocks; m_block++) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
@ -531,8 +528,8 @@ __global__ void Marlin_24(
|
|||||||
int red_sh_wr =
|
int red_sh_wr =
|
||||||
red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
|
red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
|
||||||
if (i < red_off) {
|
if (i < red_off) {
|
||||||
float *c_rd = reinterpret_cast<float *>(
|
float* c_rd =
|
||||||
&sh[red_sh_delta * j + red_sh_rd]);
|
reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]);
|
||||||
float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]);
|
float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int k = 0; k < 4; k++)
|
for (int k = 0; k < 4; k++)
|
||||||
@ -562,9 +559,9 @@ __global__ void Marlin_24(
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Since multiple threadblocks may process parts of the same column slice, we
|
// Since multiple threadblocks may process parts of the same column slice, we
|
||||||
// finally have to globally reduce over the results. As the striped partitioning
|
// finally have to globally reduce over the results. As the striped
|
||||||
// minimizes the number of such reductions and our outputs are usually rather
|
// partitioning minimizes the number of such reductions and our outputs are
|
||||||
// small, we perform this reduction serially in L2 cache.
|
// usually rather small, we perform this reduction serially in L2 cache.
|
||||||
auto global_reduce = [&](bool first = false, bool last = false) {
|
auto global_reduce = [&](bool first = false, bool last = false) {
|
||||||
// We are very careful here to reduce directly in the output buffer to
|
// We are very careful here to reduce directly in the output buffer to
|
||||||
// maximize L2 cache utilization in this step. To do this, we write out
|
// maximize L2 cache utilization in this step. To do this, we write out
|
||||||
@ -584,9 +581,9 @@ __global__ void Marlin_24(
|
|||||||
int col = 2 * ((threadIdx.x % 32) % 4);
|
int col = 2 * ((threadIdx.x % 32) % 4);
|
||||||
|
|
||||||
if (!first) {
|
if (!first) {
|
||||||
// Interestingly, doing direct global accesses here really seems to mess up the
|
// Interestingly, doing direct global accesses here really seems to mess up
|
||||||
// compiler and lead to slowdowns, hence we also use async-copies even though
|
// the compiler and lead to slowdowns, hence we also use async-copies even
|
||||||
// these fetches are not actually asynchronous.
|
// though these fetches are not actually asynchronous.
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < thread_m_blocks * 4; i++) {
|
for (int i = 0; i < thread_m_blocks * 4; i++) {
|
||||||
cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i],
|
cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i],
|
||||||
@ -722,8 +719,7 @@ __global__ void Marlin_24(
|
|||||||
// Start global fetch and register load pipelines.
|
// Start global fetch and register load pipelines.
|
||||||
auto start_pipes = [&]() {
|
auto start_pipes = [&]() {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < stages - 1; i++)
|
for (int i = 0; i < stages - 1; i++) fetch_to_shared(i, i, i < slice_iters);
|
||||||
fetch_to_shared(i, i, i < slice_iters);
|
|
||||||
zero_accums();
|
zero_accums();
|
||||||
wait_for_stage();
|
wait_for_stage();
|
||||||
fetch_to_registers(0, 0);
|
fetch_to_registers(0, 0);
|
||||||
@ -733,9 +729,9 @@ __global__ void Marlin_24(
|
|||||||
|
|
||||||
// Main loop.
|
// Main loop.
|
||||||
while (slice_iters) {
|
while (slice_iters) {
|
||||||
// We unroll over both the global fetch and the register load pipeline to ensure
|
// We unroll over both the global fetch and the register load pipeline to
|
||||||
// all shared memory accesses are static. Note that both pipelines have even
|
// ensure all shared memory accesses are static. Note that both pipelines have
|
||||||
// length meaning that the next iteration will always start at index 0.
|
// even length meaning that the next iteration will always start at index 0.
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int pipe = 0; pipe < stages;) {
|
for (int pipe = 0; pipe < stages;) {
|
||||||
fetch_to_shared((pipe + stages - 1) % stages, pipe,
|
fetch_to_shared((pipe + stages - 1) % stages, pipe,
|
||||||
@ -747,8 +743,7 @@ __global__ void Marlin_24(
|
|||||||
|
|
||||||
pipe++;
|
pipe++;
|
||||||
slice_iters--;
|
slice_iters--;
|
||||||
if (slice_iters == 0)
|
if (slice_iters == 0) break;
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
a_gl_rd += a_gl_rd_delta_o * stages;
|
a_gl_rd += a_gl_rd_delta_o * stages;
|
||||||
|
|
||||||
@ -762,13 +757,11 @@ __global__ void Marlin_24(
|
|||||||
// write-out
|
// write-out
|
||||||
if constexpr (group_blocks == -1) {
|
if constexpr (group_blocks == -1) {
|
||||||
if constexpr (num_bits == 8) {
|
if constexpr (num_bits == 8) {
|
||||||
if (s_sh_wr_pred)
|
if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]);
|
||||||
cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]);
|
|
||||||
cp_async_fence();
|
cp_async_fence();
|
||||||
} else {
|
} else {
|
||||||
if (last) {
|
if (last) {
|
||||||
if (s_sh_wr_pred)
|
if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]);
|
||||||
cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]);
|
|
||||||
cp_async_fence();
|
cp_async_fence();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -851,11 +844,9 @@ __global__ void Marlin_24(
|
|||||||
meta_ptr[i] += (m_sh_stride)-m_gl_rd_delta_o * k_tiles;
|
meta_ptr[i] += (m_sh_stride)-m_gl_rd_delta_o * k_tiles;
|
||||||
if (slice_col == 0) {
|
if (slice_col == 0) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < b_sh_wr_iters; i++)
|
for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
|
||||||
B_ptr[i] -= b_gl_stride;
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < m_sh_iters; i++)
|
for (int i = 0; i < m_sh_iters; i++) meta_ptr[i] -= m_gl_stride;
|
||||||
meta_ptr[i] -= m_gl_stride;
|
|
||||||
}
|
}
|
||||||
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
|
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
|
||||||
start_pipes();
|
start_pipes();
|
||||||
@ -904,8 +895,8 @@ void marlin_cuda_2_4(const void *A, const void *B, const void *meta, void *C,
|
|||||||
|
|
||||||
if (thread_k == -1 || thread_m == -1) {
|
if (thread_k == -1 || thread_m == -1) {
|
||||||
if (prob_n <= 16) {
|
if (prob_n <= 16) {
|
||||||
// For small batchizes, better partitioningif is slightly more important than
|
// For small batchizes, better partitioningif is slightly more important
|
||||||
// better compute utilization
|
// than better compute utilization
|
||||||
thread_k = 128;
|
thread_k = 128;
|
||||||
thread_m = 128;
|
thread_m = 128;
|
||||||
} else {
|
} else {
|
||||||
@ -946,8 +937,7 @@ void marlin_cuda_2_4(const void *A, const void *B, const void *meta, void *C,
|
|||||||
// Note that parallel > 1 currently only works for inputs without any
|
// Note that parallel > 1 currently only works for inputs without any
|
||||||
// padding
|
// padding
|
||||||
par = (16 * thread_n_blocks - pad) / 64;
|
par = (16 * thread_n_blocks - pad) / 64;
|
||||||
if (par > max_par)
|
if (par > max_par) par = max_par;
|
||||||
par = max_par;
|
|
||||||
prob_n = 64 * par;
|
prob_n = 64 * par;
|
||||||
i += 4 * (par - 1);
|
i += 4 * (par - 1);
|
||||||
thread_n_blocks = 4;
|
thread_n_blocks = 4;
|
||||||
@ -1037,9 +1027,9 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
|
|||||||
" is not divisible by tile_size = " + str(marlin_24::tile_size));
|
" is not divisible by tile_size = " + str(marlin_24::tile_size));
|
||||||
|
|
||||||
int actual_size_n = (b_q_weight.size(1) / marlin_24::tile_size) * pack_factor;
|
int actual_size_n = (b_q_weight.size(1) / marlin_24::tile_size) * pack_factor;
|
||||||
TORCH_CHECK(size_n == actual_size_n,
|
TORCH_CHECK(
|
||||||
"size_n = " + str(size_n) +
|
size_n == actual_size_n,
|
||||||
", actual_size_n = " + str(actual_size_n));
|
"size_n = " + str(size_n) + ", actual_size_n = " + str(actual_size_n));
|
||||||
|
|
||||||
// Verify meta
|
// Verify meta
|
||||||
TORCH_CHECK(b_meta.size(0) == size_k / 8 / 2 / 2,
|
TORCH_CHECK(b_meta.size(0) == size_k / 8 / 2 / 2,
|
||||||
|
@ -32,12 +32,8 @@ __global__ void NUQ4MatMulKernel(
|
|||||||
#else
|
#else
|
||||||
float2* __restrict__ mul,
|
float2* __restrict__ mul,
|
||||||
#endif
|
#endif
|
||||||
const __half* __restrict__ lookup_table,
|
const __half* __restrict__ lookup_table, int height, int width, int batch,
|
||||||
int height,
|
int vec_height) {
|
||||||
int width,
|
|
||||||
int batch,
|
|
||||||
int vec_height
|
|
||||||
) {
|
|
||||||
|
|
||||||
const int blockwidth2 = BLOCKWIDTH / 2;
|
const int blockwidth2 = BLOCKWIDTH / 2;
|
||||||
|
|
||||||
@ -80,7 +76,9 @@ __global__ void NUQ4MatMulKernel(
|
|||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
if (threadIdx.x < blockwidth2)
|
if (threadIdx.x < blockwidth2)
|
||||||
blockvec[threadIdx.x] = vec[b * vec_height / 2 + (row / BLOCKHEIGHT4) * blockwidth2 + threadIdx.x];
|
blockvec[threadIdx.x] =
|
||||||
|
vec[b * vec_height / 2 + (row / BLOCKHEIGHT4) * blockwidth2 +
|
||||||
|
threadIdx.x];
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
while (k < blockwidth2) {
|
while (k < blockwidth2) {
|
||||||
@ -143,7 +141,8 @@ __global__ void NUQ4MatMulKernel(
|
|||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
res = __hadd(__hadd(res2.x, res2.y), res);
|
res = __hadd(__hadd(res2.x, res2.y), res);
|
||||||
#else
|
#else
|
||||||
res = __hadd(__hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)), res);
|
res = __hadd(__hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)),
|
||||||
|
res);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
i += width;
|
i += width;
|
||||||
@ -183,22 +182,16 @@ __global__ void NUQ4MatMulKernel(
|
|||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
// 4-bit matvec kernel (LUT-based)
|
// 4-bit matvec kernel (LUT-based)
|
||||||
void squeezellm_gemm(
|
void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
||||||
torch::Tensor vec,
|
torch::Tensor lookup_table) {
|
||||||
torch::Tensor mat,
|
|
||||||
torch::Tensor mul,
|
|
||||||
torch::Tensor lookup_table
|
|
||||||
) {
|
|
||||||
int height = mat.size(0);
|
int height = mat.size(0);
|
||||||
int width = mat.size(1);
|
int width = mat.size(1);
|
||||||
|
|
||||||
int batch = vec.size(0);
|
int batch = vec.size(0);
|
||||||
int vec_height = vec.size(1);
|
int vec_height = vec.size(1);
|
||||||
|
|
||||||
dim3 blocks(
|
dim3 blocks((height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
|
||||||
(height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
|
(width + BLOCKWIDTH - 1) / BLOCKWIDTH);
|
||||||
(width + BLOCKWIDTH - 1) / BLOCKWIDTH
|
|
||||||
);
|
|
||||||
dim3 threads(BLOCKWIDTH);
|
dim3 threads(BLOCKWIDTH);
|
||||||
|
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
|
||||||
@ -211,14 +204,12 @@ void squeezellm_gemm(
|
|||||||
#endif
|
#endif
|
||||||
mat.data_ptr<int>(),
|
mat.data_ptr<int>(),
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
(half2*) mul.data<at::Half>(),
|
(half2*)mul.data<at::Half>(), (__half*)lookup_table.data<at::Half>(),
|
||||||
(__half*) lookup_table.data<at::Half>(),
|
|
||||||
#else
|
#else
|
||||||
(float2*)mul.data_ptr<float>(),
|
(float2*)mul.data_ptr<float>(),
|
||||||
(__half*)lookup_table.data_ptr<at::Half>(),
|
(__half*)lookup_table.data_ptr<at::Half>(),
|
||||||
#endif
|
#endif
|
||||||
height, width, batch, vec_height
|
height, width, batch, vec_height);
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#undef BLOCKWIDTH
|
#undef BLOCKWIDTH
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
/*
|
/*
|
||||||
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/reduce_kernel_utils.cuh
|
* Adapted from
|
||||||
|
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/reduce_kernel_utils.cuh
|
||||||
* Copyright (c) 2023, The vLLM team.
|
* Copyright (c) 2023, The vLLM team.
|
||||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||||
*
|
*
|
||||||
@ -43,17 +44,18 @@ __inline__ __device__ T blockReduceSum(T val) {
|
|||||||
static_assert(maxBlockSize <= 1024);
|
static_assert(maxBlockSize <= 1024);
|
||||||
if constexpr (maxBlockSize > WARP_SIZE) {
|
if constexpr (maxBlockSize > WARP_SIZE) {
|
||||||
val = warpReduceSum<T>(val);
|
val = warpReduceSum<T>(val);
|
||||||
// Calculates max number of lanes that need to participate in the last warpReduce
|
// Calculates max number of lanes that need to participate in the last
|
||||||
|
// warpReduce
|
||||||
constexpr int maxActiveLanes = (maxBlockSize + WARP_SIZE - 1) / WARP_SIZE;
|
constexpr int maxActiveLanes = (maxBlockSize + WARP_SIZE - 1) / WARP_SIZE;
|
||||||
static __shared__ T shared[maxActiveLanes];
|
static __shared__ T shared[maxActiveLanes];
|
||||||
int lane = threadIdx.x % WARP_SIZE;
|
int lane = threadIdx.x % WARP_SIZE;
|
||||||
int wid = threadIdx.x / WARP_SIZE;
|
int wid = threadIdx.x / WARP_SIZE;
|
||||||
if (lane == 0)
|
if (lane == 0) shared[wid] = val;
|
||||||
shared[wid] = val;
|
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
val = (threadIdx.x < blockDim.x / float(WARP_SIZE)) ? shared[lane] : (T)(0.0f);
|
val = (threadIdx.x < blockDim.x / float(WARP_SIZE)) ? shared[lane]
|
||||||
|
: (T)(0.0f);
|
||||||
val = warpReduceSum<T, _nextPow2(maxActiveLanes)>(val);
|
val = warpReduceSum<T, _nextPow2(maxActiveLanes)>(val);
|
||||||
} else {
|
} else {
|
||||||
// A single warpReduce is equal to blockReduce
|
// A single warpReduce is equal to blockReduce
|
||||||
|
57
format.sh
57
format.sh
@ -26,6 +26,7 @@ RUFF_VERSION=$(ruff --version | awk '{print $2}')
|
|||||||
MYPY_VERSION=$(mypy --version | awk '{print $2}')
|
MYPY_VERSION=$(mypy --version | awk '{print $2}')
|
||||||
CODESPELL_VERSION=$(codespell --version)
|
CODESPELL_VERSION=$(codespell --version)
|
||||||
ISORT_VERSION=$(isort --vn)
|
ISORT_VERSION=$(isort --vn)
|
||||||
|
CLANGFORMAT_VERSION=$(clang-format --version | awk '{print $3}')
|
||||||
|
|
||||||
# # params: tool name, tool version, required version
|
# # params: tool name, tool version, required version
|
||||||
tool_version_check() {
|
tool_version_check() {
|
||||||
@ -40,6 +41,7 @@ tool_version_check "ruff" $RUFF_VERSION "$(grep "ruff==" requirements-dev.txt |
|
|||||||
tool_version_check "mypy" "$MYPY_VERSION" "$(grep mypy requirements-dev.txt | cut -d'=' -f3)"
|
tool_version_check "mypy" "$MYPY_VERSION" "$(grep mypy requirements-dev.txt | cut -d'=' -f3)"
|
||||||
tool_version_check "isort" "$ISORT_VERSION" "$(grep isort requirements-dev.txt | cut -d'=' -f3)"
|
tool_version_check "isort" "$ISORT_VERSION" "$(grep isort requirements-dev.txt | cut -d'=' -f3)"
|
||||||
tool_version_check "codespell" "$CODESPELL_VERSION" "$(grep codespell requirements-dev.txt | cut -d'=' -f3)"
|
tool_version_check "codespell" "$CODESPELL_VERSION" "$(grep codespell requirements-dev.txt | cut -d'=' -f3)"
|
||||||
|
tool_version_check "clang-format" "$CLANGFORMAT_VERSION" "$(grep clang-format requirements-dev.txt | cut -d'=' -f3)"
|
||||||
|
|
||||||
YAPF_FLAGS=(
|
YAPF_FLAGS=(
|
||||||
'--recursive'
|
'--recursive'
|
||||||
@ -179,7 +181,6 @@ lint_changed() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Run Ruff
|
# Run Ruff
|
||||||
echo 'vLLM ruff:'
|
|
||||||
### This flag lints individual files. --files *must* be the first command line
|
### This flag lints individual files. --files *must* be the first command line
|
||||||
### arg to use this option.
|
### arg to use this option.
|
||||||
if [[ "$1" == '--files' ]]; then
|
if [[ "$1" == '--files' ]]; then
|
||||||
@ -192,6 +193,7 @@ else
|
|||||||
# Format only the files that changed in last commit.
|
# Format only the files that changed in last commit.
|
||||||
lint_changed
|
lint_changed
|
||||||
fi
|
fi
|
||||||
|
echo 'vLLM ruff: Done'
|
||||||
|
|
||||||
# check spelling of specified files
|
# check spelling of specified files
|
||||||
isort_check() {
|
isort_check() {
|
||||||
@ -233,6 +235,59 @@ else
|
|||||||
fi
|
fi
|
||||||
echo 'vLLM isort: Done'
|
echo 'vLLM isort: Done'
|
||||||
|
|
||||||
|
# Clang-format section
|
||||||
|
# Exclude some files for formatting because they are vendored
|
||||||
|
# NOTE: Keep up to date with .github/workflows/clang-format.yml
|
||||||
|
CLANG_FORMAT_EXCLUDES=(
|
||||||
|
'csrc/moe/topk_softmax_kernels.cu'
|
||||||
|
'csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu'
|
||||||
|
'csrc/punica/bgmv/bgmv_config.h'
|
||||||
|
'csrc/punica/bgmv/bgmv_impl.cuh'
|
||||||
|
'csrc/punica/bgmv/vec_dtypes.cuh'
|
||||||
|
'csrc/punica/punica_ops.cu'
|
||||||
|
'csrc/punica/type_convert.h'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Format specified files with clang-format
|
||||||
|
clang_format() {
|
||||||
|
clang-format -i "$@"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Format files that differ from main branch with clang-format.
|
||||||
|
clang_format_changed() {
|
||||||
|
# The `if` guard ensures that the list of filenames is not empty, which
|
||||||
|
# could cause clang-format to receive 0 positional arguments, making it hang
|
||||||
|
# waiting for STDIN.
|
||||||
|
#
|
||||||
|
# `diff-filter=ACM` and $MERGEBASE is to ensure we only format files that
|
||||||
|
# exist on both branches.
|
||||||
|
MERGEBASE="$(git merge-base origin/main HEAD)"
|
||||||
|
|
||||||
|
# Get the list of changed files, excluding the specified ones
|
||||||
|
changed_files=$(git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.h' '*.cpp' '*.cu' '*.cuh' | grep -vFf <(printf "%s\n" "${CLANG_FORMAT_EXCLUDES[@]}"))
|
||||||
|
if [ -n "$changed_files" ]; then
|
||||||
|
echo "$changed_files" | xargs -P 5 clang-format -i
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
# Format all files with clang-format
|
||||||
|
clang_format_all() {
|
||||||
|
find csrc/ \( -name '*.h' -o -name '*.cpp' -o -name '*.cu' -o -name '*.cuh' \) -print \
|
||||||
|
| grep -vFf <(printf "%s\n" "${CLANG_FORMAT_EXCLUDES[@]}") \
|
||||||
|
| xargs clang-format -i
|
||||||
|
}
|
||||||
|
|
||||||
|
# Run clang-format
|
||||||
|
if [[ "$1" == '--files' ]]; then
|
||||||
|
clang_format "${@:2}"
|
||||||
|
elif [[ "$1" == '--all' ]]; then
|
||||||
|
clang_format_all
|
||||||
|
else
|
||||||
|
clang_format_changed
|
||||||
|
fi
|
||||||
|
echo 'vLLM clang-format: Done'
|
||||||
|
|
||||||
|
|
||||||
if ! git diff --quiet &>/dev/null; then
|
if ! git diff --quiet &>/dev/null; then
|
||||||
echo 'Reformatted files. Please review and stage the changes.'
|
echo 'Reformatted files. Please review and stage the changes.'
|
||||||
echo 'Changes not staged for commit:'
|
echo 'Changes not staged for commit:'
|
||||||
|
@ -5,6 +5,7 @@ tomli==2.0.1
|
|||||||
ruff==0.1.5
|
ruff==0.1.5
|
||||||
codespell==2.2.6
|
codespell==2.2.6
|
||||||
isort==5.13.2
|
isort==5.13.2
|
||||||
|
clang-format==18.1.5
|
||||||
|
|
||||||
# type checking
|
# type checking
|
||||||
mypy==1.9.0
|
mypy==1.9.0
|
||||||
|
Loading…
x
Reference in New Issue
Block a user