[CI/Build] Enforce style for C++ and CUDA code with clang-format
(#4722)
This commit is contained in:
parent
9b9a10d6cb
commit
5f6d10c14c
26
.clang-format
Normal file
26
.clang-format
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
BasedOnStyle: Google
|
||||||
|
UseTab: Never
|
||||||
|
IndentWidth: 2
|
||||||
|
ColumnLimit: 80
|
||||||
|
|
||||||
|
# Force pointers to the type for C++.
|
||||||
|
DerivePointerAlignment: false
|
||||||
|
PointerAlignment: Left
|
||||||
|
|
||||||
|
# Reordering #include statements can (and currently will) introduce errors
|
||||||
|
SortIncludes: false
|
||||||
|
|
||||||
|
# Style choices
|
||||||
|
AlignConsecutiveAssignments: false
|
||||||
|
AlignConsecutiveDeclarations: false
|
||||||
|
IndentPPDirectives: BeforeHash
|
||||||
|
|
||||||
|
IncludeCategories:
|
||||||
|
- Regex: '^<'
|
||||||
|
Priority: 4
|
||||||
|
- Regex: '^"(llvm|llvm-c|clang|clang-c|mlir|mlir-c)/'
|
||||||
|
Priority: 3
|
||||||
|
- Regex: '^"(qoda|\.\.)/'
|
||||||
|
Priority: 2
|
||||||
|
- Regex: '.*'
|
||||||
|
Priority: 1
|
42
.github/workflows/clang-format.yml
vendored
Normal file
42
.github/workflows/clang-format.yml
vendored
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
name: clang-format
|
||||||
|
|
||||||
|
on:
|
||||||
|
# Trigger the workflow on push or pull request,
|
||||||
|
# but only for the main branch
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
clang-format:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python-version: ["3.11"]
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
uses: actions/setup-python@v2
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install clang-format==18.1.5
|
||||||
|
- name: Running clang-format
|
||||||
|
run: |
|
||||||
|
EXCLUDES=(
|
||||||
|
'csrc/moe/topk_softmax_kernels.cu'
|
||||||
|
'csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu'
|
||||||
|
'csrc/punica/bgmv/bgmv_config.h'
|
||||||
|
'csrc/punica/bgmv/bgmv_impl.cuh'
|
||||||
|
'csrc/punica/bgmv/vec_dtypes.cuh'
|
||||||
|
'csrc/punica/punica_ops.cu'
|
||||||
|
'csrc/punica/type_convert.h'
|
||||||
|
)
|
||||||
|
find csrc/ \( -name '*.h' -o -name '*.cpp' -o -name '*.cu' -o -name '*.cuh' \) -print \
|
||||||
|
| grep -vFf <(printf "%s\n" "${EXCLUDES[@]}") \
|
||||||
|
| xargs clang-format --dry-run --Werror
|
@ -10,11 +10,11 @@
|
|||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
// Activation and gating kernel template.
|
// Activation and gating kernel template.
|
||||||
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
|
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
|
||||||
__global__ void act_and_mul_kernel(
|
__global__ void act_and_mul_kernel(
|
||||||
scalar_t* __restrict__ out, // [..., d]
|
scalar_t* __restrict__ out, // [..., d]
|
||||||
const scalar_t* __restrict__ input, // [..., 2, d]
|
const scalar_t* __restrict__ input, // [..., 2, d]
|
||||||
const int d) {
|
const int d) {
|
||||||
const int64_t token_idx = blockIdx.x;
|
const int64_t token_idx = blockIdx.x;
|
||||||
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||||
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
|
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
|
||||||
@ -23,72 +23,66 @@ __global__ void act_and_mul_kernel(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T>
|
template <typename T>
|
||||||
__device__ __forceinline__ T silu_kernel(const T& x) {
|
__device__ __forceinline__ T silu_kernel(const T& x) {
|
||||||
// x * sigmoid(x)
|
// x * sigmoid(x)
|
||||||
return (T) (((float) x) / (1.0f + expf((float) -x)));
|
return (T)(((float)x) / (1.0f + expf((float)-x)));
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T>
|
template <typename T>
|
||||||
__device__ __forceinline__ T gelu_kernel(const T& x) {
|
__device__ __forceinline__ T gelu_kernel(const T& x) {
|
||||||
// Equivalent to PyTorch GELU with 'none' approximation.
|
// Equivalent to PyTorch GELU with 'none' approximation.
|
||||||
// Refer to:
|
// Refer to:
|
||||||
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38
|
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38
|
||||||
const float f = (float) x;
|
const float f = (float)x;
|
||||||
constexpr float ALPHA = M_SQRT1_2;
|
constexpr float ALPHA = M_SQRT1_2;
|
||||||
return (T) (f * 0.5f * (1.0f + ::erf(f * ALPHA)));
|
return (T)(f * 0.5f * (1.0f + ::erf(f * ALPHA)));
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T>
|
template <typename T>
|
||||||
__device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
|
__device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
|
||||||
// Equivalent to PyTorch GELU with 'tanh' approximation.
|
// Equivalent to PyTorch GELU with 'tanh' approximation.
|
||||||
// Refer to:
|
// Refer to:
|
||||||
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L25-L30
|
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L25-L30
|
||||||
const float f = (float) x;
|
const float f = (float)x;
|
||||||
constexpr float BETA = M_SQRT2 * M_2_SQRTPI * 0.5f;
|
constexpr float BETA = M_SQRT2 * M_2_SQRTPI * 0.5f;
|
||||||
constexpr float KAPPA = 0.044715;
|
constexpr float KAPPA = 0.044715;
|
||||||
float x_cube = f * f * f;
|
float x_cube = f * f * f;
|
||||||
float inner = BETA * (f + KAPPA * x_cube);
|
float inner = BETA * (f + KAPPA * x_cube);
|
||||||
return (T) (0.5f * f * (1.0f + ::tanhf(inner)));
|
return (T)(0.5f * f * (1.0f + ::tanhf(inner)));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
// Launch activation and gating kernel.
|
// Launch activation and gating kernel.
|
||||||
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
|
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
|
||||||
int d = input.size(-1) / 2; \
|
int d = input.size(-1) / 2; \
|
||||||
int64_t num_tokens = input.numel() / input.size(-1); \
|
int64_t num_tokens = input.numel() / input.size(-1); \
|
||||||
dim3 grid(num_tokens); \
|
dim3 grid(num_tokens); \
|
||||||
dim3 block(std::min(d, 1024)); \
|
dim3 block(std::min(d, 1024)); \
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
||||||
VLLM_DISPATCH_FLOATING_TYPES( \
|
VLLM_DISPATCH_FLOATING_TYPES( \
|
||||||
input.scalar_type(), \
|
input.scalar_type(), "act_and_mul_kernel", [&] { \
|
||||||
"act_and_mul_kernel", \
|
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>> \
|
||||||
[&] { \
|
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
|
||||||
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \
|
input.data_ptr<scalar_t>(), d); \
|
||||||
out.data_ptr<scalar_t>(), \
|
});
|
||||||
input.data_ptr<scalar_t>(), \
|
|
||||||
d); \
|
|
||||||
});
|
|
||||||
|
|
||||||
void silu_and_mul(
|
void silu_and_mul(torch::Tensor& out, // [..., d]
|
||||||
torch::Tensor& out, // [..., d]
|
torch::Tensor& input) // [..., 2 * d]
|
||||||
torch::Tensor& input) // [..., 2 * d]
|
|
||||||
{
|
{
|
||||||
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
|
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
|
||||||
}
|
}
|
||||||
|
|
||||||
void gelu_and_mul(
|
void gelu_and_mul(torch::Tensor& out, // [..., d]
|
||||||
torch::Tensor& out, // [..., d]
|
torch::Tensor& input) // [..., 2 * d]
|
||||||
torch::Tensor& input) // [..., 2 * d]
|
|
||||||
{
|
{
|
||||||
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel);
|
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel);
|
||||||
}
|
}
|
||||||
|
|
||||||
void gelu_tanh_and_mul(
|
void gelu_tanh_and_mul(torch::Tensor& out, // [..., d]
|
||||||
torch::Tensor& out, // [..., d]
|
torch::Tensor& input) // [..., 2 * d]
|
||||||
torch::Tensor& input) // [..., 2 * d]
|
|
||||||
{
|
{
|
||||||
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel);
|
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel);
|
||||||
}
|
}
|
||||||
@ -96,11 +90,11 @@ void gelu_tanh_and_mul(
|
|||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
// Element-wise activation kernel template.
|
// Element-wise activation kernel template.
|
||||||
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
|
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
|
||||||
__global__ void activation_kernel(
|
__global__ void activation_kernel(
|
||||||
scalar_t* __restrict__ out, // [..., d]
|
scalar_t* __restrict__ out, // [..., d]
|
||||||
const scalar_t* __restrict__ input, // [..., d]
|
const scalar_t* __restrict__ input, // [..., d]
|
||||||
const int d) {
|
const int d) {
|
||||||
const int64_t token_idx = blockIdx.x;
|
const int64_t token_idx = blockIdx.x;
|
||||||
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||||
const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]);
|
const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]);
|
||||||
@ -108,54 +102,49 @@ __global__ void activation_kernel(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
// Launch element-wise activation kernel.
|
// Launch element-wise activation kernel.
|
||||||
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
|
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
|
||||||
int d = input.size(-1); \
|
int d = input.size(-1); \
|
||||||
int64_t num_tokens = input.numel() / d; \
|
int64_t num_tokens = input.numel() / d; \
|
||||||
dim3 grid(num_tokens); \
|
dim3 grid(num_tokens); \
|
||||||
dim3 block(std::min(d, 1024)); \
|
dim3 block(std::min(d, 1024)); \
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
||||||
VLLM_DISPATCH_FLOATING_TYPES( \
|
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "activation_kernel", [&] { \
|
||||||
input.scalar_type(), \
|
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>> \
|
||||||
"activation_kernel", \
|
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
|
||||||
[&] { \
|
input.data_ptr<scalar_t>(), d); \
|
||||||
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \
|
});
|
||||||
out.data_ptr<scalar_t>(), \
|
|
||||||
input.data_ptr<scalar_t>(), \
|
|
||||||
d); \
|
|
||||||
});
|
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
template<typename T>
|
template <typename T>
|
||||||
__device__ __forceinline__ T gelu_new_kernel(const T& x) {
|
__device__ __forceinline__ T gelu_new_kernel(const T& x) {
|
||||||
const float x3 = (float) (x * x * x);
|
const float x3 = (float)(x * x * x);
|
||||||
const T t = (T) tanhf((T) (0.79788456f * (float) (x + (T) (0.044715f * x3))));
|
const T t = (T)tanhf((T)(0.79788456f * (float)(x + (T)(0.044715f * x3))));
|
||||||
return ((T) 0.5) * x * (((T) 1.0) + t);
|
return ((T)0.5) * x * (((T)1.0) + t);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T>
|
template <typename T>
|
||||||
__device__ __forceinline__ T gelu_fast_kernel(const T& x) {
|
__device__ __forceinline__ T gelu_fast_kernel(const T& x) {
|
||||||
const float f = (float) x;
|
const float f = (float)x;
|
||||||
const T t = (T) tanhf(((T) (f * 0.79788456f)) * (((T) 1.0) + (T) (0.044715f * f) * x));
|
const T t =
|
||||||
return ((T) 0.5) * x * (((T) 1.0) + t);
|
(T)tanhf(((T)(f * 0.79788456f)) * (((T)1.0) + (T)(0.044715f * f) * x));
|
||||||
|
return ((T)0.5) * x * (((T)1.0) + t);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
void gelu_new(
|
void gelu_new(torch::Tensor& out, // [..., d]
|
||||||
torch::Tensor& out, // [..., d]
|
torch::Tensor& input) // [..., d]
|
||||||
torch::Tensor& input) // [..., d]
|
|
||||||
{
|
{
|
||||||
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
|
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
|
||||||
}
|
}
|
||||||
|
|
||||||
void gelu_fast(
|
void gelu_fast(torch::Tensor& out, // [..., d]
|
||||||
torch::Tensor& out, // [..., d]
|
torch::Tensor& input) // [..., d]
|
||||||
torch::Tensor& input) // [..., d]
|
|
||||||
{
|
{
|
||||||
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
|
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
|
||||||
}
|
}
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
/*
|
/*
|
||||||
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
* Adapted from
|
||||||
|
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
||||||
* Copyright (c) 2023, The vLLM team.
|
* Copyright (c) 2023, The vLLM team.
|
||||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||||
*
|
*
|
||||||
@ -22,31 +23,31 @@
|
|||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
// A vector type to store Q, K, V elements.
|
// A vector type to store Q, K, V elements.
|
||||||
template<typename T, int VEC_SIZE>
|
template <typename T, int VEC_SIZE>
|
||||||
struct Vec {};
|
struct Vec {};
|
||||||
|
|
||||||
// A vector type to store FP32 accumulators.
|
// A vector type to store FP32 accumulators.
|
||||||
template<typename T>
|
template <typename T>
|
||||||
struct FloatVec {};
|
struct FloatVec {};
|
||||||
|
|
||||||
// Template vector operations.
|
// Template vector operations.
|
||||||
template<typename Acc, typename A, typename B>
|
template <typename Acc, typename A, typename B>
|
||||||
inline __device__ Acc mul(A a, B b);
|
inline __device__ Acc mul(A a, B b);
|
||||||
|
|
||||||
template<typename T>
|
template <typename T>
|
||||||
inline __device__ float sum(T v);
|
inline __device__ float sum(T v);
|
||||||
|
|
||||||
template<typename T>
|
template <typename T>
|
||||||
inline __device__ float dot(T a, T b) {
|
inline __device__ float dot(T a, T b) {
|
||||||
return sum(mul<T, T, T>(a, b));
|
return sum(mul<T, T, T>(a, b));
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename A, typename T>
|
template <typename A, typename T>
|
||||||
inline __device__ float dot(T a, T b) {
|
inline __device__ float dot(T a, T b) {
|
||||||
return sum(mul<A, T, T>(a, b));
|
return sum(mul<A, T, T>(a, b));
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T>
|
template <typename T>
|
||||||
inline __device__ void zero(T& dst) {
|
inline __device__ void zero(T& dst) {
|
||||||
constexpr int WORDS = sizeof(T) / 4;
|
constexpr int WORDS = sizeof(T) / 4;
|
||||||
union {
|
union {
|
||||||
@ -61,4 +62,4 @@ inline __device__ void zero(T& dst) {
|
|||||||
dst = tmp.raw;
|
dst = tmp.raw;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
/*
|
/*
|
||||||
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
* Adapted from
|
||||||
|
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
||||||
* Copyright (c) 2023, The vLLM team.
|
* Copyright (c) 2023, The vLLM team.
|
||||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||||
*
|
*
|
||||||
@ -27,15 +28,15 @@
|
|||||||
#ifdef USE_ROCM
|
#ifdef USE_ROCM
|
||||||
#include <hip/hip_bf16.h>
|
#include <hip/hip_bf16.h>
|
||||||
#include "../quantization/fp8/amd/quant_utils.cuh"
|
#include "../quantization/fp8/amd/quant_utils.cuh"
|
||||||
typedef __hip_bfloat16 __nv_bfloat16;
|
typedef __hip_bfloat16 __nv_bfloat16;
|
||||||
#else
|
#else
|
||||||
#include "../quantization/fp8/nvidia/quant_utils.cuh"
|
#include "../quantization/fp8/nvidia/quant_utils.cuh"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
#define WARP_SIZE 32
|
#define WARP_SIZE 32
|
||||||
#else
|
#else
|
||||||
#define WARP_SIZE warpSize
|
#define WARP_SIZE warpSize
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||||
@ -45,7 +46,7 @@
|
|||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
// Utility function for attention softmax.
|
// Utility function for attention softmax.
|
||||||
template<int NUM_WARPS>
|
template <int NUM_WARPS>
|
||||||
inline __device__ float block_sum(float* red_smem, float sum) {
|
inline __device__ float block_sum(float* red_smem, float sum) {
|
||||||
// Decompose the thread index into warp / lane.
|
// Decompose the thread index into warp / lane.
|
||||||
int warp = threadIdx.x / WARP_SIZE;
|
int warp = threadIdx.x / WARP_SIZE;
|
||||||
@ -82,31 +83,28 @@ inline __device__ float block_sum(float* red_smem, float sum) {
|
|||||||
|
|
||||||
// TODO(woosuk): Merge the last two dimensions of the grid.
|
// TODO(woosuk): Merge the last two dimensions of the grid.
|
||||||
// Grid: (num_heads, num_seqs, max_num_partitions).
|
// Grid: (num_heads, num_seqs, max_num_partitions).
|
||||||
template<
|
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
|
||||||
typename scalar_t,
|
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
|
||||||
typename cache_t,
|
int PARTITION_SIZE = 0> // Zero means no partitioning.
|
||||||
int HEAD_SIZE,
|
|
||||||
int BLOCK_SIZE,
|
|
||||||
int NUM_THREADS,
|
|
||||||
vllm::Fp8KVCacheDataType KV_DTYPE,
|
|
||||||
int PARTITION_SIZE = 0> // Zero means no partitioning.
|
|
||||||
__device__ void paged_attention_kernel(
|
__device__ void paged_attention_kernel(
|
||||||
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
||||||
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
|
float* __restrict__ max_logits, // [num_seqs, num_heads,
|
||||||
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
// max_num_partitions]
|
||||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
|
||||||
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
// head_size]
|
||||||
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||||
const int num_kv_heads, // [num_heads]
|
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
|
||||||
const float scale,
|
// head_size/x, block_size, x]
|
||||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
|
||||||
const int* __restrict__ seq_lens, // [num_seqs]
|
// head_size, block_size]
|
||||||
const int max_num_blocks_per_seq,
|
const int num_kv_heads, // [num_heads]
|
||||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
const float scale,
|
||||||
const int q_stride,
|
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||||
const int kv_block_stride,
|
const int* __restrict__ seq_lens, // [num_seqs]
|
||||||
const int kv_head_stride,
|
const int max_num_blocks_per_seq,
|
||||||
const float kv_scale) {
|
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||||
|
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
||||||
|
const float kv_scale) {
|
||||||
const int seq_idx = blockIdx.y;
|
const int seq_idx = blockIdx.y;
|
||||||
const int partition_idx = blockIdx.z;
|
const int partition_idx = blockIdx.z;
|
||||||
const int max_num_partitions = gridDim.z;
|
const int max_num_partitions = gridDim.z;
|
||||||
@ -118,22 +116,29 @@ __device__ void paged_attention_kernel(
|
|||||||
}
|
}
|
||||||
|
|
||||||
const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
|
const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
|
||||||
const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks;
|
const int num_blocks_per_partition =
|
||||||
|
USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks;
|
||||||
|
|
||||||
// [start_block_idx, end_block_idx) is the range of blocks to process.
|
// [start_block_idx, end_block_idx) is the range of blocks to process.
|
||||||
const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
|
const int start_block_idx =
|
||||||
const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks);
|
USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
|
||||||
|
const int end_block_idx =
|
||||||
|
MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks);
|
||||||
const int num_blocks = end_block_idx - start_block_idx;
|
const int num_blocks = end_block_idx - start_block_idx;
|
||||||
|
|
||||||
// [start_token_idx, end_token_idx) is the range of tokens to process.
|
// [start_token_idx, end_token_idx) is the range of tokens to process.
|
||||||
const int start_token_idx = start_block_idx * BLOCK_SIZE;
|
const int start_token_idx = start_block_idx * BLOCK_SIZE;
|
||||||
const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len);
|
const int end_token_idx =
|
||||||
|
MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len);
|
||||||
const int num_tokens = end_token_idx - start_token_idx;
|
const int num_tokens = end_token_idx - start_token_idx;
|
||||||
|
|
||||||
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
||||||
constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS
|
constexpr int NUM_THREAD_GROUPS =
|
||||||
|
NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE
|
||||||
|
// divides NUM_THREADS
|
||||||
assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
|
assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
|
||||||
constexpr int NUM_TOKENS_PER_THREAD_GROUP = DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE);
|
constexpr int NUM_TOKENS_PER_THREAD_GROUP =
|
||||||
|
DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE);
|
||||||
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
||||||
const int thread_idx = threadIdx.x;
|
const int thread_idx = threadIdx.x;
|
||||||
const int warp_idx = thread_idx / WARP_SIZE;
|
const int warp_idx = thread_idx / WARP_SIZE;
|
||||||
@ -143,13 +148,14 @@ __device__ void paged_attention_kernel(
|
|||||||
const int num_heads = gridDim.x;
|
const int num_heads = gridDim.x;
|
||||||
const int num_queries_per_kv = num_heads / num_kv_heads;
|
const int num_queries_per_kv = num_heads / num_kv_heads;
|
||||||
const int kv_head_idx = head_idx / num_queries_per_kv;
|
const int kv_head_idx = head_idx / num_queries_per_kv;
|
||||||
const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
|
const float alibi_slope =
|
||||||
|
alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
|
||||||
|
|
||||||
// A vector type to store a part of a key or a query.
|
// A vector type to store a part of a key or a query.
|
||||||
// The vector size is configured in such a way that the threads in a thread group
|
// The vector size is configured in such a way that the threads in a thread
|
||||||
// fetch or compute 16 bytes at a time.
|
// group fetch or compute 16 bytes at a time. For example, if the size of a
|
||||||
// For example, if the size of a thread group is 4 and the data type is half,
|
// thread group is 4 and the data type is half, then the vector size is 16 /
|
||||||
// then the vector size is 16 / (4 * sizeof(half)) == 2.
|
// (4 * sizeof(half)) == 2.
|
||||||
constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
|
constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
|
||||||
using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
||||||
using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
||||||
@ -163,18 +169,21 @@ __device__ void paged_attention_kernel(
|
|||||||
|
|
||||||
// Load the query to registers.
|
// Load the query to registers.
|
||||||
// Each thread in a thread group has a different part of the query.
|
// Each thread in a thread group has a different part of the query.
|
||||||
// For example, if the the thread group size is 4, then the first thread in the group
|
// For example, if the the thread group size is 4, then the first thread in
|
||||||
// has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ...
|
// the group has 0, 4, 8, ... th vectors of the query, and the second thread
|
||||||
// th vectors of the query, and so on.
|
// has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because
|
||||||
// NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous.
|
// q is split from a qkv tensor, it may not be contiguous.
|
||||||
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
||||||
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
|
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) {
|
for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD;
|
||||||
|
i += NUM_THREAD_GROUPS) {
|
||||||
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
|
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
|
||||||
q_vecs[thread_group_offset][i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
|
q_vecs[thread_group_offset][i] =
|
||||||
|
*reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
|
||||||
}
|
}
|
||||||
__syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs
|
__syncthreads(); // TODO(naed90): possible speedup if this is replaced with a
|
||||||
|
// memory wall right before we use q_vecs
|
||||||
|
|
||||||
// Memory planning.
|
// Memory planning.
|
||||||
extern __shared__ char shared_mem[];
|
extern __shared__ char shared_mem[];
|
||||||
@ -193,44 +202,50 @@ __device__ void paged_attention_kernel(
|
|||||||
// Each thread group in a warp fetches a key from the block, and computes
|
// Each thread group in a warp fetches a key from the block, and computes
|
||||||
// dot product with the query.
|
// dot product with the query.
|
||||||
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
|
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
|
||||||
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
|
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
|
||||||
// NOTE(woosuk): The block number is stored in int32. However, we cast it to int64
|
block_idx += NUM_WARPS) {
|
||||||
// because int32 can lead to overflow when this variable is multiplied by large numbers
|
// NOTE(woosuk): The block number is stored in int32. However, we cast it to
|
||||||
// (e.g., kv_block_stride).
|
// int64 because int32 can lead to overflow when this variable is multiplied
|
||||||
const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
|
// by large numbers (e.g., kv_block_stride).
|
||||||
|
const int64_t physical_block_number =
|
||||||
|
static_cast<int64_t>(block_table[block_idx]);
|
||||||
|
|
||||||
// Load a key to registers.
|
// Load a key to registers.
|
||||||
// Each thread in a thread group has a different part of the key.
|
// Each thread in a thread group has a different part of the key.
|
||||||
// For example, if the the thread group size is 4, then the first thread in the group
|
// For example, if the the thread group size is 4, then the first thread in
|
||||||
// has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th
|
// the group has 0, 4, 8, ... th vectors of the key, and the second thread
|
||||||
// vectors of the key, and so on.
|
// has 1, 5, 9, ... th vectors of the key, and so on.
|
||||||
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
|
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
|
||||||
const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
|
const int physical_block_offset =
|
||||||
|
(thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
|
||||||
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
|
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
|
||||||
K_vec k_vecs[NUM_VECS_PER_THREAD];
|
K_vec k_vecs[NUM_VECS_PER_THREAD];
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
|
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
|
||||||
const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride
|
const cache_t* k_ptr =
|
||||||
+ kv_head_idx * kv_head_stride
|
k_cache + physical_block_number * kv_block_stride +
|
||||||
+ physical_block_offset * x;
|
kv_head_idx * kv_head_stride + physical_block_offset * x;
|
||||||
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
|
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
|
||||||
const int offset1 = (vec_idx * VEC_SIZE) / x;
|
const int offset1 = (vec_idx * VEC_SIZE) / x;
|
||||||
const int offset2 = (vec_idx * VEC_SIZE) % x;
|
const int offset2 = (vec_idx * VEC_SIZE) % x;
|
||||||
|
|
||||||
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
|
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
|
||||||
k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
k_vecs[j] = *reinterpret_cast<const K_vec*>(
|
||||||
|
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
||||||
} else {
|
} else {
|
||||||
// Vector conversion from Quant_vec to K_vec.
|
// Vector conversion from Quant_vec to K_vec.
|
||||||
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
|
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
|
||||||
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
||||||
k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(k_vec_quant, kv_scale);
|
k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(
|
||||||
|
k_vec_quant, kv_scale);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compute dot product.
|
// Compute dot product.
|
||||||
// This includes a reduction across the threads in the same thread group.
|
// This includes a reduction across the threads in the same thread group.
|
||||||
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
|
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(
|
||||||
|
q_vecs[thread_group_offset], k_vecs);
|
||||||
// Add the ALiBi bias if slopes are given.
|
// Add the ALiBi bias if slopes are given.
|
||||||
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0;
|
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0;
|
||||||
|
|
||||||
@ -285,13 +300,12 @@ __device__ void paged_attention_kernel(
|
|||||||
|
|
||||||
// If partitioning is enabled, store the max logit and exp_sum.
|
// If partitioning is enabled, store the max logit and exp_sum.
|
||||||
if (USE_PARTITIONING && thread_idx == 0) {
|
if (USE_PARTITIONING && thread_idx == 0) {
|
||||||
float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions
|
float* max_logits_ptr = max_logits +
|
||||||
+ head_idx * max_num_partitions
|
seq_idx * num_heads * max_num_partitions +
|
||||||
+ partition_idx;
|
head_idx * max_num_partitions + partition_idx;
|
||||||
*max_logits_ptr = qk_max;
|
*max_logits_ptr = qk_max;
|
||||||
float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions
|
float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions +
|
||||||
+ head_idx * max_num_partitions
|
head_idx * max_num_partitions + partition_idx;
|
||||||
+ partition_idx;
|
|
||||||
*exp_sums_ptr = exp_sum;
|
*exp_sums_ptr = exp_sum;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -304,7 +318,8 @@ __device__ void paged_attention_kernel(
|
|||||||
|
|
||||||
constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
|
constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
|
||||||
constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
|
constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
|
||||||
constexpr int NUM_ROWS_PER_THREAD = DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
|
constexpr int NUM_ROWS_PER_THREAD =
|
||||||
|
DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
|
||||||
|
|
||||||
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
|
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
|
||||||
float accs[NUM_ROWS_PER_THREAD];
|
float accs[NUM_ROWS_PER_THREAD];
|
||||||
@ -315,18 +330,21 @@ __device__ void paged_attention_kernel(
|
|||||||
|
|
||||||
scalar_t zero_value;
|
scalar_t zero_value;
|
||||||
zero(zero_value);
|
zero(zero_value);
|
||||||
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
|
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
|
||||||
// NOTE(woosuk): The block number is stored in int32. However, we cast it to int64
|
block_idx += NUM_WARPS) {
|
||||||
// because int32 can lead to overflow when this variable is multiplied by large numbers
|
// NOTE(woosuk): The block number is stored in int32. However, we cast it to
|
||||||
// (e.g., kv_block_stride).
|
// int64 because int32 can lead to overflow when this variable is multiplied
|
||||||
const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
|
// by large numbers (e.g., kv_block_stride).
|
||||||
|
const int64_t physical_block_number =
|
||||||
|
static_cast<int64_t>(block_table[block_idx]);
|
||||||
const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
|
const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
|
||||||
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
|
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
|
||||||
L_vec logits_vec;
|
L_vec logits_vec;
|
||||||
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx - start_token_idx));
|
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx -
|
||||||
|
start_token_idx));
|
||||||
|
|
||||||
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride
|
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
|
||||||
+ kv_head_idx * kv_head_stride;
|
kv_head_idx * kv_head_stride;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
||||||
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
||||||
@ -337,14 +355,17 @@ __device__ void paged_attention_kernel(
|
|||||||
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
|
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
|
||||||
v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
|
v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
|
||||||
} else {
|
} else {
|
||||||
V_quant_vec v_quant_vec = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
|
V_quant_vec v_quant_vec =
|
||||||
|
*reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
|
||||||
// Vector conversion from V_quant_vec to V_vec.
|
// Vector conversion from V_quant_vec to V_vec.
|
||||||
v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec, kv_scale);
|
v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec,
|
||||||
|
kv_scale);
|
||||||
}
|
}
|
||||||
if (block_idx == num_seq_blocks - 1) {
|
if (block_idx == num_seq_blocks - 1) {
|
||||||
// NOTE(woosuk): When v_vec contains the tokens that are out of the context,
|
// NOTE(woosuk): When v_vec contains the tokens that are out of the
|
||||||
// we should explicitly zero out the values since they may contain NaNs.
|
// context, we should explicitly zero out the values since they may
|
||||||
// See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
|
// contain NaNs. See
|
||||||
|
// https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
|
||||||
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
|
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < V_VEC_SIZE; j++) {
|
for (int j = 0; j < V_VEC_SIZE; j++) {
|
||||||
@ -367,8 +388,8 @@ __device__ void paged_attention_kernel(
|
|||||||
accs[i] = acc;
|
accs[i] = acc;
|
||||||
}
|
}
|
||||||
|
|
||||||
// NOTE(woosuk): A barrier is required because the shared memory space for logits
|
// NOTE(woosuk): A barrier is required because the shared memory space for
|
||||||
// is reused for the output.
|
// logits is reused for the output.
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
// Perform reduction across warps.
|
// Perform reduction across warps.
|
||||||
@ -405,9 +426,9 @@ __device__ void paged_attention_kernel(
|
|||||||
|
|
||||||
// Write the final output.
|
// Write the final output.
|
||||||
if (warp_idx == 0) {
|
if (warp_idx == 0) {
|
||||||
scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
|
scalar_t* out_ptr =
|
||||||
+ head_idx * max_num_partitions * HEAD_SIZE
|
out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
|
||||||
+ partition_idx * HEAD_SIZE;
|
head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
||||||
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
||||||
@ -419,79 +440,75 @@ __device__ void paged_attention_kernel(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Grid: (num_heads, num_seqs, 1).
|
// Grid: (num_heads, num_seqs, 1).
|
||||||
template<
|
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
|
||||||
typename scalar_t,
|
int NUM_THREADS,
|
||||||
typename cache_t,
|
vllm::Fp8KVCacheDataType KV_DTYPE>
|
||||||
int HEAD_SIZE,
|
|
||||||
int BLOCK_SIZE,
|
|
||||||
int NUM_THREADS,
|
|
||||||
vllm::Fp8KVCacheDataType KV_DTYPE>
|
|
||||||
__global__ void paged_attention_v1_kernel(
|
__global__ void paged_attention_v1_kernel(
|
||||||
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
||||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||||
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
|
||||||
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
// head_size/x, block_size, x]
|
||||||
const int num_kv_heads, // [num_heads]
|
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
|
||||||
const float scale,
|
// head_size, block_size]
|
||||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
const int num_kv_heads, // [num_heads]
|
||||||
const int* __restrict__ seq_lens, // [num_seqs]
|
const float scale,
|
||||||
const int max_num_blocks_per_seq,
|
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
const int* __restrict__ seq_lens, // [num_seqs]
|
||||||
const int q_stride,
|
const int max_num_blocks_per_seq,
|
||||||
const int kv_block_stride,
|
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||||
const int kv_head_stride,
|
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
||||||
const float kv_scale) {
|
const float kv_scale) {
|
||||||
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_DTYPE>(
|
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
|
||||||
/* exp_sums */ nullptr, /* max_logits */ nullptr,
|
KV_DTYPE>(
|
||||||
out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, seq_lens,
|
/* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache,
|
||||||
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_scale);
|
v_cache, num_kv_heads, scale, block_tables, seq_lens,
|
||||||
|
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride,
|
||||||
|
kv_head_stride, kv_scale);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Grid: (num_heads, num_seqs, max_num_partitions).
|
// Grid: (num_heads, num_seqs, max_num_partitions).
|
||||||
template<
|
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
|
||||||
typename scalar_t,
|
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
|
||||||
typename cache_t,
|
int PARTITION_SIZE>
|
||||||
int HEAD_SIZE,
|
|
||||||
int BLOCK_SIZE,
|
|
||||||
int NUM_THREADS,
|
|
||||||
vllm::Fp8KVCacheDataType KV_DTYPE,
|
|
||||||
int PARTITION_SIZE>
|
|
||||||
__global__ void paged_attention_v2_kernel(
|
__global__ void paged_attention_v2_kernel(
|
||||||
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
||||||
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
|
float* __restrict__ max_logits, // [num_seqs, num_heads,
|
||||||
scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
// max_num_partitions]
|
||||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
|
||||||
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
// max_num_partitions, head_size]
|
||||||
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||||
const int num_kv_heads, // [num_heads]
|
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
|
||||||
const float scale,
|
// head_size/x, block_size, x]
|
||||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
|
||||||
const int* __restrict__ seq_lens, // [num_seqs]
|
// head_size, block_size]
|
||||||
const int max_num_blocks_per_seq,
|
const int num_kv_heads, // [num_heads]
|
||||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
const float scale,
|
||||||
const int q_stride,
|
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||||
const int kv_block_stride,
|
const int* __restrict__ seq_lens, // [num_seqs]
|
||||||
const int kv_head_stride,
|
const int max_num_blocks_per_seq,
|
||||||
const float kv_scale) {
|
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||||
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_DTYPE, PARTITION_SIZE>(
|
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
||||||
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
|
const float kv_scale) {
|
||||||
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes,
|
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
|
||||||
q_stride, kv_block_stride, kv_head_stride, kv_scale);
|
KV_DTYPE, PARTITION_SIZE>(
|
||||||
|
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
|
||||||
|
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride,
|
||||||
|
kv_block_stride, kv_head_stride, kv_scale);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Grid: (num_heads, num_seqs).
|
// Grid: (num_heads, num_seqs).
|
||||||
template<
|
template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS,
|
||||||
typename scalar_t,
|
int PARTITION_SIZE>
|
||||||
int HEAD_SIZE,
|
|
||||||
int NUM_THREADS,
|
|
||||||
int PARTITION_SIZE>
|
|
||||||
__global__ void paged_attention_v2_reduce_kernel(
|
__global__ void paged_attention_v2_reduce_kernel(
|
||||||
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
||||||
const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
const float* __restrict__ exp_sums, // [num_seqs, num_heads,
|
||||||
const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
|
// max_num_partitions]
|
||||||
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
const float* __restrict__ max_logits, // [num_seqs, num_heads,
|
||||||
const int* __restrict__ seq_lens, // [num_seqs]
|
// max_num_partitions]
|
||||||
const int max_num_partitions) {
|
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
|
||||||
|
// max_num_partitions, head_size]
|
||||||
|
const int* __restrict__ seq_lens, // [num_seqs]
|
||||||
|
const int max_num_partitions) {
|
||||||
const int num_heads = gridDim.x;
|
const int num_heads = gridDim.x;
|
||||||
const int head_idx = blockIdx.x;
|
const int head_idx = blockIdx.x;
|
||||||
const int seq_idx = blockIdx.y;
|
const int seq_idx = blockIdx.y;
|
||||||
@ -499,9 +516,11 @@ __global__ void paged_attention_v2_reduce_kernel(
|
|||||||
const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE);
|
const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE);
|
||||||
if (num_partitions == 1) {
|
if (num_partitions == 1) {
|
||||||
// No need to reduce. Only copy tmp_out to out.
|
// No need to reduce. Only copy tmp_out to out.
|
||||||
scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
|
scalar_t* out_ptr =
|
||||||
const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
|
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
|
||||||
+ head_idx * max_num_partitions * HEAD_SIZE;
|
const scalar_t* tmp_out_ptr =
|
||||||
|
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
|
||||||
|
head_idx * max_num_partitions * HEAD_SIZE;
|
||||||
for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) {
|
for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) {
|
||||||
out_ptr[i] = tmp_out_ptr[i];
|
out_ptr[i] = tmp_out_ptr[i];
|
||||||
}
|
}
|
||||||
@ -520,8 +539,9 @@ __global__ void paged_attention_v2_reduce_kernel(
|
|||||||
|
|
||||||
// Load max logits to shared memory.
|
// Load max logits to shared memory.
|
||||||
float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
|
float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
|
||||||
const float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions
|
const float* max_logits_ptr = max_logits +
|
||||||
+ head_idx * max_num_partitions;
|
seq_idx * num_heads * max_num_partitions +
|
||||||
|
head_idx * max_num_partitions;
|
||||||
float max_logit = -FLT_MAX;
|
float max_logit = -FLT_MAX;
|
||||||
for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
|
for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
|
||||||
const float l = max_logits_ptr[i];
|
const float l = max_logits_ptr[i];
|
||||||
@ -550,9 +570,11 @@ __global__ void paged_attention_v2_reduce_kernel(
|
|||||||
max_logit = VLLM_SHFL_SYNC(max_logit, 0);
|
max_logit = VLLM_SHFL_SYNC(max_logit, 0);
|
||||||
|
|
||||||
// Load rescaled exp sums to shared memory.
|
// Load rescaled exp sums to shared memory.
|
||||||
float* shared_exp_sums = reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
|
float* shared_exp_sums =
|
||||||
const float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions
|
reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
|
||||||
+ head_idx * max_num_partitions;
|
const float* exp_sums_ptr = exp_sums +
|
||||||
|
seq_idx * num_heads * max_num_partitions +
|
||||||
|
head_idx * max_num_partitions;
|
||||||
float global_exp_sum = 0.0f;
|
float global_exp_sum = 0.0f;
|
||||||
for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
|
for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
|
||||||
float l = shared_max_logits[i];
|
float l = shared_max_logits[i];
|
||||||
@ -565,61 +587,45 @@ __global__ void paged_attention_v2_reduce_kernel(
|
|||||||
const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
|
const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
|
||||||
|
|
||||||
// Aggregate tmp_out to out.
|
// Aggregate tmp_out to out.
|
||||||
const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
|
const scalar_t* tmp_out_ptr =
|
||||||
+ head_idx * max_num_partitions * HEAD_SIZE;
|
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
|
||||||
scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
|
head_idx * max_num_partitions * HEAD_SIZE;
|
||||||
|
scalar_t* out_ptr =
|
||||||
|
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
|
for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
|
||||||
float acc = 0.0f;
|
float acc = 0.0f;
|
||||||
for (int j = 0; j < num_partitions; ++j) {
|
for (int j = 0; j < num_partitions; ++j) {
|
||||||
acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * inv_global_exp_sum;
|
acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] *
|
||||||
|
inv_global_exp_sum;
|
||||||
}
|
}
|
||||||
from_float(out_ptr[i], acc);
|
from_float(out_ptr[i], acc);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
|
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
|
||||||
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
|
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
|
||||||
((void*)vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
|
((void*)vllm::paged_attention_v1_kernel< \
|
||||||
KV_DTYPE>), shared_mem_size); \
|
T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_DTYPE>), \
|
||||||
vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
|
shared_mem_size); \
|
||||||
KV_DTYPE><<<grid, block, shared_mem_size, stream>>>( \
|
vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
|
||||||
out_ptr, \
|
NUM_THREADS, KV_DTYPE> \
|
||||||
query_ptr, \
|
<<<grid, block, shared_mem_size, stream>>>( \
|
||||||
key_cache_ptr, \
|
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \
|
||||||
value_cache_ptr, \
|
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
|
||||||
num_kv_heads, \
|
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
|
||||||
scale, \
|
kv_scale);
|
||||||
block_tables_ptr, \
|
|
||||||
seq_lens_ptr, \
|
|
||||||
max_num_blocks_per_seq, \
|
|
||||||
alibi_slopes_ptr, \
|
|
||||||
q_stride, \
|
|
||||||
kv_block_stride, \
|
|
||||||
kv_head_stride, \
|
|
||||||
kv_scale);
|
|
||||||
|
|
||||||
// TODO(woosuk): Tune NUM_THREADS.
|
// TODO(woosuk): Tune NUM_THREADS.
|
||||||
template<
|
template <typename T, typename CACHE_T, int BLOCK_SIZE,
|
||||||
typename T,
|
vllm::Fp8KVCacheDataType KV_DTYPE, int NUM_THREADS = 128>
|
||||||
typename CACHE_T,
|
|
||||||
int BLOCK_SIZE,
|
|
||||||
vllm::Fp8KVCacheDataType KV_DTYPE,
|
|
||||||
int NUM_THREADS = 128>
|
|
||||||
void paged_attention_v1_launcher(
|
void paged_attention_v1_launcher(
|
||||||
torch::Tensor& out,
|
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||||
torch::Tensor& query,
|
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
||||||
torch::Tensor& key_cache,
|
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
|
||||||
torch::Tensor& value_cache,
|
const c10::optional<torch::Tensor>& alibi_slopes, float kv_scale) {
|
||||||
int num_kv_heads,
|
|
||||||
float scale,
|
|
||||||
torch::Tensor& block_tables,
|
|
||||||
torch::Tensor& seq_lens,
|
|
||||||
int max_seq_len,
|
|
||||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
|
||||||
float kv_scale) {
|
|
||||||
int num_seqs = query.size(0);
|
int num_seqs = query.size(0);
|
||||||
int num_heads = query.size(1);
|
int num_heads = query.size(1);
|
||||||
int head_size = query.size(2);
|
int head_size = query.size(2);
|
||||||
@ -632,9 +638,10 @@ void paged_attention_v1_launcher(
|
|||||||
assert(head_size % thread_group_size == 0);
|
assert(head_size % thread_group_size == 0);
|
||||||
|
|
||||||
// NOTE: alibi_slopes is optional.
|
// NOTE: alibi_slopes is optional.
|
||||||
const float* alibi_slopes_ptr = alibi_slopes ?
|
const float* alibi_slopes_ptr =
|
||||||
reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
|
alibi_slopes
|
||||||
: nullptr;
|
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
|
||||||
|
: nullptr;
|
||||||
|
|
||||||
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
|
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
|
||||||
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
|
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
|
||||||
@ -644,7 +651,8 @@ void paged_attention_v1_launcher(
|
|||||||
int* seq_lens_ptr = seq_lens.data_ptr<int>();
|
int* seq_lens_ptr = seq_lens.data_ptr<int>();
|
||||||
|
|
||||||
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
||||||
int padded_max_seq_len = DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
|
int padded_max_seq_len =
|
||||||
|
DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
|
||||||
int logits_size = padded_max_seq_len * sizeof(float);
|
int logits_size = padded_max_seq_len * sizeof(float);
|
||||||
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
|
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
|
||||||
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
|
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
|
||||||
@ -683,19 +691,10 @@ void paged_attention_v1_launcher(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE) \
|
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE) \
|
||||||
paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE>( \
|
paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE>( \
|
||||||
out, \
|
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
|
||||||
query, \
|
seq_lens, max_seq_len, alibi_slopes, kv_scale);
|
||||||
key_cache, \
|
|
||||||
value_cache, \
|
|
||||||
num_kv_heads, \
|
|
||||||
scale, \
|
|
||||||
block_tables, \
|
|
||||||
seq_lens, \
|
|
||||||
max_seq_len, \
|
|
||||||
alibi_slopes, \
|
|
||||||
kv_scale);
|
|
||||||
|
|
||||||
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
||||||
// 1, 2, 4, 64, 128, 256.
|
// 1, 2, 4, 64, 128, 256.
|
||||||
@ -716,74 +715,45 @@ void paged_attention_v1_launcher(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void paged_attention_v1(
|
void paged_attention_v1(
|
||||||
torch::Tensor& out, // [num_seqs, num_heads, head_size]
|
torch::Tensor& out, // [num_seqs, num_heads, head_size]
|
||||||
torch::Tensor& query, // [num_seqs, num_heads, head_size]
|
torch::Tensor& query, // [num_seqs, num_heads, head_size]
|
||||||
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
torch::Tensor&
|
||||||
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||||
int num_kv_heads, // [num_heads]
|
torch::Tensor&
|
||||||
float scale,
|
value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||||
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
int num_kv_heads, // [num_heads]
|
||||||
torch::Tensor& seq_lens, // [num_seqs]
|
float scale,
|
||||||
int block_size,
|
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||||
int max_seq_len,
|
torch::Tensor& seq_lens, // [num_seqs]
|
||||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
int block_size, int max_seq_len,
|
||||||
const std::string& kv_cache_dtype,
|
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||||
float kv_scale) {
|
const std::string& kv_cache_dtype, float kv_scale){
|
||||||
|
|
||||||
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, CALL_V1_LAUNCHER_BLOCK_SIZE)
|
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
|
||||||
}
|
CALL_V1_LAUNCHER_BLOCK_SIZE)}
|
||||||
|
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
|
||||||
|
vllm::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
|
||||||
|
NUM_THREADS, KV_DTYPE, PARTITION_SIZE> \
|
||||||
|
<<<grid, block, shared_mem_size, stream>>>( \
|
||||||
|
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
|
||||||
|
value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
|
||||||
|
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
|
||||||
|
kv_block_stride, kv_head_stride, kv_scale); \
|
||||||
|
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \
|
||||||
|
PARTITION_SIZE> \
|
||||||
|
<<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
|
||||||
|
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
|
||||||
|
max_num_partitions);
|
||||||
|
|
||||||
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
|
template <typename T, typename CACHE_T, int BLOCK_SIZE,
|
||||||
vllm::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
|
vllm::Fp8KVCacheDataType KV_DTYPE, int NUM_THREADS = 128,
|
||||||
KV_DTYPE, PARTITION_SIZE> \
|
int PARTITION_SIZE = 512>
|
||||||
<<<grid, block, shared_mem_size, stream>>>( \
|
|
||||||
exp_sums_ptr, \
|
|
||||||
max_logits_ptr, \
|
|
||||||
tmp_out_ptr, \
|
|
||||||
query_ptr, \
|
|
||||||
key_cache_ptr, \
|
|
||||||
value_cache_ptr, \
|
|
||||||
num_kv_heads, \
|
|
||||||
scale, \
|
|
||||||
block_tables_ptr, \
|
|
||||||
seq_lens_ptr, \
|
|
||||||
max_num_blocks_per_seq, \
|
|
||||||
alibi_slopes_ptr, \
|
|
||||||
q_stride, \
|
|
||||||
kv_block_stride, \
|
|
||||||
kv_head_stride, \
|
|
||||||
kv_scale); \
|
|
||||||
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, PARTITION_SIZE> \
|
|
||||||
<<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
|
|
||||||
out_ptr, \
|
|
||||||
exp_sums_ptr, \
|
|
||||||
max_logits_ptr, \
|
|
||||||
tmp_out_ptr, \
|
|
||||||
seq_lens_ptr, \
|
|
||||||
max_num_partitions);
|
|
||||||
|
|
||||||
template<
|
|
||||||
typename T,
|
|
||||||
typename CACHE_T,
|
|
||||||
int BLOCK_SIZE,
|
|
||||||
vllm::Fp8KVCacheDataType KV_DTYPE,
|
|
||||||
int NUM_THREADS = 128,
|
|
||||||
int PARTITION_SIZE = 512>
|
|
||||||
void paged_attention_v2_launcher(
|
void paged_attention_v2_launcher(
|
||||||
torch::Tensor& out,
|
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
|
||||||
torch::Tensor& exp_sums,
|
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||||
torch::Tensor& max_logits,
|
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
||||||
torch::Tensor& tmp_out,
|
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
|
||||||
torch::Tensor& query,
|
const c10::optional<torch::Tensor>& alibi_slopes, float kv_scale) {
|
||||||
torch::Tensor& key_cache,
|
|
||||||
torch::Tensor& value_cache,
|
|
||||||
int num_kv_heads,
|
|
||||||
float scale,
|
|
||||||
torch::Tensor& block_tables,
|
|
||||||
torch::Tensor& seq_lens,
|
|
||||||
int max_seq_len,
|
|
||||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
|
||||||
float kv_scale) {
|
|
||||||
int num_seqs = query.size(0);
|
int num_seqs = query.size(0);
|
||||||
int num_heads = query.size(1);
|
int num_heads = query.size(1);
|
||||||
int head_size = query.size(2);
|
int head_size = query.size(2);
|
||||||
@ -796,9 +766,10 @@ void paged_attention_v2_launcher(
|
|||||||
assert(head_size % thread_group_size == 0);
|
assert(head_size % thread_group_size == 0);
|
||||||
|
|
||||||
// NOTE: alibi_slopes is optional.
|
// NOTE: alibi_slopes is optional.
|
||||||
const float* alibi_slopes_ptr = alibi_slopes ?
|
const float* alibi_slopes_ptr =
|
||||||
reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
|
alibi_slopes
|
||||||
: nullptr;
|
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
|
||||||
|
: nullptr;
|
||||||
|
|
||||||
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
|
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
|
||||||
float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
|
float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
|
||||||
@ -853,59 +824,50 @@ void paged_attention_v2_launcher(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE) \
|
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE) \
|
||||||
paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE>( \
|
paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE>( \
|
||||||
out, \
|
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
|
||||||
exp_sums, \
|
num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
|
||||||
max_logits, \
|
kv_scale);
|
||||||
tmp_out, \
|
|
||||||
query, \
|
|
||||||
key_cache, \
|
|
||||||
value_cache, \
|
|
||||||
num_kv_heads, \
|
|
||||||
scale, \
|
|
||||||
block_tables, \
|
|
||||||
seq_lens, \
|
|
||||||
max_seq_len, \
|
|
||||||
alibi_slopes, \
|
|
||||||
kv_scale);
|
|
||||||
|
|
||||||
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
||||||
// 1, 2, 4, 64, 128, 256.
|
// 1, 2, 4, 64, 128, 256.
|
||||||
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
|
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
|
||||||
switch (block_size) { \
|
switch (block_size) { \
|
||||||
case 8: \
|
case 8: \
|
||||||
CALL_V2_LAUNCHER(T, CACHE_T, 8, KV_DTYPE); \
|
CALL_V2_LAUNCHER(T, CACHE_T, 8, KV_DTYPE); \
|
||||||
break; \
|
break; \
|
||||||
case 16: \
|
case 16: \
|
||||||
CALL_V2_LAUNCHER(T, CACHE_T, 16, KV_DTYPE); \
|
CALL_V2_LAUNCHER(T, CACHE_T, 16, KV_DTYPE); \
|
||||||
break; \
|
break; \
|
||||||
case 32: \
|
case 32: \
|
||||||
CALL_V2_LAUNCHER(T, CACHE_T, 32, KV_DTYPE); \
|
CALL_V2_LAUNCHER(T, CACHE_T, 32, KV_DTYPE); \
|
||||||
break; \
|
break; \
|
||||||
default: \
|
default: \
|
||||||
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
||||||
break; \
|
break; \
|
||||||
}
|
}
|
||||||
|
|
||||||
void paged_attention_v2(
|
void paged_attention_v2(
|
||||||
torch::Tensor& out, // [num_seqs, num_heads, head_size]
|
torch::Tensor& out, // [num_seqs, num_heads, head_size]
|
||||||
torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
||||||
torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
|
torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
|
||||||
torch::Tensor& tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
torch::Tensor&
|
||||||
torch::Tensor& query, // [num_seqs, num_heads, head_size]
|
tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
||||||
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
torch::Tensor& query, // [num_seqs, num_heads, head_size]
|
||||||
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
torch::Tensor&
|
||||||
int num_kv_heads, // [num_heads]
|
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||||
float scale,
|
torch::Tensor&
|
||||||
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||||
torch::Tensor& seq_lens, // [num_seqs]
|
int num_kv_heads, // [num_heads]
|
||||||
int block_size,
|
float scale,
|
||||||
int max_seq_len,
|
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
torch::Tensor& seq_lens, // [num_seqs]
|
||||||
const std::string& kv_cache_dtype,
|
int block_size, int max_seq_len,
|
||||||
float kv_scale) {
|
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||||
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, CALL_V2_LAUNCHER_BLOCK_SIZE)
|
const std::string& kv_cache_dtype, float kv_scale) {
|
||||||
|
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
|
||||||
|
CALL_V2_LAUNCHER_BLOCK_SIZE)
|
||||||
}
|
}
|
||||||
|
|
||||||
#undef WARP_SIZE
|
#undef WARP_SIZE
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
/*
|
/*
|
||||||
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
* Adapted from
|
||||||
|
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
||||||
* Copyright (c) 2023, The vLLM team.
|
* Copyright (c) 2023, The vLLM team.
|
||||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||||
*
|
*
|
||||||
@ -26,7 +27,7 @@
|
|||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
// Q*K^T operation.
|
// Q*K^T operation.
|
||||||
template<int THREAD_GROUP_SIZE, typename Vec, int N>
|
template <int THREAD_GROUP_SIZE, typename Vec, int N>
|
||||||
inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
|
inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
|
||||||
using A_vec = typename FloatVec<Vec>::Type;
|
using A_vec = typename FloatVec<Vec>::Type;
|
||||||
// Compute the parallel products for Q*K^T (treat vector lanes separately).
|
// Compute the parallel products for Q*K^T (treat vector lanes separately).
|
||||||
@ -45,12 +46,12 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
|
|||||||
return qk;
|
return qk;
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T, int THREAD_GROUP_SIZE>
|
template <typename T, int THREAD_GROUP_SIZE>
|
||||||
struct Qk_dot {
|
struct Qk_dot {
|
||||||
template<typename Vec, int N>
|
template <typename Vec, int N>
|
||||||
static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) {
|
static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) {
|
||||||
return qk_dot_<THREAD_GROUP_SIZE>(q, k);
|
return qk_dot_<THREAD_GROUP_SIZE>(q, k);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
/*
|
/*
|
||||||
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
* Adapted from
|
||||||
* and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
||||||
|
* and
|
||||||
|
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
||||||
* Copyright (c) 2023, The vLLM team.
|
* Copyright (c) 2023, The vLLM team.
|
||||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||||
*
|
*
|
||||||
@ -28,8 +30,8 @@
|
|||||||
#include <hip/hip_bf16.h>
|
#include <hip/hip_bf16.h>
|
||||||
#include <hip/hip_fp16.h>
|
#include <hip/hip_fp16.h>
|
||||||
|
|
||||||
typedef __hip_bfloat162 __nv_bfloat162;
|
typedef __hip_bfloat162 __nv_bfloat162;
|
||||||
typedef __hip_bfloat16 __nv_bfloat16;
|
typedef __hip_bfloat16 __nv_bfloat16;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
@ -50,37 +52,37 @@ struct bf16_8_t {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// BF16 vector types for Q, K, V.
|
// BF16 vector types for Q, K, V.
|
||||||
template<>
|
template <>
|
||||||
struct Vec<__nv_bfloat16, 1> {
|
struct Vec<__nv_bfloat16, 1> {
|
||||||
using Type = __nv_bfloat16;
|
using Type = __nv_bfloat16;
|
||||||
};
|
};
|
||||||
template<>
|
template <>
|
||||||
struct Vec<__nv_bfloat16, 2> {
|
struct Vec<__nv_bfloat16, 2> {
|
||||||
using Type = __nv_bfloat162;
|
using Type = __nv_bfloat162;
|
||||||
};
|
};
|
||||||
template<>
|
template <>
|
||||||
struct Vec<__nv_bfloat16, 4> {
|
struct Vec<__nv_bfloat16, 4> {
|
||||||
using Type = bf16_4_t;
|
using Type = bf16_4_t;
|
||||||
};
|
};
|
||||||
template<>
|
template <>
|
||||||
struct Vec<__nv_bfloat16, 8> {
|
struct Vec<__nv_bfloat16, 8> {
|
||||||
using Type = bf16_8_t;
|
using Type = bf16_8_t;
|
||||||
};
|
};
|
||||||
|
|
||||||
// FP32 accumulator vector types corresponding to Vec.
|
// FP32 accumulator vector types corresponding to Vec.
|
||||||
template<>
|
template <>
|
||||||
struct FloatVec<__nv_bfloat16> {
|
struct FloatVec<__nv_bfloat16> {
|
||||||
using Type = float;
|
using Type = float;
|
||||||
};
|
};
|
||||||
template<>
|
template <>
|
||||||
struct FloatVec<__nv_bfloat162> {
|
struct FloatVec<__nv_bfloat162> {
|
||||||
using Type = float2;
|
using Type = float2;
|
||||||
};
|
};
|
||||||
template<>
|
template <>
|
||||||
struct FloatVec<bf16_4_t> {
|
struct FloatVec<bf16_4_t> {
|
||||||
using Type = Float4_;
|
using Type = Float4_;
|
||||||
};
|
};
|
||||||
template<>
|
template <>
|
||||||
struct FloatVec<bf16_8_t> {
|
struct FloatVec<bf16_8_t> {
|
||||||
using Type = Float8_;
|
using Type = Float8_;
|
||||||
};
|
};
|
||||||
@ -108,9 +110,9 @@ inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) {
|
|||||||
assert(false);
|
assert(false);
|
||||||
#else
|
#else
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
return a + b;
|
return a + b;
|
||||||
#else
|
#else
|
||||||
return __hadd(a, b);
|
return __hadd(a, b);
|
||||||
#endif
|
#endif
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
@ -161,7 +163,7 @@ inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Vector multiplication.
|
// Vector multiplication.
|
||||||
template<>
|
template <>
|
||||||
inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) {
|
inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) {
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||||
assert(false);
|
assert(false);
|
||||||
@ -170,7 +172,7 @@ inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) {
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
template<>
|
template <>
|
||||||
inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
|
inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||||
assert(false);
|
assert(false);
|
||||||
@ -179,12 +181,12 @@ inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
template<>
|
template <>
|
||||||
inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) {
|
inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) {
|
||||||
return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
|
return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<>
|
template <>
|
||||||
inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) {
|
inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) {
|
||||||
bf16_4_t c;
|
bf16_4_t c;
|
||||||
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
|
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
|
||||||
@ -192,7 +194,7 @@ inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) {
|
|||||||
return c;
|
return c;
|
||||||
}
|
}
|
||||||
|
|
||||||
template<>
|
template <>
|
||||||
inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) {
|
inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) {
|
||||||
__nv_bfloat162 s = bf162bf162(a);
|
__nv_bfloat162 s = bf162bf162(a);
|
||||||
bf16_4_t c;
|
bf16_4_t c;
|
||||||
@ -201,7 +203,7 @@ inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) {
|
|||||||
return c;
|
return c;
|
||||||
}
|
}
|
||||||
|
|
||||||
template<>
|
template <>
|
||||||
inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) {
|
inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) {
|
||||||
bf16_8_t c;
|
bf16_8_t c;
|
||||||
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
|
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
|
||||||
@ -211,7 +213,7 @@ inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) {
|
|||||||
return c;
|
return c;
|
||||||
}
|
}
|
||||||
|
|
||||||
template<>
|
template <>
|
||||||
inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) {
|
inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) {
|
||||||
__nv_bfloat162 s = bf162bf162(a);
|
__nv_bfloat162 s = bf162bf162(a);
|
||||||
bf16_8_t c;
|
bf16_8_t c;
|
||||||
@ -222,26 +224,26 @@ inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) {
|
|||||||
return c;
|
return c;
|
||||||
}
|
}
|
||||||
|
|
||||||
template<>
|
template <>
|
||||||
inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) {
|
inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) {
|
||||||
float fa = __bfloat162float(a);
|
float fa = __bfloat162float(a);
|
||||||
float fb = __bfloat162float(b);
|
float fb = __bfloat162float(b);
|
||||||
return fa * fb;
|
return fa * fb;
|
||||||
}
|
}
|
||||||
|
|
||||||
template<>
|
template <>
|
||||||
inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
|
inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
|
||||||
float2 fa = bf1622float2(a);
|
float2 fa = bf1622float2(a);
|
||||||
float2 fb = bf1622float2(b);
|
float2 fb = bf1622float2(b);
|
||||||
return mul<float2, float2, float2>(fa, fb);
|
return mul<float2, float2, float2>(fa, fb);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<>
|
template <>
|
||||||
inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) {
|
inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) {
|
||||||
return mul<float2, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
|
return mul<float2, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<>
|
template <>
|
||||||
inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) {
|
inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) {
|
||||||
Float4_ fc;
|
Float4_ fc;
|
||||||
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
|
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
|
||||||
@ -249,7 +251,7 @@ inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) {
|
|||||||
return fc;
|
return fc;
|
||||||
}
|
}
|
||||||
|
|
||||||
template<>
|
template <>
|
||||||
inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) {
|
inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) {
|
||||||
__nv_bfloat162 s = bf162bf162(a);
|
__nv_bfloat162 s = bf162bf162(a);
|
||||||
Float4_ fc;
|
Float4_ fc;
|
||||||
@ -258,7 +260,7 @@ inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) {
|
|||||||
return fc;
|
return fc;
|
||||||
}
|
}
|
||||||
|
|
||||||
template<>
|
template <>
|
||||||
inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) {
|
inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) {
|
||||||
Float8_ fc;
|
Float8_ fc;
|
||||||
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
|
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
|
||||||
@ -268,7 +270,7 @@ inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) {
|
|||||||
return fc;
|
return fc;
|
||||||
}
|
}
|
||||||
|
|
||||||
template<>
|
template <>
|
||||||
inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) {
|
inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) {
|
||||||
__nv_bfloat162 s = bf162bf162(a);
|
__nv_bfloat162 s = bf162bf162(a);
|
||||||
Float8_ fc;
|
Float8_ fc;
|
||||||
@ -280,7 +282,8 @@ inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Vector fused multiply-add.
|
// Vector fused multiply-add.
|
||||||
inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) {
|
inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b,
|
||||||
|
__nv_bfloat162 c) {
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||||
assert(false);
|
assert(false);
|
||||||
#else
|
#else
|
||||||
@ -288,7 +291,8 @@ inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bf
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c) {
|
inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b,
|
||||||
|
__nv_bfloat162 c) {
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||||
assert(false);
|
assert(false);
|
||||||
#else
|
#else
|
||||||
@ -379,23 +383,23 @@ inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Vector sum.
|
// Vector sum.
|
||||||
template<>
|
template <>
|
||||||
inline __device__ float sum(__nv_bfloat16 v) {
|
inline __device__ float sum(__nv_bfloat16 v) {
|
||||||
return __bfloat162float(v);
|
return __bfloat162float(v);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<>
|
template <>
|
||||||
inline __device__ float sum(__nv_bfloat162 v) {
|
inline __device__ float sum(__nv_bfloat162 v) {
|
||||||
float2 vf = bf1622float2(v);
|
float2 vf = bf1622float2(v);
|
||||||
return vf.x + vf.y;
|
return vf.x + vf.y;
|
||||||
}
|
}
|
||||||
|
|
||||||
template<>
|
template <>
|
||||||
inline __device__ float sum(bf16_4_t v) {
|
inline __device__ float sum(bf16_4_t v) {
|
||||||
return sum(v.x) + sum(v.y);
|
return sum(v.x) + sum(v.y);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<>
|
template <>
|
||||||
inline __device__ float sum(bf16_8_t v) {
|
inline __device__ float sum(bf16_8_t v) {
|
||||||
return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w);
|
return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w);
|
||||||
}
|
}
|
||||||
@ -448,4 +452,4 @@ inline __device__ void zero(__nv_bfloat16& dst) {
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
/*
|
/*
|
||||||
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
* Adapted from
|
||||||
* and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
||||||
|
* and
|
||||||
|
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
||||||
* Copyright (c) 2023, The vLLM team.
|
* Copyright (c) 2023, The vLLM team.
|
||||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||||
*
|
*
|
||||||
@ -30,37 +32,37 @@
|
|||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
// FP16 vector types for Q, K, V.
|
// FP16 vector types for Q, K, V.
|
||||||
template<>
|
template <>
|
||||||
struct Vec<uint16_t, 1> {
|
struct Vec<uint16_t, 1> {
|
||||||
using Type = uint16_t;
|
using Type = uint16_t;
|
||||||
};
|
};
|
||||||
template<>
|
template <>
|
||||||
struct Vec<uint16_t, 2> {
|
struct Vec<uint16_t, 2> {
|
||||||
using Type = uint32_t;
|
using Type = uint32_t;
|
||||||
};
|
};
|
||||||
template<>
|
template <>
|
||||||
struct Vec<uint16_t, 4> {
|
struct Vec<uint16_t, 4> {
|
||||||
using Type = uint2;
|
using Type = uint2;
|
||||||
};
|
};
|
||||||
template<>
|
template <>
|
||||||
struct Vec<uint16_t, 8> {
|
struct Vec<uint16_t, 8> {
|
||||||
using Type = uint4;
|
using Type = uint4;
|
||||||
};
|
};
|
||||||
|
|
||||||
// FP32 accumulator vector types corresponding to Vec.
|
// FP32 accumulator vector types corresponding to Vec.
|
||||||
template<>
|
template <>
|
||||||
struct FloatVec<uint16_t> {
|
struct FloatVec<uint16_t> {
|
||||||
using Type = float;
|
using Type = float;
|
||||||
};
|
};
|
||||||
template<>
|
template <>
|
||||||
struct FloatVec<uint32_t> {
|
struct FloatVec<uint32_t> {
|
||||||
using Type = float2;
|
using Type = float2;
|
||||||
};
|
};
|
||||||
template<>
|
template <>
|
||||||
struct FloatVec<uint2> {
|
struct FloatVec<uint2> {
|
||||||
using Type = Float4_;
|
using Type = Float4_;
|
||||||
};
|
};
|
||||||
template<>
|
template <>
|
||||||
struct FloatVec<uint4> {
|
struct FloatVec<uint4> {
|
||||||
using Type = Float8_;
|
using Type = Float8_;
|
||||||
};
|
};
|
||||||
@ -73,8 +75,8 @@ inline __device__ uint32_t h0_h0(uint16_t a) {
|
|||||||
return b;
|
return b;
|
||||||
#else
|
#else
|
||||||
union {
|
union {
|
||||||
uint32_t u32;
|
uint32_t u32;
|
||||||
uint16_t u16[2];
|
uint16_t u16[2];
|
||||||
} tmp;
|
} tmp;
|
||||||
tmp.u16[0] = a;
|
tmp.u16[0] = a;
|
||||||
tmp.u16[1] = a;
|
tmp.u16[1] = a;
|
||||||
@ -130,10 +132,12 @@ inline __device__ uint32_t float2_to_half2(float2 f) {
|
|||||||
} tmp;
|
} tmp;
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||||
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x));
|
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n"
|
||||||
|
: "=r"(tmp.u32)
|
||||||
|
: "f"(f.y), "f"(f.x));
|
||||||
#else
|
#else
|
||||||
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
|
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
|
||||||
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
|
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
|
||||||
#endif
|
#endif
|
||||||
#else
|
#else
|
||||||
tmp.u16[0] = float_to_half(f.x);
|
tmp.u16[0] = float_to_half(f.x);
|
||||||
@ -201,7 +205,7 @@ inline __device__ Float8_ add(uint4 a, Float8_ fb) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Vector multiplication.
|
// Vector multiplication.
|
||||||
template<>
|
template <>
|
||||||
inline __device__ uint16_t mul(uint16_t a, uint16_t b) {
|
inline __device__ uint16_t mul(uint16_t a, uint16_t b) {
|
||||||
uint16_t c;
|
uint16_t c;
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
@ -212,7 +216,7 @@ inline __device__ uint16_t mul(uint16_t a, uint16_t b) {
|
|||||||
return c;
|
return c;
|
||||||
}
|
}
|
||||||
|
|
||||||
template<>
|
template <>
|
||||||
inline __device__ uint32_t mul(uint32_t a, uint32_t b) {
|
inline __device__ uint32_t mul(uint32_t a, uint32_t b) {
|
||||||
uint32_t c;
|
uint32_t c;
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
@ -223,12 +227,12 @@ inline __device__ uint32_t mul(uint32_t a, uint32_t b) {
|
|||||||
return c;
|
return c;
|
||||||
}
|
}
|
||||||
|
|
||||||
template<>
|
template <>
|
||||||
inline __device__ uint32_t mul(uint16_t a, uint32_t b) {
|
inline __device__ uint32_t mul(uint16_t a, uint32_t b) {
|
||||||
return mul<uint32_t, uint32_t, uint32_t>(h0_h0(a), b);
|
return mul<uint32_t, uint32_t, uint32_t>(h0_h0(a), b);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<>
|
template <>
|
||||||
inline __device__ uint2 mul(uint2 a, uint2 b) {
|
inline __device__ uint2 mul(uint2 a, uint2 b) {
|
||||||
uint2 c;
|
uint2 c;
|
||||||
c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
|
c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
|
||||||
@ -236,7 +240,7 @@ inline __device__ uint2 mul(uint2 a, uint2 b) {
|
|||||||
return c;
|
return c;
|
||||||
}
|
}
|
||||||
|
|
||||||
template<>
|
template <>
|
||||||
inline __device__ uint2 mul(uint16_t a, uint2 b) {
|
inline __device__ uint2 mul(uint16_t a, uint2 b) {
|
||||||
uint32_t s = h0_h0(a);
|
uint32_t s = h0_h0(a);
|
||||||
uint2 c;
|
uint2 c;
|
||||||
@ -245,7 +249,7 @@ inline __device__ uint2 mul(uint16_t a, uint2 b) {
|
|||||||
return c;
|
return c;
|
||||||
}
|
}
|
||||||
|
|
||||||
template<>
|
template <>
|
||||||
inline __device__ uint4 mul(uint4 a, uint4 b) {
|
inline __device__ uint4 mul(uint4 a, uint4 b) {
|
||||||
uint4 c;
|
uint4 c;
|
||||||
c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
|
c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
|
||||||
@ -255,7 +259,7 @@ inline __device__ uint4 mul(uint4 a, uint4 b) {
|
|||||||
return c;
|
return c;
|
||||||
}
|
}
|
||||||
|
|
||||||
template<>
|
template <>
|
||||||
inline __device__ uint4 mul(uint16_t a, uint4 b) {
|
inline __device__ uint4 mul(uint16_t a, uint4 b) {
|
||||||
uint32_t s = h0_h0(a);
|
uint32_t s = h0_h0(a);
|
||||||
uint4 c;
|
uint4 c;
|
||||||
@ -266,26 +270,26 @@ inline __device__ uint4 mul(uint16_t a, uint4 b) {
|
|||||||
return c;
|
return c;
|
||||||
}
|
}
|
||||||
|
|
||||||
template<>
|
template <>
|
||||||
inline __device__ float mul(uint16_t a, uint16_t b) {
|
inline __device__ float mul(uint16_t a, uint16_t b) {
|
||||||
float fa = half_to_float(a);
|
float fa = half_to_float(a);
|
||||||
float fb = half_to_float(b);
|
float fb = half_to_float(b);
|
||||||
return fa * fb;
|
return fa * fb;
|
||||||
}
|
}
|
||||||
|
|
||||||
template<>
|
template <>
|
||||||
inline __device__ float2 mul(uint32_t a, uint32_t b) {
|
inline __device__ float2 mul(uint32_t a, uint32_t b) {
|
||||||
float2 fa = half2_to_float2(a);
|
float2 fa = half2_to_float2(a);
|
||||||
float2 fb = half2_to_float2(b);
|
float2 fb = half2_to_float2(b);
|
||||||
return mul<float2, float2, float2>(fa, fb);
|
return mul<float2, float2, float2>(fa, fb);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<>
|
template <>
|
||||||
inline __device__ float2 mul(uint16_t a, uint32_t b) {
|
inline __device__ float2 mul(uint16_t a, uint32_t b) {
|
||||||
return mul<float2, uint32_t, uint32_t>(h0_h0(a), b);
|
return mul<float2, uint32_t, uint32_t>(h0_h0(a), b);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<>
|
template <>
|
||||||
inline __device__ Float4_ mul(uint2 a, uint2 b) {
|
inline __device__ Float4_ mul(uint2 a, uint2 b) {
|
||||||
Float4_ fc;
|
Float4_ fc;
|
||||||
fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
|
fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
|
||||||
@ -293,7 +297,7 @@ inline __device__ Float4_ mul(uint2 a, uint2 b) {
|
|||||||
return fc;
|
return fc;
|
||||||
}
|
}
|
||||||
|
|
||||||
template<>
|
template <>
|
||||||
inline __device__ Float4_ mul(uint16_t a, uint2 b) {
|
inline __device__ Float4_ mul(uint16_t a, uint2 b) {
|
||||||
uint32_t s = h0_h0(a);
|
uint32_t s = h0_h0(a);
|
||||||
Float4_ fc;
|
Float4_ fc;
|
||||||
@ -302,7 +306,7 @@ inline __device__ Float4_ mul(uint16_t a, uint2 b) {
|
|||||||
return fc;
|
return fc;
|
||||||
}
|
}
|
||||||
|
|
||||||
template<>
|
template <>
|
||||||
inline __device__ Float8_ mul(uint4 a, uint4 b) {
|
inline __device__ Float8_ mul(uint4 a, uint4 b) {
|
||||||
Float8_ fc;
|
Float8_ fc;
|
||||||
fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
|
fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
|
||||||
@ -312,7 +316,7 @@ inline __device__ Float8_ mul(uint4 a, uint4 b) {
|
|||||||
return fc;
|
return fc;
|
||||||
}
|
}
|
||||||
|
|
||||||
template<>
|
template <>
|
||||||
inline __device__ Float8_ mul(uint16_t a, uint4 b) {
|
inline __device__ Float8_ mul(uint16_t a, uint4 b) {
|
||||||
uint32_t s = h0_h0(a);
|
uint32_t s = h0_h0(a);
|
||||||
Float8_ fc;
|
Float8_ fc;
|
||||||
@ -327,9 +331,13 @@ inline __device__ Float8_ mul(uint16_t a, uint4 b) {
|
|||||||
inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) {
|
inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) {
|
||||||
uint32_t d;
|
uint32_t d;
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c));
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
|
||||||
|
: "=r"(d)
|
||||||
|
: "r"(a), "r"(b), "r"(c));
|
||||||
#else
|
#else
|
||||||
asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" : "=v"(d) : "v"(a), "v"(b), "v"(c));
|
asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n"
|
||||||
|
: "=v"(d)
|
||||||
|
: "v"(a), "v"(b), "v"(c));
|
||||||
#endif
|
#endif
|
||||||
return d;
|
return d;
|
||||||
}
|
}
|
||||||
@ -423,24 +431,24 @@ inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Vector sum.
|
// Vector sum.
|
||||||
template<>
|
template <>
|
||||||
inline __device__ float sum(uint16_t v) {
|
inline __device__ float sum(uint16_t v) {
|
||||||
return half_to_float(v);
|
return half_to_float(v);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<>
|
template <>
|
||||||
inline __device__ float sum(uint32_t v) {
|
inline __device__ float sum(uint32_t v) {
|
||||||
float2 tmp = half2_to_float2(v);
|
float2 tmp = half2_to_float2(v);
|
||||||
return tmp.x + tmp.y;
|
return tmp.x + tmp.y;
|
||||||
}
|
}
|
||||||
|
|
||||||
template<>
|
template <>
|
||||||
inline __device__ float sum(uint2 v) {
|
inline __device__ float sum(uint2 v) {
|
||||||
uint32_t c = add(v.x, v.y);
|
uint32_t c = add(v.x, v.y);
|
||||||
return sum(c);
|
return sum(c);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<>
|
template <>
|
||||||
inline __device__ float sum(uint4 v) {
|
inline __device__ float sum(uint4 v) {
|
||||||
uint32_t c = add(v.x, v.y);
|
uint32_t c = add(v.x, v.y);
|
||||||
c = add(c, v.z);
|
c = add(c, v.z);
|
||||||
@ -470,13 +478,9 @@ inline __device__ void from_float(uint4& dst, Float8_ src) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// From float16 to float32.
|
// From float16 to float32.
|
||||||
inline __device__ float to_float(uint16_t u) {
|
inline __device__ float to_float(uint16_t u) { return half_to_float(u); }
|
||||||
return half_to_float(u);
|
|
||||||
}
|
|
||||||
|
|
||||||
inline __device__ float2 to_float(uint32_t u) {
|
inline __device__ float2 to_float(uint32_t u) { return half2_to_float2(u); }
|
||||||
return half2_to_float2(u);
|
|
||||||
}
|
|
||||||
|
|
||||||
inline __device__ Float4_ to_float(uint2 u) {
|
inline __device__ Float4_ to_float(uint2 u) {
|
||||||
Float4_ tmp;
|
Float4_ tmp;
|
||||||
@ -495,8 +499,6 @@ inline __device__ Float8_ to_float(uint4 u) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Zero-out a variable.
|
// Zero-out a variable.
|
||||||
inline __device__ void zero(uint16_t& dst) {
|
inline __device__ void zero(uint16_t& dst) { dst = uint16_t(0); }
|
||||||
dst = uint16_t(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
/*
|
/*
|
||||||
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
* Adapted from
|
||||||
* and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
||||||
|
* and
|
||||||
|
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
||||||
* Copyright (c) 2023, The vLLM team.
|
* Copyright (c) 2023, The vLLM team.
|
||||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||||
*
|
*
|
||||||
@ -38,37 +40,35 @@ struct Float8_ {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// FP32 vector types for Q, K, V.
|
// FP32 vector types for Q, K, V.
|
||||||
template<>
|
template <>
|
||||||
struct Vec<float, 1> {
|
struct Vec<float, 1> {
|
||||||
using Type = float;
|
using Type = float;
|
||||||
};
|
};
|
||||||
template<>
|
template <>
|
||||||
struct Vec<float, 2> {
|
struct Vec<float, 2> {
|
||||||
using Type = float2;
|
using Type = float2;
|
||||||
};
|
};
|
||||||
template<>
|
template <>
|
||||||
struct Vec<float, 4> {
|
struct Vec<float, 4> {
|
||||||
using Type = float4;
|
using Type = float4;
|
||||||
};
|
};
|
||||||
|
|
||||||
// FP32 accumulator vector types corresponding to Vec.
|
// FP32 accumulator vector types corresponding to Vec.
|
||||||
template<>
|
template <>
|
||||||
struct FloatVec<float> {
|
struct FloatVec<float> {
|
||||||
using Type = float;
|
using Type = float;
|
||||||
};
|
};
|
||||||
template<>
|
template <>
|
||||||
struct FloatVec<float2> {
|
struct FloatVec<float2> {
|
||||||
using Type = float2;
|
using Type = float2;
|
||||||
};
|
};
|
||||||
template<>
|
template <>
|
||||||
struct FloatVec<float4> {
|
struct FloatVec<float4> {
|
||||||
using Type = float4;
|
using Type = float4;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Vector addition.
|
// Vector addition.
|
||||||
inline __device__ float add(float a, float b) {
|
inline __device__ float add(float a, float b) { return a + b; }
|
||||||
return a + b;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline __device__ float2 add(float2 a, float2 b) {
|
inline __device__ float2 add(float2 a, float2 b) {
|
||||||
float2 c;
|
float2 c;
|
||||||
@ -87,12 +87,12 @@ inline __device__ float4 add(float4 a, float4 b) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Vector multiplication.
|
// Vector multiplication.
|
||||||
template<>
|
template <>
|
||||||
inline __device__ float mul<float, float>(float a, float b) {
|
inline __device__ float mul<float, float>(float a, float b) {
|
||||||
return a * b;
|
return a * b;
|
||||||
}
|
}
|
||||||
|
|
||||||
template<>
|
template <>
|
||||||
inline __device__ float2 mul(float2 a, float2 b) {
|
inline __device__ float2 mul(float2 a, float2 b) {
|
||||||
float2 c;
|
float2 c;
|
||||||
c.x = a.x * b.x;
|
c.x = a.x * b.x;
|
||||||
@ -100,7 +100,7 @@ inline __device__ float2 mul(float2 a, float2 b) {
|
|||||||
return c;
|
return c;
|
||||||
}
|
}
|
||||||
|
|
||||||
template<>
|
template <>
|
||||||
inline __device__ float2 mul(float a, float2 b) {
|
inline __device__ float2 mul(float a, float2 b) {
|
||||||
float2 c;
|
float2 c;
|
||||||
c.x = a * b.x;
|
c.x = a * b.x;
|
||||||
@ -108,7 +108,7 @@ inline __device__ float2 mul(float a, float2 b) {
|
|||||||
return c;
|
return c;
|
||||||
}
|
}
|
||||||
|
|
||||||
template<>
|
template <>
|
||||||
inline __device__ float4 mul(float4 a, float4 b) {
|
inline __device__ float4 mul(float4 a, float4 b) {
|
||||||
float4 c;
|
float4 c;
|
||||||
c.x = a.x * b.x;
|
c.x = a.x * b.x;
|
||||||
@ -118,7 +118,7 @@ inline __device__ float4 mul(float4 a, float4 b) {
|
|||||||
return c;
|
return c;
|
||||||
}
|
}
|
||||||
|
|
||||||
template<>
|
template <>
|
||||||
inline __device__ float4 mul(float a, float4 b) {
|
inline __device__ float4 mul(float a, float4 b) {
|
||||||
float4 c;
|
float4 c;
|
||||||
c.x = a * b.x;
|
c.x = a * b.x;
|
||||||
@ -129,9 +129,7 @@ inline __device__ float4 mul(float a, float4 b) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Vector fused multiply-add.
|
// Vector fused multiply-add.
|
||||||
inline __device__ float fma(float a, float b, float c) {
|
inline __device__ float fma(float a, float b, float c) { return a * b + c; }
|
||||||
return a * b + c;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline __device__ float2 fma(float2 a, float2 b, float2 c) {
|
inline __device__ float2 fma(float2 a, float2 b, float2 c) {
|
||||||
float2 d;
|
float2 d;
|
||||||
@ -182,35 +180,33 @@ inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Vector sum.
|
// Vector sum.
|
||||||
template<>
|
template <>
|
||||||
inline __device__ float sum(float v) {
|
inline __device__ float sum(float v) {
|
||||||
return v;
|
return v;
|
||||||
}
|
}
|
||||||
|
|
||||||
template<>
|
template <>
|
||||||
inline __device__ float sum(float2 v) {
|
inline __device__ float sum(float2 v) {
|
||||||
return v.x + v.y;
|
return v.x + v.y;
|
||||||
}
|
}
|
||||||
|
|
||||||
template<>
|
template <>
|
||||||
inline __device__ float sum(float4 v) {
|
inline __device__ float sum(float4 v) {
|
||||||
return v.x + v.y + v.z + v.w;
|
return v.x + v.y + v.z + v.w;
|
||||||
}
|
}
|
||||||
|
|
||||||
template<>
|
template <>
|
||||||
inline __device__ float sum(Float4_ v) {
|
inline __device__ float sum(Float4_ v) {
|
||||||
return v.x.x + v.x.y + v.y.x + v.y.y;
|
return v.x.x + v.x.y + v.y.x + v.y.y;
|
||||||
}
|
}
|
||||||
|
|
||||||
template<>
|
template <>
|
||||||
inline __device__ float sum(Float8_ v) {
|
inline __device__ float sum(Float8_ v) {
|
||||||
return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y;
|
return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Vector dot product.
|
// Vector dot product.
|
||||||
inline __device__ float dot(float a, float b) {
|
inline __device__ float dot(float a, float b) { return a * b; }
|
||||||
return a * b;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline __device__ float dot(float2 a, float2 b) {
|
inline __device__ float dot(float2 a, float2 b) {
|
||||||
float2 c = mul<float2, float2, float2>(a, b);
|
float2 c = mul<float2, float2, float2>(a, b);
|
||||||
@ -232,42 +228,24 @@ inline __device__ float dot(Float8_ a, Float8_ b) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// From float to float.
|
// From float to float.
|
||||||
inline __device__ void from_float(float& dst, float src) {
|
inline __device__ void from_float(float& dst, float src) { dst = src; }
|
||||||
dst = src;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline __device__ void from_float(float2& dst, float2 src) {
|
inline __device__ void from_float(float2& dst, float2 src) { dst = src; }
|
||||||
dst = src;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline __device__ void from_float(float4& dst, float4 src) {
|
inline __device__ void from_float(float4& dst, float4 src) { dst = src; }
|
||||||
dst = src;
|
|
||||||
}
|
|
||||||
|
|
||||||
// From float to float.
|
// From float to float.
|
||||||
inline __device__ float to_float(float u) {
|
inline __device__ float to_float(float u) { return u; }
|
||||||
return u;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline __device__ float2 to_float(float2 u) {
|
inline __device__ float2 to_float(float2 u) { return u; }
|
||||||
return u;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline __device__ float4 to_float(float4 u) {
|
inline __device__ float4 to_float(float4 u) { return u; }
|
||||||
return u;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline __device__ Float4_ to_float(Float4_ u) {
|
inline __device__ Float4_ to_float(Float4_ u) { return u; }
|
||||||
return u;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline __device__ Float8_ to_float(Float8_ u) {
|
inline __device__ Float8_ to_float(Float8_ u) { return u; }
|
||||||
return u;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Zero-out a variable.
|
// Zero-out a variable.
|
||||||
inline __device__ void zero(float& dst) {
|
inline __device__ void zero(float& dst) { dst = 0.f; }
|
||||||
dst = 0.f;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
@ -4,38 +4,38 @@
|
|||||||
|
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
#ifdef ENABLE_FP8
|
#ifdef ENABLE_FP8
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
#include <cuda_fp8.h>
|
#include <cuda_fp8.h>
|
||||||
#endif // USE_ROCM
|
#endif // USE_ROCM
|
||||||
#endif // ENABLE_FP8
|
#endif // ENABLE_FP8
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
enum class Fp8KVCacheDataType {
|
enum class Fp8KVCacheDataType {
|
||||||
kAuto = 0,
|
kAuto = 0,
|
||||||
kFp8E4M3 = 1,
|
kFp8E4M3 = 1,
|
||||||
kFp8E5M2 = 2,
|
kFp8E5M2 = 2,
|
||||||
};
|
};
|
||||||
|
|
||||||
// fp8 vector types for quantization of kv cache
|
// fp8 vector types for quantization of kv cache
|
||||||
template<>
|
template <>
|
||||||
struct Vec<uint8_t, 1> {
|
struct Vec<uint8_t, 1> {
|
||||||
using Type = uint8_t;
|
using Type = uint8_t;
|
||||||
};
|
};
|
||||||
|
|
||||||
template<>
|
template <>
|
||||||
struct Vec<uint8_t, 2> {
|
struct Vec<uint8_t, 2> {
|
||||||
using Type = uint16_t;
|
using Type = uint16_t;
|
||||||
};
|
};
|
||||||
|
|
||||||
template<>
|
template <>
|
||||||
struct Vec<uint8_t, 4> {
|
struct Vec<uint8_t, 4> {
|
||||||
using Type = uint32_t;
|
using Type = uint32_t;
|
||||||
};
|
};
|
||||||
|
|
||||||
template<>
|
template <>
|
||||||
struct Vec<uint8_t, 8> {
|
struct Vec<uint8_t, 8> {
|
||||||
using Type = uint2;
|
using Type = uint2;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
44
csrc/cache.h
44
csrc/cache.h
@ -5,36 +5,24 @@
|
|||||||
#include <map>
|
#include <map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
void swap_blocks(
|
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
|
||||||
torch::Tensor& src,
|
const torch::Tensor& block_mapping);
|
||||||
torch::Tensor& dst,
|
|
||||||
const torch::Tensor& block_mapping);
|
|
||||||
|
|
||||||
void copy_blocks(
|
void copy_blocks(std::vector<torch::Tensor>& key_caches,
|
||||||
std::vector<torch::Tensor>& key_caches,
|
std::vector<torch::Tensor>& value_caches,
|
||||||
std::vector<torch::Tensor>& value_caches,
|
const torch::Tensor& block_mapping);
|
||||||
const torch::Tensor& block_mapping);
|
|
||||||
|
|
||||||
void reshape_and_cache(
|
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
|
||||||
torch::Tensor& key,
|
torch::Tensor& key_cache, torch::Tensor& value_cache,
|
||||||
torch::Tensor& value,
|
torch::Tensor& slot_mapping,
|
||||||
torch::Tensor& key_cache,
|
const std::string& kv_cache_dtype, const float kv_scale);
|
||||||
torch::Tensor& value_cache,
|
|
||||||
torch::Tensor& slot_mapping,
|
|
||||||
const std::string& kv_cache_dtype,
|
|
||||||
const float kv_scale);
|
|
||||||
|
|
||||||
void reshape_and_cache_flash(
|
void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
|
||||||
torch::Tensor& key,
|
torch::Tensor& key_cache,
|
||||||
torch::Tensor& value,
|
torch::Tensor& value_cache,
|
||||||
torch::Tensor& key_cache,
|
torch::Tensor& slot_mapping,
|
||||||
torch::Tensor& value_cache,
|
const std::string& kv_cache_dtype);
|
||||||
torch::Tensor& slot_mapping,
|
|
||||||
const std::string& kv_cache_dtype);
|
|
||||||
|
|
||||||
// Just for unittest
|
// Just for unittest
|
||||||
void convert_fp8(
|
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
|
||||||
torch::Tensor& dst_cache,
|
const float scale, const std::string& kv_cache_dtype);
|
||||||
torch::Tensor& src_cache,
|
|
||||||
const float scale,
|
|
||||||
const std::string& kv_cache_dtype);
|
|
||||||
|
@ -6,9 +6,9 @@
|
|||||||
#include "dispatch_utils.h"
|
#include "dispatch_utils.h"
|
||||||
|
|
||||||
#ifdef USE_ROCM
|
#ifdef USE_ROCM
|
||||||
#include "quantization/fp8/amd/quant_utils.cuh"
|
#include "quantization/fp8/amd/quant_utils.cuh"
|
||||||
#else
|
#else
|
||||||
#include "quantization/fp8/nvidia/quant_utils.cuh"
|
#include "quantization/fp8/nvidia/quant_utils.cuh"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
@ -18,20 +18,17 @@
|
|||||||
|
|
||||||
#ifdef USE_ROCM
|
#ifdef USE_ROCM
|
||||||
#include <hip/hip_bf16.h>
|
#include <hip/hip_bf16.h>
|
||||||
typedef __hip_bfloat16 __nv_bfloat16;
|
typedef __hip_bfloat16 __nv_bfloat16;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
void swap_blocks(
|
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
|
||||||
torch::Tensor& src,
|
const torch::Tensor& block_mapping) {
|
||||||
torch::Tensor& dst,
|
|
||||||
const torch::Tensor& block_mapping) {
|
|
||||||
torch::Device src_device = src.device();
|
torch::Device src_device = src.device();
|
||||||
torch::Device dst_device = dst.device();
|
torch::Device dst_device = dst.device();
|
||||||
cudaMemcpyKind memcpy_type;
|
cudaMemcpyKind memcpy_type;
|
||||||
if (src_device.is_cuda() && dst_device.is_cuda()) {
|
if (src_device.is_cuda() && dst_device.is_cuda()) {
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(src_device.index() == dst_device.index(),
|
||||||
src_device.index() == dst_device.index(),
|
"src and dst must be on the same GPU");
|
||||||
"src and dst must be on the same GPU");
|
|
||||||
memcpy_type = cudaMemcpyDeviceToDevice;
|
memcpy_type = cudaMemcpyDeviceToDevice;
|
||||||
} else if (src_device.is_cuda() && dst_device.is_cpu()) {
|
} else if (src_device.is_cuda() && dst_device.is_cpu()) {
|
||||||
memcpy_type = cudaMemcpyDeviceToHost;
|
memcpy_type = cudaMemcpyDeviceToHost;
|
||||||
@ -46,11 +43,12 @@ void swap_blocks(
|
|||||||
// synchronization.
|
// synchronization.
|
||||||
TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU");
|
TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU");
|
||||||
|
|
||||||
char *src_ptr = static_cast<char*>(src.data_ptr());
|
char* src_ptr = static_cast<char*>(src.data_ptr());
|
||||||
char *dst_ptr = static_cast<char*>(dst.data_ptr());
|
char* dst_ptr = static_cast<char*>(dst.data_ptr());
|
||||||
|
|
||||||
const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
|
const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(src_device.is_cuda() ? src_device : dst_device);
|
const at::cuda::OptionalCUDAGuard device_guard(
|
||||||
|
src_device.is_cuda() ? src_device : dst_device);
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
// NOTE(woosuk): This can be slow if the number of blocks is large.
|
// NOTE(woosuk): This can be slow if the number of blocks is large.
|
||||||
const int64_t num_blocks = block_mapping.size(0);
|
const int64_t num_blocks = block_mapping.size(0);
|
||||||
@ -59,29 +57,25 @@ void swap_blocks(
|
|||||||
int64_t dst_block_number = block_mapping[i][1].item<int64_t>();
|
int64_t dst_block_number = block_mapping[i][1].item<int64_t>();
|
||||||
int64_t src_offset = src_block_number * block_size_in_bytes;
|
int64_t src_offset = src_block_number * block_size_in_bytes;
|
||||||
int64_t dst_offset = dst_block_number * block_size_in_bytes;
|
int64_t dst_offset = dst_block_number * block_size_in_bytes;
|
||||||
cudaMemcpyAsync(
|
cudaMemcpyAsync(dst_ptr + dst_offset, src_ptr + src_offset,
|
||||||
dst_ptr + dst_offset,
|
block_size_in_bytes, memcpy_type, stream);
|
||||||
src_ptr + src_offset,
|
|
||||||
block_size_in_bytes,
|
|
||||||
memcpy_type,
|
|
||||||
stream);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
// Grid: (num_layers, num_pairs)
|
// Grid: (num_layers, num_pairs)
|
||||||
template<typename scalar_t>
|
template <typename scalar_t>
|
||||||
__global__ void copy_blocks_kernel(
|
__global__ void copy_blocks_kernel(int64_t* key_cache_ptrs,
|
||||||
int64_t* key_cache_ptrs,
|
int64_t* value_cache_ptrs,
|
||||||
int64_t* value_cache_ptrs,
|
const int64_t* __restrict__ block_mapping,
|
||||||
const int64_t* __restrict__ block_mapping,
|
const int numel_per_block) {
|
||||||
const int numel_per_block) {
|
|
||||||
const int layer_idx = blockIdx.x;
|
const int layer_idx = blockIdx.x;
|
||||||
const int pair_idx = blockIdx.y;
|
const int pair_idx = blockIdx.y;
|
||||||
|
|
||||||
scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]);
|
scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]);
|
||||||
scalar_t* value_cache = reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
|
scalar_t* value_cache =
|
||||||
|
reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
|
||||||
int64_t src_block_number = block_mapping[2 * pair_idx];
|
int64_t src_block_number = block_mapping[2 * pair_idx];
|
||||||
int64_t dst_block_number = block_mapping[2 * pair_idx + 1];
|
int64_t dst_block_number = block_mapping[2 * pair_idx + 1];
|
||||||
|
|
||||||
@ -99,12 +93,11 @@ __global__ void copy_blocks_kernel(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
void copy_blocks(
|
void copy_blocks(std::vector<torch::Tensor>& key_caches,
|
||||||
std::vector<torch::Tensor>& key_caches,
|
std::vector<torch::Tensor>& value_caches,
|
||||||
std::vector<torch::Tensor>& value_caches,
|
const torch::Tensor& block_mapping) {
|
||||||
const torch::Tensor& block_mapping) {
|
|
||||||
int num_layers = key_caches.size();
|
int num_layers = key_caches.size();
|
||||||
TORCH_CHECK(num_layers == value_caches.size());
|
TORCH_CHECK(num_layers == value_caches.size());
|
||||||
if (num_layers == 0) {
|
if (num_layers == 0) {
|
||||||
@ -118,8 +111,10 @@ void copy_blocks(
|
|||||||
int64_t key_cache_ptrs[num_layers];
|
int64_t key_cache_ptrs[num_layers];
|
||||||
int64_t value_cache_ptrs[num_layers];
|
int64_t value_cache_ptrs[num_layers];
|
||||||
for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
|
for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
|
||||||
key_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
|
key_cache_ptrs[layer_idx] =
|
||||||
value_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
|
reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
|
||||||
|
value_cache_ptrs[layer_idx] =
|
||||||
|
reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
|
||||||
}
|
}
|
||||||
|
|
||||||
// block_mapping is a 2D tensor with shape (num_pairs, 2).
|
// block_mapping is a 2D tensor with shape (num_pairs, 2).
|
||||||
@ -127,10 +122,12 @@ void copy_blocks(
|
|||||||
|
|
||||||
// Move the data structures to the GPU.
|
// Move the data structures to the GPU.
|
||||||
// NOTE: This synchronizes the CPU and GPU.
|
// NOTE: This synchronizes the CPU and GPU.
|
||||||
torch::Tensor key_cache_ptrs_tensor = torch::from_blob(
|
torch::Tensor key_cache_ptrs_tensor =
|
||||||
key_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
|
torch::from_blob(key_cache_ptrs, {num_layers}, torch::kInt64)
|
||||||
torch::Tensor value_cache_ptrs_tensor = torch::from_blob(
|
.to(cache_device);
|
||||||
value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
|
torch::Tensor value_cache_ptrs_tensor =
|
||||||
|
torch::from_blob(value_cache_ptrs, {num_layers}, torch::kInt64)
|
||||||
|
.to(cache_device);
|
||||||
|
|
||||||
// Launch the kernel.
|
// Launch the kernel.
|
||||||
const int numel_per_block = key_caches[0][0].numel();
|
const int numel_per_block = key_caches[0][0].numel();
|
||||||
@ -139,31 +136,28 @@ void copy_blocks(
|
|||||||
const at::cuda::OptionalCUDAGuard device_guard(cache_device);
|
const at::cuda::OptionalCUDAGuard device_guard(cache_device);
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
|
VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
|
||||||
key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
|
key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
|
||||||
vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||||
key_cache_ptrs_tensor.data_ptr<int64_t>(),
|
key_cache_ptrs_tensor.data_ptr<int64_t>(),
|
||||||
value_cache_ptrs_tensor.data_ptr<int64_t>(),
|
value_cache_ptrs_tensor.data_ptr<int64_t>(),
|
||||||
block_mapping.data_ptr<int64_t>(),
|
block_mapping.data_ptr<int64_t>(), numel_per_block);
|
||||||
numel_per_block);
|
}));
|
||||||
}));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
template<typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
|
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
|
||||||
__global__ void reshape_and_cache_kernel(
|
__global__ void reshape_and_cache_kernel(
|
||||||
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
||||||
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
|
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
|
||||||
cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x,
|
||||||
cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
|
// block_size, x]
|
||||||
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size,
|
||||||
const int key_stride,
|
// block_size]
|
||||||
const int value_stride,
|
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||||
const int num_heads,
|
const int key_stride, const int value_stride, const int num_heads,
|
||||||
const int head_size,
|
const int head_size, const int block_size, const int x,
|
||||||
const int block_size,
|
const float kv_scale) {
|
||||||
const int x,
|
|
||||||
const float kv_scale) {
|
|
||||||
const int64_t token_idx = blockIdx.x;
|
const int64_t token_idx = blockIdx.x;
|
||||||
const int64_t slot_idx = slot_mapping[token_idx];
|
const int64_t slot_idx = slot_mapping[token_idx];
|
||||||
if (slot_idx < 0) {
|
if (slot_idx < 0) {
|
||||||
@ -184,40 +178,39 @@ __global__ void reshape_and_cache_kernel(
|
|||||||
const int x_idx = head_offset / x;
|
const int x_idx = head_offset / x;
|
||||||
const int x_offset = head_offset % x;
|
const int x_offset = head_offset % x;
|
||||||
|
|
||||||
const int64_t tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
|
const int64_t tgt_key_idx =
|
||||||
+ head_idx * (head_size / x) * block_size * x
|
block_idx * num_heads * (head_size / x) * block_size * x +
|
||||||
+ x_idx * block_size * x
|
head_idx * (head_size / x) * block_size * x + x_idx * block_size * x +
|
||||||
+ block_offset * x
|
block_offset * x + x_offset;
|
||||||
+ x_offset;
|
const int64_t tgt_value_idx =
|
||||||
const int64_t tgt_value_idx = block_idx * num_heads * head_size * block_size
|
block_idx * num_heads * head_size * block_size +
|
||||||
+ head_idx * head_size * block_size
|
head_idx * head_size * block_size + head_offset * block_size +
|
||||||
+ head_offset * block_size
|
block_offset;
|
||||||
+ block_offset;
|
|
||||||
scalar_t tgt_key = key[src_key_idx];
|
scalar_t tgt_key = key[src_key_idx];
|
||||||
scalar_t tgt_value = value[src_value_idx];
|
scalar_t tgt_value = value[src_value_idx];
|
||||||
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
|
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
|
||||||
key_cache[tgt_key_idx] = tgt_key;
|
key_cache[tgt_key_idx] = tgt_key;
|
||||||
value_cache[tgt_value_idx] = tgt_value;
|
value_cache[tgt_value_idx] = tgt_value;
|
||||||
} else {
|
} else {
|
||||||
key_cache[tgt_key_idx] = fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, kv_scale);
|
key_cache[tgt_key_idx] =
|
||||||
value_cache[tgt_value_idx] = fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, kv_scale);
|
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, kv_scale);
|
||||||
|
value_cache[tgt_value_idx] =
|
||||||
|
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, kv_scale);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename scalar_t>
|
template <typename scalar_t>
|
||||||
__global__ void reshape_and_cache_flash_kernel(
|
__global__ void reshape_and_cache_flash_kernel(
|
||||||
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
||||||
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
|
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
|
||||||
scalar_t* __restrict__ k_cache, // [num_blocks, block_size, num_heads, head_size]
|
scalar_t* __restrict__ k_cache, // [num_blocks, block_size, num_heads,
|
||||||
scalar_t* __restrict__ v_cache, // [num_blocks, block_size, num_heads, head_size]
|
// head_size]
|
||||||
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
scalar_t* __restrict__ v_cache, // [num_blocks, block_size, num_heads,
|
||||||
const int block_stride,
|
// head_size]
|
||||||
const int key_stride,
|
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||||
const int value_stride,
|
const int block_stride, const int key_stride, const int value_stride,
|
||||||
const int num_heads,
|
const int num_heads, const int head_size, const int block_size) {
|
||||||
const int head_size,
|
|
||||||
const int block_size) {
|
|
||||||
const int64_t token_idx = blockIdx.x;
|
const int64_t token_idx = blockIdx.x;
|
||||||
const int64_t slot_idx = slot_mapping[token_idx];
|
const int64_t slot_idx = slot_mapping[token_idx];
|
||||||
// NOTE: slot_idx can be -1 if the token is padded
|
// NOTE: slot_idx can be -1 if the token is padded
|
||||||
@ -232,43 +225,37 @@ __global__ void reshape_and_cache_flash_kernel(
|
|||||||
const int64_t src_value_idx = token_idx * value_stride + i;
|
const int64_t src_value_idx = token_idx * value_stride + i;
|
||||||
const int head_idx = i / head_size;
|
const int head_idx = i / head_size;
|
||||||
const int head_offset = i % head_size;
|
const int head_offset = i % head_size;
|
||||||
const int64_t tgt_value_idx = block_idx * block_stride
|
const int64_t tgt_value_idx = block_idx * block_stride +
|
||||||
+ block_offset * num_heads * head_size
|
block_offset * num_heads * head_size +
|
||||||
+ head_idx * head_size
|
head_idx * head_size + head_offset;
|
||||||
+ head_offset;
|
|
||||||
k_cache[tgt_value_idx] = key[src_key_idx];
|
k_cache[tgt_value_idx] = key[src_key_idx];
|
||||||
v_cache[tgt_value_idx] = value[src_value_idx];
|
v_cache[tgt_value_idx] = value[src_value_idx];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
// KV_T is the stored data type of kv-cache.
|
// KV_T is the stored data type of kv-cache.
|
||||||
// CACHE_T is the data type of key and value tensors.
|
// CACHE_T is the data type of key and value tensors.
|
||||||
// KV_DTYPE is the real data type of kv-cache.
|
// KV_DTYPE is the real data type of kv-cache.
|
||||||
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \
|
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \
|
||||||
vllm::reshape_and_cache_kernel<KV_T, CACHE_T, KV_DTYPE><<<grid, block, 0, stream>>>( \
|
vllm::reshape_and_cache_kernel<KV_T, CACHE_T, KV_DTYPE> \
|
||||||
reinterpret_cast<KV_T*>(key.data_ptr()), \
|
<<<grid, block, 0, stream>>>( \
|
||||||
reinterpret_cast<KV_T*>(value.data_ptr()), \
|
reinterpret_cast<KV_T*>(key.data_ptr()), \
|
||||||
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
|
reinterpret_cast<KV_T*>(value.data_ptr()), \
|
||||||
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
|
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
|
||||||
slot_mapping.data_ptr<int64_t>(), \
|
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
|
||||||
key_stride, \
|
slot_mapping.data_ptr<int64_t>(), key_stride, value_stride, \
|
||||||
value_stride, \
|
num_heads, head_size, block_size, x, kv_scale);
|
||||||
num_heads, \
|
|
||||||
head_size, \
|
|
||||||
block_size, \
|
|
||||||
x, \
|
|
||||||
kv_scale);
|
|
||||||
|
|
||||||
void reshape_and_cache(
|
void reshape_and_cache(
|
||||||
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
||||||
torch::Tensor& value, // [num_tokens, num_heads, head_size]
|
torch::Tensor& value, // [num_tokens, num_heads, head_size]
|
||||||
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
torch::Tensor&
|
||||||
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||||
torch::Tensor& slot_mapping, // [num_tokens]
|
torch::Tensor&
|
||||||
const std::string& kv_cache_dtype,
|
value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||||
const float kv_scale)
|
torch::Tensor& slot_mapping, // [num_tokens]
|
||||||
{
|
const std::string& kv_cache_dtype, const float kv_scale) {
|
||||||
int num_tokens = key.size(0);
|
int num_tokens = key.size(0);
|
||||||
int num_heads = key.size(1);
|
int num_heads = key.size(1);
|
||||||
int head_size = key.size(2);
|
int head_size = key.size(2);
|
||||||
@ -283,17 +270,17 @@ void reshape_and_cache(
|
|||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype, CALL_RESHAPE_AND_CACHE)
|
DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
|
||||||
|
CALL_RESHAPE_AND_CACHE)
|
||||||
}
|
}
|
||||||
|
|
||||||
void reshape_and_cache_flash(
|
void reshape_and_cache_flash(
|
||||||
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
||||||
torch::Tensor& value, // [num_tokens, num_heads, head_size]
|
torch::Tensor& value, // [num_tokens, num_heads, head_size]
|
||||||
torch::Tensor& k_cache, // [num_blocks, block_size, num_heads, head_size]
|
torch::Tensor& k_cache, // [num_blocks, block_size, num_heads, head_size]
|
||||||
torch::Tensor& v_cache, // [num_blocks, block_size, num_heads, head_size]
|
torch::Tensor& v_cache, // [num_blocks, block_size, num_heads, head_size]
|
||||||
torch::Tensor& slot_mapping, // [num_tokens]
|
torch::Tensor& slot_mapping, // [num_tokens]
|
||||||
const std::string& kv_cache_dtype)
|
const std::string& kv_cache_dtype) {
|
||||||
{
|
|
||||||
// FIXME: only support auto datatype, does not support fp8
|
// FIXME: only support auto datatype, does not support fp8
|
||||||
if (kv_cache_dtype != "auto") {
|
if (kv_cache_dtype != "auto") {
|
||||||
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
|
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
|
||||||
@ -313,62 +300,47 @@ void reshape_and_cache_flash(
|
|||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(
|
VLLM_DISPATCH_FLOATING_TYPES(
|
||||||
key.scalar_type(),
|
key.scalar_type(), "reshape_and_cache_flash", [&] {
|
||||||
"reshape_and_cache_flash",
|
vllm::reshape_and_cache_flash_kernel<scalar_t>
|
||||||
[&] {
|
<<<grid, block, 0, stream>>>(
|
||||||
vllm::reshape_and_cache_flash_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
|
||||||
key.data_ptr<scalar_t>(),
|
k_cache.data_ptr<scalar_t>(), v_cache.data_ptr<scalar_t>(),
|
||||||
value.data_ptr<scalar_t>(),
|
slot_mapping.data_ptr<int64_t>(), block_stride, key_stride,
|
||||||
k_cache.data_ptr<scalar_t>(),
|
value_stride, num_heads, head_size, block_size);
|
||||||
v_cache.data_ptr<scalar_t>(),
|
});
|
||||||
slot_mapping.data_ptr<int64_t>(),
|
|
||||||
block_stride,
|
|
||||||
key_stride,
|
|
||||||
value_stride,
|
|
||||||
num_heads,
|
|
||||||
head_size,
|
|
||||||
block_size);
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
template<typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
||||||
__global__ void convert_fp8_kernel(
|
__global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache,
|
||||||
const Tin* __restrict__ src_cache,
|
Tout* __restrict__ dst_cache,
|
||||||
Tout* __restrict__ dst_cache,
|
const float kv_scale,
|
||||||
const float kv_scale,
|
const int64_t block_stride) {
|
||||||
const int64_t block_stride) {
|
|
||||||
const int64_t block_idx = blockIdx.x;
|
const int64_t block_idx = blockIdx.x;
|
||||||
for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
|
for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
|
||||||
int64_t idx = block_idx * block_stride + i;
|
int64_t idx = block_idx * block_stride + i;
|
||||||
dst_cache[idx] = fp8::scaled_convert<Tout, Tin, kv_dt>(src_cache[idx], kv_scale);
|
dst_cache[idx] =
|
||||||
|
fp8::scaled_convert<Tout, Tin, kv_dt>(src_cache[idx], kv_scale);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
#define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE) \
|
#define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE) \
|
||||||
vllm::convert_fp8_kernel<Tout, Tin, KV_DTYPE><<<grid, block, 0, stream>>>( \
|
vllm::convert_fp8_kernel<Tout, Tin, KV_DTYPE><<<grid, block, 0, stream>>>( \
|
||||||
reinterpret_cast<Tin*>(src_cache.data_ptr()), \
|
reinterpret_cast<Tin*>(src_cache.data_ptr()), \
|
||||||
reinterpret_cast<Tout*>(dst_cache.data_ptr()), \
|
reinterpret_cast<Tout*>(dst_cache.data_ptr()), kv_scale, block_stride);
|
||||||
kv_scale, \
|
|
||||||
block_stride);
|
|
||||||
|
|
||||||
// Only for testing.
|
// Only for testing.
|
||||||
void convert_fp8(
|
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
|
||||||
torch::Tensor& dst_cache,
|
const float kv_scale, const std::string& kv_cache_dtype) {
|
||||||
torch::Tensor& src_cache,
|
|
||||||
const float kv_scale,
|
|
||||||
const std::string& kv_cache_dtype)
|
|
||||||
{
|
|
||||||
torch::Device src_device = src_cache.device();
|
torch::Device src_device = src_cache.device();
|
||||||
torch::Device dst_device = dst_cache.device();
|
torch::Device dst_device = dst_cache.device();
|
||||||
TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU")
|
TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU")
|
||||||
TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU")
|
TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU")
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(src_device.index() == dst_device.index(),
|
||||||
src_device.index() == dst_device.index(),
|
"src and dst must be on the same GPU");
|
||||||
"src and dst must be on the same GPU");
|
|
||||||
at::cuda::OptionalCUDAGuard device_guard(src_device);
|
at::cuda::OptionalCUDAGuard device_guard(src_device);
|
||||||
|
|
||||||
int64_t num_blocks = src_cache.size(0);
|
int64_t num_blocks = src_cache.size(0);
|
||||||
@ -398,13 +370,15 @@ void convert_fp8(
|
|||||||
} else if (src_cache.dtype() == at::ScalarType::Half) {
|
} else if (src_cache.dtype() == at::ScalarType::Half) {
|
||||||
CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
|
CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
|
||||||
} else if (src_cache.dtype() == at::ScalarType::BFloat16) {
|
} else if (src_cache.dtype() == at::ScalarType::BFloat16) {
|
||||||
CALL_CONVERT_FP8(uint8_t, __nv_bfloat16, vllm::Fp8KVCacheDataType::kFp8E4M3);
|
CALL_CONVERT_FP8(uint8_t, __nv_bfloat16,
|
||||||
|
vllm::Fp8KVCacheDataType::kFp8E4M3);
|
||||||
} else if (dst_cache.dtype() == at::ScalarType::Float) {
|
} else if (dst_cache.dtype() == at::ScalarType::Float) {
|
||||||
CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
|
CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
|
||||||
} else if (dst_cache.dtype() == at::ScalarType::Half) {
|
} else if (dst_cache.dtype() == at::ScalarType::Half) {
|
||||||
CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
|
CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
|
||||||
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
|
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
|
||||||
CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
|
CALL_CONVERT_FP8(__nv_bfloat16, uint8_t,
|
||||||
|
vllm::Fp8KVCacheDataType::kFp8E4M3);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype);
|
TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype);
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
#include "cpu_types.hpp"
|
#include "cpu_types.hpp"
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
template <typename scalar_t, vec_op::FP32Vec8 (*func)(const vec_op::FP32Vec8 &),
|
template <typename scalar_t, vec_op::FP32Vec8 (*func)(const vec_op::FP32Vec8&),
|
||||||
bool is_gated>
|
bool is_gated>
|
||||||
void activation_kernel(int num_tokens, int d, scalar_t *__restrict__ input,
|
void activation_kernel(int num_tokens, int d, scalar_t* __restrict__ input,
|
||||||
scalar_t *__restrict__ output) {
|
scalar_t* __restrict__ output) {
|
||||||
using scalar_vec_t = vec_op::vec_t<scalar_t>;
|
using scalar_vec_t = vec_op::vec_t<scalar_t>;
|
||||||
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
|
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
|
||||||
|
|
||||||
@ -34,13 +34,13 @@ void activation_kernel(int num_tokens, int d, scalar_t *__restrict__ input,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
FORCE_INLINE vec_op::FP32Vec8 silu_act(const vec_op::FP32Vec8 &x) {
|
FORCE_INLINE vec_op::FP32Vec8 silu_act(const vec_op::FP32Vec8& x) {
|
||||||
const vec_op::FP32Vec8 zeros(0.0);
|
const vec_op::FP32Vec8 zeros(0.0);
|
||||||
const vec_op::FP32Vec8 ones(1.0);
|
const vec_op::FP32Vec8 ones(1.0);
|
||||||
return x / (ones + (zeros - x).exp());
|
return x / (ones + (zeros - x).exp());
|
||||||
}
|
}
|
||||||
|
|
||||||
FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8 &x) {
|
FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8& x) {
|
||||||
const vec_op::FP32Vec8 ones(1.0);
|
const vec_op::FP32Vec8 ones(1.0);
|
||||||
const vec_op::FP32Vec8 w1(0.79788456f);
|
const vec_op::FP32Vec8 w1(0.79788456f);
|
||||||
const vec_op::FP32Vec8 w2(0.044715f);
|
const vec_op::FP32Vec8 w2(0.044715f);
|
||||||
@ -50,7 +50,7 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8 &x) {
|
|||||||
return w3 * x * (ones + t);
|
return w3 * x * (ones + t);
|
||||||
}
|
}
|
||||||
|
|
||||||
FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8 &x) {
|
FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8& x) {
|
||||||
const vec_op::FP32Vec8 ones(1.0);
|
const vec_op::FP32Vec8 ones(1.0);
|
||||||
const vec_op::FP32Vec8 w1(0.79788456f);
|
const vec_op::FP32Vec8 w1(0.79788456f);
|
||||||
const vec_op::FP32Vec8 w2(0.044715f);
|
const vec_op::FP32Vec8 w2(0.044715f);
|
||||||
@ -59,14 +59,14 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8 &x) {
|
|||||||
return w3 * x * (ones + t);
|
return w3 * x * (ones + t);
|
||||||
}
|
}
|
||||||
|
|
||||||
FORCE_INLINE vec_op::FP32Vec8 gelu_act(const vec_op::FP32Vec8 &x) {
|
FORCE_INLINE vec_op::FP32Vec8 gelu_act(const vec_op::FP32Vec8& x) {
|
||||||
const vec_op::FP32Vec8 ones(1.0);
|
const vec_op::FP32Vec8 ones(1.0);
|
||||||
const vec_op::FP32Vec8 w1(M_SQRT1_2);
|
const vec_op::FP32Vec8 w1(M_SQRT1_2);
|
||||||
const vec_op::FP32Vec8 w2(0.5);
|
const vec_op::FP32Vec8 w2(0.5);
|
||||||
return x * w2 * (ones + (x * w1).er());
|
return x * w2 * (ones + (x * w1).er());
|
||||||
}
|
}
|
||||||
|
|
||||||
FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8 &x) {
|
FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8& x) {
|
||||||
const vec_op::FP32Vec8 ones(1.0);
|
const vec_op::FP32Vec8 ones(1.0);
|
||||||
const vec_op::FP32Vec8 w1(M_SQRT2 * M_2_SQRTPI * 0.5);
|
const vec_op::FP32Vec8 w1(M_SQRT2 * M_2_SQRTPI * 0.5);
|
||||||
const vec_op::FP32Vec8 w2(0.5);
|
const vec_op::FP32Vec8 w2(0.5);
|
||||||
@ -75,40 +75,36 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8 &x) {
|
|||||||
const vec_op::FP32Vec8 inner = w1 * (x + x_3 * w3);
|
const vec_op::FP32Vec8 inner = w1 * (x + x_3 * w3);
|
||||||
return x * w2 * (ones + inner.tanh());
|
return x * w2 * (ones + inner.tanh());
|
||||||
}
|
}
|
||||||
}; // namespace
|
}; // namespace
|
||||||
|
|
||||||
void silu_and_mul(torch::Tensor &out, torch::Tensor &input) {
|
void silu_and_mul(torch::Tensor& out, torch::Tensor& input) {
|
||||||
int num_tokens = input.numel() / input.size(-1);
|
int num_tokens = input.numel() / input.size(-1);
|
||||||
int d = input.size(-1) / 2;
|
int d = input.size(-1) / 2;
|
||||||
|
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(
|
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "silu_and_mul_impl", [&] {
|
||||||
input.scalar_type(), "silu_and_mul_impl", [&] {
|
CPU_KERNEL_GUARD_IN(silu_and_mul_impl)
|
||||||
CPU_KERNEL_GUARD_IN(silu_and_mul_impl)
|
activation_kernel<scalar_t, silu_act, true>(
|
||||||
activation_kernel<scalar_t, silu_act, true>(num_tokens, d,
|
num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
|
||||||
input.data_ptr<scalar_t>(),
|
CPU_KERNEL_GUARD_OUT(silu_and_mul_impl)
|
||||||
out.data_ptr<scalar_t>());
|
});
|
||||||
CPU_KERNEL_GUARD_OUT(silu_and_mul_impl)
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void gelu_and_mul(torch::Tensor &out, // [..., d]
|
void gelu_and_mul(torch::Tensor& out, // [..., d]
|
||||||
torch::Tensor &input) // [..., 2 * d]
|
torch::Tensor& input) // [..., 2 * d]
|
||||||
{
|
{
|
||||||
int num_tokens = input.numel() / input.size(-1);
|
int num_tokens = input.numel() / input.size(-1);
|
||||||
int d = input.size(-1) / 2;
|
int d = input.size(-1) / 2;
|
||||||
|
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(
|
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_and_mul_impl", [&] {
|
||||||
input.scalar_type(), "gelu_and_mul_impl", [&] {
|
CPU_KERNEL_GUARD_IN(gelu_and_mul_impl)
|
||||||
CPU_KERNEL_GUARD_IN(gelu_and_mul_impl)
|
activation_kernel<scalar_t, gelu_act, true>(
|
||||||
activation_kernel<scalar_t, gelu_act, true>(num_tokens, d,
|
num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
|
||||||
input.data_ptr<scalar_t>(),
|
CPU_KERNEL_GUARD_OUT(gelu_and_mul_impl)
|
||||||
out.data_ptr<scalar_t>());
|
});
|
||||||
CPU_KERNEL_GUARD_OUT(gelu_and_mul_impl)
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void gelu_tanh_and_mul(torch::Tensor &out, // [..., d]
|
void gelu_tanh_and_mul(torch::Tensor& out, // [..., d]
|
||||||
torch::Tensor &input) // [..., 2 * d]
|
torch::Tensor& input) // [..., 2 * d]
|
||||||
{
|
{
|
||||||
int num_tokens = input.numel() / input.size(-1);
|
int num_tokens = input.numel() / input.size(-1);
|
||||||
int d = input.size(-1) / 2;
|
int d = input.size(-1) / 2;
|
||||||
@ -123,7 +119,7 @@ void gelu_tanh_and_mul(torch::Tensor &out, // [..., d]
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void gelu_new(torch::Tensor &out, torch::Tensor &input) {
|
void gelu_new(torch::Tensor& out, torch::Tensor& input) {
|
||||||
int num_tokens = input.numel() / input.size(-1);
|
int num_tokens = input.numel() / input.size(-1);
|
||||||
int d = input.size(-1);
|
int d = input.size(-1);
|
||||||
|
|
||||||
@ -135,7 +131,7 @@ void gelu_new(torch::Tensor &out, torch::Tensor &input) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void gelu_fast(torch::Tensor &out, torch::Tensor &input) {
|
void gelu_fast(torch::Tensor& out, torch::Tensor& input) {
|
||||||
int num_tokens = input.numel() / input.size(-1);
|
int num_tokens = input.numel() / input.size(-1);
|
||||||
int d = input.size(-1);
|
int d = input.size(-1);
|
||||||
|
|
||||||
|
@ -2,7 +2,8 @@
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
template <typename scalar_t> struct KernelVecType {
|
template <typename scalar_t>
|
||||||
|
struct KernelVecType {
|
||||||
using q_load_vec_type = void;
|
using q_load_vec_type = void;
|
||||||
using q_vec_type = void;
|
using q_vec_type = void;
|
||||||
using k_load_vec_type = void;
|
using k_load_vec_type = void;
|
||||||
@ -11,7 +12,8 @@ template <typename scalar_t> struct KernelVecType {
|
|||||||
using v_load_vec_type = void;
|
using v_load_vec_type = void;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <> struct KernelVecType<float> {
|
template <>
|
||||||
|
struct KernelVecType<float> {
|
||||||
using q_load_vec_type = vec_op::FP32Vec4;
|
using q_load_vec_type = vec_op::FP32Vec4;
|
||||||
using q_vec_type = vec_op::FP32Vec16;
|
using q_vec_type = vec_op::FP32Vec16;
|
||||||
using k_load_vec_type = vec_op::FP32Vec16;
|
using k_load_vec_type = vec_op::FP32Vec16;
|
||||||
@ -21,7 +23,8 @@ template <> struct KernelVecType<float> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
#ifdef __AVX512BF16__
|
#ifdef __AVX512BF16__
|
||||||
template <> struct KernelVecType<c10::BFloat16> {
|
template <>
|
||||||
|
struct KernelVecType<c10::BFloat16> {
|
||||||
using q_load_vec_type = vec_op::BF16Vec8;
|
using q_load_vec_type = vec_op::BF16Vec8;
|
||||||
using q_vec_type = vec_op::BF16Vec32;
|
using q_vec_type = vec_op::BF16Vec32;
|
||||||
using k_load_vec_type = vec_op::BF16Vec32;
|
using k_load_vec_type = vec_op::BF16Vec32;
|
||||||
@ -30,7 +33,8 @@ template <> struct KernelVecType<c10::BFloat16> {
|
|||||||
using v_load_vec_type = vec_op::BF16Vec16;
|
using v_load_vec_type = vec_op::BF16Vec16;
|
||||||
};
|
};
|
||||||
#else
|
#else
|
||||||
template <> struct KernelVecType<c10::BFloat16> {
|
template <>
|
||||||
|
struct KernelVecType<c10::BFloat16> {
|
||||||
using q_load_vec_type = vec_op::BF16Vec8;
|
using q_load_vec_type = vec_op::BF16Vec8;
|
||||||
using q_vec_type = vec_op::FP32Vec16;
|
using q_vec_type = vec_op::FP32Vec16;
|
||||||
using k_load_vec_type = vec_op::BF16Vec16;
|
using k_load_vec_type = vec_op::BF16Vec16;
|
||||||
@ -41,7 +45,7 @@ template <> struct KernelVecType<c10::BFloat16> {
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
FORCE_INLINE std::pair<T, T> reduceSoftmax(T *data, const int size,
|
FORCE_INLINE std::pair<T, T> reduceSoftmax(T* data, const int size,
|
||||||
const int capacity) {
|
const int capacity) {
|
||||||
T max = data[0];
|
T max = data[0];
|
||||||
for (int i = 1; i < size; ++i) {
|
for (int i = 1; i < size; ++i) {
|
||||||
@ -67,10 +71,11 @@ FORCE_INLINE std::pair<T, T> reduceSoftmax(T *data, const int size,
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
FORCE_INLINE std::pair<T, T>
|
FORCE_INLINE std::pair<T, T> reduceSoftmaxAlibi(T* data, const int size,
|
||||||
reduceSoftmaxAlibi(T *data, const int size, const int capacity,
|
const int capacity,
|
||||||
const float alibi_slope, const int start_index,
|
const float alibi_slope,
|
||||||
const int seq_len) {
|
const int start_index,
|
||||||
|
const int seq_len) {
|
||||||
data[0] += alibi_slope * (start_index - seq_len + 1);
|
data[0] += alibi_slope * (start_index - seq_len + 1);
|
||||||
T max = data[0];
|
T max = data[0];
|
||||||
for (int i = 1; i < size; ++i) {
|
for (int i = 1; i < size; ++i) {
|
||||||
@ -98,7 +103,7 @@ reduceSoftmaxAlibi(T *data, const int size, const int capacity,
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
FORCE_INLINE void reducePartitonSoftmax(const T *max_data, T *sum_data,
|
FORCE_INLINE void reducePartitonSoftmax(const T* max_data, T* sum_data,
|
||||||
const int size) {
|
const int size) {
|
||||||
T max = max_data[0];
|
T max = max_data[0];
|
||||||
for (int i = 1; i < size; ++i) {
|
for (int i = 1; i < size; ++i) {
|
||||||
@ -132,9 +137,9 @@ struct reduceQKBlockKernel {
|
|||||||
static_assert(k_load_vec_type::get_elem_num() % x == 0);
|
static_assert(k_load_vec_type::get_elem_num() % x == 0);
|
||||||
static_assert(q_load_vec_type::get_elem_num() * sizeof(scalar_t) == 16);
|
static_assert(q_load_vec_type::get_elem_num() * sizeof(scalar_t) == 16);
|
||||||
|
|
||||||
FORCE_INLINE static void call(const scalar_t *__restrict__ q,
|
FORCE_INLINE static void call(const scalar_t* __restrict__ q,
|
||||||
const scalar_t *__restrict__ k_block,
|
const scalar_t* __restrict__ k_block,
|
||||||
float *__restrict__ logits, float scale,
|
float* __restrict__ logits, float scale,
|
||||||
const int token_num) {
|
const int token_num) {
|
||||||
const int group_num = (token_num + TOKEN_PER_GROUP - 1) / TOKEN_PER_GROUP;
|
const int group_num = (token_num + TOKEN_PER_GROUP - 1) / TOKEN_PER_GROUP;
|
||||||
|
|
||||||
@ -196,8 +201,8 @@ struct reduceQKBlockKernel {
|
|||||||
|
|
||||||
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE,
|
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE,
|
||||||
int HEAD_PARTITION_SIZE, typename acc_t>
|
int HEAD_PARTITION_SIZE, typename acc_t>
|
||||||
FORCE_INLINE void reduceValueBlock(const float *prob, const scalar_t *v_block,
|
FORCE_INLINE void reduceValueBlock(const float* prob, const scalar_t* v_block,
|
||||||
acc_t &&acc) {
|
acc_t&& acc) {
|
||||||
using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type;
|
using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type;
|
||||||
constexpr int ELEM_NUM = v_load_vec_type::get_elem_num();
|
constexpr int ELEM_NUM = v_load_vec_type::get_elem_num();
|
||||||
static_assert(BLOCK_SIZE == ELEM_NUM);
|
static_assert(BLOCK_SIZE == ELEM_NUM);
|
||||||
@ -209,27 +214,27 @@ FORCE_INLINE void reduceValueBlock(const float *prob, const scalar_t *v_block,
|
|||||||
acc[head_elem_idx] = acc[head_elem_idx] + prob_vec * fp32_v_vec;
|
acc[head_elem_idx] = acc[head_elem_idx] + prob_vec * fp32_v_vec;
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}; // namespace
|
}; // namespace
|
||||||
|
|
||||||
// Paged attention v1
|
// Paged attention v1
|
||||||
namespace {
|
namespace {
|
||||||
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE>
|
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE>
|
||||||
struct paged_attention_v1_impl {
|
struct paged_attention_v1_impl {
|
||||||
static void
|
static void call(
|
||||||
call(scalar_t *__restrict__ out, // [num_seqs, num_heads, head_size]
|
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
||||||
const scalar_t *__restrict__ q, // [num_seqs, num_heads, head_size]
|
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||||
const scalar_t *__restrict__ k_cache, // [num_blocks, num_kv_heads,
|
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
|
||||||
// head_size/x, block_size, x]
|
// head_size/x, block_size, x]
|
||||||
const scalar_t *__restrict__ v_cache, // [num_blocks, num_kv_heads,
|
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
|
||||||
// head_size, block_size]
|
// head_size, block_size]
|
||||||
const int num_kv_heads, const float scale,
|
const int num_kv_heads, const float scale,
|
||||||
const int
|
const int* __restrict__ block_tables, // [num_seqs,
|
||||||
*__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
// max_num_blocks_per_seq]
|
||||||
const int *__restrict__ seq_lens, // [num_seqs]
|
const int* __restrict__ seq_lens, // [num_seqs]
|
||||||
const int max_num_blocks_per_seq,
|
const int max_num_blocks_per_seq,
|
||||||
const float *__restrict__ alibi_slopes, // [num_heads]
|
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||||
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
||||||
const int num_seqs, const int num_heads) {
|
const int num_seqs, const int num_heads) {
|
||||||
constexpr int x = 16 / sizeof(scalar_t);
|
constexpr int x = 16 / sizeof(scalar_t);
|
||||||
const int num_queries_per_kv = num_heads / num_kv_heads;
|
const int num_queries_per_kv = num_heads / num_kv_heads;
|
||||||
|
|
||||||
@ -243,32 +248,31 @@ struct paged_attention_v1_impl {
|
|||||||
|
|
||||||
size_t logits_bytes =
|
size_t logits_bytes =
|
||||||
parallel_work_item_num * max_seq_len_padded * sizeof(float);
|
parallel_work_item_num * max_seq_len_padded * sizeof(float);
|
||||||
float *logits = (float *)std::aligned_alloc(
|
float* logits = (float*)std::aligned_alloc(
|
||||||
64, logits_bytes); // Cacheline alignment for each context token.
|
64, logits_bytes); // Cacheline alignment for each context token.
|
||||||
// [parallel_work_item_num, max_seq_len_padded]
|
// [parallel_work_item_num, max_seq_len_padded]
|
||||||
|
|
||||||
#pragma omp parallel for collapse(2) schedule(dynamic, 1)
|
#pragma omp parallel for collapse(2) schedule(dynamic, 1)
|
||||||
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
|
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
|
||||||
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
|
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
|
||||||
int seq_len = seq_lens[seq_idx];
|
int seq_len = seq_lens[seq_idx];
|
||||||
const int *seq_block_table =
|
const int* seq_block_table =
|
||||||
block_tables + max_num_blocks_per_seq * seq_idx;
|
block_tables + max_num_blocks_per_seq * seq_idx;
|
||||||
const int block_num = (seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
const int block_num = (seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||||
const int64_t kv_head_idx = head_idx / num_queries_per_kv;
|
const int64_t kv_head_idx = head_idx / num_queries_per_kv;
|
||||||
const scalar_t *__restrict__ q_vec_ptr =
|
const scalar_t* __restrict__ q_vec_ptr =
|
||||||
q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
||||||
const int last_block_token_num =
|
const int last_block_token_num = seq_len - (block_num - 1) * BLOCK_SIZE;
|
||||||
seq_len - (block_num - 1) * BLOCK_SIZE;
|
float* __restrict__ thread_block_logits =
|
||||||
float *__restrict__ thread_block_logits =
|
|
||||||
logits + omp_get_thread_num() * max_seq_len_padded;
|
logits + omp_get_thread_num() * max_seq_len_padded;
|
||||||
|
|
||||||
// Compute logits
|
// Compute logits
|
||||||
for (int block_idx = 0; block_idx < block_num; ++block_idx) {
|
for (int block_idx = 0; block_idx < block_num; ++block_idx) {
|
||||||
const int64_t physical_block_idx = seq_block_table[block_idx];
|
const int64_t physical_block_idx = seq_block_table[block_idx];
|
||||||
const scalar_t *__restrict__ k_block_cache_ptr =
|
const scalar_t* __restrict__ k_block_cache_ptr =
|
||||||
k_cache + physical_block_idx * kv_block_stride +
|
k_cache + physical_block_idx * kv_block_stride +
|
||||||
kv_head_idx * kv_head_stride;
|
kv_head_idx * kv_head_stride;
|
||||||
float *__restrict__ head_block_logits =
|
float* __restrict__ head_block_logits =
|
||||||
thread_block_logits + block_idx * BLOCK_SIZE;
|
thread_block_logits + block_idx * BLOCK_SIZE;
|
||||||
|
|
||||||
reduceQKBlockKernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, x>::call(
|
reduceQKBlockKernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, x>::call(
|
||||||
@ -282,8 +286,7 @@ struct paged_attention_v1_impl {
|
|||||||
block_num * BLOCK_SIZE, alibi_slopes[head_idx], 0,
|
block_num * BLOCK_SIZE, alibi_slopes[head_idx], 0,
|
||||||
seq_len);
|
seq_len);
|
||||||
} else {
|
} else {
|
||||||
reduceSoftmax(thread_block_logits, seq_len,
|
reduceSoftmax(thread_block_logits, seq_len, block_num * BLOCK_SIZE);
|
||||||
block_num * BLOCK_SIZE);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compute value
|
// Compute value
|
||||||
@ -293,14 +296,14 @@ struct paged_attention_v1_impl {
|
|||||||
for (int head_part_idx = 0; head_part_idx < head_partition_num;
|
for (int head_part_idx = 0; head_part_idx < head_partition_num;
|
||||||
++head_part_idx) {
|
++head_part_idx) {
|
||||||
vec_op::FP32Vec16 accums[head_elem_num_per_partition];
|
vec_op::FP32Vec16 accums[head_elem_num_per_partition];
|
||||||
scalar_t *__restrict__ out_ptr =
|
scalar_t* __restrict__ out_ptr =
|
||||||
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE +
|
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE +
|
||||||
head_part_idx * head_elem_num_per_partition;
|
head_part_idx * head_elem_num_per_partition;
|
||||||
for (int block_idx = 0; block_idx < block_num; ++block_idx) {
|
for (int block_idx = 0; block_idx < block_num; ++block_idx) {
|
||||||
const int64_t physical_block_idx = seq_block_table[block_idx];
|
const int64_t physical_block_idx = seq_block_table[block_idx];
|
||||||
const float *__restrict__ prob_vec_ptr =
|
const float* __restrict__ prob_vec_ptr =
|
||||||
thread_block_logits + block_idx * BLOCK_SIZE;
|
thread_block_logits + block_idx * BLOCK_SIZE;
|
||||||
const scalar_t *__restrict__ v_block_cache_ptr =
|
const scalar_t* __restrict__ v_block_cache_ptr =
|
||||||
v_cache + physical_block_idx * kv_block_stride +
|
v_cache + physical_block_idx * kv_block_stride +
|
||||||
kv_head_idx * kv_head_stride +
|
kv_head_idx * kv_head_stride +
|
||||||
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
|
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
|
||||||
@ -311,7 +314,7 @@ struct paged_attention_v1_impl {
|
|||||||
if (block_idx != block_num - 1) {
|
if (block_idx != block_num - 1) {
|
||||||
const int64_t next_physical_block_idx =
|
const int64_t next_physical_block_idx =
|
||||||
seq_block_table[block_idx + 1];
|
seq_block_table[block_idx + 1];
|
||||||
const scalar_t *__restrict__ next_v_block_cache_ptr =
|
const scalar_t* __restrict__ next_v_block_cache_ptr =
|
||||||
v_cache + next_physical_block_idx * kv_block_stride +
|
v_cache + next_physical_block_idx * kv_block_stride +
|
||||||
kv_head_idx * kv_head_stride +
|
kv_head_idx * kv_head_stride +
|
||||||
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
|
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
|
||||||
@ -340,16 +343,16 @@ struct paged_attention_v1_impl {
|
|||||||
#define LAUNCH_V1_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \
|
#define LAUNCH_V1_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \
|
||||||
paged_attention_v1_impl<T, HEAD_SIZE, BLOCK_SIZE>::call( \
|
paged_attention_v1_impl<T, HEAD_SIZE, BLOCK_SIZE>::call( \
|
||||||
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
|
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
|
||||||
block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
|
block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
|
||||||
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, num_seqs, \
|
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, num_seqs, \
|
||||||
num_heads);
|
num_heads);
|
||||||
|
|
||||||
template <typename T, int BLOCK_SIZE>
|
template <typename T, int BLOCK_SIZE>
|
||||||
void paged_attention_v1_impl_launcher(
|
void paged_attention_v1_impl_launcher(
|
||||||
torch::Tensor &out, torch::Tensor &query, torch::Tensor &key_cache,
|
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||||
torch::Tensor &value_cache, int num_kv_heads, float scale,
|
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
||||||
torch::Tensor &block_tables, torch::Tensor &seq_lens,
|
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
|
||||||
int max_seq_len, const c10::optional<torch::Tensor> &alibi_slopes) {
|
const c10::optional<torch::Tensor>& alibi_slopes) {
|
||||||
int num_seqs = query.size(0);
|
int num_seqs = query.size(0);
|
||||||
int num_heads = query.size(1);
|
int num_heads = query.size(1);
|
||||||
int head_size = query.size(2);
|
int head_size = query.size(2);
|
||||||
@ -359,67 +362,66 @@ void paged_attention_v1_impl_launcher(
|
|||||||
int kv_head_stride = key_cache.stride(1);
|
int kv_head_stride = key_cache.stride(1);
|
||||||
|
|
||||||
// NOTE: alibi_slopes is optional.
|
// NOTE: alibi_slopes is optional.
|
||||||
const float *alibi_slopes_ptr =
|
const float* alibi_slopes_ptr =
|
||||||
alibi_slopes
|
alibi_slopes
|
||||||
? reinterpret_cast<const float *>(alibi_slopes.value().data_ptr())
|
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
|
||||||
: nullptr;
|
: nullptr;
|
||||||
|
|
||||||
T *out_ptr = reinterpret_cast<T *>(out.data_ptr());
|
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
|
||||||
T *query_ptr = reinterpret_cast<T *>(query.data_ptr());
|
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
|
||||||
T *key_cache_ptr = reinterpret_cast<T *>(key_cache.data_ptr());
|
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
|
||||||
T *value_cache_ptr = reinterpret_cast<T *>(value_cache.data_ptr());
|
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
|
||||||
int *block_tables_ptr = block_tables.data_ptr<int>();
|
int* block_tables_ptr = block_tables.data_ptr<int>();
|
||||||
int *seq_lens_ptr = seq_lens.data_ptr<int>();
|
int* seq_lens_ptr = seq_lens.data_ptr<int>();
|
||||||
|
|
||||||
switch (head_size) {
|
switch (head_size) {
|
||||||
case 64:
|
case 64:
|
||||||
LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
|
LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
|
||||||
break;
|
break;
|
||||||
case 80:
|
case 80:
|
||||||
LAUNCH_V1_ATTENTION_KERNEL(T, 80, BLOCK_SIZE);
|
LAUNCH_V1_ATTENTION_KERNEL(T, 80, BLOCK_SIZE);
|
||||||
break;
|
break;
|
||||||
case 96:
|
case 96:
|
||||||
LAUNCH_V1_ATTENTION_KERNEL(T, 96, BLOCK_SIZE);
|
LAUNCH_V1_ATTENTION_KERNEL(T, 96, BLOCK_SIZE);
|
||||||
break;
|
break;
|
||||||
case 112:
|
case 112:
|
||||||
LAUNCH_V1_ATTENTION_KERNEL(T, 112, BLOCK_SIZE);
|
LAUNCH_V1_ATTENTION_KERNEL(T, 112, BLOCK_SIZE);
|
||||||
break;
|
break;
|
||||||
case 128:
|
case 128:
|
||||||
LAUNCH_V1_ATTENTION_KERNEL(T, 128, BLOCK_SIZE);
|
LAUNCH_V1_ATTENTION_KERNEL(T, 128, BLOCK_SIZE);
|
||||||
break;
|
break;
|
||||||
case 256:
|
case 256:
|
||||||
LAUNCH_V1_ATTENTION_KERNEL(T, 256, BLOCK_SIZE);
|
LAUNCH_V1_ATTENTION_KERNEL(T, 256, BLOCK_SIZE);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
TORCH_CHECK(false, "Unsupported head size: ", head_size);
|
TORCH_CHECK(false, "Unsupported head size: ", head_size);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
|
#define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
|
||||||
paged_attention_v1_impl_launcher<T, BLOCK_SIZE>( \
|
paged_attention_v1_impl_launcher<T, BLOCK_SIZE>( \
|
||||||
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
|
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
|
||||||
seq_lens, max_seq_len, alibi_slopes);
|
seq_lens, max_seq_len, alibi_slopes);
|
||||||
|
|
||||||
#define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
|
#define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
|
||||||
switch (block_size) { \
|
switch (block_size) { \
|
||||||
case 16: \
|
case 16: \
|
||||||
CALL_V1_KERNEL_LAUNCHER(T, 16); \
|
CALL_V1_KERNEL_LAUNCHER(T, 16); \
|
||||||
break; \
|
break; \
|
||||||
default: \
|
default: \
|
||||||
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
||||||
break; \
|
break; \
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void paged_attention_v1(torch::Tensor &out, torch::Tensor &query,
|
void paged_attention_v1(torch::Tensor& out, torch::Tensor& query,
|
||||||
torch::Tensor &key_cache, torch::Tensor &value_cache,
|
torch::Tensor& key_cache, torch::Tensor& value_cache,
|
||||||
int num_kv_heads, float scale,
|
int num_kv_heads, float scale,
|
||||||
torch::Tensor &block_tables,
|
torch::Tensor& block_tables, torch::Tensor& seq_lens,
|
||||||
torch::Tensor &seq_lens, int block_size,
|
int block_size, int max_seq_len,
|
||||||
int max_seq_len,
|
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||||
const c10::optional<torch::Tensor> &alibi_slopes,
|
const std::string& kv_cache_dtype, float kv_scale) {
|
||||||
const std::string &kv_cache_dtype, float kv_scale) {
|
|
||||||
TORCH_CHECK(kv_scale == 1.0f);
|
TORCH_CHECK(kv_scale == 1.0f);
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl",
|
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl",
|
||||||
[&] {
|
[&] {
|
||||||
@ -434,23 +436,24 @@ namespace {
|
|||||||
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, int PARTITION_SIZE>
|
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, int PARTITION_SIZE>
|
||||||
struct paged_attention_v2_impl {
|
struct paged_attention_v2_impl {
|
||||||
static void call(
|
static void call(
|
||||||
scalar_t *__restrict__ out, // [num_seqs, num_heads, head_size]
|
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
||||||
float *__restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
float* __restrict__ exp_sums, // [num_seqs, num_heads,
|
||||||
float
|
// max_num_partitions]
|
||||||
*__restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
|
float* __restrict__ max_logits, // [num_seqs, num_heads,
|
||||||
scalar_t *__restrict__ tmp_out, // [num_seqs, num_heads,
|
// max_num_partitions]
|
||||||
// max_num_partitions, head_size]
|
scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
|
||||||
const scalar_t *__restrict__ q, // [num_seqs, num_heads, head_size]
|
// max_num_partitions, head_size]
|
||||||
const scalar_t *__restrict__ k_cache, // [num_blocks, num_kv_heads,
|
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||||
// head_size/x, block_size, x]
|
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
|
||||||
const scalar_t *__restrict__ v_cache, // [num_blocks, num_kv_heads,
|
// head_size/x, block_size, x]
|
||||||
// head_size, block_size]
|
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
|
||||||
|
// head_size, block_size]
|
||||||
const int num_kv_heads, const float scale,
|
const int num_kv_heads, const float scale,
|
||||||
const int
|
const int* __restrict__ block_tables, // [num_seqs,
|
||||||
*__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
// max_num_blocks_per_seq]
|
||||||
const int *__restrict__ seq_lens, // [num_seqs]
|
const int* __restrict__ seq_lens, // [num_seqs]
|
||||||
const int max_num_blocks_per_seq,
|
const int max_num_blocks_per_seq,
|
||||||
const float *__restrict__ alibi_slopes, // [num_heads]
|
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||||
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
||||||
const int num_seqs, const int num_heads, const int max_num_partitions) {
|
const int num_seqs, const int num_heads, const int max_num_partitions) {
|
||||||
constexpr int x = 16 / sizeof(scalar_t);
|
constexpr int x = 16 / sizeof(scalar_t);
|
||||||
@ -468,8 +471,7 @@ struct paged_attention_v2_impl {
|
|||||||
const int seq_len = seq_lens[seq_idx];
|
const int seq_len = seq_lens[seq_idx];
|
||||||
const int start_token_idx = partition_idx * PARTITION_SIZE;
|
const int start_token_idx = partition_idx * PARTITION_SIZE;
|
||||||
|
|
||||||
if (start_token_idx >= seq_len)
|
if (start_token_idx >= seq_len) continue;
|
||||||
continue;
|
|
||||||
|
|
||||||
const int partition_num =
|
const int partition_num =
|
||||||
(seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
|
(seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
|
||||||
@ -477,15 +479,14 @@ struct paged_attention_v2_impl {
|
|||||||
const int token_num =
|
const int token_num =
|
||||||
(std::min(seq_len, start_token_idx + PARTITION_SIZE) -
|
(std::min(seq_len, start_token_idx + PARTITION_SIZE) -
|
||||||
start_token_idx);
|
start_token_idx);
|
||||||
const int block_num =
|
const int block_num = (token_num + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||||
(token_num + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
|
||||||
const int last_block_token_num =
|
const int last_block_token_num =
|
||||||
token_num - (block_num - 1) * BLOCK_SIZE;
|
token_num - (block_num - 1) * BLOCK_SIZE;
|
||||||
const int *seq_block_table = block_tables +
|
const int* seq_block_table = block_tables +
|
||||||
max_num_blocks_per_seq * seq_idx +
|
max_num_blocks_per_seq * seq_idx +
|
||||||
start_token_idx / BLOCK_SIZE;
|
start_token_idx / BLOCK_SIZE;
|
||||||
const int64_t kv_head_idx = head_idx / num_queries_per_kv;
|
const int64_t kv_head_idx = head_idx / num_queries_per_kv;
|
||||||
const scalar_t *__restrict__ q_vec_ptr =
|
const scalar_t* __restrict__ q_vec_ptr =
|
||||||
q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
||||||
|
|
||||||
float logits[PARTITION_SIZE] __attribute__((aligned(64))) = {0};
|
float logits[PARTITION_SIZE] __attribute__((aligned(64))) = {0};
|
||||||
@ -493,10 +494,10 @@ struct paged_attention_v2_impl {
|
|||||||
// Compute logits
|
// Compute logits
|
||||||
for (int block_idx = 0; block_idx < block_num; ++block_idx) {
|
for (int block_idx = 0; block_idx < block_num; ++block_idx) {
|
||||||
const int64_t physical_block_idx = seq_block_table[block_idx];
|
const int64_t physical_block_idx = seq_block_table[block_idx];
|
||||||
const scalar_t *__restrict__ k_block_cache_ptr =
|
const scalar_t* __restrict__ k_block_cache_ptr =
|
||||||
k_cache + physical_block_idx * kv_block_stride +
|
k_cache + physical_block_idx * kv_block_stride +
|
||||||
kv_head_idx * kv_head_stride;
|
kv_head_idx * kv_head_stride;
|
||||||
float *__restrict__ head_block_logits =
|
float* __restrict__ head_block_logits =
|
||||||
logits + block_idx * BLOCK_SIZE;
|
logits + block_idx * BLOCK_SIZE;
|
||||||
|
|
||||||
reduceQKBlockKernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, x>::call(
|
reduceQKBlockKernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, x>::call(
|
||||||
@ -510,13 +511,13 @@ struct paged_attention_v2_impl {
|
|||||||
logits, token_num, block_num * BLOCK_SIZE,
|
logits, token_num, block_num * BLOCK_SIZE,
|
||||||
alibi_slopes[head_idx], start_token_idx, seq_len);
|
alibi_slopes[head_idx], start_token_idx, seq_len);
|
||||||
} else {
|
} else {
|
||||||
max_and_sum = reduceSoftmax(logits, token_num,
|
max_and_sum =
|
||||||
block_num * BLOCK_SIZE);
|
reduceSoftmax(logits, token_num, block_num * BLOCK_SIZE);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto &&[max_logit, exp_sum] = max_and_sum;
|
auto&& [max_logit, exp_sum] = max_and_sum;
|
||||||
|
|
||||||
scalar_t *__restrict__ output_buffer = nullptr;
|
scalar_t* __restrict__ output_buffer = nullptr;
|
||||||
if (!no_reduce) {
|
if (!no_reduce) {
|
||||||
auto idx = seq_idx * num_heads * max_num_partitions +
|
auto idx = seq_idx * num_heads * max_num_partitions +
|
||||||
head_idx * max_num_partitions + partition_idx;
|
head_idx * max_num_partitions + partition_idx;
|
||||||
@ -538,13 +539,13 @@ struct paged_attention_v2_impl {
|
|||||||
for (int head_part_idx = 0; head_part_idx < head_partition_num;
|
for (int head_part_idx = 0; head_part_idx < head_partition_num;
|
||||||
++head_part_idx) {
|
++head_part_idx) {
|
||||||
vec_op::FP32Vec16 accums[head_elem_num_per_partition];
|
vec_op::FP32Vec16 accums[head_elem_num_per_partition];
|
||||||
scalar_t *__restrict__ out_ptr =
|
scalar_t* __restrict__ out_ptr =
|
||||||
output_buffer + head_part_idx * head_elem_num_per_partition;
|
output_buffer + head_part_idx * head_elem_num_per_partition;
|
||||||
for (int block_idx = 0; block_idx < block_num; ++block_idx) {
|
for (int block_idx = 0; block_idx < block_num; ++block_idx) {
|
||||||
const int64_t physical_block_idx = seq_block_table[block_idx];
|
const int64_t physical_block_idx = seq_block_table[block_idx];
|
||||||
const float *__restrict__ prob_vec_ptr =
|
const float* __restrict__ prob_vec_ptr =
|
||||||
logits + block_idx * BLOCK_SIZE;
|
logits + block_idx * BLOCK_SIZE;
|
||||||
const scalar_t *__restrict__ v_block_cache_ptr =
|
const scalar_t* __restrict__ v_block_cache_ptr =
|
||||||
v_cache + physical_block_idx * kv_block_stride +
|
v_cache + physical_block_idx * kv_block_stride +
|
||||||
kv_head_idx * kv_head_stride +
|
kv_head_idx * kv_head_stride +
|
||||||
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
|
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
|
||||||
@ -555,7 +556,7 @@ struct paged_attention_v2_impl {
|
|||||||
if (block_idx != block_num - 1) {
|
if (block_idx != block_num - 1) {
|
||||||
const int64_t next_physical_block_idx =
|
const int64_t next_physical_block_idx =
|
||||||
seq_block_table[block_idx + 1];
|
seq_block_table[block_idx + 1];
|
||||||
const scalar_t *__restrict__ next_v_block_cache_ptr =
|
const scalar_t* __restrict__ next_v_block_cache_ptr =
|
||||||
v_cache + next_physical_block_idx * kv_block_stride +
|
v_cache + next_physical_block_idx * kv_block_stride +
|
||||||
kv_head_idx * kv_head_stride +
|
kv_head_idx * kv_head_stride +
|
||||||
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
|
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
|
||||||
@ -587,8 +588,7 @@ struct paged_attention_v2_impl {
|
|||||||
const int partition_num =
|
const int partition_num =
|
||||||
(seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
|
(seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
|
||||||
|
|
||||||
if (partition_num == 1)
|
if (partition_num == 1) continue;
|
||||||
continue;
|
|
||||||
|
|
||||||
reducePartitonSoftmax(
|
reducePartitonSoftmax(
|
||||||
max_logits + seq_idx * num_heads * max_num_partitions +
|
max_logits + seq_idx * num_heads * max_num_partitions +
|
||||||
@ -603,11 +603,11 @@ struct paged_attention_v2_impl {
|
|||||||
using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type;
|
using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type;
|
||||||
static_assert(v_load_vec_type::get_elem_num() == BLOCK_SIZE);
|
static_assert(v_load_vec_type::get_elem_num() == BLOCK_SIZE);
|
||||||
constexpr int head_elem_num_per_group =
|
constexpr int head_elem_num_per_group =
|
||||||
16; // Note: didn't align with the cacheline size, due to some HEAD_SIZE
|
16; // Note: didn't align with the cacheline size, due to some
|
||||||
// didn't align with 64 bytes
|
// HEAD_SIZE didn't align with 64 bytes
|
||||||
static_assert(HEAD_SIZE % head_elem_num_per_group == 0);
|
static_assert(HEAD_SIZE % head_elem_num_per_group == 0);
|
||||||
constexpr int head_group_num = HEAD_SIZE / head_elem_num_per_group;
|
constexpr int head_group_num = HEAD_SIZE / head_elem_num_per_group;
|
||||||
const float *__restrict__ rescale_factors = exp_sums;
|
const float* __restrict__ rescale_factors = exp_sums;
|
||||||
#pragma omp parallel for collapse(3) schedule(static, 1)
|
#pragma omp parallel for collapse(3) schedule(static, 1)
|
||||||
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
|
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
|
||||||
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
|
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
|
||||||
@ -616,17 +616,16 @@ struct paged_attention_v2_impl {
|
|||||||
const int partition_num =
|
const int partition_num =
|
||||||
(seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
|
(seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
|
||||||
|
|
||||||
if (partition_num == 1)
|
if (partition_num == 1) continue;
|
||||||
continue;
|
|
||||||
|
|
||||||
const float *__restrict__ seq_head_rescale_factors =
|
const float* __restrict__ seq_head_rescale_factors =
|
||||||
rescale_factors + seq_idx * num_heads * max_num_partitions +
|
rescale_factors + seq_idx * num_heads * max_num_partitions +
|
||||||
head_idx * max_num_partitions;
|
head_idx * max_num_partitions;
|
||||||
const scalar_t *__restrict__ seq_head_tmp_out =
|
const scalar_t* __restrict__ seq_head_tmp_out =
|
||||||
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
|
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
|
||||||
head_idx * max_num_partitions * HEAD_SIZE +
|
head_idx * max_num_partitions * HEAD_SIZE +
|
||||||
group_idx * head_elem_num_per_group;
|
group_idx * head_elem_num_per_group;
|
||||||
scalar_t *__restrict__ seq_head_output =
|
scalar_t* __restrict__ seq_head_output =
|
||||||
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE +
|
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE +
|
||||||
group_idx * head_elem_num_per_group;
|
group_idx * head_elem_num_per_group;
|
||||||
|
|
||||||
@ -645,21 +644,21 @@ struct paged_attention_v2_impl {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
#define LAUNCH_V2_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \
|
#define LAUNCH_V2_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \
|
||||||
paged_attention_v2_impl<T, HEAD_SIZE, BLOCK_SIZE, PARTITION_SIZE>::call( \
|
paged_attention_v2_impl<T, HEAD_SIZE, BLOCK_SIZE, PARTITION_SIZE>::call( \
|
||||||
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \
|
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \
|
||||||
key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
|
key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
|
||||||
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
|
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
|
||||||
kv_block_stride, kv_head_stride, num_seqs, num_heads, \
|
kv_block_stride, kv_head_stride, num_seqs, num_heads, \
|
||||||
max_num_partitions);
|
max_num_partitions);
|
||||||
|
|
||||||
template <typename T, int BLOCK_SIZE, int PARTITION_SIZE = 512>
|
template <typename T, int BLOCK_SIZE, int PARTITION_SIZE = 512>
|
||||||
void paged_attention_v2_impl_launcher(
|
void paged_attention_v2_impl_launcher(
|
||||||
torch::Tensor &out, torch::Tensor &exp_sums, torch::Tensor &max_logits,
|
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
|
||||||
torch::Tensor &tmp_out, torch::Tensor &query, torch::Tensor &key_cache,
|
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||||
torch::Tensor &value_cache, int num_kv_heads, float scale,
|
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
||||||
torch::Tensor &block_tables, torch::Tensor &seq_lens, int block_size,
|
torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
|
||||||
int max_seq_len, const c10::optional<torch::Tensor> &alibi_slopes) {
|
int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes) {
|
||||||
int num_seqs = query.size(0);
|
int num_seqs = query.size(0);
|
||||||
int num_heads = query.size(1);
|
int num_heads = query.size(1);
|
||||||
int head_size = query.size(2);
|
int head_size = query.size(2);
|
||||||
@ -670,72 +669,72 @@ void paged_attention_v2_impl_launcher(
|
|||||||
int max_num_partitions = exp_sums.size(-1);
|
int max_num_partitions = exp_sums.size(-1);
|
||||||
|
|
||||||
// NOTE: alibi_slopes is optional.
|
// NOTE: alibi_slopes is optional.
|
||||||
const float *alibi_slopes_ptr =
|
const float* alibi_slopes_ptr =
|
||||||
alibi_slopes
|
alibi_slopes
|
||||||
? reinterpret_cast<const float *>(alibi_slopes.value().data_ptr())
|
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
|
||||||
: nullptr;
|
: nullptr;
|
||||||
|
|
||||||
T *out_ptr = reinterpret_cast<T *>(out.data_ptr());
|
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
|
||||||
float *exp_sums_ptr = reinterpret_cast<float *>(exp_sums.data_ptr());
|
float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
|
||||||
float *max_logits_ptr = reinterpret_cast<float *>(max_logits.data_ptr());
|
float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
|
||||||
T *tmp_out_ptr = reinterpret_cast<T *>(tmp_out.data_ptr());
|
T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
|
||||||
T *query_ptr = reinterpret_cast<T *>(query.data_ptr());
|
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
|
||||||
T *key_cache_ptr = reinterpret_cast<T *>(key_cache.data_ptr());
|
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
|
||||||
T *value_cache_ptr = reinterpret_cast<T *>(value_cache.data_ptr());
|
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
|
||||||
int *block_tables_ptr = block_tables.data_ptr<int>();
|
int* block_tables_ptr = block_tables.data_ptr<int>();
|
||||||
int *seq_lens_ptr = seq_lens.data_ptr<int>();
|
int* seq_lens_ptr = seq_lens.data_ptr<int>();
|
||||||
|
|
||||||
switch (head_size) {
|
switch (head_size) {
|
||||||
case 64:
|
case 64:
|
||||||
LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
|
LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
|
||||||
break;
|
break;
|
||||||
case 80:
|
case 80:
|
||||||
LAUNCH_V2_ATTENTION_KERNEL(T, 80, BLOCK_SIZE);
|
LAUNCH_V2_ATTENTION_KERNEL(T, 80, BLOCK_SIZE);
|
||||||
break;
|
break;
|
||||||
case 96:
|
case 96:
|
||||||
LAUNCH_V2_ATTENTION_KERNEL(T, 96, BLOCK_SIZE);
|
LAUNCH_V2_ATTENTION_KERNEL(T, 96, BLOCK_SIZE);
|
||||||
break;
|
break;
|
||||||
case 112:
|
case 112:
|
||||||
LAUNCH_V2_ATTENTION_KERNEL(T, 112, BLOCK_SIZE);
|
LAUNCH_V2_ATTENTION_KERNEL(T, 112, BLOCK_SIZE);
|
||||||
break;
|
break;
|
||||||
case 128:
|
case 128:
|
||||||
LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE);
|
LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE);
|
||||||
break;
|
break;
|
||||||
case 256:
|
case 256:
|
||||||
LAUNCH_V2_ATTENTION_KERNEL(T, 256, BLOCK_SIZE);
|
LAUNCH_V2_ATTENTION_KERNEL(T, 256, BLOCK_SIZE);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
TORCH_CHECK(false, "Unsupported head size: ", head_size);
|
TORCH_CHECK(false, "Unsupported head size: ", head_size);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
|
#define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
|
||||||
paged_attention_v2_impl_launcher<T, BLOCK_SIZE>( \
|
paged_attention_v2_impl_launcher<T, BLOCK_SIZE>( \
|
||||||
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
|
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
|
||||||
num_kv_heads, scale, block_tables, seq_lens, block_size, \
|
num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, \
|
||||||
max_seq_len, alibi_slopes);
|
alibi_slopes);
|
||||||
|
|
||||||
#define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
|
#define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
|
||||||
switch (block_size) { \
|
switch (block_size) { \
|
||||||
case 16: \
|
case 16: \
|
||||||
CALL_V2_KERNEL_LAUNCHER(T, 16); \
|
CALL_V2_KERNEL_LAUNCHER(T, 16); \
|
||||||
break; \
|
break; \
|
||||||
default: \
|
default: \
|
||||||
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
||||||
break; \
|
break; \
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void paged_attention_v2(torch::Tensor &out, torch::Tensor &exp_sums,
|
void paged_attention_v2(torch::Tensor& out, torch::Tensor& exp_sums,
|
||||||
torch::Tensor &max_logits, torch::Tensor &tmp_out,
|
torch::Tensor& max_logits, torch::Tensor& tmp_out,
|
||||||
torch::Tensor &query, torch::Tensor &key_cache,
|
torch::Tensor& query, torch::Tensor& key_cache,
|
||||||
torch::Tensor &value_cache, int num_kv_heads,
|
torch::Tensor& value_cache, int num_kv_heads,
|
||||||
float scale, torch::Tensor &block_tables,
|
float scale, torch::Tensor& block_tables,
|
||||||
torch::Tensor &seq_lens, int block_size,
|
torch::Tensor& seq_lens, int block_size,
|
||||||
int max_seq_len,
|
int max_seq_len,
|
||||||
const c10::optional<torch::Tensor> &alibi_slopes,
|
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||||
const std::string &kv_cache_dtype, float kv_scale) {
|
const std::string& kv_cache_dtype, float kv_scale) {
|
||||||
TORCH_CHECK(kv_scale == 1.0f);
|
TORCH_CHECK(kv_scale == 1.0f);
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl",
|
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl",
|
||||||
[&] {
|
[&] {
|
||||||
|
@ -5,25 +5,26 @@
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
void copy_blocks_cpu_impl(
|
void copy_blocks_cpu_impl(std::vector<torch::Tensor>& key_caches,
|
||||||
std::vector<torch::Tensor> &key_caches,
|
std::vector<torch::Tensor>& value_caches,
|
||||||
std::vector<torch::Tensor> &value_caches,
|
const torch::Tensor& mapping_pairs,
|
||||||
const torch::Tensor& mapping_pairs,
|
const int element_num_per_block,
|
||||||
const int element_num_per_block, const int layer_num) {
|
const int layer_num) {
|
||||||
const size_t pair_num = mapping_pairs.size(0);
|
const size_t pair_num = mapping_pairs.size(0);
|
||||||
const size_t block_bytes = sizeof(scalar_t) * element_num_per_block;
|
const size_t block_bytes = sizeof(scalar_t) * element_num_per_block;
|
||||||
#pragma omp parallel for collapse(2)
|
#pragma omp parallel for collapse(2)
|
||||||
for (int layer = 0; layer < layer_num; ++layer) {
|
for (int layer = 0; layer < layer_num; ++layer) {
|
||||||
for (size_t pair = 0; pair < pair_num; ++pair) {
|
for (size_t pair = 0; pair < pair_num; ++pair) {
|
||||||
int64_t source_offset = element_num_per_block * mapping_pairs[pair][0].item<int64_t>();
|
int64_t source_offset =
|
||||||
|
element_num_per_block * mapping_pairs[pair][0].item<int64_t>();
|
||||||
int64_t target_offset =
|
int64_t target_offset =
|
||||||
element_num_per_block * mapping_pairs[pair][1].item<int64_t>();
|
element_num_per_block * mapping_pairs[pair][1].item<int64_t>();
|
||||||
scalar_t *key_cache_ptr = key_caches[layer].data_ptr<scalar_t>();
|
scalar_t* key_cache_ptr = key_caches[layer].data_ptr<scalar_t>();
|
||||||
scalar_t *source_ptr = key_cache_ptr + source_offset;
|
scalar_t* source_ptr = key_cache_ptr + source_offset;
|
||||||
scalar_t *target_ptr = key_cache_ptr + target_offset;
|
scalar_t* target_ptr = key_cache_ptr + target_offset;
|
||||||
std::memcpy(target_ptr, source_ptr, block_bytes);
|
std::memcpy(target_ptr, source_ptr, block_bytes);
|
||||||
|
|
||||||
scalar_t *value_cache_ptr = value_caches[layer].data_ptr<scalar_t>();
|
scalar_t* value_cache_ptr = value_caches[layer].data_ptr<scalar_t>();
|
||||||
source_ptr = value_cache_ptr + source_offset;
|
source_ptr = value_cache_ptr + source_offset;
|
||||||
target_ptr = value_cache_ptr + target_offset;
|
target_ptr = value_cache_ptr + target_offset;
|
||||||
std::memcpy(target_ptr, source_ptr, block_bytes);
|
std::memcpy(target_ptr, source_ptr, block_bytes);
|
||||||
@ -33,9 +34,9 @@ void copy_blocks_cpu_impl(
|
|||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
void reshape_and_cache_cpu_impl(
|
void reshape_and_cache_cpu_impl(
|
||||||
const scalar_t *__restrict__ key, const scalar_t *__restrict__ value,
|
const scalar_t* __restrict__ key, const scalar_t* __restrict__ value,
|
||||||
scalar_t *__restrict__ key_cache, scalar_t *__restrict__ value_cache,
|
scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache,
|
||||||
const int64_t *__restrict__ slot_mapping, const int num_tokens,
|
const int64_t* __restrict__ slot_mapping, const int num_tokens,
|
||||||
const int key_stride, const int value_stride, const int num_heads,
|
const int key_stride, const int value_stride, const int num_heads,
|
||||||
const int head_size, const int block_size, const int x) {
|
const int head_size, const int block_size, const int x) {
|
||||||
const int block_elem_num = num_heads * head_size * block_size;
|
const int block_elem_num = num_heads * head_size * block_size;
|
||||||
@ -48,14 +49,14 @@ void reshape_and_cache_cpu_impl(
|
|||||||
int src_key_head_idx = token_idx * key_stride + head_idx * head_size;
|
int src_key_head_idx = token_idx * key_stride + head_idx * head_size;
|
||||||
int src_value_head_idx =
|
int src_value_head_idx =
|
||||||
token_idx * value_stride + head_idx * head_size;
|
token_idx * value_stride + head_idx * head_size;
|
||||||
const scalar_t *src_key_head_ptr = key + src_key_head_idx;
|
const scalar_t* src_key_head_ptr = key + src_key_head_idx;
|
||||||
const scalar_t *src_value_head_ptr = value + src_value_head_idx;
|
const scalar_t* src_value_head_ptr = value + src_value_head_idx;
|
||||||
const int64_t block_index = slot_idx / block_size;
|
const int64_t block_index = slot_idx / block_size;
|
||||||
const int64_t block_offset = slot_idx % block_size;
|
const int64_t block_offset = slot_idx % block_size;
|
||||||
scalar_t *target_key_head_ptr = key_cache +
|
scalar_t* target_key_head_ptr = key_cache +
|
||||||
block_elem_num * block_index +
|
block_elem_num * block_index +
|
||||||
head_idx * block_size * head_size;
|
head_idx * block_size * head_size;
|
||||||
scalar_t *target_value_head_ptr = value_cache +
|
scalar_t* target_value_head_ptr = value_cache +
|
||||||
block_elem_num * block_index +
|
block_elem_num * block_index +
|
||||||
head_idx * block_size * head_size;
|
head_idx * block_size * head_size;
|
||||||
|
|
||||||
@ -79,10 +80,10 @@ void reshape_and_cache_cpu_impl(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}; // namespace
|
}; // namespace
|
||||||
|
|
||||||
void copy_blocks(std::vector<torch::Tensor> &key_caches,
|
void copy_blocks(std::vector<torch::Tensor>& key_caches,
|
||||||
std::vector<torch::Tensor> &value_caches,
|
std::vector<torch::Tensor>& value_caches,
|
||||||
const torch::Tensor& block_mapping) {
|
const torch::Tensor& block_mapping) {
|
||||||
unsigned num_layers = key_caches.size();
|
unsigned num_layers = key_caches.size();
|
||||||
TORCH_CHECK(num_layers == value_caches.size());
|
TORCH_CHECK(num_layers == value_caches.size());
|
||||||
@ -100,10 +101,10 @@ void copy_blocks(std::vector<torch::Tensor> &key_caches,
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void reshape_and_cache(torch::Tensor &key, torch::Tensor &value,
|
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
|
||||||
torch::Tensor &key_cache, torch::Tensor &value_cache,
|
torch::Tensor& key_cache, torch::Tensor& value_cache,
|
||||||
torch::Tensor &slot_mapping,
|
torch::Tensor& slot_mapping,
|
||||||
const std::string &kv_cache_dtype, float kv_scale) {
|
const std::string& kv_cache_dtype, float kv_scale) {
|
||||||
TORCH_CHECK(kv_scale == 1.0f);
|
TORCH_CHECK(kv_scale == 1.0f);
|
||||||
|
|
||||||
int num_tokens = key.size(0);
|
int num_tokens = key.size(0);
|
||||||
@ -127,7 +128,7 @@ void reshape_and_cache(torch::Tensor &key, torch::Tensor &value,
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void swap_blocks(torch::Tensor &src, torch::Tensor &dst,
|
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
|
||||||
const torch::Tensor&block_mapping) {
|
const torch::Tensor& block_mapping) {
|
||||||
TORCH_CHECK(false, "swap_blocks is unsupported on CPU.")
|
TORCH_CHECK(false, "swap_blocks is unsupported on CPU.")
|
||||||
}
|
}
|
||||||
|
@ -2,10 +2,10 @@
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
void rms_norm_impl(scalar_t *__restrict__ out,
|
void rms_norm_impl(scalar_t* __restrict__ out,
|
||||||
const scalar_t *__restrict__ input,
|
const scalar_t* __restrict__ input,
|
||||||
const scalar_t *__restrict__ weight, const float epsilon,
|
const scalar_t* __restrict__ weight, const float epsilon,
|
||||||
const int num_tokens, const int hidden_size) {
|
const int num_tokens, const int hidden_size) {
|
||||||
using scalar_vec_t = vec_op::vec_t<scalar_t>;
|
using scalar_vec_t = vec_op::vec_t<scalar_t>;
|
||||||
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
|
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
|
||||||
TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0);
|
TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0);
|
||||||
@ -41,11 +41,11 @@ void rms_norm_impl(scalar_t *__restrict__ out,
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
void fused_add_rms_norm_impl(scalar_t *__restrict__ input,
|
void fused_add_rms_norm_impl(scalar_t* __restrict__ input,
|
||||||
scalar_t *__restrict__ residual,
|
scalar_t* __restrict__ residual,
|
||||||
const scalar_t *__restrict__ weight,
|
const scalar_t* __restrict__ weight,
|
||||||
const float epsilon, const int num_tokens,
|
const float epsilon, const int num_tokens,
|
||||||
const int hidden_size) {
|
const int hidden_size) {
|
||||||
using scalar_vec_t = vec_op::vec_t<scalar_t>;
|
using scalar_vec_t = vec_op::vec_t<scalar_t>;
|
||||||
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
|
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
|
||||||
TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0);
|
TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0);
|
||||||
@ -85,24 +85,24 @@ void fused_add_rms_norm_impl(scalar_t *__restrict__ input,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void rms_norm(torch::Tensor &out, torch::Tensor &input,
|
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
|
||||||
torch::Tensor &weight, float epsilon) {
|
float epsilon) {
|
||||||
int hidden_size = input.size(-1);
|
int hidden_size = input.size(-1);
|
||||||
int num_tokens = input.numel() / hidden_size;
|
int num_tokens = input.numel() / hidden_size;
|
||||||
|
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_impl", [&] {
|
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_impl", [&] {
|
||||||
CPU_KERNEL_GUARD_IN(rms_norm_impl)
|
CPU_KERNEL_GUARD_IN(rms_norm_impl)
|
||||||
rms_norm_impl(out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
|
rms_norm_impl(out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
|
||||||
weight.data_ptr<scalar_t>(), epsilon, num_tokens,
|
weight.data_ptr<scalar_t>(), epsilon, num_tokens,
|
||||||
hidden_size);
|
hidden_size);
|
||||||
CPU_KERNEL_GUARD_OUT(rms_norm_impl)
|
CPU_KERNEL_GUARD_OUT(rms_norm_impl)
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void fused_add_rms_norm(torch::Tensor &input, torch::Tensor &residual,
|
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
|
||||||
torch::Tensor &weight, float epsilon) {
|
torch::Tensor& weight, float epsilon) {
|
||||||
int hidden_size = input.size(-1);
|
int hidden_size = input.size(-1);
|
||||||
int num_tokens = input.numel() / hidden_size;
|
int num_tokens = input.numel() / hidden_size;
|
||||||
|
|
||||||
|
@ -4,16 +4,16 @@
|
|||||||
namespace {
|
namespace {
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
void rotary_embedding_impl(
|
void rotary_embedding_impl(
|
||||||
const int64_t
|
const int64_t* __restrict__ positions, // [batch_size, seq_len] or
|
||||||
*__restrict__ positions, // [batch_size, seq_len] or [num_tokens]
|
// [num_tokens]
|
||||||
scalar_t
|
scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads,
|
||||||
*__restrict__ query, /// [batch_size, seq_len, num_heads, head_size] or
|
/// head_size] or [num_tokens, num_heads,
|
||||||
/// [num_tokens, num_heads, head_size]
|
/// head_size]
|
||||||
scalar_t
|
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
|
||||||
*__restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or
|
// head_size] or [num_tokens, num_kv_heads,
|
||||||
// [num_tokens, num_kv_heads, head_size]
|
// head_size]
|
||||||
const scalar_t
|
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
|
||||||
*__restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
|
// 2]
|
||||||
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
|
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
|
||||||
const int num_heads, const int num_kv_heads, const int head_size,
|
const int num_heads, const int num_kv_heads, const int head_size,
|
||||||
const int num_tokens) {
|
const int num_tokens) {
|
||||||
@ -26,7 +26,7 @@ void rotary_embedding_impl(
|
|||||||
#pragma omp parallel for
|
#pragma omp parallel for
|
||||||
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||||
int64_t pos = positions[token_idx];
|
int64_t pos = positions[token_idx];
|
||||||
const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim;
|
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
|
||||||
|
|
||||||
for (int i = 0; i < num_heads; ++i) {
|
for (int i = 0; i < num_heads; ++i) {
|
||||||
const int head_idx = i;
|
const int head_idx = i;
|
||||||
@ -94,16 +94,16 @@ void rotary_embedding_impl(
|
|||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
void rotary_embedding_gptj_impl(
|
void rotary_embedding_gptj_impl(
|
||||||
const int64_t
|
const int64_t* __restrict__ positions, // [batch_size, seq_len] or
|
||||||
*__restrict__ positions, // [batch_size, seq_len] or [num_tokens]
|
// [num_tokens]
|
||||||
scalar_t
|
scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads,
|
||||||
*__restrict__ query, /// [batch_size, seq_len, num_heads, head_size] or
|
/// head_size] or [num_tokens, num_heads,
|
||||||
/// [num_tokens, num_heads, head_size]
|
/// head_size]
|
||||||
scalar_t
|
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
|
||||||
*__restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or
|
// head_size] or [num_tokens, num_kv_heads,
|
||||||
// [num_tokens, num_kv_heads, head_size]
|
// head_size]
|
||||||
const scalar_t
|
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
|
||||||
*__restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
|
// 2]
|
||||||
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
|
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
|
||||||
const int num_heads, const int num_kv_heads, const int head_size,
|
const int num_heads, const int num_kv_heads, const int head_size,
|
||||||
const int num_tokens) {
|
const int num_tokens) {
|
||||||
@ -113,13 +113,13 @@ void rotary_embedding_gptj_impl(
|
|||||||
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||||
for (int i = 0; i < num_heads; ++i) {
|
for (int i = 0; i < num_heads; ++i) {
|
||||||
int64_t pos = positions[token_idx];
|
int64_t pos = positions[token_idx];
|
||||||
const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim;
|
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
|
||||||
const scalar_t *cos_cache_ptr = cache_ptr;
|
const scalar_t* cos_cache_ptr = cache_ptr;
|
||||||
const scalar_t *sin_cache_ptr = cache_ptr + embed_dim;
|
const scalar_t* sin_cache_ptr = cache_ptr + embed_dim;
|
||||||
const int head_idx = i;
|
const int head_idx = i;
|
||||||
const int64_t token_head =
|
const int64_t token_head =
|
||||||
token_idx * query_stride + head_idx * head_size;
|
token_idx * query_stride + head_idx * head_size;
|
||||||
scalar_t *head_query = token_head + query;
|
scalar_t* head_query = token_head + query;
|
||||||
for (int j = 0; j < embed_dim; j += 1) {
|
for (int j = 0; j < embed_dim; j += 1) {
|
||||||
const int rot_offset = j;
|
const int rot_offset = j;
|
||||||
const int x_index = 2 * rot_offset;
|
const int x_index = 2 * rot_offset;
|
||||||
@ -141,12 +141,12 @@ void rotary_embedding_gptj_impl(
|
|||||||
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||||
for (int i = 0; i < num_kv_heads; ++i) {
|
for (int i = 0; i < num_kv_heads; ++i) {
|
||||||
int64_t pos = positions[token_idx];
|
int64_t pos = positions[token_idx];
|
||||||
const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim;
|
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
|
||||||
const scalar_t *cos_cache_ptr = cache_ptr;
|
const scalar_t* cos_cache_ptr = cache_ptr;
|
||||||
const scalar_t *sin_cache_ptr = cache_ptr + embed_dim;
|
const scalar_t* sin_cache_ptr = cache_ptr + embed_dim;
|
||||||
const int head_idx = i;
|
const int head_idx = i;
|
||||||
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
|
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
|
||||||
scalar_t *head_key = key + token_head;
|
scalar_t* head_key = key + token_head;
|
||||||
for (int j = 0; j < embed_dim; j += 1) {
|
for (int j = 0; j < embed_dim; j += 1) {
|
||||||
const int rot_offset = j;
|
const int rot_offset = j;
|
||||||
const int x_index = 2 * rot_offset;
|
const int x_index = 2 * rot_offset;
|
||||||
@ -164,11 +164,11 @@ void rotary_embedding_gptj_impl(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}; // namespace
|
}; // namespace
|
||||||
|
|
||||||
void rotary_embedding(torch::Tensor &positions, torch::Tensor &query,
|
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
|
||||||
torch::Tensor &key, int head_size,
|
torch::Tensor& key, int head_size,
|
||||||
torch::Tensor &cos_sin_cache, bool is_neox) {
|
torch::Tensor& cos_sin_cache, bool is_neox) {
|
||||||
int num_tokens = query.numel() / query.size(-1);
|
int num_tokens = query.numel() / query.size(-1);
|
||||||
int rot_dim = cos_sin_cache.size(1);
|
int rot_dim = cos_sin_cache.size(1);
|
||||||
int num_heads = query.size(-1) / head_size;
|
int num_heads = query.size(-1) / head_size;
|
||||||
|
@ -8,66 +8,37 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||||||
pybind11::module ops = m.def_submodule("ops", "vLLM custom operators");
|
pybind11::module ops = m.def_submodule("ops", "vLLM custom operators");
|
||||||
|
|
||||||
// Attention ops
|
// Attention ops
|
||||||
ops.def(
|
ops.def("paged_attention_v1", &paged_attention_v1,
|
||||||
"paged_attention_v1",
|
"Compute the attention between an input query and the cached "
|
||||||
&paged_attention_v1,
|
"keys/values using PagedAttention.");
|
||||||
"Compute the attention between an input query and the cached keys/values using PagedAttention.");
|
ops.def("paged_attention_v2", &paged_attention_v2, "PagedAttention V2.");
|
||||||
ops.def(
|
|
||||||
"paged_attention_v2",
|
|
||||||
&paged_attention_v2,
|
|
||||||
"PagedAttention V2.");
|
|
||||||
|
|
||||||
// Activation ops
|
// Activation ops
|
||||||
ops.def(
|
ops.def("silu_and_mul", &silu_and_mul, "Activation function used in SwiGLU.");
|
||||||
"silu_and_mul",
|
ops.def("gelu_and_mul", &gelu_and_mul,
|
||||||
&silu_and_mul,
|
"Activation function used in GeGLU with `none` approximation.");
|
||||||
"Activation function used in SwiGLU.");
|
ops.def("gelu_tanh_and_mul", &gelu_tanh_and_mul,
|
||||||
ops.def(
|
"Activation function used in GeGLU with `tanh` approximation.");
|
||||||
"gelu_and_mul",
|
ops.def("gelu_new", &gelu_new, "GELU implementation used in GPT-2.");
|
||||||
&gelu_and_mul,
|
ops.def("gelu_fast", &gelu_fast, "Approximate GELU implementation.");
|
||||||
"Activation function used in GeGLU with `none` approximation.");
|
|
||||||
ops.def(
|
|
||||||
"gelu_tanh_and_mul",
|
|
||||||
&gelu_tanh_and_mul,
|
|
||||||
"Activation function used in GeGLU with `tanh` approximation.");
|
|
||||||
ops.def(
|
|
||||||
"gelu_new",
|
|
||||||
&gelu_new,
|
|
||||||
"GELU implementation used in GPT-2.");
|
|
||||||
ops.def(
|
|
||||||
"gelu_fast",
|
|
||||||
&gelu_fast,
|
|
||||||
"Approximate GELU implementation.");
|
|
||||||
|
|
||||||
// Layernorm
|
// Layernorm
|
||||||
ops.def(
|
ops.def("rms_norm", &rms_norm,
|
||||||
"rms_norm",
|
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
|
||||||
&rms_norm,
|
|
||||||
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
|
|
||||||
|
|
||||||
ops.def(
|
ops.def("fused_add_rms_norm", &fused_add_rms_norm,
|
||||||
"fused_add_rms_norm",
|
"In-place fused Add and RMS Normalization");
|
||||||
&fused_add_rms_norm,
|
|
||||||
"In-place fused Add and RMS Normalization");
|
|
||||||
|
|
||||||
// Rotary embedding
|
// Rotary embedding
|
||||||
ops.def(
|
ops.def("rotary_embedding", &rotary_embedding,
|
||||||
"rotary_embedding",
|
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
|
||||||
&rotary_embedding,
|
|
||||||
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
|
|
||||||
|
|
||||||
// Cache ops
|
// Cache ops
|
||||||
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
|
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
|
||||||
cache_ops.def(
|
cache_ops.def("swap_blocks", &swap_blocks,
|
||||||
"swap_blocks",
|
"Swap in (out) the cache blocks from src to dst");
|
||||||
&swap_blocks,
|
cache_ops.def("copy_blocks", ©_blocks,
|
||||||
"Swap in (out) the cache blocks from src to dst");
|
"Copy the cache blocks from src to dst");
|
||||||
cache_ops.def(
|
cache_ops.def("reshape_and_cache", &reshape_and_cache,
|
||||||
"copy_blocks",
|
"Reshape the key and value tensors and cache them");
|
||||||
©_blocks,
|
|
||||||
"Copy the cache blocks from src to dst");
|
|
||||||
cache_ops.def(
|
|
||||||
"reshape_and_cache",
|
|
||||||
&reshape_and_cache,
|
|
||||||
"Reshape the key and value tensors and cache them");
|
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#ifdef USE_ROCM
|
#ifdef USE_ROCM
|
||||||
#include <hip/hip_runtime.h>
|
#include <hip/hip_runtime.h>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
@ -17,7 +17,8 @@
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask)
|
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) \
|
||||||
|
__shfl_xor_sync(uint32_t(-1), var, lane_mask)
|
||||||
#else
|
#else
|
||||||
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask)
|
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask)
|
||||||
#endif
|
#endif
|
||||||
@ -29,7 +30,8 @@
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
#define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down_sync(uint32_t(-1), var, lane_delta)
|
#define VLLM_SHFL_DOWN_SYNC(var, lane_delta) \
|
||||||
|
__shfl_down_sync(uint32_t(-1), var, lane_delta)
|
||||||
#else
|
#else
|
||||||
#define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down(var, lane_delta)
|
#define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down(var, lane_delta)
|
||||||
#endif
|
#endif
|
||||||
@ -41,4 +43,3 @@
|
|||||||
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
|
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
|
||||||
hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL)
|
hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
@ -2,9 +2,6 @@
|
|||||||
|
|
||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
|
|
||||||
int get_device_attribute(
|
int get_device_attribute(int attribute, int device_id);
|
||||||
int attribute,
|
|
||||||
int device_id);
|
|
||||||
|
|
||||||
int get_max_shared_memory_per_block_device_attribute(
|
int get_max_shared_memory_per_block_device_attribute(int device_id);
|
||||||
int device_id);
|
|
||||||
|
@ -2,34 +2,28 @@
|
|||||||
#include <hip/hip_runtime.h>
|
#include <hip/hip_runtime.h>
|
||||||
#include <hip/hip_runtime_api.h>
|
#include <hip/hip_runtime_api.h>
|
||||||
#endif
|
#endif
|
||||||
int get_device_attribute(
|
int get_device_attribute(int attribute, int device_id) {
|
||||||
int attribute,
|
int device, value;
|
||||||
int device_id)
|
if (device_id < 0) {
|
||||||
{
|
cudaGetDevice(&device);
|
||||||
int device, value;
|
} else {
|
||||||
if (device_id < 0) {
|
device = device_id;
|
||||||
cudaGetDevice(&device);
|
}
|
||||||
}
|
cudaDeviceGetAttribute(&value, static_cast<cudaDeviceAttr>(attribute),
|
||||||
else {
|
device);
|
||||||
device = device_id;
|
return value;
|
||||||
}
|
|
||||||
cudaDeviceGetAttribute(&value, static_cast<cudaDeviceAttr>(attribute), device);
|
|
||||||
return value;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int get_max_shared_memory_per_block_device_attribute(int device_id) {
|
||||||
int get_max_shared_memory_per_block_device_attribute(
|
int attribute;
|
||||||
int device_id)
|
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
|
||||||
{
|
// cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
|
||||||
int attribute;
|
|
||||||
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
|
|
||||||
// cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
|
|
||||||
|
|
||||||
#ifdef USE_ROCM
|
#ifdef USE_ROCM
|
||||||
attribute = hipDeviceAttributeMaxSharedMemoryPerBlock;
|
attribute = hipDeviceAttributeMaxSharedMemoryPerBlock;
|
||||||
#else
|
#else
|
||||||
attribute = cudaDevAttrMaxSharedMemoryPerBlockOptin;
|
attribute = cudaDevAttrMaxSharedMemoryPerBlockOptin;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
return get_device_attribute(attribute, device_id);
|
return get_device_attribute(attribute, device_id);
|
||||||
}
|
}
|
||||||
|
@ -7,11 +7,11 @@
|
|||||||
|
|
||||||
// fake pointer type
|
// fake pointer type
|
||||||
using fptr_t = uint64_t;
|
using fptr_t = uint64_t;
|
||||||
static_assert(sizeof(void *) == sizeof(fptr_t));
|
static_assert(sizeof(void*) == sizeof(fptr_t));
|
||||||
|
|
||||||
fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
|
fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
|
||||||
const std::vector<std::string> &handles,
|
const std::vector<std::string>& handles,
|
||||||
const std::vector<int64_t> &offsets, int rank,
|
const std::vector<int64_t>& offsets, int rank,
|
||||||
bool full_nvlink) {
|
bool full_nvlink) {
|
||||||
int world_size = offsets.size();
|
int world_size = offsets.size();
|
||||||
if (world_size > 8)
|
if (world_size > 8)
|
||||||
@ -29,7 +29,7 @@ fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
|
|||||||
std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t));
|
std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t));
|
||||||
}
|
}
|
||||||
return (fptr_t) new vllm::CustomAllreduce(
|
return (fptr_t) new vllm::CustomAllreduce(
|
||||||
reinterpret_cast<vllm::Signal *>(meta.data_ptr()), rank_data.data_ptr(),
|
reinterpret_cast<vllm::Signal*>(meta.data_ptr()), rank_data.data_ptr(),
|
||||||
rank_data.numel(), ipc_handles, offsets, rank, full_nvlink);
|
rank_data.numel(), ipc_handles, offsets, rank, full_nvlink);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -49,13 +49,13 @@ fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
|
|||||||
* 5. A[None].expand(2, -1, -1, -1): Not OK
|
* 5. A[None].expand(2, -1, -1, -1): Not OK
|
||||||
* 6. A[:, 1:, 1:]: Not OK
|
* 6. A[:, 1:, 1:]: Not OK
|
||||||
*/
|
*/
|
||||||
bool _is_weak_contiguous(torch::Tensor &t) {
|
bool _is_weak_contiguous(torch::Tensor& t) {
|
||||||
return t.is_contiguous() ||
|
return t.is_contiguous() ||
|
||||||
(t.storage().nbytes() - t.storage_offset() * t.element_size() ==
|
(t.storage().nbytes() - t.storage_offset() * t.element_size() ==
|
||||||
t.numel() * t.element_size());
|
t.numel() * t.element_size());
|
||||||
}
|
}
|
||||||
|
|
||||||
bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size,
|
bool should_custom_ar(torch::Tensor& inp, int max_size, int world_size,
|
||||||
bool full_nvlink) {
|
bool full_nvlink) {
|
||||||
auto inp_size = inp.numel() * inp.element_size();
|
auto inp_size = inp.numel() * inp.element_size();
|
||||||
// custom allreduce requires input byte size to be multiples of 16
|
// custom allreduce requires input byte size to be multiples of 16
|
||||||
@ -67,28 +67,27 @@ bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size,
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out,
|
void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
|
||||||
cudaStream_t stream) {
|
cudaStream_t stream) {
|
||||||
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
|
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
||||||
TORCH_CHECK(_is_weak_contiguous(out));
|
TORCH_CHECK(_is_weak_contiguous(out));
|
||||||
switch (out.scalar_type()) {
|
switch (out.scalar_type()) {
|
||||||
case at::ScalarType::Float: {
|
case at::ScalarType::Float: {
|
||||||
fa->allreduce<float>(stream, reinterpret_cast<float *>(inp.data_ptr()),
|
fa->allreduce<float>(stream, reinterpret_cast<float*>(inp.data_ptr()),
|
||||||
reinterpret_cast<float *>(out.data_ptr()),
|
reinterpret_cast<float*>(out.data_ptr()),
|
||||||
out.numel());
|
out.numel());
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case at::ScalarType::Half: {
|
case at::ScalarType::Half: {
|
||||||
fa->allreduce<half>(stream, reinterpret_cast<half *>(inp.data_ptr()),
|
fa->allreduce<half>(stream, reinterpret_cast<half*>(inp.data_ptr()),
|
||||||
reinterpret_cast<half *>(out.data_ptr()),
|
reinterpret_cast<half*>(out.data_ptr()), out.numel());
|
||||||
out.numel());
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
||||||
case at::ScalarType::BFloat16: {
|
case at::ScalarType::BFloat16: {
|
||||||
fa->allreduce<nv_bfloat16>(
|
fa->allreduce<nv_bfloat16>(
|
||||||
stream, reinterpret_cast<nv_bfloat16 *>(inp.data_ptr()),
|
stream, reinterpret_cast<nv_bfloat16*>(inp.data_ptr()),
|
||||||
reinterpret_cast<nv_bfloat16 *>(out.data_ptr()), out.numel());
|
reinterpret_cast<nv_bfloat16*>(out.data_ptr()), out.numel());
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
@ -98,7 +97,7 @@ void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) {
|
void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) {
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
|
||||||
auto stream = c10::cuda::getCurrentCUDAStream().stream();
|
auto stream = c10::cuda::getCurrentCUDAStream().stream();
|
||||||
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
|
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
|
||||||
@ -106,8 +105,8 @@ void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) {
|
|||||||
_all_reduce(_fa, inp, out, stream);
|
_all_reduce(_fa, inp, out, stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer,
|
void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
|
||||||
torch::Tensor &out) {
|
torch::Tensor& out) {
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
|
||||||
auto stream = c10::cuda::getCurrentCUDAStream().stream();
|
auto stream = c10::cuda::getCurrentCUDAStream().stream();
|
||||||
|
|
||||||
@ -122,27 +121,27 @@ void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer,
|
|||||||
}
|
}
|
||||||
|
|
||||||
void dispose(fptr_t _fa) {
|
void dispose(fptr_t _fa) {
|
||||||
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
|
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
||||||
delete fa;
|
delete fa;
|
||||||
}
|
}
|
||||||
|
|
||||||
int meta_size() { return sizeof(vllm::Signal); }
|
int meta_size() { return sizeof(vllm::Signal); }
|
||||||
|
|
||||||
void register_buffer(fptr_t _fa, torch::Tensor &t,
|
void register_buffer(fptr_t _fa, torch::Tensor& t,
|
||||||
const std::vector<std::string> &handles,
|
const std::vector<std::string>& handles,
|
||||||
const std::vector<int64_t> &offsets) {
|
const std::vector<int64_t>& offsets) {
|
||||||
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
|
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
||||||
fa->register_buffer(handles, offsets, t.data_ptr());
|
fa->register_buffer(handles, offsets, t.data_ptr());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(
|
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(
|
||||||
fptr_t _fa) {
|
fptr_t _fa) {
|
||||||
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
|
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
||||||
return fa->get_graph_buffer_ipc_meta();
|
return fa->get_graph_buffer_ipc_meta();
|
||||||
}
|
}
|
||||||
|
|
||||||
void register_graph_buffers(fptr_t _fa, const std::vector<std::string> &handles,
|
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
|
||||||
const std::vector<std::vector<int64_t>> &offsets) {
|
const std::vector<std::vector<int64_t>>& offsets) {
|
||||||
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
|
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
||||||
fa->register_graph_buffers(handles, offsets);
|
fa->register_graph_buffers(handles, offsets);
|
||||||
}
|
}
|
||||||
|
@ -31,9 +31,9 @@ struct Signal {
|
|||||||
alignas(128) uint32_t end[kMaxBlocks][8];
|
alignas(128) uint32_t end[kMaxBlocks][8];
|
||||||
};
|
};
|
||||||
|
|
||||||
struct __align__(16) RankData { const void *__restrict__ ptrs[8]; };
|
struct __align__(16) RankData { const void* __restrict__ ptrs[8]; };
|
||||||
|
|
||||||
struct __align__(16) RankSignals { volatile Signal *signals[8]; };
|
struct __align__(16) RankSignals { volatile Signal* signals[8]; };
|
||||||
|
|
||||||
// like std::array, but aligned
|
// like std::array, but aligned
|
||||||
template <typename T, int sz>
|
template <typename T, int sz>
|
||||||
@ -68,11 +68,11 @@ DINLINE half downcast_s(float val) {
|
|||||||
// scalar add functions
|
// scalar add functions
|
||||||
// for some reason when compiling with Pytorch, the + operator for half and
|
// for some reason when compiling with Pytorch, the + operator for half and
|
||||||
// bfloat is disabled so we call the intrinsics directly
|
// bfloat is disabled so we call the intrinsics directly
|
||||||
DINLINE half &assign_add(half &a, half b) {
|
DINLINE half& assign_add(half& a, half b) {
|
||||||
a = __hadd(a, b);
|
a = __hadd(a, b);
|
||||||
return a;
|
return a;
|
||||||
}
|
}
|
||||||
DINLINE float &assign_add(float &a, float b) { return a += b; }
|
DINLINE float& assign_add(float& a, float b) { return a += b; }
|
||||||
|
|
||||||
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
||||||
DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); }
|
DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); }
|
||||||
@ -80,14 +80,14 @@ template <>
|
|||||||
DINLINE nv_bfloat16 downcast_s(float val) {
|
DINLINE nv_bfloat16 downcast_s(float val) {
|
||||||
return __float2bfloat16(val);
|
return __float2bfloat16(val);
|
||||||
}
|
}
|
||||||
DINLINE nv_bfloat16 &assign_add(nv_bfloat16 &a, nv_bfloat16 b) {
|
DINLINE nv_bfloat16& assign_add(nv_bfloat16& a, nv_bfloat16 b) {
|
||||||
a = __hadd(a, b);
|
a = __hadd(a, b);
|
||||||
return a;
|
return a;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
template <typename T, int N>
|
template <typename T, int N>
|
||||||
DINLINE array_t<T, N> &packed_assign_add(array_t<T, N> &a, array_t<T, N> b) {
|
DINLINE array_t<T, N>& packed_assign_add(array_t<T, N>& a, array_t<T, N> b) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < N; i++) {
|
for (int i = 0; i < N; i++) {
|
||||||
assign_add(a.data[i], b.data[i]);
|
assign_add(a.data[i], b.data[i]);
|
||||||
@ -128,7 +128,7 @@ DINLINE O downcast(array_t<float, O::size> val) {
|
|||||||
// prior memory accesses. Note: volatile writes will not be reordered against
|
// prior memory accesses. Note: volatile writes will not be reordered against
|
||||||
// other volatile writes.
|
// other volatile writes.
|
||||||
template <int ngpus>
|
template <int ngpus>
|
||||||
DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg,
|
DINLINE void start_sync(const RankSignals& sg, volatile Signal* self_sg,
|
||||||
int rank) {
|
int rank) {
|
||||||
if (threadIdx.x < ngpus) {
|
if (threadIdx.x < ngpus) {
|
||||||
// reset flag for next time
|
// reset flag for next time
|
||||||
@ -137,8 +137,7 @@ DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg,
|
|||||||
// Latency = 1 p2p write
|
// Latency = 1 p2p write
|
||||||
sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1;
|
sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1;
|
||||||
// wait until we got true from all ranks
|
// wait until we got true from all ranks
|
||||||
while (!self_sg->start[blockIdx.x][threadIdx.x])
|
while (!self_sg->start[blockIdx.x][threadIdx.x]);
|
||||||
;
|
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
}
|
}
|
||||||
@ -147,7 +146,7 @@ DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg,
|
|||||||
// barrier in the all reduce kernel. If it's the final synchronization barrier,
|
// barrier in the all reduce kernel. If it's the final synchronization barrier,
|
||||||
// we don't need to make any visibility guarantees for prior memory accesses.
|
// we don't need to make any visibility guarantees for prior memory accesses.
|
||||||
template <int ngpus, bool final_sync = false>
|
template <int ngpus, bool final_sync = false>
|
||||||
DINLINE void end_sync(const RankSignals &sg, volatile Signal *self_sg,
|
DINLINE void end_sync(const RankSignals& sg, volatile Signal* self_sg,
|
||||||
int rank) {
|
int rank) {
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
// eliminate the case that prior writes are not visible after signals become
|
// eliminate the case that prior writes are not visible after signals become
|
||||||
@ -162,14 +161,13 @@ DINLINE void end_sync(const RankSignals &sg, volatile Signal *self_sg,
|
|||||||
// Latency = 1 p2p write
|
// Latency = 1 p2p write
|
||||||
sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1;
|
sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1;
|
||||||
// wait until we got true from all ranks
|
// wait until we got true from all ranks
|
||||||
while (!self_sg->end[blockIdx.x][threadIdx.x])
|
while (!self_sg->end[blockIdx.x][threadIdx.x]);
|
||||||
;
|
|
||||||
}
|
}
|
||||||
if constexpr (!final_sync) __syncthreads();
|
if constexpr (!final_sync) __syncthreads();
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename P, int ngpus, typename A>
|
template <typename P, int ngpus, typename A>
|
||||||
DINLINE P packed_reduce(const P *ptrs[], int idx) {
|
DINLINE P packed_reduce(const P* ptrs[], int idx) {
|
||||||
A tmp = upcast(ptrs[0][idx]);
|
A tmp = upcast(ptrs[0][idx]);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 1; i < ngpus; i++) {
|
for (int i = 1; i < ngpus; i++) {
|
||||||
@ -180,8 +178,8 @@ DINLINE P packed_reduce(const P *ptrs[], int idx) {
|
|||||||
|
|
||||||
template <typename T, int ngpus>
|
template <typename T, int ngpus>
|
||||||
__global__ void __launch_bounds__(512, 1)
|
__global__ void __launch_bounds__(512, 1)
|
||||||
cross_device_reduce_1stage(RankData *_dp, RankSignals sg,
|
cross_device_reduce_1stage(RankData* _dp, RankSignals sg,
|
||||||
volatile Signal *self_sg, T *__restrict__ result,
|
volatile Signal* self_sg, T* __restrict__ result,
|
||||||
int rank, int size) {
|
int rank, int size) {
|
||||||
using P = typename packed_t<T>::P;
|
using P = typename packed_t<T>::P;
|
||||||
using A = typename packed_t<T>::A;
|
using A = typename packed_t<T>::A;
|
||||||
@ -192,21 +190,20 @@ __global__ void __launch_bounds__(512, 1)
|
|||||||
// do the actual reduction
|
// do the actual reduction
|
||||||
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
||||||
idx += gridDim.x * blockDim.x) {
|
idx += gridDim.x * blockDim.x) {
|
||||||
((P *)result)[idx] =
|
((P*)result)[idx] = packed_reduce<P, ngpus, A>((const P**)&dp.ptrs[0], idx);
|
||||||
packed_reduce<P, ngpus, A>((const P **)&dp.ptrs[0], idx);
|
|
||||||
}
|
}
|
||||||
end_sync<ngpus, true>(sg, self_sg, rank);
|
end_sync<ngpus, true>(sg, self_sg, rank);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename P>
|
template <typename P>
|
||||||
DINLINE P *get_tmp_buf(volatile Signal *sg) {
|
DINLINE P* get_tmp_buf(volatile Signal* sg) {
|
||||||
return (P *)(((Signal *)sg) + 1);
|
return (P*)(((Signal*)sg) + 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, int ngpus>
|
template <typename T, int ngpus>
|
||||||
__global__ void __launch_bounds__(512, 1)
|
__global__ void __launch_bounds__(512, 1)
|
||||||
cross_device_reduce_2stage(RankData *_dp, RankSignals sg,
|
cross_device_reduce_2stage(RankData* _dp, RankSignals sg,
|
||||||
volatile Signal *self_sg, T *__restrict__ result,
|
volatile Signal* self_sg, T* __restrict__ result,
|
||||||
int rank, int size) {
|
int rank, int size) {
|
||||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
int stride = gridDim.x * blockDim.x;
|
int stride = gridDim.x * blockDim.x;
|
||||||
@ -216,12 +213,12 @@ __global__ void __launch_bounds__(512, 1)
|
|||||||
int start = rank * part;
|
int start = rank * part;
|
||||||
int end = rank == ngpus - 1 ? size : start + part;
|
int end = rank == ngpus - 1 ? size : start + part;
|
||||||
int largest_part = part + size % ngpus;
|
int largest_part = part + size % ngpus;
|
||||||
const P *ptrs[ngpus];
|
const P* ptrs[ngpus];
|
||||||
P *tmps[ngpus];
|
P* tmps[ngpus];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < ngpus; i++) {
|
for (int i = 0; i < ngpus; i++) {
|
||||||
int target = (rank + i) % ngpus;
|
int target = (rank + i) % ngpus;
|
||||||
ptrs[i] = (const P *)_dp->ptrs[target];
|
ptrs[i] = (const P*)_dp->ptrs[target];
|
||||||
tmps[i] = get_tmp_buf<P>(sg.signals[target]);
|
tmps[i] = get_tmp_buf<P>(sg.signals[target]);
|
||||||
}
|
}
|
||||||
auto tmp_out = tmps[0];
|
auto tmp_out = tmps[0];
|
||||||
@ -243,7 +240,7 @@ __global__ void __launch_bounds__(512, 1)
|
|||||||
int gather_from_rank = ((rank + i) % ngpus);
|
int gather_from_rank = ((rank + i) % ngpus);
|
||||||
if (gather_from_rank == ngpus - 1 || idx < part) {
|
if (gather_from_rank == ngpus - 1 || idx < part) {
|
||||||
int dst_idx = gather_from_rank * part + idx;
|
int dst_idx = gather_from_rank * part + idx;
|
||||||
((P *)result)[dst_idx] = tmps[i][idx];
|
((P*)result)[dst_idx] = tmps[i][idx];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -261,14 +258,14 @@ class CustomAllreduce {
|
|||||||
|
|
||||||
// below are device pointers
|
// below are device pointers
|
||||||
RankSignals sg_;
|
RankSignals sg_;
|
||||||
std::unordered_map<void *, RankData *> buffers_;
|
std::unordered_map<void*, RankData*> buffers_;
|
||||||
Signal *self_sg_;
|
Signal* self_sg_;
|
||||||
|
|
||||||
// stores the registered device pointers from all ranks
|
// stores the registered device pointers from all ranks
|
||||||
RankData *d_rank_data_base_, *d_rank_data_end_;
|
RankData *d_rank_data_base_, *d_rank_data_end_;
|
||||||
std::vector<void *> graph_unreg_buffers_;
|
std::vector<void*> graph_unreg_buffers_;
|
||||||
// a map from IPC handles to opened IPC pointers
|
// a map from IPC handles to opened IPC pointers
|
||||||
std::map<IPC_KEY, char *> ipc_handles_;
|
std::map<IPC_KEY, char*> ipc_handles_;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* meta is a pointer to device metadata and temporary buffer for allreduce.
|
* meta is a pointer to device metadata and temporary buffer for allreduce.
|
||||||
@ -279,22 +276,22 @@ class CustomAllreduce {
|
|||||||
* note: this class does not own any device memory. Any required buffers
|
* note: this class does not own any device memory. Any required buffers
|
||||||
* are passed in from the constructor
|
* are passed in from the constructor
|
||||||
*/
|
*/
|
||||||
CustomAllreduce(Signal *meta, void *rank_data, size_t rank_data_sz,
|
CustomAllreduce(Signal* meta, void* rank_data, size_t rank_data_sz,
|
||||||
const cudaIpcMemHandle_t *handles,
|
const cudaIpcMemHandle_t* handles,
|
||||||
const std::vector<int64_t> &offsets, int rank,
|
const std::vector<int64_t>& offsets, int rank,
|
||||||
bool full_nvlink = true)
|
bool full_nvlink = true)
|
||||||
: rank_(rank),
|
: rank_(rank),
|
||||||
world_size_(offsets.size()),
|
world_size_(offsets.size()),
|
||||||
full_nvlink_(full_nvlink),
|
full_nvlink_(full_nvlink),
|
||||||
self_sg_(meta),
|
self_sg_(meta),
|
||||||
d_rank_data_base_(reinterpret_cast<RankData *>(rank_data)),
|
d_rank_data_base_(reinterpret_cast<RankData*>(rank_data)),
|
||||||
d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) {
|
d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) {
|
||||||
for (int i = 0; i < world_size_; i++) {
|
for (int i = 0; i < world_size_; i++) {
|
||||||
Signal *rank_sg;
|
Signal* rank_sg;
|
||||||
if (i != rank_) {
|
if (i != rank_) {
|
||||||
char *handle = open_ipc_handle(&handles[i]);
|
char* handle = open_ipc_handle(&handles[i]);
|
||||||
handle += offsets[i];
|
handle += offsets[i];
|
||||||
rank_sg = (Signal *)handle;
|
rank_sg = (Signal*)handle;
|
||||||
} else {
|
} else {
|
||||||
rank_sg = self_sg_;
|
rank_sg = self_sg_;
|
||||||
}
|
}
|
||||||
@ -302,13 +299,13 @@ class CustomAllreduce {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
char *open_ipc_handle(const void *ipc_handle) {
|
char* open_ipc_handle(const void* ipc_handle) {
|
||||||
auto [it, new_handle] =
|
auto [it, new_handle] =
|
||||||
ipc_handles_.insert({*((IPC_KEY *)ipc_handle), nullptr});
|
ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr});
|
||||||
if (new_handle) {
|
if (new_handle) {
|
||||||
char *ipc_ptr;
|
char* ipc_ptr;
|
||||||
CUDACHECK(cudaIpcOpenMemHandle((void **)&ipc_ptr,
|
CUDACHECK(cudaIpcOpenMemHandle((void**)&ipc_ptr,
|
||||||
*((const cudaIpcMemHandle_t *)ipc_handle),
|
*((const cudaIpcMemHandle_t*)ipc_handle),
|
||||||
cudaIpcMemLazyEnablePeerAccess));
|
cudaIpcMemLazyEnablePeerAccess));
|
||||||
it->second = ipc_ptr;
|
it->second = ipc_ptr;
|
||||||
}
|
}
|
||||||
@ -323,7 +320,7 @@ class CustomAllreduce {
|
|||||||
std::vector<int64_t> offsets(num_buffers);
|
std::vector<int64_t> offsets(num_buffers);
|
||||||
for (int i = 0; i < num_buffers; i++) {
|
for (int i = 0; i < num_buffers; i++) {
|
||||||
auto ptr = graph_unreg_buffers_[i];
|
auto ptr = graph_unreg_buffers_[i];
|
||||||
void *base_ptr;
|
void* base_ptr;
|
||||||
// note: must share the base address of each allocation, or we get wrong
|
// note: must share the base address of each allocation, or we get wrong
|
||||||
// address
|
// address
|
||||||
if (cuPointerGetAttribute(&base_ptr,
|
if (cuPointerGetAttribute(&base_ptr,
|
||||||
@ -331,8 +328,8 @@ class CustomAllreduce {
|
|||||||
(CUdeviceptr)ptr) != CUDA_SUCCESS)
|
(CUdeviceptr)ptr) != CUDA_SUCCESS)
|
||||||
throw std::runtime_error("failed to get pointer attr");
|
throw std::runtime_error("failed to get pointer attr");
|
||||||
CUDACHECK(cudaIpcGetMemHandle(
|
CUDACHECK(cudaIpcGetMemHandle(
|
||||||
(cudaIpcMemHandle_t *)&handles[i * handle_sz], base_ptr));
|
(cudaIpcMemHandle_t*)&handles[i * handle_sz], base_ptr));
|
||||||
offsets[i] = ((char *)ptr) - ((char *)base_ptr);
|
offsets[i] = ((char*)ptr) - ((char*)base_ptr);
|
||||||
}
|
}
|
||||||
return std::make_pair(handles, offsets);
|
return std::make_pair(handles, offsets);
|
||||||
}
|
}
|
||||||
@ -344,13 +341,13 @@ class CustomAllreduce {
|
|||||||
std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
|
std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
|
||||||
}
|
}
|
||||||
|
|
||||||
void register_buffer(const std::vector<std::string> &handles,
|
void register_buffer(const std::vector<std::string>& handles,
|
||||||
const std::vector<int64_t> &offsets, void *self) {
|
const std::vector<int64_t>& offsets, void* self) {
|
||||||
check_rank_data_capacity();
|
check_rank_data_capacity();
|
||||||
RankData data;
|
RankData data;
|
||||||
for (int i = 0; i < world_size_; i++) {
|
for (int i = 0; i < world_size_; i++) {
|
||||||
if (i != rank_) {
|
if (i != rank_) {
|
||||||
char *handle = open_ipc_handle(handles[i].data());
|
char* handle = open_ipc_handle(handles[i].data());
|
||||||
handle += offsets[i];
|
handle += offsets[i];
|
||||||
data.ptrs[i] = handle;
|
data.ptrs[i] = handle;
|
||||||
} else {
|
} else {
|
||||||
@ -371,17 +368,17 @@ class CustomAllreduce {
|
|||||||
// got a different address. IPC handles have internal reference counting
|
// got a different address. IPC handles have internal reference counting
|
||||||
// mechanism so overhead should be small.
|
// mechanism so overhead should be small.
|
||||||
void register_graph_buffers(
|
void register_graph_buffers(
|
||||||
const std::vector<std::string> &handles,
|
const std::vector<std::string>& handles,
|
||||||
const std::vector<std::vector<int64_t>> &offsets) {
|
const std::vector<std::vector<int64_t>>& offsets) {
|
||||||
auto num_buffers = graph_unreg_buffers_.size();
|
auto num_buffers = graph_unreg_buffers_.size();
|
||||||
check_rank_data_capacity(num_buffers);
|
check_rank_data_capacity(num_buffers);
|
||||||
std::vector<RankData> rank_data(num_buffers);
|
std::vector<RankData> rank_data(num_buffers);
|
||||||
for (int i = 0; i < num_buffers; i++) {
|
for (int i = 0; i < num_buffers; i++) {
|
||||||
auto self_ptr = graph_unreg_buffers_[i];
|
auto self_ptr = graph_unreg_buffers_[i];
|
||||||
auto &rd = rank_data[i];
|
auto& rd = rank_data[i];
|
||||||
for (int j = 0; j < world_size_; j++) {
|
for (int j = 0; j < world_size_; j++) {
|
||||||
if (j != rank_) {
|
if (j != rank_) {
|
||||||
char *handle =
|
char* handle =
|
||||||
open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]);
|
open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]);
|
||||||
handle += offsets[j][i];
|
handle += offsets[j][i];
|
||||||
rd.ptrs[j] = handle;
|
rd.ptrs[j] = handle;
|
||||||
@ -405,7 +402,7 @@ class CustomAllreduce {
|
|||||||
* will cause contention on NVLink bus.
|
* will cause contention on NVLink bus.
|
||||||
*/
|
*/
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void allreduce(cudaStream_t stream, T *input, T *output, int size,
|
void allreduce(cudaStream_t stream, T* input, T* output, int size,
|
||||||
int threads = 512, int block_limit = 36) {
|
int threads = 512, int block_limit = 36) {
|
||||||
auto d = packed_t<T>::P::size;
|
auto d = packed_t<T>::P::size;
|
||||||
if (size % d != 0)
|
if (size % d != 0)
|
||||||
@ -418,7 +415,7 @@ class CustomAllreduce {
|
|||||||
std::to_string(kMaxBlocks) + ". Got " +
|
std::to_string(kMaxBlocks) + ". Got " +
|
||||||
std::to_string(block_limit));
|
std::to_string(block_limit));
|
||||||
|
|
||||||
RankData *ptrs;
|
RankData* ptrs;
|
||||||
cudaStreamCaptureStatus status;
|
cudaStreamCaptureStatus status;
|
||||||
CUDACHECK(cudaStreamIsCapturing(stream, &status));
|
CUDACHECK(cudaStreamIsCapturing(stream, &status));
|
||||||
if (status == cudaStreamCaptureStatusActive) {
|
if (status == cudaStreamCaptureStatusActive) {
|
||||||
|
@ -48,7 +48,7 @@ __global__ void dummy_kernel() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__global__ void set_data(T *data, int size, int myRank) {
|
__global__ void set_data(T* data, int size, int myRank) {
|
||||||
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
||||||
idx += gridDim.x * blockDim.x) {
|
idx += gridDim.x * blockDim.x) {
|
||||||
data[idx] = myRank * 0.11f;
|
data[idx] = myRank * 0.11f;
|
||||||
@ -56,8 +56,8 @@ __global__ void set_data(T *data, int size, int myRank) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__global__ void convert_data(const T *data1, const T *data2, double *fdata1,
|
__global__ void convert_data(const T* data1, const T* data2, double* fdata1,
|
||||||
double *fdata2, int size) {
|
double* fdata2, int size) {
|
||||||
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
||||||
idx += gridDim.x * blockDim.x) {
|
idx += gridDim.x * blockDim.x) {
|
||||||
fdata1[idx] = data1[idx];
|
fdata1[idx] = data1[idx];
|
||||||
@ -65,7 +65,7 @@ __global__ void convert_data(const T *data1, const T *data2, double *fdata1,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
__global__ void init_rand(curandState_t *state, int size, int nRanks) {
|
__global__ void init_rand(curandState_t* state, int size, int nRanks) {
|
||||||
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
||||||
idx += gridDim.x * blockDim.x) {
|
idx += gridDim.x * blockDim.x) {
|
||||||
for (int i = 0; i < nRanks; i++) {
|
for (int i = 0; i < nRanks; i++) {
|
||||||
@ -75,7 +75,7 @@ __global__ void init_rand(curandState_t *state, int size, int nRanks) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__global__ void gen_data(curandState_t *state, T *data, double *ground_truth,
|
__global__ void gen_data(curandState_t* state, T* data, double* ground_truth,
|
||||||
int myRank, int nRanks, int size) {
|
int myRank, int nRanks, int size) {
|
||||||
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
||||||
idx += gridDim.x * blockDim.x) {
|
idx += gridDim.x * blockDim.x) {
|
||||||
@ -91,9 +91,9 @@ __global__ void gen_data(curandState_t *state, T *data, double *ground_truth,
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
|
void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
|
||||||
int data_size, bool performance_test) {
|
int data_size, bool performance_test) {
|
||||||
T *result;
|
T* result;
|
||||||
cudaStream_t stream;
|
cudaStream_t stream;
|
||||||
CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
|
CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
|
||||||
CUDACHECK(cudaMalloc(&result, data_size * sizeof(T)));
|
CUDACHECK(cudaMalloc(&result, data_size * sizeof(T)));
|
||||||
@ -101,8 +101,8 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
|
|||||||
|
|
||||||
cudaIpcMemHandle_t self_data_handle;
|
cudaIpcMemHandle_t self_data_handle;
|
||||||
cudaIpcMemHandle_t data_handles[8];
|
cudaIpcMemHandle_t data_handles[8];
|
||||||
vllm::Signal *buffer;
|
vllm::Signal* buffer;
|
||||||
T *self_data_copy;
|
T* self_data_copy;
|
||||||
/**
|
/**
|
||||||
* Allocate IPC buffer
|
* Allocate IPC buffer
|
||||||
*
|
*
|
||||||
@ -125,22 +125,22 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
|
|||||||
MPI_BYTE, data_handles, sizeof(cudaIpcMemHandle_t),
|
MPI_BYTE, data_handles, sizeof(cudaIpcMemHandle_t),
|
||||||
MPI_BYTE, MPI_COMM_WORLD));
|
MPI_BYTE, MPI_COMM_WORLD));
|
||||||
|
|
||||||
void *rank_data;
|
void* rank_data;
|
||||||
size_t rank_data_sz = 16 * 1024 * 1024;
|
size_t rank_data_sz = 16 * 1024 * 1024;
|
||||||
CUDACHECK(cudaMalloc(&rank_data, rank_data_sz));
|
CUDACHECK(cudaMalloc(&rank_data, rank_data_sz));
|
||||||
std::vector<int64_t> offsets(nRanks, 0);
|
std::vector<int64_t> offsets(nRanks, 0);
|
||||||
vllm::CustomAllreduce fa(buffer, rank_data, rank_data_sz, data_handles,
|
vllm::CustomAllreduce fa(buffer, rank_data, rank_data_sz, data_handles,
|
||||||
offsets, myRank);
|
offsets, myRank);
|
||||||
auto *self_data =
|
auto* self_data =
|
||||||
reinterpret_cast<T *>(reinterpret_cast<char *>(buffer) +
|
reinterpret_cast<T*>(reinterpret_cast<char*>(buffer) +
|
||||||
sizeof(vllm::Signal) + data_size * sizeof(T));
|
sizeof(vllm::Signal) + data_size * sizeof(T));
|
||||||
// hack buffer registration
|
// hack buffer registration
|
||||||
{
|
{
|
||||||
std::vector<std::string> handles;
|
std::vector<std::string> handles;
|
||||||
handles.reserve(nRanks);
|
handles.reserve(nRanks);
|
||||||
for (int i = 0; i < nRanks; i++) {
|
for (int i = 0; i < nRanks; i++) {
|
||||||
char *begin = (char *)&data_handles[i];
|
char* begin = (char*)&data_handles[i];
|
||||||
char *end = (char *)&data_handles[i + 1];
|
char* end = (char*)&data_handles[i + 1];
|
||||||
handles.emplace_back(begin, end);
|
handles.emplace_back(begin, end);
|
||||||
}
|
}
|
||||||
std::vector<int64_t> offsets(nRanks,
|
std::vector<int64_t> offsets(nRanks,
|
||||||
@ -148,9 +148,9 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
|
|||||||
fa.register_buffer(handles, offsets, self_data);
|
fa.register_buffer(handles, offsets, self_data);
|
||||||
}
|
}
|
||||||
|
|
||||||
double *ground_truth;
|
double* ground_truth;
|
||||||
CUDACHECK(cudaMallocHost(&ground_truth, data_size * sizeof(double)));
|
CUDACHECK(cudaMallocHost(&ground_truth, data_size * sizeof(double)));
|
||||||
curandState_t *states;
|
curandState_t* states;
|
||||||
CUDACHECK(cudaMalloc(&states, sizeof(curandState_t) * nRanks * data_size));
|
CUDACHECK(cudaMalloc(&states, sizeof(curandState_t) * nRanks * data_size));
|
||||||
init_rand<<<108, 1024, 0, stream>>>(states, data_size, nRanks);
|
init_rand<<<108, 1024, 0, stream>>>(states, data_size, nRanks);
|
||||||
gen_data<T><<<108, 1024, 0, stream>>>(states, self_data, ground_truth, myRank,
|
gen_data<T><<<108, 1024, 0, stream>>>(states, self_data, ground_truth, myRank,
|
||||||
@ -287,7 +287,7 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
|
|||||||
CUDACHECK(cudaStreamDestroy(stream));
|
CUDACHECK(cudaStreamDestroy(stream));
|
||||||
}
|
}
|
||||||
|
|
||||||
int main(int argc, char **argv) {
|
int main(int argc, char** argv) {
|
||||||
int nRanks, myRank;
|
int nRanks, myRank;
|
||||||
MPICHECK(MPI_Init(&argc, &argv));
|
MPICHECK(MPI_Init(&argc, &argv));
|
||||||
MPICHECK(MPI_Comm_rank(MPI_COMM_WORLD, &myRank));
|
MPICHECK(MPI_Comm_rank(MPI_COMM_WORLD, &myRank));
|
||||||
@ -296,7 +296,7 @@ int main(int argc, char **argv) {
|
|||||||
ncclUniqueId id;
|
ncclUniqueId id;
|
||||||
ncclComm_t comm;
|
ncclComm_t comm;
|
||||||
if (myRank == 0) ncclGetUniqueId(&id);
|
if (myRank == 0) ncclGetUniqueId(&id);
|
||||||
MPICHECK(MPI_Bcast(static_cast<void *>(&id), sizeof(id), MPI_BYTE, 0,
|
MPICHECK(MPI_Bcast(static_cast<void*>(&id), sizeof(id), MPI_BYTE, 0,
|
||||||
MPI_COMM_WORLD));
|
MPI_COMM_WORLD));
|
||||||
NCCLCHECK(ncclCommInitRank(&comm, nRanks, id, myRank));
|
NCCLCHECK(ncclCommInitRank(&comm, nRanks, id, myRank));
|
||||||
|
|
||||||
|
@ -6,32 +6,30 @@
|
|||||||
|
|
||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
|
|
||||||
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||||
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
||||||
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
|
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
|
||||||
|
|
||||||
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
||||||
AT_DISPATCH_SWITCH( \
|
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
||||||
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
|
||||||
|
|
||||||
#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
|
#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
|
||||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||||
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
||||||
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
|
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
|
||||||
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
|
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
|
||||||
|
|
||||||
#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \
|
#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \
|
||||||
AT_DISPATCH_SWITCH( \
|
AT_DISPATCH_SWITCH(TYPE, NAME, \
|
||||||
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))
|
VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))
|
||||||
|
|
||||||
#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \
|
#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \
|
||||||
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
|
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
|
||||||
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
|
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
|
||||||
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
|
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
|
||||||
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
|
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
|
||||||
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
|
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
|
||||||
|
|
||||||
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
|
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
|
||||||
AT_DISPATCH_SWITCH( \
|
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
|
||||||
TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
|
|
||||||
|
@ -11,26 +11,24 @@
|
|||||||
#include <hip/hip_bf16.h>
|
#include <hip/hip_bf16.h>
|
||||||
#include <hip/hip_fp16.h>
|
#include <hip/hip_fp16.h>
|
||||||
|
|
||||||
using __nv_bfloat16 = __hip_bfloat16;
|
using __nv_bfloat16 = __hip_bfloat16;
|
||||||
using __nv_bfloat162 = __hip_bfloat162;
|
using __nv_bfloat162 = __hip_bfloat162;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
// TODO(woosuk): Further optimize this kernel.
|
// TODO(woosuk): Further optimize this kernel.
|
||||||
template<typename scalar_t>
|
template <typename scalar_t>
|
||||||
__global__ void rms_norm_kernel(
|
__global__ void rms_norm_kernel(
|
||||||
scalar_t* __restrict__ out, // [..., hidden_size]
|
scalar_t* __restrict__ out, // [..., hidden_size]
|
||||||
const scalar_t* __restrict__ input, // [..., hidden_size]
|
const scalar_t* __restrict__ input, // [..., hidden_size]
|
||||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||||
const float epsilon,
|
const float epsilon, const int num_tokens, const int hidden_size) {
|
||||||
const int num_tokens,
|
|
||||||
const int hidden_size) {
|
|
||||||
__shared__ float s_variance;
|
__shared__ float s_variance;
|
||||||
float variance = 0.0f;
|
float variance = 0.0f;
|
||||||
|
|
||||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||||
const float x = (float) input[blockIdx.x * hidden_size + idx];
|
const float x = (float)input[blockIdx.x * hidden_size + idx];
|
||||||
variance += x * x;
|
variance += x * x;
|
||||||
}
|
}
|
||||||
variance = blockReduceSum<float>(variance);
|
variance = blockReduceSum<float>(variance);
|
||||||
@ -40,12 +38,12 @@ __global__ void rms_norm_kernel(
|
|||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||||
float x = (float) input[blockIdx.x * hidden_size + idx];
|
float x = (float)input[blockIdx.x * hidden_size + idx];
|
||||||
out[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx];
|
out[blockIdx.x * hidden_size + idx] =
|
||||||
|
((scalar_t)(x * s_variance)) * weight[idx];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/* Converter structs for the conversion from torch types to HIP/CUDA types,
|
/* Converter structs for the conversion from torch types to HIP/CUDA types,
|
||||||
and the associated type conversions within HIP/CUDA. These helpers need
|
and the associated type conversions within HIP/CUDA. These helpers need
|
||||||
to be implemented for now because the relevant type conversion
|
to be implemented for now because the relevant type conversion
|
||||||
@ -56,46 +54,63 @@ __global__ void rms_norm_kernel(
|
|||||||
If false, the optimized kernel is not used for the corresponding torch type.
|
If false, the optimized kernel is not used for the corresponding torch type.
|
||||||
If true, the struct should be fully defined as shown in the examples below.
|
If true, the struct should be fully defined as shown in the examples below.
|
||||||
*/
|
*/
|
||||||
template<typename torch_type>
|
template <typename torch_type>
|
||||||
struct _typeConvert { static constexpr bool exists = false; };
|
struct _typeConvert {
|
||||||
|
static constexpr bool exists = false;
|
||||||
|
};
|
||||||
|
|
||||||
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
|
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
|
||||||
// CUDA < 12.0 runs into issues with packed type conversion
|
// CUDA < 12.0 runs into issues with packed type conversion
|
||||||
template<>
|
template <>
|
||||||
struct _typeConvert<c10::Half> {
|
struct _typeConvert<c10::Half> {
|
||||||
static constexpr bool exists = true;
|
static constexpr bool exists = true;
|
||||||
using hip_type = __half;
|
using hip_type = __half;
|
||||||
using packed_hip_type = __half2;
|
using packed_hip_type = __half2;
|
||||||
|
|
||||||
__device__ static inline float convert(hip_type x) { return __half2float(x); }
|
__device__ static inline float convert(hip_type x) { return __half2float(x); }
|
||||||
__device__ static inline float2 convert(packed_hip_type x) { return __half22float2(x); }
|
__device__ static inline float2 convert(packed_hip_type x) {
|
||||||
__device__ static inline hip_type convert(float x) { return __float2half_rn(x); }
|
return __half22float2(x);
|
||||||
__device__ static inline packed_hip_type convert(float2 x) { return __float22half2_rn(x); }
|
}
|
||||||
|
__device__ static inline hip_type convert(float x) {
|
||||||
|
return __float2half_rn(x);
|
||||||
|
}
|
||||||
|
__device__ static inline packed_hip_type convert(float2 x) {
|
||||||
|
return __float22half2_rn(x);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||||
// CUDA_ARCH < 800 does not have BF16 support
|
// CUDA_ARCH < 800 does not have BF16 support
|
||||||
// TODO: Add in ROCm support once public headers handle bf16 maturely
|
// TODO: Add in ROCm support once public headers handle bf16 maturely
|
||||||
template<>
|
template <>
|
||||||
struct _typeConvert<c10::BFloat16> {
|
struct _typeConvert<c10::BFloat16> {
|
||||||
static constexpr bool exists = true;
|
static constexpr bool exists = true;
|
||||||
using hip_type = __nv_bfloat16;
|
using hip_type = __nv_bfloat16;
|
||||||
using packed_hip_type = __nv_bfloat162;
|
using packed_hip_type = __nv_bfloat162;
|
||||||
|
|
||||||
__device__ static inline float convert(hip_type x) { return __bfloat162float(x); }
|
__device__ static inline float convert(hip_type x) {
|
||||||
__device__ static inline float2 convert(packed_hip_type x) { return __bfloat1622float2(x); }
|
return __bfloat162float(x);
|
||||||
__device__ static inline hip_type convert(float x) { return __float2bfloat16(x); }
|
}
|
||||||
__device__ static inline packed_hip_type convert(float2 x) { return __float22bfloat162_rn(x); }
|
__device__ static inline float2 convert(packed_hip_type x) {
|
||||||
|
return __bfloat1622float2(x);
|
||||||
|
}
|
||||||
|
__device__ static inline hip_type convert(float x) {
|
||||||
|
return __float2bfloat16(x);
|
||||||
|
}
|
||||||
|
__device__ static inline packed_hip_type convert(float2 x) {
|
||||||
|
return __float22bfloat162_rn(x);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||||
#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
|
#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >=
|
||||||
|
// 12000))
|
||||||
|
|
||||||
/* Vector POD struct to generate vectorized and packed FP16/BF16 ops
|
/* Vector POD struct to generate vectorized and packed FP16/BF16 ops
|
||||||
for appropriate specializations of fused_add_rms_norm_kernel.
|
for appropriate specializations of fused_add_rms_norm_kernel.
|
||||||
Only functions that are necessary in that kernel are implemented.
|
Only functions that are necessary in that kernel are implemented.
|
||||||
Alignment to 16 bytes is required to use 128-bit global memory ops.
|
Alignment to 16 bytes is required to use 128-bit global memory ops.
|
||||||
*/
|
*/
|
||||||
template<typename scalar_t, int width>
|
template <typename scalar_t, int width>
|
||||||
struct alignas(16) _f16Vec {
|
struct alignas(16) _f16Vec {
|
||||||
/* Not theoretically necessary that width is a power of 2 but should
|
/* Not theoretically necessary that width is a power of 2 but should
|
||||||
almost always be the case for optimization purposes */
|
almost always be the case for optimization purposes */
|
||||||
@ -108,51 +123,49 @@ struct alignas(16) _f16Vec {
|
|||||||
|
|
||||||
__device__ _f16Vec& operator+=(const _f16Vec<scalar_t, width>& other) {
|
__device__ _f16Vec& operator+=(const _f16Vec<scalar_t, width>& other) {
|
||||||
if constexpr (width % 2 == 0) {
|
if constexpr (width % 2 == 0) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < width; i += 2) {
|
for (int i = 0; i < width; i += 2) {
|
||||||
T2 temp{data[i], data[i+1]};
|
T2 temp{data[i], data[i + 1]};
|
||||||
temp += T2{other.data[i], other.data[i+1]};
|
temp += T2{other.data[i], other.data[i + 1]};
|
||||||
data[i] = temp.x;
|
data[i] = temp.x;
|
||||||
data[i+1] = temp.y;
|
data[i + 1] = temp.y;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < width; ++i)
|
for (int i = 0; i < width; ++i) data[i] += other.data[i];
|
||||||
data[i] += other.data[i];
|
|
||||||
}
|
}
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ _f16Vec& operator*=(const _f16Vec<scalar_t, width>& other) {
|
__device__ _f16Vec& operator*=(const _f16Vec<scalar_t, width>& other) {
|
||||||
if constexpr (width % 2 == 0) {
|
if constexpr (width % 2 == 0) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < width; i += 2) {
|
for (int i = 0; i < width; i += 2) {
|
||||||
T2 temp{data[i], data[i+1]};
|
T2 temp{data[i], data[i + 1]};
|
||||||
temp *= T2{other.data[i], other.data[i+1]};
|
temp *= T2{other.data[i], other.data[i + 1]};
|
||||||
data[i] = temp.x;
|
data[i] = temp.x;
|
||||||
data[i+1] = temp.y;
|
data[i + 1] = temp.y;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < width; ++i)
|
for (int i = 0; i < width; ++i) data[i] *= other.data[i];
|
||||||
data[i] *= other.data[i];
|
|
||||||
}
|
}
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ _f16Vec& operator*=(const float scale) {
|
__device__ _f16Vec& operator*=(const float scale) {
|
||||||
if constexpr (width % 2 == 0) {
|
if constexpr (width % 2 == 0) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < width; i += 2) {
|
for (int i = 0; i < width; i += 2) {
|
||||||
float2 temp_f = Converter::convert(T2{data[i], data[i+1]});
|
float2 temp_f = Converter::convert(T2{data[i], data[i + 1]});
|
||||||
temp_f.x *= scale;
|
temp_f.x *= scale;
|
||||||
temp_f.y *= scale;
|
temp_f.y *= scale;
|
||||||
T2 temp = Converter::convert(temp_f);
|
T2 temp = Converter::convert(temp_f);
|
||||||
data[i] = temp.x;
|
data[i] = temp.x;
|
||||||
data[i+1] = temp.y;
|
data[i + 1] = temp.y;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < width; ++i) {
|
for (int i = 0; i < width; ++i) {
|
||||||
float temp = Converter::convert(data[i]) * scale;
|
float temp = Converter::convert(data[i]) * scale;
|
||||||
data[i] = Converter::convert(temp);
|
data[i] = Converter::convert(temp);
|
||||||
@ -164,13 +177,13 @@ struct alignas(16) _f16Vec {
|
|||||||
__device__ float sum_squares() const {
|
__device__ float sum_squares() const {
|
||||||
float result = 0.0f;
|
float result = 0.0f;
|
||||||
if constexpr (width % 2 == 0) {
|
if constexpr (width % 2 == 0) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < width; i += 2) {
|
for (int i = 0; i < width; i += 2) {
|
||||||
float2 z = Converter::convert(T2{data[i], data[i+1]});
|
float2 z = Converter::convert(T2{data[i], data[i + 1]});
|
||||||
result += z.x * z.x + z.y * z.y;
|
result += z.x * z.x + z.y * z.y;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < width; ++i) {
|
for (int i = 0; i < width; ++i) {
|
||||||
float x = Converter::convert(data[i]);
|
float x = Converter::convert(data[i]);
|
||||||
result += x * x;
|
result += x * x;
|
||||||
@ -184,15 +197,13 @@ struct alignas(16) _f16Vec {
|
|||||||
Additional optimizations we can make in this case are
|
Additional optimizations we can make in this case are
|
||||||
packed and vectorized operations, which help with the
|
packed and vectorized operations, which help with the
|
||||||
memory latency bottleneck. */
|
memory latency bottleneck. */
|
||||||
template<typename scalar_t, int width>
|
template <typename scalar_t, int width>
|
||||||
__global__ std::enable_if_t<
|
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
|
||||||
(width > 0) && _typeConvert<scalar_t>::exists> fused_add_rms_norm_kernel(
|
fused_add_rms_norm_kernel(
|
||||||
scalar_t* __restrict__ input, // [..., hidden_size]
|
scalar_t* __restrict__ input, // [..., hidden_size]
|
||||||
scalar_t* __restrict__ residual, // [..., hidden_size]
|
scalar_t* __restrict__ residual, // [..., hidden_size]
|
||||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||||
const float epsilon,
|
const float epsilon, const int num_tokens, const int hidden_size) {
|
||||||
const int num_tokens,
|
|
||||||
const int hidden_size) {
|
|
||||||
// Sanity checks on our vector struct and type-punned pointer arithmetic
|
// Sanity checks on our vector struct and type-punned pointer arithmetic
|
||||||
static_assert(std::is_pod_v<_f16Vec<scalar_t, width>>);
|
static_assert(std::is_pod_v<_f16Vec<scalar_t, width>>);
|
||||||
static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width);
|
static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width);
|
||||||
@ -203,9 +214,12 @@ __global__ std::enable_if_t<
|
|||||||
/* These and the argument pointers are all declared `restrict` as they are
|
/* These and the argument pointers are all declared `restrict` as they are
|
||||||
not aliased in practice. Argument pointers should not be dereferenced
|
not aliased in practice. Argument pointers should not be dereferenced
|
||||||
in this kernel as that would be undefined behavior */
|
in this kernel as that would be undefined behavior */
|
||||||
auto* __restrict__ input_v = reinterpret_cast<_f16Vec<scalar_t, width>*>(input);
|
auto* __restrict__ input_v =
|
||||||
auto* __restrict__ residual_v = reinterpret_cast<_f16Vec<scalar_t, width>*>(residual);
|
reinterpret_cast<_f16Vec<scalar_t, width>*>(input);
|
||||||
auto* __restrict__ weight_v = reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight);
|
auto* __restrict__ residual_v =
|
||||||
|
reinterpret_cast<_f16Vec<scalar_t, width>*>(residual);
|
||||||
|
auto* __restrict__ weight_v =
|
||||||
|
reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight);
|
||||||
|
|
||||||
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
|
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
|
||||||
int id = blockIdx.x * vec_hidden_size + idx;
|
int id = blockIdx.x * vec_hidden_size + idx;
|
||||||
@ -218,7 +232,8 @@ __global__ std::enable_if_t<
|
|||||||
calculation of max_block_size in fused_add_rms_norm */
|
calculation of max_block_size in fused_add_rms_norm */
|
||||||
if (num_tokens < 256) {
|
if (num_tokens < 256) {
|
||||||
variance = blockReduceSum<float, 1024>(variance);
|
variance = blockReduceSum<float, 1024>(variance);
|
||||||
} else variance = blockReduceSum<float, 256>(variance);
|
} else
|
||||||
|
variance = blockReduceSum<float, 256>(variance);
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
||||||
}
|
}
|
||||||
@ -233,26 +248,23 @@ __global__ std::enable_if_t<
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/* Generic fused_add_rms_norm_kernel
|
/* Generic fused_add_rms_norm_kernel
|
||||||
The width field is not used here but necessary for other specializations.
|
The width field is not used here but necessary for other specializations.
|
||||||
*/
|
*/
|
||||||
template<typename scalar_t, int width>
|
template <typename scalar_t, int width>
|
||||||
__global__ std::enable_if_t<
|
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
|
||||||
(width == 0) || !_typeConvert<scalar_t>::exists> fused_add_rms_norm_kernel(
|
fused_add_rms_norm_kernel(
|
||||||
scalar_t* __restrict__ input, // [..., hidden_size]
|
scalar_t* __restrict__ input, // [..., hidden_size]
|
||||||
scalar_t* __restrict__ residual, // [..., hidden_size]
|
scalar_t* __restrict__ residual, // [..., hidden_size]
|
||||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||||
const float epsilon,
|
const float epsilon, const int num_tokens, const int hidden_size) {
|
||||||
const int num_tokens,
|
|
||||||
const int hidden_size) {
|
|
||||||
__shared__ float s_variance;
|
__shared__ float s_variance;
|
||||||
float variance = 0.0f;
|
float variance = 0.0f;
|
||||||
|
|
||||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||||
scalar_t z = input[blockIdx.x * hidden_size + idx];
|
scalar_t z = input[blockIdx.x * hidden_size + idx];
|
||||||
z += residual[blockIdx.x * hidden_size + idx];
|
z += residual[blockIdx.x * hidden_size + idx];
|
||||||
float x = (float) z;
|
float x = (float)z;
|
||||||
variance += x * x;
|
variance += x * x;
|
||||||
residual[blockIdx.x * hidden_size + idx] = z;
|
residual[blockIdx.x * hidden_size + idx] = z;
|
||||||
}
|
}
|
||||||
@ -260,25 +272,26 @@ __global__ std::enable_if_t<
|
|||||||
calculation of max_block_size in fused_add_rms_norm */
|
calculation of max_block_size in fused_add_rms_norm */
|
||||||
if (num_tokens < 256) {
|
if (num_tokens < 256) {
|
||||||
variance = blockReduceSum<float, 1024>(variance);
|
variance = blockReduceSum<float, 1024>(variance);
|
||||||
} else variance = blockReduceSum<float, 256>(variance);
|
} else
|
||||||
|
variance = blockReduceSum<float, 256>(variance);
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||||
float x = (float) residual[blockIdx.x * hidden_size + idx];
|
float x = (float)residual[blockIdx.x * hidden_size + idx];
|
||||||
input[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx];
|
input[blockIdx.x * hidden_size + idx] =
|
||||||
|
((scalar_t)(x * s_variance)) * weight[idx];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
void rms_norm(
|
void rms_norm(torch::Tensor& out, // [..., hidden_size]
|
||||||
torch::Tensor& out, // [..., hidden_size]
|
torch::Tensor& input, // [..., hidden_size]
|
||||||
torch::Tensor& input, // [..., hidden_size]
|
torch::Tensor& weight, // [hidden_size]
|
||||||
torch::Tensor& weight, // [hidden_size]
|
float epsilon) {
|
||||||
float epsilon) {
|
|
||||||
int hidden_size = input.size(-1);
|
int hidden_size = input.size(-1);
|
||||||
int num_tokens = input.numel() / hidden_size;
|
int num_tokens = input.numel() / hidden_size;
|
||||||
|
|
||||||
@ -286,40 +299,27 @@ void rms_norm(
|
|||||||
dim3 block(std::min(hidden_size, 1024));
|
dim3 block(std::min(hidden_size, 1024));
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(
|
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
|
||||||
input.scalar_type(),
|
vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||||
"rms_norm_kernel",
|
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
|
||||||
[&] {
|
weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
|
||||||
vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
});
|
||||||
out.data_ptr<scalar_t>(),
|
|
||||||
input.data_ptr<scalar_t>(),
|
|
||||||
weight.data_ptr<scalar_t>(),
|
|
||||||
epsilon,
|
|
||||||
num_tokens,
|
|
||||||
hidden_size);
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
|
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
|
||||||
VLLM_DISPATCH_FLOATING_TYPES( \
|
VLLM_DISPATCH_FLOATING_TYPES( \
|
||||||
input.scalar_type(), \
|
input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
|
||||||
"fused_add_rms_norm_kernel", \
|
vllm::fused_add_rms_norm_kernel<scalar_t, width> \
|
||||||
[&] { \
|
<<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(), \
|
||||||
vllm::fused_add_rms_norm_kernel \
|
residual.data_ptr<scalar_t>(), \
|
||||||
<scalar_t, width><<<grid, block, 0, stream>>>( \
|
weight.data_ptr<scalar_t>(), epsilon, \
|
||||||
input.data_ptr<scalar_t>(), \
|
num_tokens, hidden_size); \
|
||||||
residual.data_ptr<scalar_t>(), \
|
});
|
||||||
weight.data_ptr<scalar_t>(), \
|
|
||||||
epsilon, \
|
|
||||||
num_tokens, \
|
|
||||||
hidden_size); \
|
|
||||||
});
|
|
||||||
|
|
||||||
void fused_add_rms_norm(
|
void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
|
||||||
torch::Tensor& input, // [..., hidden_size]
|
torch::Tensor& residual, // [..., hidden_size]
|
||||||
torch::Tensor& residual, // [..., hidden_size]
|
torch::Tensor& weight, // [hidden_size]
|
||||||
torch::Tensor& weight, // [hidden_size]
|
float epsilon) {
|
||||||
float epsilon) {
|
|
||||||
int hidden_size = input.size(-1);
|
int hidden_size = input.size(-1);
|
||||||
int num_tokens = input.numel() / hidden_size;
|
int num_tokens = input.numel() / hidden_size;
|
||||||
|
|
||||||
@ -342,8 +342,8 @@ void fused_add_rms_norm(
|
|||||||
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
|
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
|
||||||
auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr());
|
auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr());
|
||||||
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
|
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
|
||||||
bool ptrs_are_aligned = inp_ptr % 16 == 0 && res_ptr % 16 == 0 \
|
bool ptrs_are_aligned =
|
||||||
&& wt_ptr % 16 == 0;
|
inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
|
||||||
if (ptrs_are_aligned && hidden_size % 8 == 0) {
|
if (ptrs_are_aligned && hidden_size % 8 == 0) {
|
||||||
LAUNCH_FUSED_ADD_RMS_NORM(8);
|
LAUNCH_FUSED_ADD_RMS_NORM(8);
|
||||||
} else {
|
} else {
|
||||||
|
@ -3,5 +3,6 @@
|
|||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
m.def("topk_softmax", &topk_softmax, "Apply topk softmax to the gating outputs.");
|
m.def("topk_softmax", &topk_softmax,
|
||||||
|
"Apply topk softmax to the gating outputs.");
|
||||||
}
|
}
|
||||||
|
@ -2,8 +2,6 @@
|
|||||||
|
|
||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
|
|
||||||
void topk_softmax(
|
void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices,
|
||||||
torch::Tensor& topk_weights,
|
torch::Tensor& token_expert_indices,
|
||||||
torch::Tensor& topk_indices,
|
torch::Tensor& gating_output);
|
||||||
torch::Tensor& token_expert_indices,
|
|
||||||
torch::Tensor& gating_output);
|
|
||||||
|
@ -7,119 +7,128 @@
|
|||||||
#include "cuda_compat.h"
|
#include "cuda_compat.h"
|
||||||
#include "dispatch_utils.h"
|
#include "dispatch_utils.h"
|
||||||
|
|
||||||
#define CEILDIV(x,y) (((x) + (y) - 1) / (y))
|
#define CEILDIV(x, y) (((x) + (y) - 1) / (y))
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, int32_t col) {
|
__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row,
|
||||||
// don't worry about overflow because num_experts is relatively small
|
int32_t col) {
|
||||||
return row * total_col + col;
|
// don't worry about overflow because num_experts is relatively small
|
||||||
}
|
return row * total_col + col;
|
||||||
}
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
__global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids,
|
__global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
|
||||||
int32_t *sorted_token_ids,
|
int32_t* sorted_token_ids,
|
||||||
int32_t *expert_ids,
|
int32_t* expert_ids,
|
||||||
int32_t *total_tokens_post_pad,
|
int32_t* total_tokens_post_pad,
|
||||||
int32_t num_experts,
|
int32_t num_experts,
|
||||||
int32_t block_size,
|
int32_t block_size, size_t numel) {
|
||||||
size_t numel) {
|
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
|
||||||
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
|
const size_t start_idx = threadIdx.x * tokens_per_thread;
|
||||||
const size_t start_idx = threadIdx.x * tokens_per_thread;
|
|
||||||
|
|
||||||
extern __shared__ int32_t shared_mem[];
|
extern __shared__ int32_t shared_mem[];
|
||||||
|
|
||||||
int32_t* tokens_cnts = shared_mem; // 2d tensor with shape (num_experts + 1, num_experts)
|
int32_t* tokens_cnts =
|
||||||
int32_t* cumsum = shared_mem + (num_experts + 1) * num_experts; // 1d tensor with shape (num_experts + 1)
|
shared_mem; // 2d tensor with shape (num_experts + 1, num_experts)
|
||||||
|
int32_t* cumsum =
|
||||||
|
shared_mem + (num_experts + 1) *
|
||||||
|
num_experts; // 1d tensor with shape (num_experts + 1)
|
||||||
|
|
||||||
for (int i = 0; i < num_experts; ++i) {
|
for (int i = 0; i < num_experts; ++i) {
|
||||||
tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
|
tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* In the first step we compute token_cnts[thread_index + 1][expert_index],
|
||||||
|
* which counts how many tokens in the token shard of thread_index are
|
||||||
|
* assigned to expert expert_index.
|
||||||
|
*/
|
||||||
|
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
|
||||||
|
++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])];
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// For each expert we accumulate the token counts from the different threads.
|
||||||
|
tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
|
||||||
|
for (int i = 1; i <= blockDim.x; ++i) {
|
||||||
|
tokens_cnts[index(num_experts, i, threadIdx.x)] +=
|
||||||
|
tokens_cnts[index(num_experts, i - 1, threadIdx.x)];
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// We accumulate the token counts of all experts in thread 0.
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
cumsum[0] = 0;
|
||||||
|
for (int i = 1; i <= num_experts; ++i) {
|
||||||
|
cumsum[i] = cumsum[i - 1] +
|
||||||
|
CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)],
|
||||||
|
block_size) *
|
||||||
|
block_size;
|
||||||
}
|
}
|
||||||
|
*total_tokens_post_pad = cumsum[num_experts];
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
__syncthreads();
|
||||||
* In the first step we compute token_cnts[thread_index + 1][expert_index],
|
|
||||||
* which counts how many tokens in the token shard of thread_index are assigned
|
|
||||||
* to expert expert_index.
|
|
||||||
*/
|
|
||||||
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
|
|
||||||
++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])];
|
|
||||||
}
|
|
||||||
|
|
||||||
__syncthreads();
|
/**
|
||||||
|
* For each expert, each thread processes the tokens of the corresponding
|
||||||
|
* blocks and stores the corresponding expert_id for each block.
|
||||||
|
*/
|
||||||
|
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
|
||||||
|
i += block_size) {
|
||||||
|
expert_ids[i / block_size] = threadIdx.x;
|
||||||
|
}
|
||||||
|
|
||||||
// For each expert we accumulate the token counts from the different threads.
|
/**
|
||||||
tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
|
* Each thread processes a token shard, calculating the index of each token
|
||||||
for (int i = 1; i <= blockDim.x; ++i) {
|
* after sorting by expert number. Given the example topk_ids =
|
||||||
tokens_cnts[index(num_experts, i, threadIdx.x)] += tokens_cnts[index(num_experts, i-1, threadIdx.x)];
|
* [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *,
|
||||||
}
|
* *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a
|
||||||
|
* padding value(preset in python).
|
||||||
__syncthreads();
|
*/
|
||||||
|
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
|
||||||
// We accumulate the token counts of all experts in thread 0.
|
int32_t expert_id = topk_ids[i];
|
||||||
if (threadIdx.x == 0) {
|
/** The cumsum[expert_id] stores the starting index of the tokens that the
|
||||||
cumsum[0] = 0;
|
* expert with expert_id needs to process, and
|
||||||
for (int i = 1; i <= num_experts; ++i) {
|
* tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens
|
||||||
cumsum[i] = cumsum[i-1] + CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], block_size) * block_size;
|
* processed by the expert with expert_id within the current thread's token
|
||||||
}
|
* shard.
|
||||||
*total_tokens_post_pad = cumsum[num_experts];
|
*/
|
||||||
}
|
int32_t rank_post_pad =
|
||||||
|
tokens_cnts[index(num_experts, threadIdx.x, expert_id)] +
|
||||||
__syncthreads();
|
cumsum[expert_id];
|
||||||
|
sorted_token_ids[rank_post_pad] = i;
|
||||||
/**
|
++tokens_cnts[index(num_experts, threadIdx.x, expert_id)];
|
||||||
* For each expert, each thread processes the tokens of the corresponding blocks
|
}
|
||||||
* and stores the corresponding expert_id for each block.
|
|
||||||
*/
|
|
||||||
for (int i = cumsum[threadIdx.x];i < cumsum[threadIdx.x + 1];i += block_size) {
|
|
||||||
expert_ids[i / block_size] = threadIdx.x;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Each thread processes a token shard, calculating the index of each token after
|
|
||||||
* sorting by expert number. Given the example topk_ids = [0,1,2,1,2,3,0,3,4] and
|
|
||||||
* block_size = 4, then the output would be [0, 6, *, *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *],
|
|
||||||
* where * represents a padding value(preset in python).
|
|
||||||
*/
|
|
||||||
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
|
|
||||||
int32_t expert_id = topk_ids[i];
|
|
||||||
/** The cumsum[expert_id] stores the starting index of the tokens that the
|
|
||||||
* expert with expert_id needs to process, and tokens_cnts[threadIdx.x][expert_id]
|
|
||||||
* stores the indices of the tokens processed by the expert with expert_id within
|
|
||||||
* the current thread's token shard.
|
|
||||||
*/
|
|
||||||
int32_t rank_post_pad = tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + cumsum[expert_id];
|
|
||||||
sorted_token_ids[rank_post_pad] = i;
|
|
||||||
++tokens_cnts[index(num_experts, threadIdx.x, expert_id)];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
} // namespace vllm
|
||||||
|
|
||||||
void moe_align_block_size(
|
void moe_align_block_size(torch::Tensor topk_ids, int num_experts,
|
||||||
torch::Tensor topk_ids,
|
int block_size, torch::Tensor sorted_token_ids,
|
||||||
int num_experts,
|
torch::Tensor experts_ids,
|
||||||
int block_size,
|
torch::Tensor num_tokens_post_pad) {
|
||||||
torch::Tensor sorted_token_ids,
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
torch::Tensor experts_ids,
|
VLLM_DISPATCH_INTEGRAL_TYPES(
|
||||||
torch::Tensor num_tokens_post_pad) {
|
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
|
||||||
VLLM_DISPATCH_INTEGRAL_TYPES(
|
// tensors
|
||||||
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
|
const int32_t shared_mem =
|
||||||
// calc needed amount of shared mem for `tokens_cnts` and `cumsum` tensors
|
((num_experts + 1) * num_experts + (num_experts + 1)) *
|
||||||
const int32_t shared_mem = ((num_experts + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t);
|
sizeof(int32_t);
|
||||||
|
|
||||||
// set dynamic shared mem
|
// set dynamic shared mem
|
||||||
auto kernel = vllm::moe_align_block_size_kernel<scalar_t>;
|
auto kernel = vllm::moe_align_block_size_kernel<scalar_t>;
|
||||||
AT_CUDA_CHECK(
|
AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
|
||||||
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize((void *)kernel, shared_mem));
|
(void*)kernel, shared_mem));
|
||||||
kernel<<<1, num_experts, shared_mem, stream>>>(
|
kernel<<<1, num_experts, shared_mem, stream>>>(
|
||||||
topk_ids.data_ptr<scalar_t>(),
|
topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
|
||||||
sorted_token_ids.data_ptr<int32_t>(),
|
|
||||||
experts_ids.data_ptr<int32_t>(),
|
experts_ids.data_ptr<int32_t>(),
|
||||||
num_tokens_post_pad.data_ptr<int32_t>(),
|
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
|
||||||
num_experts,
|
|
||||||
block_size,
|
|
||||||
topk_ids.numel());
|
topk_ids.numel());
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
284
csrc/ops.h
284
csrc/ops.h
@ -2,224 +2,136 @@
|
|||||||
|
|
||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
|
|
||||||
void paged_attention_v1(
|
void paged_attention_v1(torch::Tensor& out, torch::Tensor& query,
|
||||||
torch::Tensor& out,
|
torch::Tensor& key_cache, torch::Tensor& value_cache,
|
||||||
torch::Tensor& query,
|
int num_kv_heads, float scale,
|
||||||
torch::Tensor& key_cache,
|
torch::Tensor& block_tables, torch::Tensor& seq_lens,
|
||||||
torch::Tensor& value_cache,
|
int block_size, int max_seq_len,
|
||||||
int num_kv_heads,
|
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||||
float scale,
|
const std::string& kv_cache_dtype, float kv_scale);
|
||||||
torch::Tensor& block_tables,
|
|
||||||
torch::Tensor& seq_lens,
|
|
||||||
int block_size,
|
|
||||||
int max_seq_len,
|
|
||||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
|
||||||
const std::string& kv_cache_dtype,
|
|
||||||
float kv_scale);
|
|
||||||
|
|
||||||
void paged_attention_v2(
|
void paged_attention_v2(torch::Tensor& out, torch::Tensor& exp_sums,
|
||||||
torch::Tensor& out,
|
torch::Tensor& max_logits, torch::Tensor& tmp_out,
|
||||||
torch::Tensor& exp_sums,
|
torch::Tensor& query, torch::Tensor& key_cache,
|
||||||
torch::Tensor& max_logits,
|
torch::Tensor& value_cache, int num_kv_heads,
|
||||||
torch::Tensor& tmp_out,
|
float scale, torch::Tensor& block_tables,
|
||||||
torch::Tensor& query,
|
torch::Tensor& seq_lens, int block_size,
|
||||||
torch::Tensor& key_cache,
|
int max_seq_len,
|
||||||
torch::Tensor& value_cache,
|
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||||
int num_kv_heads,
|
const std::string& kv_cache_dtype, float kv_scale);
|
||||||
float scale,
|
|
||||||
torch::Tensor& block_tables,
|
|
||||||
torch::Tensor& seq_lens,
|
|
||||||
int block_size,
|
|
||||||
int max_seq_len,
|
|
||||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
|
||||||
const std::string& kv_cache_dtype,
|
|
||||||
float kv_scale);
|
|
||||||
|
|
||||||
void rms_norm(
|
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
|
||||||
torch::Tensor& out,
|
float epsilon);
|
||||||
torch::Tensor& input,
|
|
||||||
torch::Tensor& weight,
|
|
||||||
float epsilon);
|
|
||||||
|
|
||||||
void fused_add_rms_norm(
|
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
|
||||||
torch::Tensor& input,
|
torch::Tensor& weight, float epsilon);
|
||||||
torch::Tensor& residual,
|
|
||||||
torch::Tensor& weight,
|
|
||||||
float epsilon);
|
|
||||||
|
|
||||||
void rotary_embedding(
|
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
|
||||||
torch::Tensor& positions,
|
torch::Tensor& key, int head_size,
|
||||||
torch::Tensor& query,
|
torch::Tensor& cos_sin_cache, bool is_neox);
|
||||||
torch::Tensor& key,
|
|
||||||
int head_size,
|
|
||||||
torch::Tensor& cos_sin_cache,
|
|
||||||
bool is_neox);
|
|
||||||
|
|
||||||
void batched_rotary_embedding(
|
void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
|
||||||
torch::Tensor& positions,
|
torch::Tensor& key, int head_size,
|
||||||
torch::Tensor& query,
|
torch::Tensor& cos_sin_cache, bool is_neox,
|
||||||
torch::Tensor& key,
|
int rot_dim,
|
||||||
int head_size,
|
torch::Tensor& cos_sin_cache_offsets);
|
||||||
torch::Tensor& cos_sin_cache,
|
|
||||||
bool is_neox,
|
|
||||||
int rot_dim,
|
|
||||||
torch::Tensor& cos_sin_cache_offsets);
|
|
||||||
|
|
||||||
void silu_and_mul(
|
void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
|
||||||
torch::Tensor& out,
|
|
||||||
torch::Tensor& input);
|
|
||||||
|
|
||||||
void gelu_and_mul(
|
void gelu_and_mul(torch::Tensor& out, torch::Tensor& input);
|
||||||
torch::Tensor& out,
|
|
||||||
torch::Tensor& input);
|
|
||||||
|
|
||||||
void gelu_tanh_and_mul(
|
void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);
|
||||||
torch::Tensor& out,
|
|
||||||
torch::Tensor& input);
|
|
||||||
|
|
||||||
void gelu_new(
|
void gelu_new(torch::Tensor& out, torch::Tensor& input);
|
||||||
torch::Tensor& out,
|
|
||||||
torch::Tensor& input);
|
|
||||||
|
|
||||||
void gelu_fast(
|
void gelu_fast(torch::Tensor& out, torch::Tensor& input);
|
||||||
torch::Tensor& out,
|
|
||||||
torch::Tensor& input);
|
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
torch::Tensor aqlm_gemm(
|
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
|
||||||
const torch::Tensor& input,
|
const torch::Tensor& codebooks,
|
||||||
const torch::Tensor& codes,
|
const torch::Tensor& scales,
|
||||||
const torch::Tensor& codebooks,
|
const torch::Tensor& codebook_partition_sizes,
|
||||||
const torch::Tensor& scales,
|
const std::optional<torch::Tensor>& bias);
|
||||||
const torch::Tensor& codebook_partition_sizes,
|
|
||||||
const std::optional<torch::Tensor>& bias
|
|
||||||
);
|
|
||||||
|
|
||||||
torch::Tensor aqlm_dequant(
|
torch::Tensor aqlm_dequant(const torch::Tensor& codes,
|
||||||
const torch::Tensor& codes,
|
const torch::Tensor& codebooks,
|
||||||
const torch::Tensor& codebooks,
|
const torch::Tensor& codebook_partition_sizes);
|
||||||
const torch::Tensor& codebook_partition_sizes
|
|
||||||
);
|
|
||||||
|
|
||||||
torch::Tensor awq_gemm(
|
torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
|
||||||
torch::Tensor _in_feats,
|
torch::Tensor _scaling_factors, torch::Tensor _zeros,
|
||||||
torch::Tensor _kernel,
|
int split_k_iters);
|
||||||
torch::Tensor _scaling_factors,
|
|
||||||
torch::Tensor _zeros,
|
|
||||||
int split_k_iters);
|
|
||||||
|
|
||||||
torch::Tensor awq_dequantize(
|
torch::Tensor awq_dequantize(torch::Tensor _kernel,
|
||||||
torch::Tensor _kernel,
|
torch::Tensor _scaling_factors,
|
||||||
torch::Tensor _scaling_factors,
|
torch::Tensor _zeros, int split_k_iters, int thx,
|
||||||
torch::Tensor _zeros,
|
int thy);
|
||||||
int split_k_iters,
|
|
||||||
int thx,
|
|
||||||
int thy);
|
|
||||||
|
|
||||||
torch::Tensor marlin_gemm(
|
torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||||
torch::Tensor& a,
|
torch::Tensor& b_scales, torch::Tensor& workspace,
|
||||||
torch::Tensor& b_q_weight,
|
int64_t size_m, int64_t size_n, int64_t size_k);
|
||||||
torch::Tensor& b_scales,
|
|
||||||
torch::Tensor& workspace,
|
|
||||||
int64_t size_m,
|
|
||||||
int64_t size_n,
|
|
||||||
int64_t size_k);
|
|
||||||
|
|
||||||
torch::Tensor gptq_marlin_24_gemm(
|
torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||||
torch::Tensor &a,
|
torch::Tensor& b_meta,
|
||||||
torch::Tensor &b_q_weight,
|
torch::Tensor& b_scales,
|
||||||
torch::Tensor &b_meta,
|
torch::Tensor& workspace, int64_t num_bits,
|
||||||
torch::Tensor &b_scales,
|
int64_t size_m, int64_t size_n,
|
||||||
torch::Tensor &workspace,
|
int64_t size_k);
|
||||||
int64_t num_bits,
|
|
||||||
int64_t size_m,
|
|
||||||
int64_t size_n,
|
|
||||||
int64_t size_k);
|
|
||||||
|
|
||||||
torch::Tensor gptq_marlin_gemm(
|
torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||||
torch::Tensor &a,
|
torch::Tensor& b_scales, torch::Tensor& g_idx,
|
||||||
torch::Tensor &b_q_weight,
|
torch::Tensor& perm, torch::Tensor& workspace,
|
||||||
torch::Tensor &b_scales,
|
int64_t num_bits, int64_t size_m, int64_t size_n,
|
||||||
torch::Tensor &g_idx,
|
int64_t size_k, bool is_k_full);
|
||||||
torch::Tensor &perm,
|
|
||||||
torch::Tensor &workspace,
|
|
||||||
int64_t num_bits,
|
|
||||||
int64_t size_m,
|
|
||||||
int64_t size_n,
|
|
||||||
int64_t size_k,
|
|
||||||
bool is_k_full);
|
|
||||||
|
|
||||||
torch::Tensor gptq_marlin_repack(
|
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
||||||
torch::Tensor &b_q_weight,
|
int64_t size_k, int64_t size_n,
|
||||||
torch::Tensor &perm,
|
int64_t num_bits);
|
||||||
int64_t size_k,
|
|
||||||
int64_t size_n,
|
|
||||||
int64_t num_bits);
|
|
||||||
|
|
||||||
int cutlass_scaled_mm_dq(
|
int cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a,
|
||||||
torch::Tensor& out,
|
torch::Tensor const& b, torch::Tensor const& a_scales,
|
||||||
torch::Tensor const &a,
|
torch::Tensor const& b_scales);
|
||||||
torch::Tensor const &b,
|
|
||||||
torch::Tensor const &a_scales,
|
|
||||||
torch::Tensor const &b_scales);
|
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
void squeezellm_gemm(
|
void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
||||||
torch::Tensor vec,
|
torch::Tensor lookup_table);
|
||||||
torch::Tensor mat,
|
|
||||||
torch::Tensor mul,
|
|
||||||
torch::Tensor lookup_table);
|
|
||||||
|
|
||||||
torch::Tensor gptq_gemm(
|
torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
|
||||||
torch::Tensor a,
|
torch::Tensor b_gptq_qzeros,
|
||||||
torch::Tensor b_q_weight,
|
torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
|
||||||
torch::Tensor b_gptq_qzeros,
|
bool use_exllama, int bit);
|
||||||
torch::Tensor b_gptq_scales,
|
|
||||||
torch::Tensor b_g_idx,
|
|
||||||
bool use_exllama,
|
|
||||||
int bit);
|
|
||||||
|
|
||||||
void gptq_shuffle(
|
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int bit);
|
||||||
torch::Tensor q_weight,
|
|
||||||
torch::Tensor q_perm,
|
|
||||||
int bit);
|
|
||||||
|
|
||||||
void static_scaled_fp8_quant(
|
void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
|
||||||
torch::Tensor& out,
|
torch::Tensor& scale);
|
||||||
torch::Tensor& input,
|
|
||||||
torch::Tensor& scale);
|
|
||||||
|
|
||||||
void dynamic_scaled_fp8_quant(
|
void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
|
||||||
torch::Tensor& out,
|
torch::Tensor& scale);
|
||||||
torch::Tensor& input,
|
|
||||||
torch::Tensor& scale);
|
|
||||||
|
|
||||||
void moe_align_block_size(
|
void moe_align_block_size(torch::Tensor topk_ids, int num_experts,
|
||||||
torch::Tensor topk_ids,
|
int block_size, torch::Tensor sorted_token_ids,
|
||||||
int num_experts,
|
torch::Tensor experts_ids,
|
||||||
int block_size,
|
torch::Tensor num_tokens_post_pad);
|
||||||
torch::Tensor sorted_token_ids,
|
|
||||||
torch::Tensor experts_ids,
|
|
||||||
torch::Tensor num_tokens_post_pad);
|
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
using fptr_t = uint64_t;
|
using fptr_t = uint64_t;
|
||||||
fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
|
fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
|
||||||
const std::vector<std::string> &handles,
|
const std::vector<std::string>& handles,
|
||||||
const std::vector<int64_t> &offsets, int rank,
|
const std::vector<int64_t>& offsets, int rank,
|
||||||
bool full_nvlink);
|
|
||||||
bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size,
|
|
||||||
bool full_nvlink);
|
bool full_nvlink);
|
||||||
void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out);
|
bool should_custom_ar(torch::Tensor& inp, int max_size, int world_size,
|
||||||
void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer,
|
bool full_nvlink);
|
||||||
torch::Tensor &out);
|
void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
|
||||||
|
void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
|
||||||
|
torch::Tensor& out);
|
||||||
void dispose(fptr_t _fa);
|
void dispose(fptr_t _fa);
|
||||||
int meta_size();
|
int meta_size();
|
||||||
void register_buffer(fptr_t _fa, torch::Tensor &t,
|
void register_buffer(fptr_t _fa, torch::Tensor& t,
|
||||||
const std::vector<std::string> &handles,
|
const std::vector<std::string>& handles,
|
||||||
const std::vector<int64_t> &offsets);
|
const std::vector<int64_t>& offsets);
|
||||||
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa);
|
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(
|
||||||
void register_graph_buffers(fptr_t _fa, const std::vector<std::string> &handles,
|
fptr_t _fa);
|
||||||
const std::vector<std::vector<int64_t>> &offsets);
|
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
|
||||||
|
const std::vector<std::vector<int64_t>>& offsets);
|
||||||
#endif
|
#endif
|
||||||
|
@ -7,14 +7,10 @@
|
|||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
template<typename scalar_t, bool IS_NEOX>
|
template <typename scalar_t, bool IS_NEOX>
|
||||||
inline __device__ void apply_token_rotary_embedding(
|
inline __device__ void apply_token_rotary_embedding(
|
||||||
scalar_t* __restrict__ arr,
|
scalar_t* __restrict__ arr, const scalar_t* __restrict__ cos_ptr,
|
||||||
const scalar_t* __restrict__ cos_ptr,
|
const scalar_t* __restrict__ sin_ptr, int rot_offset, int embed_dim) {
|
||||||
const scalar_t* __restrict__ sin_ptr,
|
|
||||||
int rot_offset,
|
|
||||||
int embed_dim)
|
|
||||||
{
|
|
||||||
int x_index, y_index;
|
int x_index, y_index;
|
||||||
scalar_t cos, sin;
|
scalar_t cos, sin;
|
||||||
if (IS_NEOX) {
|
if (IS_NEOX) {
|
||||||
@ -37,19 +33,17 @@ inline __device__ void apply_token_rotary_embedding(
|
|||||||
arr[y_index] = y * cos + x * sin;
|
arr[y_index] = y * cos + x * sin;
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename scalar_t, bool IS_NEOX>
|
template <typename scalar_t, bool IS_NEOX>
|
||||||
inline __device__ void apply_rotary_embedding(
|
inline __device__ void apply_rotary_embedding(
|
||||||
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
|
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
|
||||||
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
|
// head_size] or [num_tokens, num_heads,
|
||||||
const scalar_t* cache_ptr,
|
// head_size]
|
||||||
const int head_size,
|
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
|
||||||
const int num_heads,
|
// head_size] or [num_tokens, num_kv_heads,
|
||||||
const int num_kv_heads,
|
// head_size]
|
||||||
const int rot_dim,
|
const scalar_t* cache_ptr, const int head_size, const int num_heads,
|
||||||
const int token_idx,
|
const int num_kv_heads, const int rot_dim, const int token_idx,
|
||||||
const int64_t query_stride,
|
const int64_t query_stride, const int64_t key_stride) {
|
||||||
const int64_t key_stride)
|
|
||||||
{
|
|
||||||
const int embed_dim = rot_dim / 2;
|
const int embed_dim = rot_dim / 2;
|
||||||
const scalar_t* cos_ptr = cache_ptr;
|
const scalar_t* cos_ptr = cache_ptr;
|
||||||
const scalar_t* sin_ptr = cache_ptr + embed_dim;
|
const scalar_t* sin_ptr = cache_ptr + embed_dim;
|
||||||
@ -59,8 +53,8 @@ inline __device__ void apply_rotary_embedding(
|
|||||||
const int head_idx = i / embed_dim;
|
const int head_idx = i / embed_dim;
|
||||||
const int64_t token_head = token_idx * query_stride + head_idx * head_size;
|
const int64_t token_head = token_idx * query_stride + head_idx * head_size;
|
||||||
const int rot_offset = i % embed_dim;
|
const int rot_offset = i % embed_dim;
|
||||||
apply_token_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr,
|
apply_token_rotary_embedding<scalar_t, IS_NEOX>(
|
||||||
sin_ptr, rot_offset, embed_dim);
|
query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
|
||||||
}
|
}
|
||||||
|
|
||||||
const int nk = num_kv_heads * embed_dim;
|
const int nk = num_kv_heads * embed_dim;
|
||||||
@ -68,62 +62,74 @@ inline __device__ void apply_rotary_embedding(
|
|||||||
const int head_idx = i / embed_dim;
|
const int head_idx = i / embed_dim;
|
||||||
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
|
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
|
||||||
const int rot_offset = i % embed_dim;
|
const int rot_offset = i % embed_dim;
|
||||||
apply_token_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr,
|
apply_token_rotary_embedding<scalar_t, IS_NEOX>(
|
||||||
sin_ptr, rot_offset, embed_dim);
|
key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename scalar_t, bool IS_NEOX>
|
template <typename scalar_t, bool IS_NEOX>
|
||||||
__global__ void rotary_embedding_kernel(
|
__global__ void rotary_embedding_kernel(
|
||||||
const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens]
|
const int64_t* __restrict__ positions, // [batch_size, seq_len] or
|
||||||
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
|
// [num_tokens]
|
||||||
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
|
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
|
||||||
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
|
// head_size] or [num_tokens, num_heads,
|
||||||
const int rot_dim,
|
// head_size]
|
||||||
const int64_t query_stride,
|
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
|
||||||
const int64_t key_stride,
|
// head_size] or [num_tokens, num_kv_heads,
|
||||||
const int num_heads,
|
// head_size]
|
||||||
const int num_kv_heads,
|
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
|
||||||
const int head_size) {
|
// 2]
|
||||||
|
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
|
||||||
|
const int num_heads, const int num_kv_heads, const int head_size) {
|
||||||
// Each thread block is responsible for one token.
|
// Each thread block is responsible for one token.
|
||||||
const int token_idx = blockIdx.x;
|
const int token_idx = blockIdx.x;
|
||||||
int64_t pos = positions[token_idx];
|
int64_t pos = positions[token_idx];
|
||||||
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
|
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
|
||||||
|
|
||||||
apply_rotary_embedding<scalar_t, IS_NEOX>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride);
|
apply_rotary_embedding<scalar_t, IS_NEOX>(
|
||||||
|
query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
|
||||||
|
token_idx, query_stride, key_stride);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename scalar_t, bool IS_NEOX>
|
template <typename scalar_t, bool IS_NEOX>
|
||||||
__global__ void batched_rotary_embedding_kernel(
|
__global__ void batched_rotary_embedding_kernel(
|
||||||
const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens]
|
const int64_t* __restrict__ positions, // [batch_size, seq_len] or
|
||||||
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
|
// [num_tokens]
|
||||||
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
|
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
|
||||||
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
|
// head_size] or [num_tokens, num_heads,
|
||||||
const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len] or [num_tokens]
|
// head_size]
|
||||||
const int rot_dim,
|
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
|
||||||
const int64_t query_stride,
|
// head_size] or [num_tokens, num_kv_heads,
|
||||||
const int64_t key_stride,
|
// head_size]
|
||||||
const int num_heads,
|
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
|
||||||
const int num_kv_heads,
|
// 2]
|
||||||
const int head_size) {
|
const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len]
|
||||||
|
// or [num_tokens]
|
||||||
|
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
|
||||||
|
const int num_heads, const int num_kv_heads, const int head_size) {
|
||||||
// Each thread block is responsible for one token.
|
// Each thread block is responsible for one token.
|
||||||
const int token_idx = blockIdx.x;
|
const int token_idx = blockIdx.x;
|
||||||
int64_t pos = positions[token_idx];
|
int64_t pos = positions[token_idx];
|
||||||
int64_t cos_sin_cache_offset = cos_sin_cache_offsets[token_idx];
|
int64_t cos_sin_cache_offset = cos_sin_cache_offsets[token_idx];
|
||||||
const scalar_t* cache_ptr = cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim;
|
const scalar_t* cache_ptr =
|
||||||
|
cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim;
|
||||||
|
|
||||||
apply_rotary_embedding<scalar_t, IS_NEOX>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride);
|
apply_rotary_embedding<scalar_t, IS_NEOX>(
|
||||||
|
query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
|
||||||
|
token_idx, query_stride, key_stride);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
void rotary_embedding(
|
void rotary_embedding(
|
||||||
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
|
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
|
||||||
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size]
|
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or
|
||||||
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size]
|
// [num_tokens, num_heads * head_size]
|
||||||
int head_size,
|
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
|
||||||
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
// [num_tokens, num_kv_heads * head_size]
|
||||||
bool is_neox) {
|
int head_size,
|
||||||
|
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
||||||
|
bool is_neox) {
|
||||||
int64_t num_tokens = query.numel() / query.size(-1);
|
int64_t num_tokens = query.numel() / query.size(-1);
|
||||||
int rot_dim = cos_sin_cache.size(1);
|
int rot_dim = cos_sin_cache.size(1);
|
||||||
int num_heads = query.size(-1) / head_size;
|
int num_heads = query.size(-1) / head_size;
|
||||||
@ -135,36 +141,21 @@ void rotary_embedding(
|
|||||||
dim3 block(std::min(num_heads * rot_dim / 2, 512));
|
dim3 block(std::min(num_heads * rot_dim / 2, 512));
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(
|
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] {
|
||||||
query.scalar_type(),
|
if (is_neox) {
|
||||||
"rotary_embedding",
|
vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
|
||||||
[&] {
|
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
|
||||||
if (is_neox) {
|
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(), rot_dim,
|
||||||
vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
|
query_stride, key_stride, num_heads, num_kv_heads, head_size);
|
||||||
positions.data_ptr<int64_t>(),
|
} else {
|
||||||
query.data_ptr<scalar_t>(),
|
vllm::rotary_embedding_kernel<scalar_t, false>
|
||||||
key.data_ptr<scalar_t>(),
|
<<<grid, block, 0, stream>>>(
|
||||||
cos_sin_cache.data_ptr<scalar_t>(),
|
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
|
||||||
rot_dim,
|
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
|
||||||
query_stride,
|
rot_dim, query_stride, key_stride, num_heads, num_kv_heads,
|
||||||
key_stride,
|
head_size);
|
||||||
num_heads,
|
}
|
||||||
num_kv_heads,
|
});
|
||||||
head_size);
|
|
||||||
} else {
|
|
||||||
vllm::rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>(
|
|
||||||
positions.data_ptr<int64_t>(),
|
|
||||||
query.data_ptr<scalar_t>(),
|
|
||||||
key.data_ptr<scalar_t>(),
|
|
||||||
cos_sin_cache.data_ptr<scalar_t>(),
|
|
||||||
rot_dim,
|
|
||||||
query_stride,
|
|
||||||
key_stride,
|
|
||||||
num_heads,
|
|
||||||
num_kv_heads,
|
|
||||||
head_size);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@ -172,14 +163,15 @@ Batched version of rotary embedding, pack multiple LoRAs together
|
|||||||
and process in batched manner.
|
and process in batched manner.
|
||||||
*/
|
*/
|
||||||
void batched_rotary_embedding(
|
void batched_rotary_embedding(
|
||||||
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
|
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
|
||||||
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size]
|
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or
|
||||||
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size]
|
// [num_tokens, num_heads * head_size]
|
||||||
int head_size,
|
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
|
||||||
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
// [num_tokens, num_kv_heads * head_size]
|
||||||
bool is_neox,
|
int head_size,
|
||||||
int rot_dim,
|
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
||||||
torch::Tensor& cos_sin_cache_offsets // [num_tokens]
|
bool is_neox, int rot_dim,
|
||||||
|
torch::Tensor& cos_sin_cache_offsets // [num_tokens]
|
||||||
) {
|
) {
|
||||||
int64_t num_tokens = cos_sin_cache_offsets.size(0);
|
int64_t num_tokens = cos_sin_cache_offsets.size(0);
|
||||||
int num_heads = query.size(-1) / head_size;
|
int num_heads = query.size(-1) / head_size;
|
||||||
@ -191,36 +183,21 @@ void batched_rotary_embedding(
|
|||||||
dim3 block(std::min(num_heads * rot_dim / 2, 512));
|
dim3 block(std::min(num_heads * rot_dim / 2, 512));
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(
|
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] {
|
||||||
query.scalar_type(),
|
if (is_neox) {
|
||||||
"rotary_embedding",
|
vllm::batched_rotary_embedding_kernel<scalar_t, true>
|
||||||
[&] {
|
<<<grid, block, 0, stream>>>(
|
||||||
if (is_neox) {
|
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
|
||||||
vllm::batched_rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
|
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
|
||||||
positions.data_ptr<int64_t>(),
|
cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
|
||||||
query.data_ptr<scalar_t>(),
|
key_stride, num_heads, num_kv_heads, head_size);
|
||||||
key.data_ptr<scalar_t>(),
|
} else {
|
||||||
cos_sin_cache.data_ptr<scalar_t>(),
|
vllm::batched_rotary_embedding_kernel<scalar_t, false>
|
||||||
cos_sin_cache_offsets.data_ptr<int64_t>(),
|
<<<grid, block, 0, stream>>>(
|
||||||
rot_dim,
|
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
|
||||||
query_stride,
|
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
|
||||||
key_stride,
|
cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
|
||||||
num_heads,
|
key_stride, num_heads, num_kv_heads, head_size);
|
||||||
num_kv_heads,
|
}
|
||||||
head_size);
|
});
|
||||||
} else {
|
|
||||||
vllm::batched_rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>(
|
|
||||||
positions.data_ptr<int64_t>(),
|
|
||||||
query.data_ptr<scalar_t>(),
|
|
||||||
key.data_ptr<scalar_t>(),
|
|
||||||
cos_sin_cache.data_ptr<scalar_t>(),
|
|
||||||
cos_sin_cache_offsets.data_ptr<int64_t>(),
|
|
||||||
rot_dim,
|
|
||||||
query_stride,
|
|
||||||
key_stride,
|
|
||||||
num_heads,
|
|
||||||
num_kv_heads,
|
|
||||||
head_size);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
140
csrc/pybind.cpp
140
csrc/pybind.cpp
@ -8,116 +8,87 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||||||
pybind11::module ops = m.def_submodule("ops", "vLLM custom operators");
|
pybind11::module ops = m.def_submodule("ops", "vLLM custom operators");
|
||||||
|
|
||||||
// Attention ops
|
// Attention ops
|
||||||
ops.def(
|
ops.def("paged_attention_v1", &paged_attention_v1,
|
||||||
"paged_attention_v1",
|
"Compute the attention between an input query and the cached "
|
||||||
&paged_attention_v1,
|
"keys/values using PagedAttention.");
|
||||||
"Compute the attention between an input query and the cached keys/values using PagedAttention.");
|
ops.def("paged_attention_v2", &paged_attention_v2, "PagedAttention V2.");
|
||||||
ops.def(
|
|
||||||
"paged_attention_v2",
|
|
||||||
&paged_attention_v2,
|
|
||||||
"PagedAttention V2.");
|
|
||||||
|
|
||||||
// Activation ops
|
// Activation ops
|
||||||
ops.def(
|
ops.def("silu_and_mul", &silu_and_mul, "Activation function used in SwiGLU.");
|
||||||
"silu_and_mul",
|
ops.def("gelu_and_mul", &gelu_and_mul,
|
||||||
&silu_and_mul,
|
"Activation function used in GeGLU with `none` approximation.");
|
||||||
"Activation function used in SwiGLU.");
|
ops.def("gelu_tanh_and_mul", &gelu_tanh_and_mul,
|
||||||
ops.def(
|
"Activation function used in GeGLU with `tanh` approximation.");
|
||||||
"gelu_and_mul",
|
ops.def("gelu_new", &gelu_new, "GELU implementation used in GPT-2.");
|
||||||
&gelu_and_mul,
|
ops.def("gelu_fast", &gelu_fast, "Approximate GELU implementation.");
|
||||||
"Activation function used in GeGLU with `none` approximation.");
|
|
||||||
ops.def(
|
|
||||||
"gelu_tanh_and_mul",
|
|
||||||
&gelu_tanh_and_mul,
|
|
||||||
"Activation function used in GeGLU with `tanh` approximation.");
|
|
||||||
ops.def(
|
|
||||||
"gelu_new",
|
|
||||||
&gelu_new,
|
|
||||||
"GELU implementation used in GPT-2.");
|
|
||||||
ops.def(
|
|
||||||
"gelu_fast",
|
|
||||||
&gelu_fast,
|
|
||||||
"Approximate GELU implementation.");
|
|
||||||
|
|
||||||
// Layernorm
|
// Layernorm
|
||||||
ops.def(
|
ops.def("rms_norm", &rms_norm,
|
||||||
"rms_norm",
|
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
|
||||||
&rms_norm,
|
|
||||||
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
|
|
||||||
|
|
||||||
ops.def(
|
ops.def("fused_add_rms_norm", &fused_add_rms_norm,
|
||||||
"fused_add_rms_norm",
|
"In-place fused Add and RMS Normalization");
|
||||||
&fused_add_rms_norm,
|
|
||||||
"In-place fused Add and RMS Normalization");
|
|
||||||
|
|
||||||
// Rotary embedding
|
// Rotary embedding
|
||||||
ops.def(
|
ops.def("rotary_embedding", &rotary_embedding,
|
||||||
"rotary_embedding",
|
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
|
||||||
&rotary_embedding,
|
|
||||||
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
|
|
||||||
|
|
||||||
ops.def(
|
ops.def("batched_rotary_embedding", &batched_rotary_embedding,
|
||||||
"batched_rotary_embedding",
|
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key "
|
||||||
&batched_rotary_embedding,
|
"(supports multiple loras)");
|
||||||
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key (supports multiple loras)");
|
|
||||||
|
|
||||||
// Quantization ops
|
// Quantization ops
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
ops.def("aqlm_gemm", &aqlm_gemm, "Quantized GEMM for AQLM");
|
ops.def("aqlm_gemm", &aqlm_gemm, "Quantized GEMM for AQLM");
|
||||||
ops.def("aqlm_dequant", &aqlm_dequant, "Decompression method for AQLM");
|
ops.def("aqlm_dequant", &aqlm_dequant, "Decompression method for AQLM");
|
||||||
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
|
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
|
||||||
ops.def("marlin_gemm", &marlin_gemm, "Marlin (Dense) Optimized Quantized GEMM for GPTQ");
|
ops.def("marlin_gemm", &marlin_gemm,
|
||||||
ops.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm, "Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ");
|
"Marlin (Dense) Optimized Quantized GEMM for GPTQ");
|
||||||
ops.def("gptq_marlin_gemm", &gptq_marlin_gemm, "gptq_marlin Optimized Quantized GEMM for GPTQ");
|
ops.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm,
|
||||||
ops.def("gptq_marlin_repack", &gptq_marlin_repack, "gptq_marlin repack from GPTQ");
|
"Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ");
|
||||||
|
ops.def("gptq_marlin_gemm", &gptq_marlin_gemm,
|
||||||
|
"gptq_marlin Optimized Quantized GEMM for GPTQ");
|
||||||
|
ops.def("gptq_marlin_repack", &gptq_marlin_repack,
|
||||||
|
"gptq_marlin repack from GPTQ");
|
||||||
ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
|
ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
|
||||||
ops.def("cutlass_scaled_mm_dq", &cutlass_scaled_mm_dq, "CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column quantization.");
|
ops.def("cutlass_scaled_mm_dq", &cutlass_scaled_mm_dq,
|
||||||
|
"CUTLASS w8a8 GEMM, supporting symmetric per-tensor or "
|
||||||
|
"per-row/column quantization.");
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
|
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
|
||||||
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
|
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
|
||||||
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
|
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
|
||||||
ops.def("static_scaled_fp8_quant", &static_scaled_fp8_quant, "Compute FP8 quantized tensor for given scaling factor");
|
ops.def("static_scaled_fp8_quant", &static_scaled_fp8_quant,
|
||||||
ops.def("dynamic_scaled_fp8_quant", &dynamic_scaled_fp8_quant, "Compute FP8 quantized tensor and scaling factor");
|
"Compute FP8 quantized tensor for given scaling factor");
|
||||||
ops.def(
|
ops.def("dynamic_scaled_fp8_quant", &dynamic_scaled_fp8_quant,
|
||||||
"moe_align_block_size",
|
"Compute FP8 quantized tensor and scaling factor");
|
||||||
&moe_align_block_size,
|
ops.def("moe_align_block_size", &moe_align_block_size,
|
||||||
"Aligning the number of tokens to be processed by each expert such that it is divisible by the block size.");
|
"Aligning the number of tokens to be processed by each expert such "
|
||||||
|
"that it is divisible by the block size.");
|
||||||
|
|
||||||
// Cache ops
|
// Cache ops
|
||||||
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
|
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
|
||||||
cache_ops.def(
|
cache_ops.def("swap_blocks", &swap_blocks,
|
||||||
"swap_blocks",
|
"Swap in (out) the cache blocks from src to dst");
|
||||||
&swap_blocks,
|
cache_ops.def("copy_blocks", ©_blocks,
|
||||||
"Swap in (out) the cache blocks from src to dst");
|
"Copy the cache blocks from src to dst");
|
||||||
cache_ops.def(
|
cache_ops.def("reshape_and_cache", &reshape_and_cache,
|
||||||
"copy_blocks",
|
"Reshape the key and value tensors and cache them");
|
||||||
©_blocks,
|
cache_ops.def("reshape_and_cache_flash", &reshape_and_cache_flash,
|
||||||
"Copy the cache blocks from src to dst");
|
"Reshape the key and value tensors and cache them");
|
||||||
cache_ops.def(
|
cache_ops.def("convert_fp8", &convert_fp8,
|
||||||
"reshape_and_cache",
|
"Convert the key and value cache to fp8 data type");
|
||||||
&reshape_and_cache,
|
|
||||||
"Reshape the key and value tensors and cache them");
|
|
||||||
cache_ops.def(
|
|
||||||
"reshape_and_cache_flash",
|
|
||||||
&reshape_and_cache_flash,
|
|
||||||
"Reshape the key and value tensors and cache them");
|
|
||||||
cache_ops.def(
|
|
||||||
"convert_fp8",
|
|
||||||
&convert_fp8,
|
|
||||||
"Convert the key and value cache to fp8 data type");
|
|
||||||
|
|
||||||
// Cuda utils
|
// Cuda utils
|
||||||
pybind11::module cuda_utils = m.def_submodule("cuda_utils", "vLLM cuda utils");
|
pybind11::module cuda_utils =
|
||||||
cuda_utils.def(
|
m.def_submodule("cuda_utils", "vLLM cuda utils");
|
||||||
"get_device_attribute",
|
cuda_utils.def("get_device_attribute", &get_device_attribute,
|
||||||
&get_device_attribute,
|
"Gets the specified device attribute.");
|
||||||
"Gets the specified device attribute.");
|
|
||||||
|
|
||||||
cuda_utils.def(
|
cuda_utils.def("get_max_shared_memory_per_block_device_attribute",
|
||||||
"get_max_shared_memory_per_block_device_attribute",
|
&get_max_shared_memory_per_block_device_attribute,
|
||||||
&get_max_shared_memory_per_block_device_attribute,
|
"Gets the maximum shared memory per block device attribute.");
|
||||||
"Gets the maximum shared memory per block device attribute.");
|
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
// Custom all-reduce kernels
|
// Custom all-reduce kernels
|
||||||
@ -134,5 +105,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||||||
custom_ar.def("register_graph_buffers", ®ister_graph_buffers,
|
custom_ar.def("register_graph_buffers", ®ister_graph_buffers,
|
||||||
"register_graph_buffers");
|
"register_graph_buffers");
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -25,32 +25,28 @@
|
|||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
|
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
namespace aqlm {
|
namespace aqlm {
|
||||||
|
|
||||||
__global__ void Code1x16MatVec(
|
__global__ void Code1x16MatVec(
|
||||||
const int4* __restrict__ A,
|
const int4* __restrict__ A, const int4* __restrict__ B,
|
||||||
const int4* __restrict__ B,
|
int4* __restrict__ C, const int4* __restrict__ codebook, const int prob_m,
|
||||||
int4* __restrict__ C,
|
const int prob_k,
|
||||||
const int4* __restrict__ codebook,
|
const int4 codebook_a_sizes, // cumulative sizes of A spanning each
|
||||||
const int prob_m,
|
// codebook, at most 3 long.
|
||||||
const int prob_k,
|
const int codebook_stride // as int4.
|
||||||
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long.
|
|
||||||
const int codebook_stride // as int4.
|
|
||||||
) {
|
) {
|
||||||
int a_gl_stride = prob_k / 8 / 8;
|
int a_gl_stride = prob_k / 8 / 8;
|
||||||
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
|
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
|
||||||
bool pred = a_gl_rd < prob_m;
|
bool pred = a_gl_rd < prob_m;
|
||||||
|
|
||||||
if (pred)
|
if (pred) {
|
||||||
{
|
// advance to the correct codebook, this easy because we only multiply one
|
||||||
// advance to the correct codebook, this easy because we only multiply one column of the codebook.
|
// column of the codebook.
|
||||||
auto codebook_size = &codebook_a_sizes.x;
|
auto codebook_size = &codebook_a_sizes.x;
|
||||||
while (a_gl_rd >= *codebook_size)
|
while (a_gl_rd >= *codebook_size) {
|
||||||
{
|
codebook += codebook_stride;
|
||||||
codebook += codebook_stride;
|
++codebook_size;
|
||||||
++codebook_size;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -67,8 +63,7 @@ __global__ void Code1x16MatVec(
|
|||||||
// We pad shared memory to avoid bank conflicts during reads
|
// We pad shared memory to avoid bank conflicts during reads
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) {
|
for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) {
|
||||||
if (b_gl_rd + i < prob_k / 8)
|
if (b_gl_rd + i < prob_k / 8) sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i];
|
||||||
sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i];
|
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
b_gl_rd += 32 * 8;
|
b_gl_rd += 32 * 8;
|
||||||
@ -76,22 +71,19 @@ __global__ void Code1x16MatVec(
|
|||||||
int b_sh_rd = 9 * (threadIdx.x % 32);
|
int b_sh_rd = 9 * (threadIdx.x % 32);
|
||||||
if (pred && a_gl_rd < a_gl_end) {
|
if (pred && a_gl_rd < a_gl_end) {
|
||||||
const uint16_t* enc = reinterpret_cast<const uint16_t*>(&A[a_gl_rd]);
|
const uint16_t* enc = reinterpret_cast<const uint16_t*>(&A[a_gl_rd]);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < 8; i++) {
|
for (int i = 0; i < 8; i++) {
|
||||||
uint32_t dec[4];
|
uint32_t dec[4];
|
||||||
// We bypass the L1 cache to avoid massive amounts of memory streaming that doesn't
|
// We bypass the L1 cache to avoid massive amounts of memory streaming
|
||||||
// actually help us; this brings > 2x speedup.
|
// that doesn't actually help us; this brings > 2x speedup.
|
||||||
asm volatile (
|
asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
|
||||||
"ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
|
: "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3])
|
||||||
: "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3])
|
: "l"((void*)&codebook[enc[i]]));
|
||||||
: "l"((void*) &codebook[enc[i]])
|
|
||||||
);
|
|
||||||
half2* a = reinterpret_cast<half2*>(&dec);
|
half2* a = reinterpret_cast<half2*>(&dec);
|
||||||
half2* b = reinterpret_cast<half2*>(&sh_b[b_sh_rd]);
|
half2* b = reinterpret_cast<half2*>(&sh_b[b_sh_rd]);
|
||||||
half2 res2 = {};
|
half2 res2 = {};
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < 4; j++)
|
for (int j = 0; j < 4; j++) res2 = __hfma2(a[j], b[j], res2);
|
||||||
res2 = __hfma2(a[j], b[j], res2);
|
|
||||||
res += __half2float(res2.x) + __half2float(res2.y);
|
res += __half2float(res2.x) + __half2float(res2.y);
|
||||||
b_sh_rd++;
|
b_sh_rd++;
|
||||||
}
|
}
|
||||||
@ -100,37 +92,33 @@ __global__ void Code1x16MatVec(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (pred) {
|
if (pred) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 16; i > 0; i /= 2)
|
for (int i = 16; i > 0; i /= 2) res += __shfl_down_sync(0xffffffff, res, i);
|
||||||
res += __shfl_down_sync(0xffffffff, res, i);
|
|
||||||
if (threadIdx.x % 32 == 0)
|
if (threadIdx.x % 32 == 0)
|
||||||
reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res);
|
reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
__global__ void Code2x8MatVec(
|
__global__ void Code2x8MatVec(
|
||||||
const int4* __restrict__ A,
|
const int4* __restrict__ A, const int4* __restrict__ B,
|
||||||
const int4* __restrict__ B,
|
int4* __restrict__ C, const int4* __restrict__ codebook, int prob_m,
|
||||||
int4* __restrict__ C,
|
int prob_k,
|
||||||
const int4* __restrict__ codebook,
|
const int4 codebook_a_sizes, // cumulative sizes of A spanning each
|
||||||
int prob_m,
|
// codebook, at most 3 long.
|
||||||
int prob_k,
|
const int codebook_stride // as int4.
|
||||||
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long.
|
|
||||||
const int codebook_stride // as int4.
|
|
||||||
|
|
||||||
) {
|
) {
|
||||||
int a_gl_stride = prob_k / 8 / 8;
|
int a_gl_stride = prob_k / 8 / 8;
|
||||||
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
|
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
|
||||||
bool pred = a_gl_rd < prob_m;
|
bool pred = a_gl_rd < prob_m;
|
||||||
|
|
||||||
if (pred)
|
if (pred) {
|
||||||
{
|
// advance to the correct codebook, this easy because we only multiply one
|
||||||
// advance to the correct codebook, this easy because we only multiply one column of the codebook.
|
// column of the codebook.
|
||||||
auto codebook_size = &codebook_a_sizes.x;
|
auto codebook_size = &codebook_a_sizes.x;
|
||||||
while (a_gl_rd >= *codebook_size)
|
while (a_gl_rd >= *codebook_size) {
|
||||||
{
|
codebook += codebook_stride;
|
||||||
codebook += codebook_stride;
|
++codebook_size;
|
||||||
++codebook_size;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -148,9 +136,8 @@ __global__ void Code2x8MatVec(
|
|||||||
|
|
||||||
for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) {
|
for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) {
|
||||||
int4 dec = codebook[i];
|
int4 dec = codebook[i];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < 8; j++)
|
for (int j = 0; j < 8; j++) sh_code[8 * i + (j + lane) % 8] = dec;
|
||||||
sh_code[8 * i + (j + lane) % 8] = dec;
|
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
@ -161,8 +148,7 @@ __global__ void Code2x8MatVec(
|
|||||||
// We pad shared memory to avoid bank conflicts during reads
|
// We pad shared memory to avoid bank conflicts during reads
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) {
|
for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) {
|
||||||
if (b_gl_rd + i < prob_k / 8)
|
if (b_gl_rd + i < prob_k / 8) sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i];
|
||||||
sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i];
|
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
b_gl_rd += 32 * 8;
|
b_gl_rd += 32 * 8;
|
||||||
@ -170,13 +156,15 @@ __global__ void Code2x8MatVec(
|
|||||||
int b_sh_rd = 9 * (threadIdx.x % 32);
|
int b_sh_rd = 9 * (threadIdx.x % 32);
|
||||||
if (pred && a_gl_rd < a_gl_end) {
|
if (pred && a_gl_rd < a_gl_end) {
|
||||||
const uint8_t* enc = reinterpret_cast<const uint8_t*>(&A[a_gl_rd]);
|
const uint8_t* enc = reinterpret_cast<const uint8_t*>(&A[a_gl_rd]);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < 8; i++) {
|
for (int i = 0; i < 8; i++) {
|
||||||
half2* a0 = reinterpret_cast<half2*>(&sh_code0[8 * enc[2 * i + 0] + lane]);
|
half2* a0 =
|
||||||
half2* a1 = reinterpret_cast<half2*>(&sh_code1[8 * enc[2 * i + 1] + lane]);
|
reinterpret_cast<half2*>(&sh_code0[8 * enc[2 * i + 0] + lane]);
|
||||||
half2* b = reinterpret_cast<half2*>(&sh_b[b_sh_rd]);
|
half2* a1 =
|
||||||
|
reinterpret_cast<half2*>(&sh_code1[8 * enc[2 * i + 1] + lane]);
|
||||||
|
half2* b = reinterpret_cast<half2*>(&sh_b[b_sh_rd]);
|
||||||
half2 res2 = {};
|
half2 res2 = {};
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < 4; j++)
|
for (int j = 0; j < 4; j++)
|
||||||
res2 = __hfma2(__hadd2(a0[j], a1[j]), b[j], res2);
|
res2 = __hfma2(__hadd2(a0[j], a1[j]), b[j], res2);
|
||||||
res += __half2float(res2.x) + __half2float(res2.y);
|
res += __half2float(res2.x) + __half2float(res2.y);
|
||||||
@ -187,36 +175,31 @@ __global__ void Code2x8MatVec(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (pred) {
|
if (pred) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 16; i > 0; i /= 2)
|
for (int i = 16; i > 0; i /= 2) res += __shfl_down_sync(0xffffffff, res, i);
|
||||||
res += __shfl_down_sync(0xffffffff, res, i);
|
|
||||||
if (threadIdx.x % 32 == 0)
|
if (threadIdx.x % 32 == 0)
|
||||||
reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res);
|
reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
__global__ void Code1x16Dequant(
|
__global__ void Code1x16Dequant(
|
||||||
const int4* __restrict__ A,
|
const int4* __restrict__ A, int4* __restrict__ C,
|
||||||
int4* __restrict__ C,
|
const int4* __restrict__ codebook, int prob_m, int prob_k,
|
||||||
const int4* __restrict__ codebook,
|
const int4 codebook_a_sizes, // cumulative sizes of A spanning each
|
||||||
int prob_m,
|
// codebook, at most 3 long, sums to m.
|
||||||
int prob_k,
|
const int codebook_stride // as int4
|
||||||
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long, sums to m.
|
|
||||||
const int codebook_stride // as int4
|
|
||||||
) {
|
) {
|
||||||
int a_gl_stride = prob_k / 8 / 8;
|
int a_gl_stride = prob_k / 8 / 8;
|
||||||
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
|
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
|
||||||
bool pred = a_gl_rd < prob_m;
|
bool pred = a_gl_rd < prob_m;
|
||||||
|
|
||||||
if (pred)
|
if (pred) {
|
||||||
{
|
// advance to the correct codebook, this easy because we only multiply one
|
||||||
// advance to the correct codebook, this easy because we only multiply one column of the codebook.
|
// column of the codebook.
|
||||||
auto codebook_size = &codebook_a_sizes.x;
|
auto codebook_size = &codebook_a_sizes.x;
|
||||||
while (a_gl_rd >= *codebook_size)
|
while (a_gl_rd >= *codebook_size) {
|
||||||
{
|
codebook += codebook_stride;
|
||||||
codebook += codebook_stride;
|
++codebook_size;
|
||||||
++codebook_size;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -231,17 +214,15 @@ __global__ void Code1x16Dequant(
|
|||||||
while (iters--) {
|
while (iters--) {
|
||||||
if (pred && a_gl_rd < a_gl_end) {
|
if (pred && a_gl_rd < a_gl_end) {
|
||||||
const uint16_t* enc = reinterpret_cast<const uint16_t*>(&A[a_gl_rd]);
|
const uint16_t* enc = reinterpret_cast<const uint16_t*>(&A[a_gl_rd]);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < 8; i++) {
|
for (int i = 0; i < 8; i++) {
|
||||||
int4 chunk;
|
int4 chunk;
|
||||||
auto dec = reinterpret_cast<uint32_t*>(&chunk);
|
auto dec = reinterpret_cast<uint32_t*>(&chunk);
|
||||||
// We bypass the L1 cache to avoid massive amounts of memory streaming that doesn't
|
// We bypass the L1 cache to avoid massive amounts of memory streaming
|
||||||
// actually help us; this brings > 2x speedup.
|
// that doesn't actually help us; this brings > 2x speedup.
|
||||||
asm volatile (
|
asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
|
||||||
"ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
|
: "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3])
|
||||||
: "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3])
|
: "l"((void*)&codebook[enc[i]]));
|
||||||
: "l"((void*) &codebook[enc[i]])
|
|
||||||
);
|
|
||||||
|
|
||||||
C[a_gl_rd * 8 + i] = chunk;
|
C[a_gl_rd * 8 + i] = chunk;
|
||||||
}
|
}
|
||||||
@ -250,28 +231,25 @@ __global__ void Code1x16Dequant(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
__global__ void Code2x8Dequant(
|
__global__ void Code2x8Dequant(
|
||||||
const int4* __restrict__ A,
|
const int4* __restrict__ A, int4* __restrict__ C,
|
||||||
int4* __restrict__ C,
|
const int4* __restrict__ codebook, int prob_m, int prob_k,
|
||||||
const int4* __restrict__ codebook,
|
const int4
|
||||||
int prob_m,
|
codebook_a_sizes, // cumulative sizes of A spanning each codebook, at
|
||||||
int prob_k,
|
// most 3 long, corresponds to cols.
|
||||||
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long, corresponds to cols.
|
const int codebook_stride // as int4
|
||||||
const int codebook_stride // as int4
|
|
||||||
) {
|
) {
|
||||||
int a_gl_stride = prob_k / 8 / 8;
|
int a_gl_stride = prob_k / 8 / 8;
|
||||||
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
|
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
|
||||||
bool pred = a_gl_rd < prob_m;
|
bool pred = a_gl_rd < prob_m;
|
||||||
|
|
||||||
if (pred)
|
if (pred) {
|
||||||
{
|
// advance to the correct codebook, this easy because we only multiply one
|
||||||
// advance to the correct codebook, this easy because we only multiply one column of the codebook.
|
// column of the codebook.
|
||||||
auto codebook_size = &codebook_a_sizes.x;
|
auto codebook_size = &codebook_a_sizes.x;
|
||||||
while (a_gl_rd >= *codebook_size)
|
while (a_gl_rd >= *codebook_size) {
|
||||||
{
|
codebook += codebook_stride;
|
||||||
codebook += codebook_stride;
|
++codebook_size;
|
||||||
++codebook_size;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -290,9 +268,8 @@ __global__ void Code2x8Dequant(
|
|||||||
|
|
||||||
for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) {
|
for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) {
|
||||||
int4 dec = codebook[i];
|
int4 dec = codebook[i];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < 8; j++)
|
for (int j = 0; j < 8; j++) sh_code[8 * i + (j + lane) % 8] = dec;
|
||||||
sh_code[8 * i + (j + lane) % 8] = dec;
|
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
@ -302,12 +279,14 @@ __global__ void Code2x8Dequant(
|
|||||||
while (iters--) {
|
while (iters--) {
|
||||||
if (pred && a_gl_rd < a_gl_end) {
|
if (pred && a_gl_rd < a_gl_end) {
|
||||||
const uint8_t* enc = reinterpret_cast<const uint8_t*>(&A[a_gl_rd]);
|
const uint8_t* enc = reinterpret_cast<const uint8_t*>(&A[a_gl_rd]);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < 8; i++) {
|
for (int i = 0; i < 8; i++) {
|
||||||
int4 chunk;
|
int4 chunk;
|
||||||
half2* a0 = reinterpret_cast<half2*>(&sh_code0[8 * enc[2 * i + 0] + lane]);
|
half2* a0 =
|
||||||
half2* a1 = reinterpret_cast<half2*>(&sh_code1[8 * enc[2 * i + 1] + lane]);
|
reinterpret_cast<half2*>(&sh_code0[8 * enc[2 * i + 0] + lane]);
|
||||||
#pragma unroll
|
half2* a1 =
|
||||||
|
reinterpret_cast<half2*>(&sh_code1[8 * enc[2 * i + 1] + lane]);
|
||||||
|
#pragma unroll
|
||||||
for (int j = 0; j < 4; j++)
|
for (int j = 0; j < 4; j++)
|
||||||
reinterpret_cast<half2*>(&chunk)[j] = __hadd2(a0[j], a1[j]);
|
reinterpret_cast<half2*>(&chunk)[j] = __hadd2(a0[j], a1[j]);
|
||||||
C[a_gl_rd * 8 + i] = chunk;
|
C[a_gl_rd * 8 + i] = chunk;
|
||||||
@ -317,22 +296,15 @@ __global__ void Code2x8Dequant(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inline int ceildiv(int a, int b) {
|
inline int ceildiv(int a, int b) { return (a + b - 1) / b; }
|
||||||
return (a + b - 1) / b;
|
|
||||||
}
|
|
||||||
|
|
||||||
const int THREAD_M = 16;
|
const int THREAD_M = 16;
|
||||||
|
|
||||||
void code1x16_matvec_cuda(
|
void code1x16_matvec_cuda(const void* __restrict__ A,
|
||||||
const void* __restrict__ A,
|
const void* __restrict__ B, void* __restrict__ C,
|
||||||
const void* __restrict__ B,
|
const void* __restrict__ codebook, int prob_m,
|
||||||
void* __restrict__ C,
|
int prob_k, const int4 codebook_a_sizes,
|
||||||
const void* __restrict__ codebook,
|
const int codebook_stride) {
|
||||||
int prob_m,
|
|
||||||
int prob_k,
|
|
||||||
const int4 codebook_a_sizes,
|
|
||||||
const int codebook_stride
|
|
||||||
) {
|
|
||||||
int sms;
|
int sms;
|
||||||
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
|
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
|
||||||
int waves = 0;
|
int waves = 0;
|
||||||
@ -345,28 +317,16 @@ void code1x16_matvec_cuda(
|
|||||||
int blocks = ceildiv(prob_m, thread_m);
|
int blocks = ceildiv(prob_m, thread_m);
|
||||||
int threads = 32 * thread_m;
|
int threads = 32 * thread_m;
|
||||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||||
Code1x16MatVec<<<blocks, threads, 16*32*9, stream>>>(
|
Code1x16MatVec<<<blocks, threads, 16 * 32 * 9, stream>>>(
|
||||||
(const int4*) A,
|
(const int4*)A, (const int4*)B, (int4*)C, (const int4*)codebook, prob_m,
|
||||||
(const int4*) B,
|
prob_k, codebook_a_sizes, codebook_stride);
|
||||||
(int4*) C,
|
|
||||||
(const int4*) codebook,
|
|
||||||
prob_m,
|
|
||||||
prob_k,
|
|
||||||
codebook_a_sizes,
|
|
||||||
codebook_stride
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void code2x8_matvec_cuda(
|
void code2x8_matvec_cuda(const void* __restrict__ A, const void* __restrict__ B,
|
||||||
const void* __restrict__ A,
|
void* __restrict__ C,
|
||||||
const void* __restrict__ B,
|
const void* __restrict__ codebook, int prob_m,
|
||||||
void* __restrict__ C,
|
int prob_k, const int4 codebook_a_sizes,
|
||||||
const void* __restrict__ codebook,
|
const int codebook_stride) {
|
||||||
int prob_m,
|
|
||||||
int prob_k,
|
|
||||||
const int4 codebook_a_sizes,
|
|
||||||
const int codebook_stride
|
|
||||||
) {
|
|
||||||
int sms;
|
int sms;
|
||||||
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
|
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
|
||||||
int waves = 0;
|
int waves = 0;
|
||||||
@ -379,30 +339,20 @@ void code2x8_matvec_cuda(
|
|||||||
int blocks = ceildiv(prob_m, thread_m);
|
int blocks = ceildiv(prob_m, thread_m);
|
||||||
int threads = 32 * thread_m;
|
int threads = 32 * thread_m;
|
||||||
int shared = 16 * (2 * 256 * 8 + 32 * 9);
|
int shared = 16 * (2 * 256 * 8 + 32 * 9);
|
||||||
cudaFuncSetAttribute(
|
cudaFuncSetAttribute(Code2x8MatVec,
|
||||||
Code2x8MatVec, cudaFuncAttributeMaxDynamicSharedMemorySize, shared
|
cudaFuncAttributeMaxDynamicSharedMemorySize, shared);
|
||||||
);
|
|
||||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||||
Code2x8MatVec<<<blocks, threads, shared, stream>>>(
|
Code2x8MatVec<<<blocks, threads, shared, stream>>>(
|
||||||
(const int4*) A,
|
(const int4*)A, (const int4*)B, (int4*)C, (const int4*)codebook, prob_m,
|
||||||
(const int4*) B,
|
prob_k, codebook_a_sizes, codebook_stride);
|
||||||
(int4*) C,
|
|
||||||
(const int4*) codebook,
|
|
||||||
prob_m,
|
|
||||||
prob_k,
|
|
||||||
codebook_a_sizes,
|
|
||||||
codebook_stride
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void code1x16_dequant_cuda(
|
void code1x16_dequant_cuda(
|
||||||
const void* __restrict__ A,
|
const void* __restrict__ A, void* __restrict__ C,
|
||||||
void* __restrict__ C,
|
const void* __restrict__ codebook, int prob_m, int prob_k,
|
||||||
const void* __restrict__ codebook,
|
const int4 codebook_a_sizes, // cumulative sizes of A spanning each
|
||||||
int prob_m,
|
// codebook, at most 3 long.
|
||||||
int prob_k,
|
const int codebook_stride // as int4.
|
||||||
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long.
|
|
||||||
const int codebook_stride // as int4.
|
|
||||||
) {
|
) {
|
||||||
int sms;
|
int sms;
|
||||||
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
|
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
|
||||||
@ -417,25 +367,21 @@ void code1x16_dequant_cuda(
|
|||||||
int threads = 32 * thread_m;
|
int threads = 32 * thread_m;
|
||||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||||
Code1x16Dequant<<<blocks, threads, 0, stream>>>(
|
Code1x16Dequant<<<blocks, threads, 0, stream>>>(
|
||||||
(const int4*) A,
|
(const int4*)A, (int4*)C, (const int4*)codebook, prob_m, prob_k,
|
||||||
(int4*) C,
|
codebook_a_sizes, // cumulative sizes of A spanning each codebook, at
|
||||||
(const int4*) codebook,
|
// most 3 long.
|
||||||
prob_m,
|
codebook_stride // as int4.
|
||||||
prob_k,
|
|
||||||
codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long.
|
|
||||||
codebook_stride // as int4.
|
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Dequantizes the code and codebook into weights.
|
// Dequantizes the code and codebook into weights.
|
||||||
void code2x8_dequant_cuda(
|
void code2x8_dequant_cuda(
|
||||||
const void* __restrict__ A,
|
const void* __restrict__ A, void* __restrict__ C,
|
||||||
void* __restrict__ C,
|
const void* __restrict__ codebook, int prob_m, int prob_k,
|
||||||
const void* __restrict__ codebook,
|
const int4
|
||||||
int prob_m,
|
codebook_a_sizes, // cumulative sizes of A spanning each codebook, at
|
||||||
int prob_k,
|
// most 3 long, corresponds to cols.
|
||||||
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long, corresponds to cols.
|
const int codebook_stride // as int4
|
||||||
const int codebook_stride // as int4
|
|
||||||
) {
|
) {
|
||||||
int sms;
|
int sms;
|
||||||
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
|
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
|
||||||
@ -451,74 +397,50 @@ void code2x8_dequant_cuda(
|
|||||||
int shared = 16 * (2 * 256 * 8 + 32 * 9);
|
int shared = 16 * (2 * 256 * 8 + 32 * 9);
|
||||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||||
|
|
||||||
cudaFuncSetAttribute(
|
cudaFuncSetAttribute(Code2x8Dequant,
|
||||||
Code2x8Dequant, cudaFuncAttributeMaxDynamicSharedMemorySize, shared
|
cudaFuncAttributeMaxDynamicSharedMemorySize, shared);
|
||||||
);
|
|
||||||
Code2x8Dequant<<<blocks, threads, shared, stream>>>(
|
Code2x8Dequant<<<blocks, threads, shared, stream>>>(
|
||||||
(const int4*) A,
|
(const int4*)A, (int4*)C, (const int4*)codebook, prob_m, prob_k,
|
||||||
(int4*) C,
|
codebook_a_sizes, codebook_stride);
|
||||||
(const int4*) codebook,
|
|
||||||
prob_m,
|
|
||||||
prob_k,
|
|
||||||
codebook_a_sizes,
|
|
||||||
codebook_stride
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int codebook_stride(const torch::Tensor& codebooks)
|
int codebook_stride(const torch::Tensor& codebooks) {
|
||||||
{
|
|
||||||
return codebooks.stride(0) * codebooks.element_size() / sizeof(int4);
|
return codebooks.stride(0) * codebooks.element_size() / sizeof(int4);
|
||||||
}
|
}
|
||||||
|
|
||||||
void code1x16_matvec(
|
void code1x16_matvec(
|
||||||
const torch::Tensor& A,
|
const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& C,
|
||||||
const torch::Tensor& B,
|
const torch::Tensor& codebook,
|
||||||
torch::Tensor& C,
|
const int4 codebook_a_sizes // cumulative sizes of A spanning each
|
||||||
const torch::Tensor& codebook,
|
// codebook, at most 3 long.
|
||||||
const int4 codebook_a_sizes // cumulative sizes of A spanning each codebook, at most 3 long.
|
|
||||||
) {
|
) {
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
|
||||||
int prob_m = C.size(0);
|
int prob_m = C.size(0);
|
||||||
int prob_k = B.size(0);
|
int prob_k = B.size(0);
|
||||||
|
|
||||||
code1x16_matvec_cuda(
|
code1x16_matvec_cuda(A.data_ptr(), B.data_ptr(), C.data_ptr(),
|
||||||
A.data_ptr(),
|
codebook.data_ptr(), prob_m, prob_k, codebook_a_sizes,
|
||||||
B.data_ptr(),
|
codebook_stride(codebook));
|
||||||
C.data_ptr(),
|
|
||||||
codebook.data_ptr(),
|
|
||||||
prob_m,
|
|
||||||
prob_k,
|
|
||||||
codebook_a_sizes,
|
|
||||||
codebook_stride(codebook)
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::Tensor code1x16_matmat(
|
torch::Tensor code1x16_matmat(const torch::Tensor& input,
|
||||||
const torch::Tensor& input,
|
const torch::Tensor& codes,
|
||||||
const torch::Tensor& codes,
|
const torch::Tensor& codebooks,
|
||||||
const torch::Tensor& codebooks,
|
const torch::Tensor& scales,
|
||||||
const torch::Tensor& scales,
|
const int4 codebook_a_sizes,
|
||||||
const int4 codebook_a_sizes,
|
const std::optional<torch::Tensor>& bias) {
|
||||||
const std::optional<torch::Tensor>& bias) {
|
|
||||||
auto input_sizes = input.sizes();
|
auto input_sizes = input.sizes();
|
||||||
auto out_features = codes.size(0) * codebooks.size(2);
|
auto out_features = codes.size(0) * codebooks.size(2);
|
||||||
auto flat_input = input.reshape({-1, input.size(-1)});
|
auto flat_input = input.reshape({-1, input.size(-1)});
|
||||||
auto flat_output = torch::empty({flat_input.size(0), out_features},
|
auto flat_output = torch::empty(
|
||||||
torch::TensorOptions()
|
{flat_input.size(0), out_features},
|
||||||
.dtype(input.dtype())
|
torch::TensorOptions().dtype(input.dtype()).device(input.device()));
|
||||||
.device(input.device())
|
|
||||||
);
|
|
||||||
|
|
||||||
for (int i = 0; i < flat_input.size(0); ++i) {
|
for (int i = 0; i < flat_input.size(0); ++i) {
|
||||||
auto input_vec = flat_input.index({i});
|
auto input_vec = flat_input.index({i});
|
||||||
auto output_vec = flat_output.index({i});
|
auto output_vec = flat_output.index({i});
|
||||||
code1x16_matvec(
|
code1x16_matvec(codes.squeeze(2), input_vec, output_vec, codebooks,
|
||||||
codes.squeeze(2),
|
codebook_a_sizes);
|
||||||
input_vec,
|
|
||||||
output_vec,
|
|
||||||
codebooks,
|
|
||||||
codebook_a_sizes
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
flat_output *= scales.flatten().unsqueeze(0);
|
flat_output *= scales.flatten().unsqueeze(0);
|
||||||
|
|
||||||
@ -533,55 +455,35 @@ torch::Tensor code1x16_matmat(
|
|||||||
return output;
|
return output;
|
||||||
}
|
}
|
||||||
|
|
||||||
void code2x8_matvec(
|
void code2x8_matvec(const torch::Tensor& A, const torch::Tensor& B,
|
||||||
const torch::Tensor& A,
|
torch::Tensor& C, const torch::Tensor& codebook,
|
||||||
const torch::Tensor& B,
|
const int4 codebook_a_sizes) {
|
||||||
torch::Tensor& C,
|
|
||||||
const torch::Tensor& codebook,
|
|
||||||
const int4 codebook_a_sizes
|
|
||||||
) {
|
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
|
||||||
int prob_m = C.size(0);
|
int prob_m = C.size(0);
|
||||||
int prob_k = B.size(0);
|
int prob_k = B.size(0);
|
||||||
code2x8_matvec_cuda(
|
code2x8_matvec_cuda(A.data_ptr(), B.data_ptr(), C.data_ptr(),
|
||||||
A.data_ptr(),
|
codebook.data_ptr(), prob_m, prob_k, codebook_a_sizes,
|
||||||
B.data_ptr(),
|
2 * codebook_stride(codebook));
|
||||||
C.data_ptr(),
|
|
||||||
codebook.data_ptr(),
|
|
||||||
prob_m,
|
|
||||||
prob_k,
|
|
||||||
codebook_a_sizes,
|
|
||||||
2 * codebook_stride(codebook)
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::Tensor code2x8_matmat(
|
torch::Tensor code2x8_matmat(const torch::Tensor& input,
|
||||||
const torch::Tensor& input,
|
const torch::Tensor& codes,
|
||||||
const torch::Tensor& codes,
|
const torch::Tensor& codebooks,
|
||||||
const torch::Tensor& codebooks,
|
const torch::Tensor& scales,
|
||||||
const torch::Tensor& scales,
|
const int4 codebook_a_sizes,
|
||||||
const int4 codebook_a_sizes,
|
const std::optional<torch::Tensor>& bias) {
|
||||||
const std::optional<torch::Tensor>& bias
|
|
||||||
) {
|
|
||||||
auto input_sizes = input.sizes();
|
auto input_sizes = input.sizes();
|
||||||
auto out_features = codes.size(0) * codebooks.size(2);
|
auto out_features = codes.size(0) * codebooks.size(2);
|
||||||
auto flat_input = input.reshape({-1, input.size(-1)});
|
auto flat_input = input.reshape({-1, input.size(-1)});
|
||||||
auto flat_output = torch::empty({flat_input.size(0), out_features},
|
auto flat_output = torch::empty(
|
||||||
torch::TensorOptions()
|
{flat_input.size(0), out_features},
|
||||||
.dtype(input.dtype())
|
torch::TensorOptions().dtype(input.dtype()).device(input.device()));
|
||||||
.device(input.device())
|
|
||||||
);
|
|
||||||
|
|
||||||
for (int i = 0; i < flat_input.size(0); ++i) {
|
for (int i = 0; i < flat_input.size(0); ++i) {
|
||||||
auto input_vec = flat_input.index({i});
|
auto input_vec = flat_input.index({i});
|
||||||
auto output_vec = flat_output.index({i});
|
auto output_vec = flat_output.index({i});
|
||||||
code2x8_matvec(
|
code2x8_matvec(codes.squeeze(2), input_vec, output_vec, codebooks,
|
||||||
codes.squeeze(2),
|
codebook_a_sizes);
|
||||||
input_vec,
|
|
||||||
output_vec,
|
|
||||||
codebooks,
|
|
||||||
codebook_a_sizes
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
flat_output *= scales.flatten().unsqueeze(0);
|
flat_output *= scales.flatten().unsqueeze(0);
|
||||||
if (bias.has_value()) {
|
if (bias.has_value()) {
|
||||||
@ -596,64 +498,56 @@ torch::Tensor code2x8_matmat(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Accumulate the partition sizes.
|
// Accumulate the partition sizes.
|
||||||
int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes)
|
int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes) {
|
||||||
{
|
|
||||||
int4 cumulative_sizes;
|
int4 cumulative_sizes;
|
||||||
auto cumulative_size = &cumulative_sizes.x;
|
auto cumulative_size = &cumulative_sizes.x;
|
||||||
int i = 0;
|
int i = 0;
|
||||||
int last = 0;
|
int last = 0;
|
||||||
assert(codebook_partition_sizes.size(0) <= 4);
|
assert(codebook_partition_sizes.size(0) <= 4);
|
||||||
for (; i < codebook_partition_sizes.size(0); ++i, ++cumulative_size)
|
for (; i < codebook_partition_sizes.size(0); ++i, ++cumulative_size) {
|
||||||
{
|
|
||||||
*cumulative_size = codebook_partition_sizes[i].item<int>() + last;
|
*cumulative_size = codebook_partition_sizes[i].item<int>() + last;
|
||||||
last = *cumulative_size;
|
last = *cumulative_size;
|
||||||
}
|
}
|
||||||
// fill in the rest with unreachable.
|
// fill in the rest with unreachable.
|
||||||
for (; i < 4; ++i, ++cumulative_size)
|
for (; i < 4; ++i, ++cumulative_size) {
|
||||||
{
|
*cumulative_size = last * 10;
|
||||||
*cumulative_size = last*10;
|
|
||||||
}
|
}
|
||||||
return cumulative_sizes;
|
return cumulative_sizes;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace aqlm
|
} // namespace aqlm
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
|
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
|
||||||
torch::Tensor aqlm_gemm(
|
const torch::Tensor& codebooks,
|
||||||
const torch::Tensor& input,
|
const torch::Tensor& scales,
|
||||||
const torch::Tensor& codes,
|
const torch::Tensor& codebook_partition_sizes,
|
||||||
const torch::Tensor& codebooks,
|
const std::optional<torch::Tensor>& bias) {
|
||||||
const torch::Tensor& scales,
|
int4 cumulative_sizes =
|
||||||
const torch::Tensor& codebook_partition_sizes,
|
vllm::aqlm::accumulate_sizes(codebook_partition_sizes);
|
||||||
const std::optional<torch::Tensor>& bias
|
|
||||||
)
|
|
||||||
{
|
|
||||||
int4 cumulative_sizes = vllm::aqlm::accumulate_sizes(codebook_partition_sizes);
|
|
||||||
|
|
||||||
int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0);
|
int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0);
|
||||||
int const entries = codebooks.size(1);
|
int const entries = codebooks.size(1);
|
||||||
|
|
||||||
if (nbooks == 1 && entries == (1 << 16))
|
if (nbooks == 1 && entries == (1 << 16)) {
|
||||||
{
|
return vllm::aqlm::code1x16_matmat(input, codes, codebooks, scales,
|
||||||
return vllm::aqlm::code1x16_matmat(input, codes, codebooks, scales, cumulative_sizes, bias);
|
cumulative_sizes, bias);
|
||||||
}
|
}
|
||||||
if (nbooks == 2 && entries == (1 << 8))
|
if (nbooks == 2 && entries == (1 << 8)) {
|
||||||
{
|
return vllm::aqlm::code2x8_matmat(input, codes, codebooks, scales,
|
||||||
return vllm::aqlm::code2x8_matmat(input, codes, codebooks, scales, cumulative_sizes, bias);
|
cumulative_sizes, bias);
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, " entries is not currently supported.")
|
TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries,
|
||||||
|
" entries is not currently supported.")
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::Tensor aqlm_dequant(
|
torch::Tensor aqlm_dequant(const torch::Tensor& codes,
|
||||||
const torch::Tensor& codes,
|
const torch::Tensor& codebooks,
|
||||||
const torch::Tensor& codebooks,
|
const torch::Tensor& codebook_partition_sizes) {
|
||||||
const torch::Tensor& codebook_partition_sizes
|
int4 cumulative_sizes =
|
||||||
)
|
vllm::aqlm::accumulate_sizes(codebook_partition_sizes);
|
||||||
{
|
|
||||||
int4 cumulative_sizes = vllm::aqlm::accumulate_sizes(codebook_partition_sizes);
|
|
||||||
|
|
||||||
int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0);
|
int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0);
|
||||||
int const entries = codebooks.size(1);
|
int const entries = codebooks.size(1);
|
||||||
@ -668,45 +562,37 @@ torch::Tensor aqlm_dequant(
|
|||||||
assert(out_features = codebook_partition_sizes.sum().item<int>());
|
assert(out_features = codebook_partition_sizes.sum().item<int>());
|
||||||
|
|
||||||
auto weights = torch::empty({out_features, in_features},
|
auto weights = torch::empty({out_features, in_features},
|
||||||
torch::TensorOptions()
|
torch::TensorOptions()
|
||||||
.dtype(codebooks.dtype())
|
.dtype(codebooks.dtype())
|
||||||
.device(codebooks.device())
|
.device(codebooks.device()));
|
||||||
);
|
|
||||||
|
|
||||||
if (nbooks == 1 && entries == (1 << 16))
|
if (nbooks == 1 && entries == (1 << 16)) {
|
||||||
{
|
vllm::aqlm::code1x16_dequant_cuda(codes.data_ptr(), weights.data_ptr(),
|
||||||
vllm::aqlm::code1x16_dequant_cuda(
|
codebooks.data_ptr(), out_features,
|
||||||
codes.data_ptr(),
|
in_features, cumulative_sizes,
|
||||||
weights.data_ptr(),
|
vllm::aqlm::codebook_stride(codebooks));
|
||||||
codebooks.data_ptr(),
|
|
||||||
out_features,
|
|
||||||
in_features,
|
|
||||||
cumulative_sizes,
|
|
||||||
vllm::aqlm::codebook_stride(codebooks));
|
|
||||||
|
|
||||||
// if you wanted to flip to scaling the weights, (though it's 30%-ish slower and not consistent with gemv implementation.)
|
// if you wanted to flip to scaling the weights, (though it's 30%-ish slower
|
||||||
// weights *= scales.index({"...", 0, 0});
|
// and not consistent with gemv implementation.) weights *=
|
||||||
|
// scales.index({"...", 0, 0});
|
||||||
|
|
||||||
return weights;
|
return weights;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (nbooks == 2 && entries == (1 << 8))
|
if (nbooks == 2 && entries == (1 << 8)) {
|
||||||
{
|
vllm::aqlm::code2x8_dequant_cuda(codes.data_ptr(), weights.data_ptr(),
|
||||||
vllm::aqlm::code2x8_dequant_cuda(
|
codebooks.data_ptr(), out_features,
|
||||||
codes.data_ptr(),
|
in_features, cumulative_sizes,
|
||||||
weights.data_ptr(),
|
vllm::aqlm::codebook_stride(codebooks));
|
||||||
codebooks.data_ptr(),
|
|
||||||
out_features,
|
|
||||||
in_features,
|
|
||||||
cumulative_sizes,
|
|
||||||
vllm::aqlm::codebook_stride(codebooks));
|
|
||||||
|
|
||||||
// if you wanted to flip to scaling the weights, (though it's 30%-ish slower and not consistent with gemv implementation)
|
// if you wanted to flip to scaling the weights, (though it's 30%-ish slower
|
||||||
// weights *= scales.index({"...", 0, 0});
|
// and not consistent with gemv implementation) weights *=
|
||||||
|
// scales.index({"...", 0, 0});
|
||||||
|
|
||||||
return weights;
|
return weights;
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, " entries is not currently supported.")
|
TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries,
|
||||||
|
" entries is not currently supported.")
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
/*
|
/*
|
||||||
Adapted from https://github.com/mit-han-lab/llm-awq
|
Adapted from https://github.com/mit-han-lab/llm-awq
|
||||||
Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
Modified from NVIDIA FasterTransformer:
|
||||||
|
https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
||||||
@article{lin2023awq,
|
@article{lin2023awq,
|
||||||
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
|
title={AWQ: Activation-aware Weight Quantization for LLM Compression and
|
||||||
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
|
Acceleration}, author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang,
|
||||||
journal={arXiv},
|
Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023}
|
||||||
year={2023}
|
|
||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
|
|
||||||
@ -14,74 +14,88 @@ Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransfor
|
|||||||
namespace vllm {
|
namespace vllm {
|
||||||
namespace awq {
|
namespace awq {
|
||||||
|
|
||||||
__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
|
__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) {
|
||||||
{
|
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
|
||||||
assert(false);
|
assert(false);
|
||||||
#else
|
#else
|
||||||
uint4 result;
|
uint4 result;
|
||||||
|
|
||||||
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
|
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
|
||||||
uint32_t const i4s = reinterpret_cast<uint32_t const&>(source);
|
uint32_t const i4s = reinterpret_cast<uint32_t const&>(source);
|
||||||
|
|
||||||
// First, we extract the i4s and construct an intermediate fp16 number.
|
// First, we extract the i4s and construct an intermediate fp16 number.
|
||||||
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
|
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
|
||||||
static constexpr uint32_t BOTTOM_MASK = 0x000f000f;
|
static constexpr uint32_t BOTTOM_MASK = 0x000f000f;
|
||||||
static constexpr uint32_t TOP_MASK = 0x00f000f0;
|
static constexpr uint32_t TOP_MASK = 0x00f000f0;
|
||||||
static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
|
static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
|
||||||
|
|
||||||
// Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing
|
// Note that the entire sequence only requires 1 shift instruction. This is
|
||||||
// format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions.
|
// thanks to the register packing format and the fact that we force our
|
||||||
// In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and
|
// integers to be unsigned, and account for this in the fp16 subtractions. In
|
||||||
// elt_67 to fp16 without having to shift them to the bottom bits before hand.
|
// addition, I exploit the fact that sub and fma have the same throughput in
|
||||||
|
// order to convert elt_23 and elt_67 to fp16 without having to shift them to
|
||||||
|
// the bottom bits before hand.
|
||||||
|
|
||||||
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue
|
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW
|
||||||
// immediately before required.
|
// dependency if we issue immediately before required.
|
||||||
const uint32_t top_i4s = i4s >> 8;
|
const uint32_t top_i4s = i4s >> 8;
|
||||||
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
|
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
|
||||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||||
: "=r"(h[0])
|
: "=r"(h[0])
|
||||||
: "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
: "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM),
|
||||||
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
|
"n"(immLut));
|
||||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
|
||||||
: "=r"(h[1])
|
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||||
: "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
: "=r"(h[1])
|
||||||
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
|
: "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM),
|
||||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
"n"(immLut));
|
||||||
: "=r"(h[2])
|
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
|
||||||
: "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||||
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
|
: "=r"(h[2])
|
||||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
: "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM),
|
||||||
: "=r"(h[3])
|
"n"(immLut));
|
||||||
: "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
|
||||||
|
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||||
|
: "=r"(h[3])
|
||||||
|
: "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM),
|
||||||
|
"n"(immLut));
|
||||||
|
|
||||||
// I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the
|
// I use inline PTX below because I am not sure if the compiler will emit
|
||||||
// half2 ctor. In this case, I chose performance reliability over code readability.
|
// float2half instructions if I use the half2 ctor. In this case, I chose
|
||||||
|
// performance reliability over code readability.
|
||||||
|
|
||||||
// This is the half2 {1032, 1032} represented as an integer.
|
// This is the half2 {1032, 1032} represented as an integer.
|
||||||
// static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
|
// static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
|
||||||
// Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7]
|
// Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7]
|
||||||
static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400;
|
static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400;
|
||||||
// This is the half2 {1 / 16, 1 / 16} represented as an integer.
|
// This is the half2 {1 / 16, 1 / 16} represented as an integer.
|
||||||
static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
|
static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
|
||||||
// This is the half2 {-72, -72} represented as an integer.
|
// This is the half2 {-72, -72} represented as an integer.
|
||||||
// static constexpr uint32_t NEG_72 = 0xd480d480;
|
// static constexpr uint32_t NEG_72 = 0xd480d480;
|
||||||
// Haotian: Let's use {-64, -64}.
|
// Haotian: Let's use {-64, -64}.
|
||||||
static constexpr uint32_t NEG_64 = 0xd400d400;
|
static constexpr uint32_t NEG_64 = 0xd400d400;
|
||||||
|
|
||||||
// Finally, we construct the output numbers.
|
// Finally, we construct the output numbers.
|
||||||
// Convert elt_01
|
// Convert elt_01
|
||||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
|
asm volatile("sub.f16x2 %0, %1, %2;\n"
|
||||||
// Convert elt_23
|
: "=r"(h[0])
|
||||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
|
: "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
|
||||||
// Convert elt_45
|
// Convert elt_23
|
||||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
|
||||||
// Convert elt_67
|
: "=r"(h[1])
|
||||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
|
: "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
|
||||||
|
// Convert elt_45
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n"
|
||||||
|
: "=r"(h[2])
|
||||||
|
: "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
|
||||||
|
// Convert elt_67
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
|
||||||
|
: "=r"(h[3])
|
||||||
|
: "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace awq
|
} // namespace awq
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
@ -1,14 +1,12 @@
|
|||||||
/*
|
/*
|
||||||
Adapted from https://github.com/mit-han-lab/llm-awq
|
Adapted from https://github.com/mit-han-lab/llm-awq
|
||||||
@article{lin2023awq,
|
@article{lin2023awq,
|
||||||
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
|
title={AWQ: Activation-aware Weight Quantization for LLM Compression and
|
||||||
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
|
Acceleration}, author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang,
|
||||||
journal={arXiv},
|
Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023}
|
||||||
year={2023}
|
|
||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
|
||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
|
||||||
@ -20,26 +18,20 @@ namespace vllm {
|
|||||||
namespace awq {
|
namespace awq {
|
||||||
|
|
||||||
// Pack two half values.
|
// Pack two half values.
|
||||||
static inline __device__ __host__ unsigned
|
static inline __device__ __host__ unsigned __pack_half2(const half x,
|
||||||
__pack_half2(const half x, const half y) {
|
const half y) {
|
||||||
unsigned v0 = *((unsigned short *)&x);
|
unsigned v0 = *((unsigned short*)&x);
|
||||||
unsigned v1 = *((unsigned short *)&y);
|
unsigned v1 = *((unsigned short*)&y);
|
||||||
return (v1 << 16) | v0;
|
return (v1 << 16) | v0;
|
||||||
}
|
}
|
||||||
|
|
||||||
template<int N>
|
template <int N>
|
||||||
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32(
|
__global__ void __launch_bounds__(64)
|
||||||
int G,
|
gemm_forward_4bit_cuda_m16nXk32(int G, int split_k_iters,
|
||||||
int split_k_iters,
|
half* __restrict__ A, int* __restrict__ B,
|
||||||
half* __restrict__ A,
|
half* __restrict__ scaling_factors,
|
||||||
int* __restrict__ B,
|
int* __restrict__ zeros, int M, int IC,
|
||||||
half* __restrict__ scaling_factors,
|
int OC, half* __restrict__ C) {
|
||||||
int* __restrict__ zeros,
|
|
||||||
int M,
|
|
||||||
int IC,
|
|
||||||
int OC,
|
|
||||||
half* __restrict__ C)
|
|
||||||
{
|
|
||||||
// Only support matrix n = 64 or 128
|
// Only support matrix n = 64 or 128
|
||||||
assert(N == 64 || N == 128);
|
assert(N == 64 || N == 128);
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
|
||||||
@ -70,43 +62,46 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32(
|
|||||||
static constexpr int row_stride = 2 * 32 * 8 / N;
|
static constexpr int row_stride = 2 * 32 * 8 / N;
|
||||||
bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < N;
|
bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < N;
|
||||||
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
|
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
|
||||||
bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id
|
bool ld_A_flag =
|
||||||
|
(blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp +
|
||||||
|
threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id
|
||||||
// bool wb_C_flag = (threadIdx.x / 4) < M;
|
// bool wb_C_flag = (threadIdx.x / 4) < M;
|
||||||
|
|
||||||
half* A_ptr = A
|
half* A_ptr =
|
||||||
+ (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC
|
A +
|
||||||
+ (((int)threadIdx.x) % (32 / 8)) * 8;
|
(((int)blockIdx_y) / j_factors1 * 16 +
|
||||||
|
(((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) *
|
||||||
|
IC +
|
||||||
|
(((int)threadIdx.x) % (32 / 8)) * 8;
|
||||||
|
|
||||||
int* B_ptr = B
|
int* B_ptr = B + ((int)threadIdx.y) * (OC / 8) * (256 / N) +
|
||||||
+ ((int)threadIdx.y) * (OC / 8) * (256 / N)
|
(((int)threadIdx.x) / (N / 8)) * (OC / 8) +
|
||||||
+ (((int)threadIdx.x) / (N / 8)) * (OC / 8)
|
(((int)blockIdx_y) % j_factors1) * (N / 8) +
|
||||||
+ (((int)blockIdx_y) % j_factors1) * (N / 8)
|
(((int)threadIdx.x) % (N / 8)) * 1;
|
||||||
+ (((int)threadIdx.x) % (N / 8)) * 1;
|
// Why * 1 in the above line?
|
||||||
// Why * 1 in the above line?
|
|
||||||
|
|
||||||
half* A_shared_ptr = A_shared
|
half* A_shared_ptr = A_shared +
|
||||||
+ ((int)threadIdx.y) * row_stride_warp * (32 + 8)
|
((int)threadIdx.y) * row_stride_warp * (32 + 8) +
|
||||||
+ (((int)threadIdx.x) / (32 / 8)) * (32 + 8)
|
(((int)threadIdx.x) / (32 / 8)) * (32 + 8) +
|
||||||
+ (((int)threadIdx.x) % (32 / 8) ) * 8;
|
(((int)threadIdx.x) % (32 / 8)) * 8;
|
||||||
|
|
||||||
half* B_shared_ptr = B_shared
|
half* B_shared_ptr = B_shared +
|
||||||
+ ((int)threadIdx.y) * (row_stride / 2) * (N + 8)
|
((int)threadIdx.y) * (row_stride / 2) * (N + 8) +
|
||||||
+ (((int)threadIdx.x) / (N / 8)) * (N + 8)
|
(((int)threadIdx.x) / (N / 8)) * (N + 8) +
|
||||||
+ (((int)threadIdx.x) % (N / 8)) * 8;
|
(((int)threadIdx.x) % (N / 8)) * 8;
|
||||||
|
|
||||||
int* zeros_ptr = zeros
|
int* zeros_ptr = zeros + (((int)blockIdx_y) % j_factors1) * (N / 8) +
|
||||||
+ (((int)blockIdx_y) % j_factors1) * (N / 8)
|
((int)threadIdx.x) % (N / 8);
|
||||||
+ ((int)threadIdx.x) % (N / 8);
|
|
||||||
|
|
||||||
half* scaling_factors_ptr = scaling_factors
|
half* scaling_factors_ptr = scaling_factors +
|
||||||
+ (((int)blockIdx_y) % j_factors1) * N
|
(((int)blockIdx_y) % j_factors1) * N +
|
||||||
+ (((int)threadIdx.x) % (N / 8)) * 8;
|
(((int)threadIdx.x) % (N / 8)) * 8;
|
||||||
|
|
||||||
half* C_ptr = C
|
half* C_ptr =
|
||||||
+ static_cast<long long>(blockIdx_z) * M * OC // blockIdz.x -> split_k dim
|
C +
|
||||||
+ (((int)blockIdx_y) % j_factors1) * N
|
static_cast<long long>(blockIdx_z) * M * OC // blockIdz.x -> split_k dim
|
||||||
+ ((int)threadIdx.y) * (N / 2)
|
+ (((int)blockIdx_y) % j_factors1) * N + ((int)threadIdx.y) * (N / 2) +
|
||||||
+ (((int)threadIdx.x) % 4) * 2;
|
(((int)threadIdx.x) % 4) * 2;
|
||||||
|
|
||||||
// preload s.f. and zeros
|
// preload s.f. and zeros
|
||||||
int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
|
int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
|
||||||
@ -115,57 +110,83 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32(
|
|||||||
int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
|
int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
|
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
|
||||||
if (ld_A_flag)
|
if (ld_A_flag) {
|
||||||
{
|
|
||||||
*(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
|
*(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
|
||||||
}
|
} else {
|
||||||
else
|
|
||||||
{
|
|
||||||
*(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
|
*(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
// for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
|
// for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
|
||||||
uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
|
uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
|
||||||
uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
|
uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
|
||||||
uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
|
uint4 B_loaded_scale =
|
||||||
|
*(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
|
||||||
/*
|
/*
|
||||||
if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
|
if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 &&
|
||||||
printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
|
threadIdx.y == 0){ printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x,
|
||||||
|
B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x,
|
||||||
|
B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
|
||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
// uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
|
// uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
|
||||||
int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
|
int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
|
||||||
|
|
||||||
for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < N / 16; ++ax0_ax1_fused_0) {
|
for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < N / 16; ++ax0_ax1_fused_0) {
|
||||||
|
|
||||||
// B: 32 x 136 (128+8) float16
|
// B: 32 x 136 (128+8) float16
|
||||||
// each warp: 32 x 4
|
// each warp: 32 x 4
|
||||||
// each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
|
// each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus
|
||||||
// *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
|
// zero -> WB UINT4
|
||||||
// row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
|
// *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) *
|
||||||
uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
|
// 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15)
|
||||||
|
// * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 *
|
||||||
|
// 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) *
|
||||||
|
// 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) *
|
||||||
|
// 8))); row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
|
||||||
|
uint32_t B_loaded =
|
||||||
|
*(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
|
||||||
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
|
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
|
||||||
//uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8);
|
// uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N /
|
||||||
|
// 8)) * 8);
|
||||||
|
|
||||||
// uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8);
|
// uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x
|
||||||
|
// % (cta_N / 8)) * 8);
|
||||||
// - zero and * scale
|
// - zero and * scale
|
||||||
// TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale.
|
// TODO (Haotian): can save 4 assembly instructions if sormulate as deq =
|
||||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
|
// q * scale - zero * scale.
|
||||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
|
asm volatile("sub.f16x2 %0, %1, %2;\n"
|
||||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
|
: "=r"(B_loaded_fp16.x)
|
||||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
|
: "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
|
||||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
|
||||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
|
: "=r"(B_loaded_fp16.x)
|
||||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
|
: "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
|
||||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
|
asm volatile("sub.f16x2 %0, %1, %2;\n"
|
||||||
|
: "=r"(B_loaded_fp16.y)
|
||||||
|
: "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
|
||||||
|
: "=r"(B_loaded_fp16.y)
|
||||||
|
: "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n"
|
||||||
|
: "=r"(B_loaded_fp16.z)
|
||||||
|
: "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
|
||||||
|
: "=r"(B_loaded_fp16.z)
|
||||||
|
: "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n"
|
||||||
|
: "=r"(B_loaded_fp16.w)
|
||||||
|
: "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
|
||||||
|
: "=r"(B_loaded_fp16.w)
|
||||||
|
: "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
|
||||||
/*
|
/*
|
||||||
if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){
|
if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 ==
|
||||||
printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
|
0 && threadIdx.x == 17 && threadIdx.y == 0){ printf("[x] %X %X %X %X\n",
|
||||||
|
B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
|
||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
|
|
||||||
// write back
|
// write back
|
||||||
*(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) = B_loaded_fp16;
|
*(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) =
|
||||||
|
B_loaded_fp16;
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
@ -173,112 +194,179 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32(
|
|||||||
{
|
{
|
||||||
unsigned int addr;
|
unsigned int addr;
|
||||||
__asm__ __volatile__(
|
__asm__ __volatile__(
|
||||||
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
|
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, "
|
||||||
: "=r"(addr)
|
"addr; }\n"
|
||||||
: "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))
|
: "=r"(addr)
|
||||||
);
|
: "l"((void*)((&(A_shared[(k_0_1 * 16)])) +
|
||||||
|
(((((int)threadIdx.x) & 15) * 40) +
|
||||||
|
((((int)threadIdx.x) >> 4) * 8)))));
|
||||||
|
|
||||||
__asm__ __volatile__(
|
__asm__ __volatile__(
|
||||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
|
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
|
||||||
"{%0, %1, %2, %3}, [%4];\n"
|
"{%0, %1, %2, %3}, [%4];\n"
|
||||||
: "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3])
|
: "=r"(((unsigned*)(A_shared_warp + 0))[0]),
|
||||||
: "r"(addr)
|
"=r"(((unsigned*)(A_shared_warp + 0))[1]),
|
||||||
);
|
"=r"(((unsigned*)(A_shared_warp + 0))[2]),
|
||||||
|
"=r"(((unsigned*)(A_shared_warp + 0))[3])
|
||||||
|
: "r"(addr));
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int ax1_0 = 0; ax1_0 < N / 32; ++ax1_0) {
|
for (int ax1_0 = 0; ax1_0 < N / 32; ++ax1_0) {
|
||||||
{
|
{
|
||||||
unsigned int addr;
|
unsigned int addr;
|
||||||
__asm__ __volatile__(
|
__asm__ __volatile__(
|
||||||
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
|
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, "
|
||||||
: "=r"(addr)
|
"addr; }\n"
|
||||||
: "l"((void *)((&(B_shared[(((k_0_1 * (N * 16 + 128)) + (((int)threadIdx.y) * (N / 2))) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * (N + 8)) + ((((int)threadIdx.x) >> 4) * 8))))
|
: "=r"(addr)
|
||||||
);
|
: "l"((void*)((&(B_shared[(((k_0_1 * (N * 16 + 128)) +
|
||||||
|
(((int)threadIdx.y) * (N / 2))) +
|
||||||
|
(ax1_0 * 16))])) +
|
||||||
|
(((((int)threadIdx.x) & 15) * (N + 8)) +
|
||||||
|
((((int)threadIdx.x) >> 4) * 8)))));
|
||||||
__asm__ __volatile__(
|
__asm__ __volatile__(
|
||||||
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
|
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
|
||||||
"{%0, %1, %2, %3}, [%4];\n"
|
"{%0, %1, %2, %3}, [%4];\n"
|
||||||
: "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3])
|
: "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[0]),
|
||||||
: "r"(addr)
|
"=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[1]),
|
||||||
);
|
"=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[2]),
|
||||||
|
"=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[3])
|
||||||
|
: "r"(addr));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (int j_0_4 = 0; j_0_4 < N / 32; ++j_0_4) {
|
for (int j_0_4 = 0; j_0_4 < N / 32; ++j_0_4) {
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||||
{
|
{
|
||||||
__asm__ __volatile__(
|
__asm__ __volatile__(
|
||||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
||||||
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
||||||
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
|
: "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
|
||||||
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
|
"=f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
|
||||||
|
"=f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
|
||||||
|
"=f"(((float*)(C_warp + (j_0_4 * 8)))[3])
|
||||||
|
: "r"(((unsigned*)(A_shared_warp + 0))[0]),
|
||||||
|
"r"(((unsigned*)(A_shared_warp + 0))[1]),
|
||||||
|
"r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[0]),
|
||||||
|
"f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
|
||||||
|
"f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
|
||||||
|
"f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
|
||||||
|
"f"(((float*)(C_warp + (j_0_4 * 8)))[3]));
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
__asm__ __volatile__(
|
__asm__ __volatile__(
|
||||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
||||||
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
||||||
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
: "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
|
||||||
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
|
||||||
|
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
|
||||||
|
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
||||||
|
: "r"(((unsigned*)(A_shared_warp + 0))[0]),
|
||||||
|
"r"(((unsigned*)(A_shared_warp + 0))[1]),
|
||||||
|
"r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]),
|
||||||
|
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
|
||||||
|
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
|
||||||
|
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
|
||||||
|
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
__asm__ __volatile__(
|
__asm__ __volatile__(
|
||||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
||||||
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
||||||
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
|
: "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
|
||||||
: "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
|
"=f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
|
||||||
|
"=f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
|
||||||
|
"=f"(((float*)(C_warp + (j_0_4 * 8)))[3])
|
||||||
|
: "r"(((unsigned*)(A_shared_warp + 0))[2]),
|
||||||
|
"r"(((unsigned*)(A_shared_warp + 0))[3]),
|
||||||
|
"r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[1]),
|
||||||
|
"f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
|
||||||
|
"f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
|
||||||
|
"f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
|
||||||
|
"f"(((float*)(C_warp + (j_0_4 * 8)))[3]));
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
__asm__ __volatile__(
|
__asm__ __volatile__(
|
||||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
||||||
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
||||||
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
: "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
|
||||||
: "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
|
||||||
|
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
|
||||||
|
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
||||||
|
: "r"(((unsigned*)(A_shared_warp + 0))[2]),
|
||||||
|
"r"(((unsigned*)(A_shared_warp + 0))[3]),
|
||||||
|
"r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]),
|
||||||
|
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
|
||||||
|
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
|
||||||
|
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
|
||||||
|
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
{
|
{
|
||||||
__asm__ __volatile__(
|
__asm__ __volatile__(
|
||||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
|
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
|
||||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
|
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, "
|
||||||
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
|
"%13};\n"
|
||||||
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
|
: "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
|
||||||
|
"=f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
|
||||||
|
"=f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
|
||||||
|
"=f"(((float*)(C_warp + (j_0_4 * 8)))[3])
|
||||||
|
: "r"(((unsigned*)(A_shared_warp + 0))[0]),
|
||||||
|
"r"(((unsigned*)(A_shared_warp + 0))[1]),
|
||||||
|
"r"(((unsigned*)(A_shared_warp + 0))[2]),
|
||||||
|
"r"(((unsigned*)(A_shared_warp + 0))[3]),
|
||||||
|
"r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[0]),
|
||||||
|
"r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[1]),
|
||||||
|
"f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
|
||||||
|
"f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
|
||||||
|
"f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
|
||||||
|
"f"(((float*)(C_warp + (j_0_4 * 8)))[3]));
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
__asm__ __volatile__(
|
__asm__ __volatile__(
|
||||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
|
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
|
||||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
|
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, "
|
||||||
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
"%13};\n"
|
||||||
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
: "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
|
||||||
|
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
|
||||||
|
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
|
||||||
|
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
||||||
|
: "r"(((unsigned*)(A_shared_warp + 0))[0]),
|
||||||
|
"r"(((unsigned*)(A_shared_warp + 0))[1]),
|
||||||
|
"r"(((unsigned*)(A_shared_warp + 0))[2]),
|
||||||
|
"r"(((unsigned*)(A_shared_warp + 0))[3]),
|
||||||
|
"r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]),
|
||||||
|
"r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]),
|
||||||
|
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
|
||||||
|
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
|
||||||
|
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
|
||||||
|
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Shang: Hoist loop invariance.
|
// TODO: Shang: Hoist loop invariance.
|
||||||
for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) {
|
for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) {
|
||||||
for (int local_id = 0; local_id < 8; ++local_id) {
|
for (int local_id = 0; local_id < 8; ++local_id) {
|
||||||
int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
|
int row_offset = (((int)blockIdx_y) / j_factors1) * 16 +
|
||||||
if (row_offset < M)
|
((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
|
||||||
{
|
if (row_offset < M) {
|
||||||
*(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
|
*(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 +
|
||||||
|
local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
__global__ void __launch_bounds__(64) dequantize_weights(
|
__global__ void __launch_bounds__(64)
|
||||||
int* __restrict__ B,
|
dequantize_weights(int* __restrict__ B, half* __restrict__ scaling_factors,
|
||||||
half* __restrict__ scaling_factors,
|
int* __restrict__ zeros, half* __restrict__ C, int G) {
|
||||||
int* __restrict__ zeros,
|
|
||||||
half* __restrict__ C,
|
|
||||||
int G
|
|
||||||
)
|
|
||||||
{
|
|
||||||
int j_factors1 = 4;
|
int j_factors1 = 4;
|
||||||
int row_stride2 = 4;
|
int row_stride2 = 4;
|
||||||
int split_k_iters = 1;
|
int split_k_iters = 1;
|
||||||
@ -310,14 +398,30 @@ __global__ void __launch_bounds__(64) dequantize_weights(
|
|||||||
|
|
||||||
uint32_t B_loaded = *(uint32_t*)B_ptr2;
|
uint32_t B_loaded = *(uint32_t*)B_ptr2;
|
||||||
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
|
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
|
||||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
|
asm volatile("sub.f16x2 %0, %1, %2;\n"
|
||||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
|
: "=r"(B_loaded_fp16.x)
|
||||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
|
: "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
|
||||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
|
||||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
|
: "=r"(B_loaded_fp16.x)
|
||||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
|
: "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
|
||||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
|
asm volatile("sub.f16x2 %0, %1, %2;\n"
|
||||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
|
: "=r"(B_loaded_fp16.y)
|
||||||
|
: "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
|
||||||
|
: "=r"(B_loaded_fp16.y)
|
||||||
|
: "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n"
|
||||||
|
: "=r"(B_loaded_fp16.z)
|
||||||
|
: "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
|
||||||
|
: "=r"(B_loaded_fp16.z)
|
||||||
|
: "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n"
|
||||||
|
: "=r"(B_loaded_fp16.w)
|
||||||
|
: "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
|
||||||
|
: "=r"(B_loaded_fp16.w)
|
||||||
|
: "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
|
||||||
|
|
||||||
*(uint4*)B_shared_ptr2 = B_loaded_fp16;
|
*(uint4*)B_shared_ptr2 = B_loaded_fp16;
|
||||||
|
|
||||||
@ -326,58 +430,57 @@ __global__ void __launch_bounds__(64) dequantize_weights(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace awq
|
} // namespace awq
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
torch::Tensor awq_dequantize(
|
torch::Tensor awq_dequantize(torch::Tensor _kernel,
|
||||||
torch::Tensor _kernel,
|
torch::Tensor _scaling_factors,
|
||||||
torch::Tensor _scaling_factors,
|
torch::Tensor _zeros, int split_k_iters, int thx,
|
||||||
torch::Tensor _zeros,
|
int thy) {
|
||||||
int split_k_iters,
|
int in_c = _kernel.size(0);
|
||||||
int thx,
|
int qout_c = _kernel.size(1);
|
||||||
int thy)
|
int out_c = qout_c * 8;
|
||||||
{
|
int G = in_c / _scaling_factors.size(0);
|
||||||
int in_c = _kernel.size(0);
|
|
||||||
int qout_c = _kernel.size(1);
|
|
||||||
int out_c = qout_c * 8;
|
|
||||||
int G = in_c / _scaling_factors.size(0);
|
|
||||||
|
|
||||||
int x_thread = thx;
|
int x_thread = thx;
|
||||||
int y_thread = thy;
|
int y_thread = thy;
|
||||||
|
|
||||||
int x_blocks = 1;
|
int x_blocks = 1;
|
||||||
int y_blocks = 1;
|
int y_blocks = 1;
|
||||||
if (thx==0) {
|
if (thx == 0) {
|
||||||
x_thread = qout_c;
|
x_thread = qout_c;
|
||||||
}
|
}
|
||||||
if (thy==0) {
|
if (thy == 0) {
|
||||||
y_thread = in_c;
|
y_thread = in_c;
|
||||||
}
|
}
|
||||||
if (thx==0 && thy==0) {
|
if (thx == 0 && thy == 0) {
|
||||||
x_thread = 8;
|
x_thread = 8;
|
||||||
y_thread = 8;
|
y_thread = 8;
|
||||||
x_blocks = (int)(qout_c / 8);
|
x_blocks = (int)(qout_c / 8);
|
||||||
y_blocks = (int)(in_c / 8);
|
y_blocks = (int)(in_c / 8);
|
||||||
}
|
}
|
||||||
|
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors));
|
||||||
|
|
||||||
auto options = torch::TensorOptions().dtype(_scaling_factors.dtype()).device(_scaling_factors.device());
|
auto options = torch::TensorOptions()
|
||||||
at::Tensor _de_kernel = torch::empty({in_c, out_c}, options);
|
.dtype(_scaling_factors.dtype())
|
||||||
|
.device(_scaling_factors.device());
|
||||||
|
at::Tensor _de_kernel = torch::empty({in_c, out_c}, options);
|
||||||
|
|
||||||
auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
|
auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
|
||||||
auto de_kernel = reinterpret_cast<half*>(_de_kernel.data_ptr<at::Half>());
|
auto de_kernel = reinterpret_cast<half*>(_de_kernel.data_ptr<at::Half>());
|
||||||
auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
|
auto scaling_factors =
|
||||||
auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
|
reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
|
||||||
|
auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
|
||||||
|
|
||||||
dim3 num_blocks(x_blocks, y_blocks);
|
dim3 num_blocks(x_blocks, y_blocks);
|
||||||
dim3 threads_per_block(x_thread, y_thread);
|
dim3 threads_per_block(x_thread, y_thread);
|
||||||
|
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
vllm::awq::dequantize_weights<<<num_blocks, threads_per_block, 0, stream>>>(
|
vllm::awq::dequantize_weights<<<num_blocks, threads_per_block, 0, stream>>>(
|
||||||
kernel, scaling_factors, zeros, de_kernel, G);
|
kernel, scaling_factors, zeros, de_kernel, G);
|
||||||
|
|
||||||
return _de_kernel;
|
return _de_kernel;
|
||||||
}
|
}
|
||||||
|
|
||||||
// in_feats: M, IC [float16]
|
// in_feats: M, IC [float16]
|
||||||
@ -386,61 +489,61 @@ torch::Tensor awq_dequantize(
|
|||||||
// zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b]
|
// zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b]
|
||||||
// assume that batch_size < 16 for now
|
// assume that batch_size < 16 for now
|
||||||
|
|
||||||
torch::Tensor awq_gemm(
|
torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
|
||||||
torch::Tensor _in_feats,
|
torch::Tensor _scaling_factors, torch::Tensor _zeros,
|
||||||
torch::Tensor _kernel,
|
int split_k_iters) {
|
||||||
torch::Tensor _scaling_factors,
|
int num_in_feats = _in_feats.size(0);
|
||||||
torch::Tensor _zeros,
|
int num_in_channels = _in_feats.size(1);
|
||||||
int split_k_iters)
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats));
|
||||||
{
|
|
||||||
int num_in_feats = _in_feats.size(0);
|
|
||||||
int num_in_channels = _in_feats.size(1);
|
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats));
|
|
||||||
|
|
||||||
auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device());
|
auto options = torch::TensorOptions()
|
||||||
at::Tensor _out_feats = torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options);
|
.dtype(_in_feats.dtype())
|
||||||
int num_out_feats = _out_feats.size(-2);
|
.device(_in_feats.device());
|
||||||
int num_out_channels = _out_feats.size(-1);
|
at::Tensor _out_feats =
|
||||||
|
torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options);
|
||||||
|
int num_out_feats = _out_feats.size(-2);
|
||||||
|
int num_out_channels = _out_feats.size(-1);
|
||||||
|
|
||||||
auto in_feats = reinterpret_cast<half*>(_in_feats.data_ptr<at::Half>());
|
auto in_feats = reinterpret_cast<half*>(_in_feats.data_ptr<at::Half>());
|
||||||
auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
|
auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
|
||||||
auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());
|
auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());
|
||||||
auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
|
auto scaling_factors =
|
||||||
auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
|
reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
|
||||||
int group_size = num_in_channels / _scaling_factors.size(0);
|
auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
|
||||||
|
int group_size = num_in_channels / _scaling_factors.size(0);
|
||||||
|
|
||||||
if (num_out_channels % 64 != 0)
|
if (num_out_channels % 64 != 0)
|
||||||
throw std::invalid_argument("OC is not multiple of cta_N = 64");
|
throw std::invalid_argument("OC is not multiple of cta_N = 64");
|
||||||
if (num_out_channels % 8 != 0)
|
if (num_out_channels % 8 != 0)
|
||||||
throw std::invalid_argument("OC is not multiple of pack_num = 8");
|
throw std::invalid_argument("OC is not multiple of pack_num = 8");
|
||||||
if (group_size % 32 != 0)
|
if (group_size % 32 != 0)
|
||||||
throw std::invalid_argument("Group size should be a multiple of 32");
|
throw std::invalid_argument("Group size should be a multiple of 32");
|
||||||
if (num_out_channels % group_size != 0)
|
if (num_out_channels % group_size != 0)
|
||||||
throw std::invalid_argument("OC is not multiple of Group size");
|
throw std::invalid_argument("OC is not multiple of Group size");
|
||||||
|
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
if (num_out_channels % 128 == 0)
|
if (num_out_channels % 128 == 0) {
|
||||||
{
|
int j_factors1 = num_out_channels / 128 / 1;
|
||||||
int j_factors1 = num_out_channels / 128 / 1;
|
dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
|
||||||
dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
|
// threadIdx.x: 32
|
||||||
// threadIdx.x: 32
|
// threadIdx.y: i_factors[2] * j_factors[2]
|
||||||
// threadIdx.y: i_factors[2] * j_factors[2]
|
dim3 threads_per_block(32, 2);
|
||||||
dim3 threads_per_block(32, 2);
|
vllm::awq::gemm_forward_4bit_cuda_m16nXk32<128>
|
||||||
vllm::awq::gemm_forward_4bit_cuda_m16nXk32<128><<<num_blocks, threads_per_block, 0, stream>>>(
|
<<<num_blocks, threads_per_block, 0, stream>>>(
|
||||||
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels,
|
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros,
|
||||||
num_out_channels, out_feats);
|
num_in_feats, num_in_channels, num_out_channels, out_feats);
|
||||||
}
|
} else if (num_out_channels % 64 == 0) {
|
||||||
else if (num_out_channels % 64 == 0)
|
int j_factors1 = num_out_channels / 64 / 1;
|
||||||
{
|
dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 *
|
||||||
int j_factors1 = num_out_channels / 64 / 1;
|
split_k_iters);
|
||||||
dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
|
|
||||||
|
|
||||||
// threadIdx.x: 32
|
// threadIdx.x: 32
|
||||||
// threadIdx.y: i_factors[2] * j_factors[2]
|
// threadIdx.y: i_factors[2] * j_factors[2]
|
||||||
dim3 threads_per_block(32, 2);
|
dim3 threads_per_block(32, 2);
|
||||||
vllm::awq::gemm_forward_4bit_cuda_m16nXk32<64><<<num_blocks, threads_per_block, 0, stream>>>(
|
vllm::awq::gemm_forward_4bit_cuda_m16nXk32<64>
|
||||||
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels,
|
<<<num_blocks, threads_per_block, 0, stream>>>(
|
||||||
num_out_channels, out_feats);
|
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros,
|
||||||
}
|
num_in_feats, num_in_channels, num_out_channels, out_feats);
|
||||||
return _out_feats.sum(0);
|
}
|
||||||
|
return _out_feats.sum(0);
|
||||||
}
|
}
|
||||||
|
@ -117,10 +117,10 @@ struct cutlass_2x_gemm {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <typename Gemm>
|
template <typename Gemm>
|
||||||
void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a,
|
void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
|
||||||
torch::Tensor const &b,
|
torch::Tensor const& b,
|
||||||
torch::Tensor const &a_scales,
|
torch::Tensor const& a_scales,
|
||||||
torch::Tensor const &b_scales) {
|
torch::Tensor const& b_scales) {
|
||||||
using ElementAB = typename Gemm::ElementAB;
|
using ElementAB = typename Gemm::ElementAB;
|
||||||
using ElementD = typename Gemm::ElementD;
|
using ElementD = typename Gemm::ElementD;
|
||||||
|
|
||||||
@ -136,9 +136,9 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a,
|
|||||||
using StrideC = Stride<int64_t, Int<1>, Int<0>>;
|
using StrideC = Stride<int64_t, Int<1>, Int<0>>;
|
||||||
StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
|
StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
|
||||||
|
|
||||||
auto a_ptr = static_cast<ElementAB const *>(a.data_ptr());
|
auto a_ptr = static_cast<ElementAB const*>(a.data_ptr());
|
||||||
auto b_ptr = static_cast<ElementAB const *>(b.data_ptr());
|
auto b_ptr = static_cast<ElementAB const*>(b.data_ptr());
|
||||||
auto c_ptr = static_cast<ElementD *>(out.data_ptr());
|
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
||||||
|
|
||||||
auto a_scales_ptr = a_scales.data_ptr<float>();
|
auto a_scales_ptr = a_scales.data_ptr<float>();
|
||||||
auto b_scales_ptr = b_scales.data_ptr<float>();
|
auto b_scales_ptr = b_scales.data_ptr<float>();
|
||||||
@ -196,10 +196,10 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a,
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void cutlass_scaled_mm_dq_sm75(torch::Tensor &out, torch::Tensor const &a,
|
void cutlass_scaled_mm_dq_sm75(torch::Tensor& out, torch::Tensor const& a,
|
||||||
torch::Tensor const &b,
|
torch::Tensor const& b,
|
||||||
torch::Tensor const &a_scales,
|
torch::Tensor const& a_scales,
|
||||||
torch::Tensor const &b_scales) {
|
torch::Tensor const& b_scales) {
|
||||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||||
@ -223,10 +223,10 @@ void cutlass_scaled_mm_dq_sm75(torch::Tensor &out, torch::Tensor const &a,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void cutlass_scaled_mm_dq_sm80(torch::Tensor &out, torch::Tensor const &a,
|
void cutlass_scaled_mm_dq_sm80(torch::Tensor& out, torch::Tensor const& a,
|
||||||
torch::Tensor const &b,
|
torch::Tensor const& b,
|
||||||
torch::Tensor const &a_scales,
|
torch::Tensor const& a_scales,
|
||||||
torch::Tensor const &b_scales) {
|
torch::Tensor const& b_scales) {
|
||||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||||
@ -250,10 +250,10 @@ void cutlass_scaled_mm_dq_sm80(torch::Tensor &out, torch::Tensor const &a,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void cutlass_scaled_mm_dq_sm89(torch::Tensor &out, torch::Tensor const &a,
|
void cutlass_scaled_mm_dq_sm89(torch::Tensor& out, torch::Tensor const& a,
|
||||||
torch::Tensor const &b,
|
torch::Tensor const& b,
|
||||||
torch::Tensor const &a_scales,
|
torch::Tensor const& a_scales,
|
||||||
torch::Tensor const &b_scales) {
|
torch::Tensor const& b_scales) {
|
||||||
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
|
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
|
||||||
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
||||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||||
|
@ -120,10 +120,10 @@ struct cutlass_3x_gemm {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <typename Gemm>
|
template <typename Gemm>
|
||||||
void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a,
|
void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
|
||||||
torch::Tensor const &b,
|
torch::Tensor const& b,
|
||||||
torch::Tensor const &a_scales,
|
torch::Tensor const& a_scales,
|
||||||
torch::Tensor const &b_scales) {
|
torch::Tensor const& b_scales) {
|
||||||
using ElementAB = typename Gemm::ElementAB;
|
using ElementAB = typename Gemm::ElementAB;
|
||||||
using ElementD = typename Gemm::ElementD;
|
using ElementD = typename Gemm::ElementD;
|
||||||
|
|
||||||
@ -146,12 +146,12 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a,
|
|||||||
using GemmKernel = typename Gemm::GemmKernel;
|
using GemmKernel = typename Gemm::GemmKernel;
|
||||||
typename GemmKernel::ProblemShape prob_shape{m, n, k, 1};
|
typename GemmKernel::ProblemShape prob_shape{m, n, k, 1};
|
||||||
|
|
||||||
auto a_ptr = static_cast<ElementAB *>(a.data_ptr());
|
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
|
||||||
auto b_ptr = static_cast<ElementAB *>(b.data_ptr());
|
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
|
||||||
typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr,
|
typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr,
|
||||||
b_stride};
|
b_stride};
|
||||||
|
|
||||||
auto c_ptr = static_cast<ElementD *>(out.data_ptr());
|
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
||||||
typename GemmKernel::EpilogueArguments epilogue_args{
|
typename GemmKernel::EpilogueArguments epilogue_args{
|
||||||
{}, c_ptr, c_stride, c_ptr, c_stride};
|
{}, c_ptr, c_stride, c_ptr, c_stride};
|
||||||
|
|
||||||
@ -183,10 +183,10 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a,
|
|||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void cutlass_scaled_mm_dq_sm90(torch::Tensor &out, torch::Tensor const &a,
|
void cutlass_scaled_mm_dq_sm90(torch::Tensor& out, torch::Tensor const& a,
|
||||||
torch::Tensor const &b,
|
torch::Tensor const& b,
|
||||||
torch::Tensor const &a_scales,
|
torch::Tensor const& a_scales,
|
||||||
torch::Tensor const &b_scales) {
|
torch::Tensor const& b_scales) {
|
||||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||||
|
|
||||||
|
@ -2,29 +2,29 @@
|
|||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
|
|
||||||
void cutlass_scaled_mm_dq_sm75(torch::Tensor &c, torch::Tensor const &a,
|
void cutlass_scaled_mm_dq_sm75(torch::Tensor& c, torch::Tensor const& a,
|
||||||
torch::Tensor const &b,
|
torch::Tensor const& b,
|
||||||
torch::Tensor const &a_scales,
|
torch::Tensor const& a_scales,
|
||||||
torch::Tensor const &b_scales);
|
torch::Tensor const& b_scales);
|
||||||
|
|
||||||
void cutlass_scaled_mm_dq_sm80(torch::Tensor &c, torch::Tensor const &a,
|
void cutlass_scaled_mm_dq_sm80(torch::Tensor& c, torch::Tensor const& a,
|
||||||
torch::Tensor const &b,
|
torch::Tensor const& b,
|
||||||
torch::Tensor const &a_scales,
|
torch::Tensor const& a_scales,
|
||||||
torch::Tensor const &b_scales);
|
torch::Tensor const& b_scales);
|
||||||
|
|
||||||
void cutlass_scaled_mm_dq_sm89(torch::Tensor &c, torch::Tensor const &a,
|
void cutlass_scaled_mm_dq_sm89(torch::Tensor& c, torch::Tensor const& a,
|
||||||
torch::Tensor const &b,
|
torch::Tensor const& b,
|
||||||
torch::Tensor const &a_scales,
|
torch::Tensor const& a_scales,
|
||||||
torch::Tensor const &b_scales);
|
torch::Tensor const& b_scales);
|
||||||
|
|
||||||
void cutlass_scaled_mm_dq_sm90(torch::Tensor &c, torch::Tensor const &a,
|
void cutlass_scaled_mm_dq_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||||
torch::Tensor const &b,
|
torch::Tensor const& b,
|
||||||
torch::Tensor const &a_scales,
|
torch::Tensor const& a_scales,
|
||||||
torch::Tensor const &b_scales);
|
torch::Tensor const& b_scales);
|
||||||
|
|
||||||
void cutlass_scaled_mm_dq(torch::Tensor &c, torch::Tensor const &a,
|
void cutlass_scaled_mm_dq(torch::Tensor& c, torch::Tensor const& a,
|
||||||
torch::Tensor const &b, torch::Tensor const &a_scales,
|
torch::Tensor const& b, torch::Tensor const& a_scales,
|
||||||
torch::Tensor const &b_scales) {
|
torch::Tensor const& b_scales) {
|
||||||
int32_t major_capability;
|
int32_t major_capability;
|
||||||
int32_t minor_capability;
|
int32_t minor_capability;
|
||||||
cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
|
cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
|
||||||
@ -36,14 +36,15 @@ void cutlass_scaled_mm_dq(torch::Tensor &c, torch::Tensor const &a,
|
|||||||
// Checks for conformality
|
// Checks for conformality
|
||||||
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
|
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
|
||||||
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
|
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
|
||||||
b.size(1) == c.size(1));
|
b.size(1) == c.size(1));
|
||||||
TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
|
TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
|
||||||
TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
|
TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
|
||||||
|
|
||||||
// Check for strides and alignment
|
// Check for strides and alignment
|
||||||
TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
|
TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
|
||||||
TORCH_CHECK(b.stride(0) == 1); // Column-major
|
TORCH_CHECK(b.stride(0) == 1); // Column-major
|
||||||
TORCH_CHECK(c.stride(0) % 16 == 0 && b.stride(1) % 16 == 0); // 16 Byte Alignment
|
TORCH_CHECK(c.stride(0) % 16 == 0 &&
|
||||||
|
b.stride(1) % 16 == 0); // 16 Byte Alignment
|
||||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||||
|
|
||||||
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
|
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
|
||||||
|
@ -1,167 +1,137 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#ifdef __HIPCC__
|
#ifdef __HIPCC__
|
||||||
#include <hip/hip_runtime.h>
|
#include <hip/hip_runtime.h>
|
||||||
#else
|
#else
|
||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
#include <math.h>
|
#include <math.h>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include "hip_float8_impl.h"
|
#include "hip_float8_impl.h"
|
||||||
|
|
||||||
struct alignas(1) hip_fp8
|
struct alignas(1) hip_fp8 {
|
||||||
{
|
struct from_bits_t {};
|
||||||
struct from_bits_t
|
HIP_FP8_HOST_DEVICE static constexpr from_bits_t from_bits() {
|
||||||
{
|
return from_bits_t();
|
||||||
};
|
}
|
||||||
HIP_FP8_HOST_DEVICE static constexpr from_bits_t from_bits() { return from_bits_t(); }
|
uint8_t data;
|
||||||
uint8_t data;
|
|
||||||
|
|
||||||
hip_fp8() = default;
|
hip_fp8() = default;
|
||||||
HIP_FP8_HOST_DEVICE constexpr hip_fp8(const hip_fp8&) = default;
|
HIP_FP8_HOST_DEVICE constexpr hip_fp8(const hip_fp8&) = default;
|
||||||
HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v) = delete;
|
HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v) = delete;
|
||||||
explicit HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v, from_bits_t)
|
explicit HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v, from_bits_t)
|
||||||
: data(v)
|
: data(v) {}
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
#ifdef __HIP__MI300__
|
#ifdef __HIP__MI300__
|
||||||
// NOTE: ON-DEVICE... always optimal bias
|
// NOTE: ON-DEVICE... always optimal bias
|
||||||
explicit HIP_FP8_DEVICE hip_fp8(float v)
|
explicit HIP_FP8_DEVICE hip_fp8(float v)
|
||||||
: data(hip_fp8_impl::to_fp8_from_fp32(v))
|
: data(hip_fp8_impl::to_fp8_from_fp32(v)) {}
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
explicit HIP_FP8_DEVICE hip_fp8(_Float16 v)
|
explicit HIP_FP8_DEVICE hip_fp8(_Float16 v)
|
||||||
: hip_fp8(static_cast<float>(v))
|
: hip_fp8(static_cast<float>(v)) {}
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
// Host only implementation using s/w simulation
|
// Host only implementation using s/w simulation
|
||||||
explicit HIP_FP8_HOST
|
explicit HIP_FP8_HOST
|
||||||
#else // __HIP__MI300__
|
#else // __HIP__MI300__
|
||||||
// both Host and DEVICE for non-MI300 using s/w simulation
|
// both Host and DEVICE for non-MI300 using s/w simulation
|
||||||
explicit HIP_FP8_HOST_DEVICE
|
explicit HIP_FP8_HOST_DEVICE
|
||||||
#endif // __HIP__MI300__
|
#endif // __HIP__MI300__
|
||||||
hip_fp8(float v)
|
hip_fp8(float v) {
|
||||||
{
|
data = hip_fp8_impl::to_float8<4, 3, float, true /*negative_zero_nan*/,
|
||||||
data = hip_fp8_impl::to_float8<4, 3, float, true /*negative_zero_nan*/, true /*clip*/>(v);
|
true /*clip*/>(v);
|
||||||
}
|
}
|
||||||
|
|
||||||
explicit HIP_FP8_HOST_DEVICE hip_fp8(double v)
|
explicit HIP_FP8_HOST_DEVICE hip_fp8(double v)
|
||||||
: hip_fp8(static_cast<float>(v))
|
: hip_fp8(static_cast<float>(v)) {}
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
#ifdef __HIP__MI300__
|
#ifdef __HIP__MI300__
|
||||||
// upcast using device specific intrinsic
|
// upcast using device specific intrinsic
|
||||||
explicit inline HIP_FP8_DEVICE operator float() const
|
explicit inline HIP_FP8_DEVICE operator float() const {
|
||||||
{
|
float fval;
|
||||||
float fval;
|
uint32_t i32val = static_cast<uint32_t>(data);
|
||||||
uint32_t i32val = static_cast<uint32_t>(data);
|
|
||||||
|
|
||||||
// upcast
|
// upcast
|
||||||
asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
|
asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0"
|
||||||
|
: "=v"(fval)
|
||||||
|
: "v"(i32val));
|
||||||
|
|
||||||
return fval;
|
return fval;
|
||||||
}
|
}
|
||||||
|
|
||||||
explicit inline HIP_FP8_HOST operator float() const
|
explicit inline HIP_FP8_HOST operator float() const
|
||||||
#else // __HIP__MI300__
|
#else // __HIP__MI300__
|
||||||
explicit inline HIP_FP8_HOST_DEVICE operator float() const
|
explicit inline HIP_FP8_HOST_DEVICE operator float() const
|
||||||
#endif // __HIP__MI300__
|
#endif // __HIP__MI300__
|
||||||
{
|
{
|
||||||
return hip_fp8_impl::from_float8<4, 3, float, true /*negative_zero_nan*/>(data);
|
return hip_fp8_impl::from_float8<4, 3, float, true /*negative_zero_nan*/>(
|
||||||
}
|
data);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
namespace std
|
namespace std {
|
||||||
{
|
inline hip_fp8 sin(hip_fp8 a) { return hip_fp8(sinf(float(a))); }
|
||||||
inline hip_fp8 sin(hip_fp8 a)
|
inline hip_fp8 cos(hip_fp8 a) { return hip_fp8(cosf(float(a))); }
|
||||||
{
|
HIP_FP8_HOST_DEVICE constexpr hip_fp8 real(const hip_fp8& a) { return a; }
|
||||||
return hip_fp8(sinf(float(a)));
|
} // namespace std
|
||||||
}
|
|
||||||
inline hip_fp8 cos(hip_fp8 a)
|
|
||||||
{
|
|
||||||
return hip_fp8(cosf(float(a)));
|
|
||||||
}
|
|
||||||
HIP_FP8_HOST_DEVICE constexpr hip_fp8 real(const hip_fp8& a)
|
|
||||||
{
|
|
||||||
return a;
|
|
||||||
}
|
|
||||||
} // namespace std
|
|
||||||
|
|
||||||
// Special operator overloading
|
// Special operator overloading
|
||||||
inline std::ostream& operator<<(std::ostream& os, const hip_fp8& f8)
|
inline std::ostream& operator<<(std::ostream& os, const hip_fp8& f8) {
|
||||||
{
|
return os << float(f8);
|
||||||
return os << float(f8);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// all + operator overloading with mixed types
|
// all + operator overloading with mixed types
|
||||||
// mixed types, always converts to f32, does computation in f32, and returns float
|
// mixed types, always converts to f32, does computation in f32, and returns
|
||||||
inline HIP_FP8_HOST_DEVICE float operator+(const float fa, hip_fp8 b)
|
// float
|
||||||
{
|
inline HIP_FP8_HOST_DEVICE float operator+(const float fa, hip_fp8 b) {
|
||||||
return (fa + float(b));
|
return (fa + float(b));
|
||||||
}
|
}
|
||||||
|
|
||||||
inline HIP_FP8_HOST_DEVICE float operator+(hip_fp8 a, const float fb)
|
inline HIP_FP8_HOST_DEVICE float operator+(hip_fp8 a, const float fb) {
|
||||||
{
|
return (float(a) + fb);
|
||||||
return (float(a) + fb);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
inline HIP_FP8_HOST_DEVICE hip_fp8 operator+(hip_fp8 a, hip_fp8 b)
|
inline HIP_FP8_HOST_DEVICE hip_fp8 operator+(hip_fp8 a, hip_fp8 b) {
|
||||||
{
|
return hip_fp8(float(a) + float(b));
|
||||||
return hip_fp8(float(a) + float(b));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
inline HIP_FP8_HOST_DEVICE hip_fp8& operator+=(hip_fp8& a, hip_fp8 b)
|
inline HIP_FP8_HOST_DEVICE hip_fp8& operator+=(hip_fp8& a, hip_fp8 b) {
|
||||||
{
|
return a = hip_fp8(float(a) + float(b));
|
||||||
return a = hip_fp8(float(a) + float(b));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// overloading multiplication, always returns float,
|
// overloading multiplication, always returns float,
|
||||||
inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, hip_fp8 b)
|
inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, hip_fp8 b) {
|
||||||
{
|
return float(a) * float(b);
|
||||||
return float(a) * float(b);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
inline HIP_FP8_HOST_DEVICE float operator*(float a, hip_fp8 b)
|
inline HIP_FP8_HOST_DEVICE float operator*(float a, hip_fp8 b) {
|
||||||
{
|
return (a * float(b));
|
||||||
return (a * float(b));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, float b)
|
inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, float b) {
|
||||||
{
|
return (float(a) * b);
|
||||||
return (float(a) * b);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
inline HIP_FP8_HOST_DEVICE float operator*(int32_t a, hip_fp8 b)
|
inline HIP_FP8_HOST_DEVICE float operator*(int32_t a, hip_fp8 b) {
|
||||||
{
|
return ((float)a * float(b));
|
||||||
return ((float)a * float(b));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
inline HIP_FP8_HOST_DEVICE float operator*(double a, hip_fp8 b)
|
inline HIP_FP8_HOST_DEVICE float operator*(double a, hip_fp8 b) {
|
||||||
{
|
return ((float)a * float(b));
|
||||||
return ((float)a * float(b));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// overloading for compare
|
// overloading for compare
|
||||||
inline HIP_FP8_HOST_DEVICE bool operator==(hip_fp8 a, hip_fp8 b)
|
inline HIP_FP8_HOST_DEVICE bool operator==(hip_fp8 a, hip_fp8 b) {
|
||||||
{
|
return (a.data == b.data);
|
||||||
return (a.data == b.data);
|
|
||||||
}
|
}
|
||||||
inline HIP_FP8_HOST_DEVICE bool operator!=(hip_fp8 a, hip_fp8 b)
|
inline HIP_FP8_HOST_DEVICE bool operator!=(hip_fp8 a, hip_fp8 b) {
|
||||||
{
|
return (a.data != b.data);
|
||||||
return (a.data != b.data);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
inline HIP_FP8_HOST_DEVICE bool operator>=(hip_fp8 a, hip_fp8 b)
|
inline HIP_FP8_HOST_DEVICE bool operator>=(hip_fp8 a, hip_fp8 b) {
|
||||||
{
|
return static_cast<float>(a) >= static_cast<float>(b);
|
||||||
return static_cast<float>(a) >= static_cast<float>(b);
|
|
||||||
}
|
}
|
||||||
inline HIP_FP8_HOST_DEVICE bool operator>(hip_fp8 a, hip_fp8 b)
|
inline HIP_FP8_HOST_DEVICE bool operator>(hip_fp8 a, hip_fp8 b) {
|
||||||
{
|
return static_cast<float>(a) > static_cast<float>(b);
|
||||||
return static_cast<float>(a) > static_cast<float>(b);
|
|
||||||
}
|
}
|
||||||
|
@ -1,316 +1,316 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#if defined(__HIPCC__) && (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
#if defined(__HIPCC__) && \
|
||||||
#define __HIP__MI300__
|
(defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||||
|
#define __HIP__MI300__
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef __HIPCC__
|
#ifdef __HIPCC__
|
||||||
#define HIP_FP8_HOST_DEVICE __host__ __device__
|
#define HIP_FP8_HOST_DEVICE __host__ __device__
|
||||||
#define HIP_FP8_HOST __host__
|
#define HIP_FP8_HOST __host__
|
||||||
#define HIP_FP8_DEVICE __device__
|
#define HIP_FP8_DEVICE __device__
|
||||||
#else
|
#else
|
||||||
#define HIP_FP8_HOST_DEVICE
|
#define HIP_FP8_HOST_DEVICE
|
||||||
#define HIP_FP8_HOST
|
#define HIP_FP8_HOST
|
||||||
#define HIP_FP8_DEVICE
|
#define HIP_FP8_DEVICE
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace hip_fp8_impl
|
namespace hip_fp8_impl {
|
||||||
{
|
|
||||||
|
|
||||||
#ifdef __HIP__MI300__
|
#ifdef __HIP__MI300__
|
||||||
HIP_FP8_DEVICE uint8_t to_fp8_from_fp32(float v)
|
HIP_FP8_DEVICE uint8_t to_fp8_from_fp32(float v) {
|
||||||
{
|
uint8_t i8data;
|
||||||
uint8_t i8data;
|
union {
|
||||||
union {
|
float fval;
|
||||||
float fval;
|
uint32_t i32val;
|
||||||
uint32_t i32val;
|
uint8_t i8val[4]; // NOTE: not endian independent
|
||||||
uint8_t i8val[4]; // NOTE: not endian independent
|
} val;
|
||||||
} val;
|
|
||||||
|
|
||||||
uint32_t ival = 0;
|
uint32_t ival = 0;
|
||||||
val.fval = v;
|
val.fval = v;
|
||||||
|
|
||||||
if ((val.i32val & 0x7F800000) != 0x7F800000) { /// propagate NAN/INF, no clipping
|
if ((val.i32val & 0x7F800000) !=
|
||||||
val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0);
|
0x7F800000) { /// propagate NAN/INF, no clipping
|
||||||
}
|
val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0);
|
||||||
|
}
|
||||||
|
|
||||||
ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival,
|
ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival,
|
||||||
false); // false -> WORD0
|
false); // false -> WORD0
|
||||||
val.i32val = ival;
|
val.i32val = ival;
|
||||||
i8data = val.i8val[0];
|
i8data = val.i8val[0];
|
||||||
|
|
||||||
return i8data;
|
return i8data;
|
||||||
}
|
}
|
||||||
#endif // __HIP__MI300__
|
#endif // __HIP__MI300__
|
||||||
|
|
||||||
HIP_FP8_HOST inline int clz(uint32_t x)
|
HIP_FP8_HOST inline int clz(uint32_t x) { return __builtin_clz(x); }
|
||||||
{
|
|
||||||
return __builtin_clz(x);
|
|
||||||
}
|
|
||||||
#if defined(__HIPCC__) || defined(__CUDA_ARCH__)
|
#if defined(__HIPCC__) || defined(__CUDA_ARCH__)
|
||||||
HIP_FP8_DEVICE inline int clz(uint32_t x)
|
HIP_FP8_DEVICE inline int clz(uint32_t x) { return __clz(x); }
|
||||||
{
|
|
||||||
return __clz(x);
|
|
||||||
}
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
template <int we, int wm, typename T, bool negative_zero_nan, bool clip>
|
template <int we, int wm, typename T, bool negative_zero_nan, bool clip>
|
||||||
HIP_FP8_HOST_DEVICE uint8_t to_float8(T _x, bool stoch = false, uint32_t rng = 0)
|
HIP_FP8_HOST_DEVICE uint8_t to_float8(T _x, bool stoch = false,
|
||||||
{
|
uint32_t rng = 0) {
|
||||||
#ifdef __HIPCC__
|
#ifdef __HIPCC__
|
||||||
constexpr bool is_half = std::is_same<T, _Float16>::value;
|
constexpr bool is_half = std::is_same<T, _Float16>::value;
|
||||||
#else
|
#else
|
||||||
constexpr bool is_half = false;
|
constexpr bool is_half = false;
|
||||||
#endif
|
#endif
|
||||||
constexpr bool is_float = std::is_same<T, float>::value;
|
constexpr bool is_float = std::is_same<T, float>::value;
|
||||||
static_assert(wm + we == 7, "wm+we==7");
|
static_assert(wm + we == 7, "wm+we==7");
|
||||||
static_assert(is_half || is_float, "Only half and float can be cast to f8");
|
static_assert(is_half || is_float, "Only half and float can be cast to f8");
|
||||||
|
|
||||||
const int mfmt = (sizeof(T) == 4) ? 23 : 10;
|
const int mfmt = (sizeof(T) == 4) ? 23 : 10;
|
||||||
uint32_t x;
|
uint32_t x;
|
||||||
|
if (sizeof(T) == 4) {
|
||||||
|
x = reinterpret_cast<uint32_t&>(_x);
|
||||||
|
} else {
|
||||||
|
x = reinterpret_cast<uint16_t&>(_x);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t head, mantissa;
|
||||||
|
int exponent, bias;
|
||||||
|
uint32_t sign;
|
||||||
|
|
||||||
|
if (sizeof(T) == 4) {
|
||||||
|
head = x & 0xFF800000;
|
||||||
|
mantissa = x & 0x7FFFFF;
|
||||||
|
exponent = (head >> 23) & 0xFF;
|
||||||
|
sign = head >> 31;
|
||||||
|
bias = 127;
|
||||||
|
} else {
|
||||||
|
head = x & 0xFC00;
|
||||||
|
mantissa = x & 0x3FF;
|
||||||
|
exponent = (head >> 10) & 0x1F;
|
||||||
|
sign = head >> 15;
|
||||||
|
bias = 15;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t signed_inf = (sign << 7) + (((1 << we) - 1) << wm);
|
||||||
|
|
||||||
|
// Deal with inf and NaNs
|
||||||
|
if (negative_zero_nan) {
|
||||||
if (sizeof(T) == 4) {
|
if (sizeof(T) == 4) {
|
||||||
x = reinterpret_cast<uint32_t&>(_x);
|
if ((x & 0x7F800000) == 0x7F800000) {
|
||||||
|
return 0x80;
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
x = reinterpret_cast<uint16_t&>(_x);
|
// if(__hisinf(x) || __hisnan(x))
|
||||||
|
if ((x & 0x7C00) == 0x7C00) {
|
||||||
|
return 0x80;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
uint32_t head, mantissa;
|
|
||||||
int exponent, bias;
|
|
||||||
uint32_t sign;
|
|
||||||
|
|
||||||
if (sizeof(T) == 4) {
|
if (sizeof(T) == 4) {
|
||||||
head = x & 0xFF800000;
|
if ((x & 0x7F800000) == 0x7F800000) {
|
||||||
mantissa = x & 0x7FFFFF;
|
return signed_inf + (mantissa != 0 ? 1 : 0);
|
||||||
exponent = (head >> 23) & 0xFF;
|
}
|
||||||
sign = head >> 31;
|
|
||||||
bias = 127;
|
|
||||||
} else {
|
} else {
|
||||||
head = x & 0xFC00;
|
if ((x & 0x7C00) == 0x7C00) {
|
||||||
mantissa = x & 0x3FF;
|
return signed_inf + (mantissa != 0 ? 1 : 0);
|
||||||
exponent = (head >> 10) & 0x1F;
|
}
|
||||||
sign = head >> 15;
|
|
||||||
bias = 15;
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
if (x == 0) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
uint32_t signed_inf = (sign << 7) + (((1 << we) - 1) << wm);
|
// First need to check if it is normal or denorm as there is a difference of
|
||||||
|
// implicit 1 Then need to adjust the exponent to align with the F8 exponent,
|
||||||
|
// in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng
|
||||||
|
// to mantissa and truncate. And for RNE, no need to add rng. Then probably
|
||||||
|
// need to check whether there is carry and adjust exponent and mantissa again
|
||||||
|
|
||||||
// Deal with inf and NaNs
|
// For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent
|
||||||
if (negative_zero_nan) {
|
// bits
|
||||||
if (sizeof(T) == 4) {
|
const int f8_bias = (1 << (we - 1)) - 1 + (negative_zero_nan ? 1 : 0);
|
||||||
if ((x & 0x7F800000) == 0x7F800000) {
|
const int f8_denormal_act_exponent =
|
||||||
return 0x80;
|
1 - f8_bias; // actual exponent of f8 denormal
|
||||||
}
|
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
|
||||||
} else {
|
// f8_exponent is the converted f8 exponent with bias encoding
|
||||||
// if(__hisinf(x) || __hisnan(x))
|
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
|
||||||
if ((x & 0x7C00) == 0x7C00) {
|
// the difference needs to be adjusted and mantissa shifted
|
||||||
return 0x80;
|
int act_exponent, f8_exponent, exponent_diff;
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if (sizeof(T) == 4) {
|
|
||||||
if ((x & 0x7F800000) == 0x7F800000) {
|
|
||||||
return signed_inf + (mantissa != 0 ? 1 : 0);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if ((x & 0x7C00) == 0x7C00) {
|
|
||||||
return signed_inf + (mantissa != 0 ? 1 : 0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (x == 0) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
// First need to check if it is normal or denorm as there is a difference of
|
if (exponent == 0) { // fp32/fp16 is in denormal.
|
||||||
// implicit 1 Then need to adjust the exponent to align with the F8 exponent,
|
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we
|
||||||
// in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng
|
|
||||||
// to mantissa and truncate. And for RNE, no need to add rng. Then probably
|
|
||||||
// need to check whether there is carry and adjust exponent and mantissa again
|
|
||||||
|
|
||||||
// For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent
|
|
||||||
// bits
|
|
||||||
const int f8_bias = (1 << (we - 1)) - 1 + (negative_zero_nan ? 1 : 0);
|
|
||||||
const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal
|
|
||||||
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
|
|
||||||
// f8_exponent is the converted f8 exponent with bias encoding
|
|
||||||
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
|
|
||||||
// the difference needs to be adjusted and mantissa shifted
|
|
||||||
int act_exponent, f8_exponent, exponent_diff;
|
|
||||||
|
|
||||||
if (exponent == 0) { // fp32/fp16 is in denormal.
|
|
||||||
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we
|
|
||||||
mostly concern fp16 here. In this case, f8 is usually in denormal. But there
|
mostly concern fp16 here. In this case, f8 is usually in denormal. But there
|
||||||
could be exceptions. fp16 denormal has exponent bias 15 while bf8 with NANOO has
|
could be exceptions. fp16 denormal has exponent bias 15 while bf8 with NANOO has
|
||||||
exponent bias 16. It means that there are some numbers in fp16 denormal but they
|
exponent bias 16. It means that there are some numbers in fp16 denormal but they
|
||||||
are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
|
are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
|
||||||
where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8
|
where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8
|
||||||
(NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */
|
(NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */
|
||||||
act_exponent = exponent - bias + 1;
|
act_exponent = exponent - bias + 1;
|
||||||
exponent_diff = f8_denormal_act_exponent - act_exponent; // actual exponent is exponent-bias+1 as it is denormal
|
exponent_diff =
|
||||||
} else { // fp32/fp16 is normal with implicit 1
|
f8_denormal_act_exponent -
|
||||||
act_exponent = exponent - bias;
|
act_exponent; // actual exponent is exponent-bias+1 as it is denormal
|
||||||
if (act_exponent <= f8_denormal_act_exponent) {
|
} else { // fp32/fp16 is normal with implicit 1
|
||||||
/* This is the case where fp32/fp16 is normal but it is in f8 denormal
|
act_exponent = exponent - bias;
|
||||||
range. For example fp8 nanoo mode, denormal exponent is -7, but if the
|
if (act_exponent <= f8_denormal_act_exponent) {
|
||||||
fp32/fp16 actual exponent is -7, it is actually larger due to the implicit 1,
|
/* This is the case where fp32/fp16 is normal but it is in f8 denormal
|
||||||
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
|
range. For example fp8 nanoo mode, denormal exponent is -7, but if the
|
||||||
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
|
fp32/fp16 actual exponent is -7, it is actually larger due to the implicit 1,
|
||||||
exponent_diff = f8_denormal_act_exponent - act_exponent;
|
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
|
||||||
} else { // both fp32/fp16 and f8 are in normal range
|
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
|
||||||
exponent_diff = 0; // exponent_diff=0 does not mean there is no difference
|
exponent_diff = f8_denormal_act_exponent - act_exponent;
|
||||||
// for this case,
|
} else { // both fp32/fp16 and f8 are in normal range
|
||||||
// act_exponent could be larger. Just that it does not need shift mantissa
|
exponent_diff = 0; // exponent_diff=0 does not mean there is no
|
||||||
}
|
// difference for this case, act_exponent could be
|
||||||
mantissa += (1 << mfmt); // Add the implicit 1 into mantissa
|
// larger. Just that it does not need shift mantissa
|
||||||
}
|
}
|
||||||
|
mantissa += (1 << mfmt); // Add the implicit 1 into mantissa
|
||||||
|
}
|
||||||
|
|
||||||
bool midpoint = (mantissa & ((1 << (mfmt - wm + exponent_diff)) - 1)) ==
|
bool midpoint = (mantissa & ((1 << (mfmt - wm + exponent_diff)) - 1)) ==
|
||||||
static_cast<uint32_t>(1 << (mfmt - wm + exponent_diff - 1));
|
static_cast<uint32_t>(1 << (mfmt - wm + exponent_diff - 1));
|
||||||
/* This part is a bit tricky. The judgment of whether it is a tie needs to be
|
/* This part is a bit tricky. The judgment of whether it is a tie needs to be
|
||||||
done before we shift right as shift right could rip off some residual part
|
done before we shift right as shift right could rip off some residual part
|
||||||
and make something not midpoint look like midpoint. For example, the fp16
|
and make something not midpoint look like midpoint. For example, the fp16
|
||||||
number 0x1002 (0 00100 0000000010), it is larger than midpoint, but after
|
number 0x1002 (0 00100 0000000010), it is larger than midpoint, but after
|
||||||
shift right by 4 bits, it would look like midpoint.
|
shift right by 4 bits, it would look like midpoint.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
if (exponent_diff > 0) {
|
if (exponent_diff > 0) {
|
||||||
mantissa >>= exponent_diff;
|
mantissa >>= exponent_diff;
|
||||||
} else if (exponent_diff == -1) {
|
} else if (exponent_diff == -1) {
|
||||||
mantissa <<= -exponent_diff;
|
mantissa <<= -exponent_diff;
|
||||||
|
}
|
||||||
|
bool implicit_one = mantissa & (1 << mfmt);
|
||||||
|
// if there is no implicit 1, it means the f8 is denormal and need to adjust
|
||||||
|
// to denorm exponent
|
||||||
|
f8_exponent = (act_exponent + exponent_diff) /*actual f8 exponent*/ +
|
||||||
|
f8_bias - (implicit_one ? 0 : 1);
|
||||||
|
|
||||||
|
// Now we have the exponent and mantissa adjusted
|
||||||
|
uint32_t drop_mask = (1 << (mfmt - wm)) - 1;
|
||||||
|
bool odd = mantissa & (1 << (mfmt - wm)); // if the least significant bit
|
||||||
|
// that is not truncated is 1
|
||||||
|
mantissa +=
|
||||||
|
(stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) &
|
||||||
|
drop_mask;
|
||||||
|
|
||||||
|
// Now we deal with overflow
|
||||||
|
if (f8_exponent == 0) {
|
||||||
|
if ((1 << mfmt) & mantissa) {
|
||||||
|
f8_exponent = 1; // denormal overflow to become normal, promote exponent
|
||||||
}
|
}
|
||||||
bool implicit_one = mantissa & (1 << mfmt);
|
} else {
|
||||||
// if there is no implicit 1, it means the f8 is denormal and need to adjust
|
if ((1 << (mfmt + 1)) & mantissa) {
|
||||||
// to denorm exponent
|
mantissa >>= 1;
|
||||||
f8_exponent = (act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1);
|
f8_exponent++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Now we have the exponent and mantissa adjusted
|
mantissa >>= (mfmt - wm);
|
||||||
uint32_t drop_mask = (1 << (mfmt - wm)) - 1;
|
|
||||||
bool odd = mantissa & (1 << (mfmt - wm)); // if the least significant bit that
|
|
||||||
// is not truncated is 1
|
|
||||||
mantissa += (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & drop_mask;
|
|
||||||
|
|
||||||
// Now we deal with overflow
|
// above range: quantize to maximum possible float of the same sign
|
||||||
if (f8_exponent == 0) {
|
const int max_exp = (1 << we) - (negative_zero_nan ? 1 : 2);
|
||||||
if ((1 << mfmt) & mantissa) {
|
if (f8_exponent > max_exp) {
|
||||||
f8_exponent = 1; // denormal overflow to become normal, promote exponent
|
if (clip) {
|
||||||
}
|
mantissa = (1 << wm) - 1;
|
||||||
|
f8_exponent = max_exp;
|
||||||
} else {
|
} else {
|
||||||
if ((1 << (mfmt + 1)) & mantissa) {
|
return signed_inf;
|
||||||
mantissa >>= 1;
|
|
||||||
f8_exponent++;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
mantissa >>= (mfmt - wm);
|
if (f8_exponent == 0 && mantissa == 0) {
|
||||||
|
return negative_zero_nan ? 0 : (sign << 7);
|
||||||
// above range: quantize to maximum possible float of the same sign
|
}
|
||||||
const int max_exp = (1 << we) - (negative_zero_nan ? 1 : 2);
|
mantissa &= (1 << wm) - 1;
|
||||||
if (f8_exponent > max_exp) {
|
return (sign << 7) | (f8_exponent << wm) | mantissa;
|
||||||
if (clip) {
|
|
||||||
mantissa = (1 << wm) - 1;
|
|
||||||
f8_exponent = max_exp;
|
|
||||||
} else {
|
|
||||||
return signed_inf;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (f8_exponent == 0 && mantissa == 0) {
|
|
||||||
return negative_zero_nan ? 0 : (sign << 7);
|
|
||||||
}
|
|
||||||
mantissa &= (1 << wm) - 1;
|
|
||||||
return (sign << 7) | (f8_exponent << wm) | mantissa;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int we, int wm, typename T = float, bool negative_zero_nan = true>
|
template <int we, int wm, typename T = float, bool negative_zero_nan = true>
|
||||||
inline HIP_FP8_HOST_DEVICE T from_float8(uint8_t x)
|
inline HIP_FP8_HOST_DEVICE T from_float8(uint8_t x) {
|
||||||
{
|
|
||||||
#ifdef __HIPCC__
|
#ifdef __HIPCC__
|
||||||
constexpr bool is_half = std::is_same<T, _Float16>::value;
|
constexpr bool is_half = std::is_same<T, _Float16>::value;
|
||||||
#else
|
#else
|
||||||
constexpr bool is_half = false;
|
constexpr bool is_half = false;
|
||||||
#endif
|
#endif
|
||||||
constexpr bool is_float = std::is_same<T, float>::value;
|
constexpr bool is_float = std::is_same<T, float>::value;
|
||||||
static_assert(is_half || is_float, "only half and float are supported");
|
static_assert(is_half || is_float, "only half and float are supported");
|
||||||
|
|
||||||
constexpr int weo = is_half ? 5 : 8;
|
constexpr int weo = is_half ? 5 : 8;
|
||||||
constexpr int wmo = is_half ? 10 : (is_float ? 23 : 7);
|
constexpr int wmo = is_half ? 10 : (is_float ? 23 : 7);
|
||||||
|
|
||||||
T fInf, fNegInf, fNaN, fNeg0;
|
T fInf, fNegInf, fNaN, fNeg0;
|
||||||
|
|
||||||
#ifdef __HIPCC__
|
#ifdef __HIPCC__
|
||||||
if (is_half) {
|
if (is_half) {
|
||||||
const uint16_t ihInf = 0x7C00;
|
const uint16_t ihInf = 0x7C00;
|
||||||
const uint16_t ihNegInf = 0xFC00;
|
const uint16_t ihNegInf = 0xFC00;
|
||||||
const uint16_t ihNaN = 0x7C01;
|
const uint16_t ihNaN = 0x7C01;
|
||||||
const uint16_t ihNeg0 = 0x8000;
|
const uint16_t ihNeg0 = 0x8000;
|
||||||
fInf = reinterpret_cast<const _Float16&>(ihInf);
|
fInf = reinterpret_cast<const _Float16&>(ihInf);
|
||||||
fNegInf = reinterpret_cast<const _Float16&>(ihNegInf);
|
fNegInf = reinterpret_cast<const _Float16&>(ihNegInf);
|
||||||
fNaN = reinterpret_cast<const _Float16&>(ihNaN);
|
fNaN = reinterpret_cast<const _Float16&>(ihNaN);
|
||||||
fNeg0 = reinterpret_cast<const _Float16&>(ihNeg0);
|
fNeg0 = reinterpret_cast<const _Float16&>(ihNeg0);
|
||||||
} else
|
} else
|
||||||
#endif
|
#endif
|
||||||
if (is_float) {
|
if (is_float) {
|
||||||
const uint32_t ifInf = 0x7F800000;
|
const uint32_t ifInf = 0x7F800000;
|
||||||
const uint32_t ifNegInf = 0xFF800000;
|
const uint32_t ifNegInf = 0xFF800000;
|
||||||
const uint32_t ifNaN = 0x7F800001;
|
const uint32_t ifNaN = 0x7F800001;
|
||||||
const uint32_t ifNeg0 = 0x80000000;
|
const uint32_t ifNeg0 = 0x80000000;
|
||||||
fInf = reinterpret_cast<const float&>(ifInf);
|
fInf = reinterpret_cast<const float&>(ifInf);
|
||||||
fNegInf = reinterpret_cast<const float&>(ifNegInf);
|
fNegInf = reinterpret_cast<const float&>(ifNegInf);
|
||||||
fNaN = reinterpret_cast<const float&>(ifNaN);
|
fNaN = reinterpret_cast<const float&>(ifNaN);
|
||||||
fNeg0 = reinterpret_cast<const float&>(ifNeg0);
|
fNeg0 = reinterpret_cast<const float&>(ifNeg0);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (x == 0) {
|
if (x == 0) {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t sign = x >> 7;
|
uint32_t sign = x >> 7;
|
||||||
uint32_t mantissa = x & ((1 << wm) - 1);
|
uint32_t mantissa = x & ((1 << wm) - 1);
|
||||||
int exponent = (x & 0x7F) >> wm;
|
int exponent = (x & 0x7F) >> wm;
|
||||||
if (negative_zero_nan) {
|
if (negative_zero_nan) {
|
||||||
if (x == 0x80) {
|
if (x == 0x80) {
|
||||||
return fNaN;
|
return fNaN;
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if (x == 0x80) {
|
|
||||||
return fNeg0;
|
|
||||||
}
|
|
||||||
if (exponent == ((1 << we) - 1)) {
|
|
||||||
return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
typename std::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type retval;
|
} else {
|
||||||
if (we == 5 && is_half && !negative_zero_nan) {
|
if (x == 0x80) {
|
||||||
retval = x << 8;
|
return fNeg0;
|
||||||
return reinterpret_cast<const T&>(retval);
|
|
||||||
}
|
}
|
||||||
|
if (exponent == ((1 << we) - 1)) {
|
||||||
const int exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0);
|
return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN;
|
||||||
|
|
||||||
// subnormal input
|
|
||||||
if (exponent == 0) {
|
|
||||||
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
|
|
||||||
int sh = 1 + clz(mantissa) - (32 - wm);
|
|
||||||
mantissa <<= sh;
|
|
||||||
exponent += 1 - sh;
|
|
||||||
mantissa &= ((1 << wm) - 1);
|
|
||||||
}
|
|
||||||
exponent += exp_low_cutoff - 1;
|
|
||||||
mantissa <<= wmo - wm;
|
|
||||||
|
|
||||||
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
|
|
||||||
if (exponent <= 0) {
|
|
||||||
mantissa |= 1 << wmo;
|
|
||||||
mantissa >>= 1 - exponent;
|
|
||||||
exponent = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (sizeof(T) == 2) {
|
|
||||||
retval = (sign << 15) | (exponent << 10) | mantissa;
|
|
||||||
} else {
|
|
||||||
retval = (sign << 31) | (exponent << 23) | mantissa;
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
typename std::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type retval;
|
||||||
|
if (we == 5 && is_half && !negative_zero_nan) {
|
||||||
|
retval = x << 8;
|
||||||
return reinterpret_cast<const T&>(retval);
|
return reinterpret_cast<const T&>(retval);
|
||||||
|
}
|
||||||
|
|
||||||
|
const int exp_low_cutoff =
|
||||||
|
(1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0);
|
||||||
|
|
||||||
|
// subnormal input
|
||||||
|
if (exponent == 0) {
|
||||||
|
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
|
||||||
|
int sh = 1 + clz(mantissa) - (32 - wm);
|
||||||
|
mantissa <<= sh;
|
||||||
|
exponent += 1 - sh;
|
||||||
|
mantissa &= ((1 << wm) - 1);
|
||||||
|
}
|
||||||
|
exponent += exp_low_cutoff - 1;
|
||||||
|
mantissa <<= wmo - wm;
|
||||||
|
|
||||||
|
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
|
||||||
|
if (exponent <= 0) {
|
||||||
|
mantissa |= 1 << wmo;
|
||||||
|
mantissa >>= 1 - exponent;
|
||||||
|
exponent = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (sizeof(T) == 2) {
|
||||||
|
retval = (sign << 15) | (exponent << 10) | mantissa;
|
||||||
|
} else {
|
||||||
|
retval = (sign << 31) | (exponent << 23) | mantissa;
|
||||||
|
}
|
||||||
|
return reinterpret_cast<const T&>(retval);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace hip_fp8_impl
|
} // namespace hip_fp8_impl
|
||||||
|
@ -9,566 +9,567 @@
|
|||||||
#include "../../../attention/dtype_float32.cuh"
|
#include "../../../attention/dtype_float32.cuh"
|
||||||
#include "../../../attention/dtype_bfloat16.cuh"
|
#include "../../../attention/dtype_bfloat16.cuh"
|
||||||
|
|
||||||
namespace vllm
|
namespace vllm {
|
||||||
{
|
|
||||||
#ifdef USE_ROCM
|
#ifdef USE_ROCM
|
||||||
|
|
||||||
namespace fp8 {
|
namespace fp8 {
|
||||||
#ifdef ENABLE_FP8
|
#ifdef ENABLE_FP8
|
||||||
|
|
||||||
template <typename Tout, typename Tin>
|
template <typename Tout, typename Tin>
|
||||||
__inline__ __device__ Tout vec_conversion(const Tin& x)
|
__inline__ __device__ Tout vec_conversion(const Tin& x) {
|
||||||
{
|
return x;
|
||||||
return x;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Tout, typename Tin>
|
template <typename Tout, typename Tin>
|
||||||
__inline__ __device__ Tout scaled_vec_conversion(const Tin& x, const float scale)
|
__inline__ __device__ Tout scaled_vec_conversion(const Tin& x,
|
||||||
{
|
const float scale) {
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
// fp8 -> half
|
// fp8 -> half
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>(const uint8_t& a)
|
__inline__ __device__ uint16_t
|
||||||
{
|
vec_conversion<uint16_t, uint8_t>(const uint8_t& a) {
|
||||||
hip_fp8 f8{a, hip_fp8::from_bits()};
|
hip_fp8 f8{a, hip_fp8::from_bits()};
|
||||||
__half_raw res;
|
__half_raw res;
|
||||||
res.data = static_cast<float>(f8);
|
res.data = static_cast<float>(f8);
|
||||||
return res.x;
|
return res.x;
|
||||||
}
|
}
|
||||||
|
|
||||||
// fp8x2 -> half2
|
// fp8x2 -> half2
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ uint32_t vec_conversion<uint32_t, uint16_t>(const uint16_t& a)
|
__inline__ __device__ uint32_t
|
||||||
{
|
vec_conversion<uint32_t, uint16_t>(const uint16_t& a) {
|
||||||
#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
#if defined(__HIP__MI300__) && \
|
||||||
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
||||||
union {
|
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
||||||
__half2_raw h2r;
|
union {
|
||||||
uint32_t ui32;
|
__half2_raw h2r;
|
||||||
} tmp;
|
uint32_t ui32;
|
||||||
tmp.h2r.x.data = f2[0];
|
} tmp;
|
||||||
tmp.h2r.y.data = f2[1];
|
tmp.h2r.x.data = f2[0];
|
||||||
return tmp.ui32;
|
tmp.h2r.y.data = f2[1];
|
||||||
#else
|
return tmp.ui32;
|
||||||
union {
|
#else
|
||||||
uint16_t u16[2];
|
union {
|
||||||
uint32_t u32;
|
uint16_t u16[2];
|
||||||
} tmp;
|
uint32_t u32;
|
||||||
|
} tmp;
|
||||||
|
|
||||||
tmp.u16[0] = vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a));
|
tmp.u16[0] = vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a));
|
||||||
tmp.u16[1] = vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a >> 8U));
|
tmp.u16[1] = vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a >> 8U));
|
||||||
return tmp.u32;
|
return tmp.u32;
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
// fp8x4 -> half2x2
|
// fp8x4 -> half2x2
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a)
|
__inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a) {
|
||||||
{
|
union {
|
||||||
union {
|
uint2 u32x2;
|
||||||
uint2 u32x2;
|
uint32_t u32[2];
|
||||||
uint32_t u32[2];
|
} tmp;
|
||||||
} tmp;
|
tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a);
|
||||||
tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a);
|
tmp.u32[1] = vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U));
|
||||||
tmp.u32[1] = vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U));
|
return tmp.u32x2;
|
||||||
return tmp.u32x2;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// fp8x8 -> half2x4
|
// fp8x8 -> half2x4
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a)
|
__inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a) {
|
||||||
{
|
union {
|
||||||
union {
|
uint4 u64x2;
|
||||||
uint4 u64x2;
|
uint2 u64[2];
|
||||||
uint2 u64[2];
|
} tmp;
|
||||||
} tmp;
|
tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x);
|
||||||
tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x);
|
tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y);
|
||||||
tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y);
|
return tmp.u64x2;
|
||||||
return tmp.u64x2;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
using __nv_bfloat16 = __hip_bfloat16;
|
using __nv_bfloat16 = __hip_bfloat16;
|
||||||
|
|
||||||
// fp8 -> __nv_bfloat16
|
// fp8 -> __nv_bfloat16
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a)
|
__inline__ __device__ __nv_bfloat16
|
||||||
{
|
vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) {
|
||||||
hip_fp8 f8{a, hip_fp8::from_bits()};
|
hip_fp8 f8{a, hip_fp8::from_bits()};
|
||||||
float f{f8};
|
float f{f8};
|
||||||
return __float2bfloat16(f);
|
return __float2bfloat16(f);
|
||||||
}
|
}
|
||||||
|
|
||||||
using __nv_bfloat162 = __hip_bfloat162;
|
using __nv_bfloat162 = __hip_bfloat162;
|
||||||
|
|
||||||
// fp8x2 -> __nv_bfloat162
|
// fp8x2 -> __nv_bfloat162
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a)
|
__inline__ __device__ __nv_bfloat162
|
||||||
{
|
vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a) {
|
||||||
__nv_bfloat162 res;
|
__nv_bfloat162 res;
|
||||||
res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a);
|
res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a);
|
||||||
res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U));
|
res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U));
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
// fp8x4 -> bf16_4_t
|
// fp8x4 -> bf16_4_t
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a)
|
__inline__ __device__ bf16_4_t
|
||||||
{
|
vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a) {
|
||||||
bf16_4_t res;
|
bf16_4_t res;
|
||||||
res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a);
|
res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a);
|
||||||
res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U));
|
res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U));
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
// fp8x8 -> bf16_8_t
|
// fp8x8 -> bf16_8_t
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a)
|
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a) {
|
||||||
{
|
bf16_4_t tmp1, tmp2;
|
||||||
bf16_4_t tmp1, tmp2;
|
tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x);
|
||||||
tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x);
|
tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y);
|
||||||
tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y);
|
bf16_8_t res;
|
||||||
bf16_8_t res;
|
res.x = tmp1.x;
|
||||||
res.x = tmp1.x;
|
res.y = tmp1.y;
|
||||||
res.y = tmp1.y;
|
res.z = tmp2.x;
|
||||||
res.z = tmp2.x;
|
res.w = tmp2.y;
|
||||||
res.w = tmp2.y;
|
return res;
|
||||||
return res;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// fp8 -> float
|
// fp8 -> float
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a)
|
__inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a) {
|
||||||
{
|
hip_fp8 fp8{a, hip_fp8::from_bits()};
|
||||||
hip_fp8 fp8{a, hip_fp8::from_bits()};
|
return static_cast<float>(fp8);
|
||||||
return static_cast<float>(fp8);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// fp8x2 -> float2
|
// fp8x2 -> float2
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ float2 vec_conversion<float2, uint16_t>(const uint16_t& a)
|
__inline__ __device__ float2
|
||||||
{
|
vec_conversion<float2, uint16_t>(const uint16_t& a) {
|
||||||
#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
#if defined(__HIP__MI300__) && \
|
||||||
float2 res;
|
defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
||||||
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
float2 res;
|
||||||
res.x = f2[0];
|
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
||||||
res.y = f2[1];
|
res.x = f2[0];
|
||||||
return res;
|
res.y = f2[1];
|
||||||
#else
|
return res;
|
||||||
float2 res;
|
#else
|
||||||
res.x = vec_conversion<float, uint8_t>(static_cast<uint8_t>(a));
|
float2 res;
|
||||||
res.y = vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U));
|
res.x = vec_conversion<float, uint8_t>(static_cast<uint8_t>(a));
|
||||||
return res;
|
res.y = vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U));
|
||||||
#endif
|
return res;
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
// fp8x4 -> float4
|
// fp8x4 -> float4
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ Float4_ vec_conversion<Float4_, uint32_t>(const uint32_t& a)
|
__inline__ __device__ Float4_
|
||||||
{
|
vec_conversion<Float4_, uint32_t>(const uint32_t& a) {
|
||||||
Float4_ res;
|
Float4_ res;
|
||||||
res.x = vec_conversion<float2, uint16_t>((uint16_t)a);
|
res.x = vec_conversion<float2, uint16_t>((uint16_t)a);
|
||||||
res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U));
|
res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U));
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
// fp8x8 -> float8
|
// fp8x8 -> float8
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a)
|
__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a) {
|
||||||
{
|
Float4_ tmp1, tmp2;
|
||||||
Float4_ tmp1, tmp2;
|
tmp1 = vec_conversion<Float4_, uint32_t>(a.x);
|
||||||
tmp1 = vec_conversion<Float4_, uint32_t>(a.x);
|
tmp2 = vec_conversion<Float4_, uint32_t>(a.y);
|
||||||
tmp2 = vec_conversion<Float4_, uint32_t>(a.y);
|
Float8_ res;
|
||||||
Float8_ res;
|
res.x = tmp1.x;
|
||||||
res.x = tmp1.x;
|
res.y = tmp1.y;
|
||||||
res.y = tmp1.y;
|
res.z = tmp2.x;
|
||||||
res.z = tmp2.x;
|
res.w = tmp2.y;
|
||||||
res.w = tmp2.y;
|
return res;
|
||||||
return res;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// half -> fp8
|
// half -> fp8
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ uint8_t vec_conversion<uint8_t, uint16_t>(const uint16_t& a)
|
__inline__ __device__ uint8_t
|
||||||
{
|
vec_conversion<uint8_t, uint16_t>(const uint16_t& a) {
|
||||||
__half_raw tmp;
|
__half_raw tmp;
|
||||||
tmp.x = a;
|
tmp.x = a;
|
||||||
|
|
||||||
hip_fp8 f8{static_cast<float>(tmp.data)};
|
hip_fp8 f8{static_cast<float>(tmp.data)};
|
||||||
return f8.data;
|
return f8.data;
|
||||||
}
|
}
|
||||||
|
|
||||||
// bf16 -> fp8
|
// bf16 -> fp8
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ uint8_t vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a)
|
__inline__ __device__ uint8_t
|
||||||
{
|
vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a) {
|
||||||
hip_fp8 res{__bfloat162float(a)};
|
hip_fp8 res{__bfloat162float(a)};
|
||||||
return res.data;
|
return res.data;
|
||||||
}
|
}
|
||||||
|
|
||||||
// float -> fp8
|
// float -> fp8
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a)
|
__inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a) {
|
||||||
{
|
hip_fp8 f8(a);
|
||||||
hip_fp8 f8(a);
|
return f8.data;
|
||||||
return f8.data;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// fp8x4 -> float4
|
// fp8x4 -> float4
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ float4 vec_conversion<float4, uint32_t>(const uint32_t& a)
|
__inline__ __device__ float4
|
||||||
{
|
vec_conversion<float4, uint32_t>(const uint32_t& a) {
|
||||||
Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
|
Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
|
||||||
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
// float2 -> half2
|
// float2 -> half2
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ uint32_t vec_conversion<uint32_t, float2>(const float2& a)
|
__inline__ __device__ uint32_t
|
||||||
{
|
vec_conversion<uint32_t, float2>(const float2& a) {
|
||||||
union {
|
union {
|
||||||
half2 float16;
|
half2 float16;
|
||||||
uint32_t uint32;
|
uint32_t uint32;
|
||||||
};
|
};
|
||||||
|
|
||||||
float16 = __float22half2_rn(a);
|
float16 = __float22half2_rn(a);
|
||||||
return uint32;
|
return uint32;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Float4 -> half2x2
|
// Float4 -> half2x2
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a)
|
__inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a) {
|
||||||
{
|
uint2 b;
|
||||||
uint2 b;
|
float2 val;
|
||||||
float2 val;
|
val.x = a.x.x;
|
||||||
val.x = a.x.x;
|
val.y = a.x.y;
|
||||||
val.y = a.x.y;
|
b.x = vec_conversion<uint32_t, float2>(val);
|
||||||
b.x = vec_conversion<uint32_t, float2>(val);
|
|
||||||
|
|
||||||
val.x = a.y.x;
|
val.x = a.y.x;
|
||||||
val.y = a.y.y;
|
val.y = a.y.y;
|
||||||
b.y = vec_conversion<uint32_t, float2>(val);
|
b.y = vec_conversion<uint32_t, float2>(val);
|
||||||
return b;
|
return b;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Float4 -> float4
|
// Float4 -> float4
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a)
|
__inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a) {
|
||||||
{
|
float4 b;
|
||||||
float4 b;
|
b.x = a.x.x;
|
||||||
b.x = a.x.x;
|
b.y = a.x.y;
|
||||||
b.y = a.x.y;
|
b.z = a.y.x;
|
||||||
b.z = a.y.x;
|
b.w = a.y.y;
|
||||||
b.w = a.y.y;
|
return b;
|
||||||
return b;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Float8 -> half2x4
|
// Float8 -> half2x4
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a)
|
__inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a) {
|
||||||
{
|
uint4 b;
|
||||||
uint4 b;
|
b.x = vec_conversion<uint32_t, float2>(a.x);
|
||||||
b.x = vec_conversion<uint32_t, float2>(a.x);
|
b.y = vec_conversion<uint32_t, float2>(a.y);
|
||||||
b.y = vec_conversion<uint32_t, float2>(a.y);
|
b.z = vec_conversion<uint32_t, float2>(a.z);
|
||||||
b.z = vec_conversion<uint32_t, float2>(a.z);
|
b.w = vec_conversion<uint32_t, float2>(a.w);
|
||||||
b.w = vec_conversion<uint32_t, float2>(a.w);
|
return b;
|
||||||
return b;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// float2 -> bfloat162
|
// float2 -> bfloat162
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(const float2& a)
|
__inline__ __device__ __nv_bfloat162
|
||||||
{
|
vec_conversion<__nv_bfloat162, float2>(const float2& a) {
|
||||||
__nv_bfloat162 b = __float22bfloat162_rn(a);
|
__nv_bfloat162 b = __float22bfloat162_rn(a);
|
||||||
return b;
|
return b;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Float4 -> bfloat162x2
|
// Float4 -> bfloat162x2
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, Float4_>(const Float4_& a)
|
__inline__ __device__ bf16_4_t
|
||||||
{
|
vec_conversion<bf16_4_t, Float4_>(const Float4_& a) {
|
||||||
bf16_4_t b;
|
bf16_4_t b;
|
||||||
b.x = __float22bfloat162_rn(a.x);
|
b.x = __float22bfloat162_rn(a.x);
|
||||||
b.y = __float22bfloat162_rn(a.y);
|
b.y = __float22bfloat162_rn(a.y);
|
||||||
return b;
|
return b;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Float8 -> bfloat162x4
|
// Float8 -> bfloat162x4
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(const Float8_& a)
|
__inline__ __device__ bf16_8_t
|
||||||
{
|
vec_conversion<bf16_8_t, Float8_>(const Float8_& a) {
|
||||||
bf16_8_t b;
|
bf16_8_t b;
|
||||||
b.x = __float22bfloat162_rn(a.x);
|
b.x = __float22bfloat162_rn(a.x);
|
||||||
b.y = __float22bfloat162_rn(a.y);
|
b.y = __float22bfloat162_rn(a.y);
|
||||||
b.z = __float22bfloat162_rn(a.z);
|
b.z = __float22bfloat162_rn(a.z);
|
||||||
b.w = __float22bfloat162_rn(a.w);
|
b.w = __float22bfloat162_rn(a.w);
|
||||||
return b;
|
return b;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* Scaled and vectorized conversions, for data exchange between high and low
|
||||||
|
precision domains
|
||||||
|
|
||||||
/* Scaled and vectorized conversions, for data exchange between high and low precision domains
|
Convention of the scale in API, e.g: FP8_data = Quantization(
|
||||||
|
High_Precision_data / scale ) s.t. Quantize(HP / scale) => FP8 Dequant(FP8) *
|
||||||
Convention of the scale in API, e.g: FP8_data = Quantization( High_Precision_data / scale )
|
scale => HP
|
||||||
s.t.
|
|
||||||
Quantize(HP / scale) => FP8
|
|
||||||
Dequant(FP8) * scale => HP
|
|
||||||
|
|
||||||
*/
|
*/
|
||||||
|
|
||||||
// fp8 -> half
|
// fp8 -> half
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, const float scale)
|
__inline__ __device__ uint16_t
|
||||||
{
|
scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, const float scale) {
|
||||||
hip_fp8 f8{a, hip_fp8::from_bits()};
|
hip_fp8 f8{a, hip_fp8::from_bits()};
|
||||||
__half_raw res;
|
__half_raw res;
|
||||||
res.data = static_cast<float>(f8) * scale;
|
res.data = static_cast<float>(f8) * scale;
|
||||||
return res.x;
|
return res.x;
|
||||||
}
|
}
|
||||||
|
|
||||||
// fp8x2 -> half2
|
// fp8x2 -> half2
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>(const uint16_t& a, const float scale)
|
__inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>(
|
||||||
{
|
const uint16_t& a, const float scale) {
|
||||||
#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
#if defined(__HIP__MI300__) && \
|
||||||
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
||||||
union {
|
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
||||||
__half2_raw h2r;
|
union {
|
||||||
uint32_t ui32;
|
__half2_raw h2r;
|
||||||
} tmp;
|
uint32_t ui32;
|
||||||
tmp.h2r.x.data = f2[0] * scale;
|
} tmp;
|
||||||
tmp.h2r.y.data = f2[1] * scale;
|
tmp.h2r.x.data = f2[0] * scale;
|
||||||
return tmp.ui32;
|
tmp.h2r.y.data = f2[1] * scale;
|
||||||
#else
|
return tmp.ui32;
|
||||||
union {
|
#else
|
||||||
uint16_t u16[2];
|
union {
|
||||||
uint32_t u32;
|
uint16_t u16[2];
|
||||||
} tmp;
|
uint32_t u32;
|
||||||
|
} tmp;
|
||||||
|
|
||||||
tmp.u16[0] = scaled_vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a), scale);
|
tmp.u16[0] =
|
||||||
tmp.u16[1] = scaled_vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a >> 8U), scale);
|
scaled_vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a), scale);
|
||||||
return tmp.u32;
|
tmp.u16[1] = scaled_vec_conversion<uint16_t, uint8_t>(
|
||||||
#endif
|
static_cast<uint8_t>(a >> 8U), scale);
|
||||||
|
return tmp.u32;
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
// fp8x4 -> half2x2
|
// fp8x4 -> half2x2
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ uint2 scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, const float scale)
|
__inline__ __device__ uint2
|
||||||
{
|
scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, const float scale) {
|
||||||
union {
|
union {
|
||||||
uint2 u32x2;
|
uint2 u32x2;
|
||||||
uint32_t u32[2];
|
uint32_t u32[2];
|
||||||
} tmp;
|
} tmp;
|
||||||
tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale);
|
tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale);
|
||||||
tmp.u32[1] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale);
|
tmp.u32[1] =
|
||||||
return tmp.u32x2;
|
scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale);
|
||||||
|
return tmp.u32x2;
|
||||||
}
|
}
|
||||||
|
|
||||||
// fp8x8 -> half2x4
|
// fp8x8 -> half2x4
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ uint4 scaled_vec_conversion<uint4, uint2>(const uint2& a, const float scale)
|
__inline__ __device__ uint4
|
||||||
{
|
scaled_vec_conversion<uint4, uint2>(const uint2& a, const float scale) {
|
||||||
union {
|
union {
|
||||||
uint4 u64x2;
|
uint4 u64x2;
|
||||||
uint2 u64[2];
|
uint2 u64[2];
|
||||||
} tmp;
|
} tmp;
|
||||||
tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale);
|
tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale);
|
||||||
tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale);
|
tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale);
|
||||||
return tmp.u64x2;
|
return tmp.u64x2;
|
||||||
}
|
}
|
||||||
|
|
||||||
using __nv_bfloat16 = __hip_bfloat16;
|
using __nv_bfloat16 = __hip_bfloat16;
|
||||||
|
|
||||||
// fp8 -> __nv_bfloat16
|
// fp8 -> __nv_bfloat16
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ __nv_bfloat16 scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, const float scale)
|
__inline__ __device__ __nv_bfloat16
|
||||||
{
|
scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a,
|
||||||
hip_fp8 f8{a, hip_fp8::from_bits()};
|
const float scale) {
|
||||||
float f{f8};
|
hip_fp8 f8{a, hip_fp8::from_bits()};
|
||||||
return __float2bfloat16(f * scale);
|
float f{f8};
|
||||||
|
return __float2bfloat16(f * scale);
|
||||||
}
|
}
|
||||||
|
|
||||||
using __nv_bfloat162 = __hip_bfloat162;
|
using __nv_bfloat162 = __hip_bfloat162;
|
||||||
|
|
||||||
// fp8x2 -> __nv_bfloat162
|
// fp8x2 -> __nv_bfloat162
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ __nv_bfloat162 scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a, const float scale)
|
__inline__ __device__ __nv_bfloat162
|
||||||
{
|
scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a,
|
||||||
__nv_bfloat162 res;
|
const float scale) {
|
||||||
res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale);
|
__nv_bfloat162 res;
|
||||||
res.y = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale);
|
res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale);
|
||||||
return res;
|
res.y =
|
||||||
|
scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale);
|
||||||
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
// fp8x4 -> bf16_4_t
|
// fp8x4 -> bf16_4_t
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a, const float scale)
|
__inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(
|
||||||
{
|
const uint32_t& a, const float scale) {
|
||||||
bf16_4_t res;
|
bf16_4_t res;
|
||||||
res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale);
|
res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale);
|
||||||
res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), scale);
|
res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U),
|
||||||
return res;
|
scale);
|
||||||
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
// fp8x8 -> bf16_8_t
|
// fp8x8 -> bf16_8_t
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ bf16_8_t scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, const float scale)
|
__inline__ __device__ bf16_8_t
|
||||||
{
|
scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, const float scale) {
|
||||||
bf16_4_t tmp1, tmp2;
|
bf16_4_t tmp1, tmp2;
|
||||||
tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale);
|
tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale);
|
||||||
tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale);
|
tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale);
|
||||||
bf16_8_t res;
|
bf16_8_t res;
|
||||||
res.x = tmp1.x;
|
res.x = tmp1.x;
|
||||||
res.y = tmp1.y;
|
res.y = tmp1.y;
|
||||||
res.z = tmp2.x;
|
res.z = tmp2.x;
|
||||||
res.w = tmp2.y;
|
res.w = tmp2.y;
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
// fp8 -> float
|
// fp8 -> float
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>(const uint8_t& a, const float scale)
|
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
|
||||||
{
|
const uint8_t& a, const float scale) {
|
||||||
hip_fp8 fp8{a, hip_fp8::from_bits()};
|
hip_fp8 fp8{a, hip_fp8::from_bits()};
|
||||||
return static_cast<float>(fp8) * scale;
|
return static_cast<float>(fp8) * scale;
|
||||||
}
|
}
|
||||||
|
|
||||||
// fp8x2 -> float2
|
// fp8x2 -> float2
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ float2 scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, const float scale)
|
__inline__ __device__ float2
|
||||||
{
|
scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, const float scale) {
|
||||||
#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
#if defined(__HIP__MI300__) && \
|
||||||
float2 res;
|
defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
||||||
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
float2 res;
|
||||||
res.x = f2[0] * scale;
|
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
||||||
res.y = f2[1] * scale;
|
res.x = f2[0] * scale;
|
||||||
return res;
|
res.y = f2[1] * scale;
|
||||||
#else
|
return res;
|
||||||
float2 res;
|
#else
|
||||||
res.x = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a), scale);
|
float2 res;
|
||||||
res.y = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U), scale);
|
res.x = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a), scale);
|
||||||
return res;
|
res.y = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U),
|
||||||
#endif
|
scale);
|
||||||
|
return res;
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
// fp8x4 -> float4
|
// fp8x4 -> float4
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ Float4_ scaled_vec_conversion<Float4_, uint32_t>(const uint32_t& a, const float scale)
|
__inline__ __device__ Float4_
|
||||||
{
|
scaled_vec_conversion<Float4_, uint32_t>(const uint32_t& a, const float scale) {
|
||||||
Float4_ res;
|
Float4_ res;
|
||||||
res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale);
|
res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale);
|
||||||
res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale);
|
res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale);
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
// fp8x8 -> float8
|
// fp8x8 -> float8
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ Float8_ scaled_vec_conversion<Float8_, uint2>(const uint2& a, const float scale)
|
__inline__ __device__ Float8_
|
||||||
{
|
scaled_vec_conversion<Float8_, uint2>(const uint2& a, const float scale) {
|
||||||
Float4_ tmp1, tmp2;
|
Float4_ tmp1, tmp2;
|
||||||
tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale);
|
tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale);
|
||||||
tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale);
|
tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale);
|
||||||
Float8_ res;
|
Float8_ res;
|
||||||
res.x = tmp1.x;
|
res.x = tmp1.x;
|
||||||
res.y = tmp1.y;
|
res.y = tmp1.y;
|
||||||
res.z = tmp2.x;
|
res.z = tmp2.x;
|
||||||
res.w = tmp2.y;
|
res.w = tmp2.y;
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/* Quantize(HP / scale) => FP8 */
|
/* Quantize(HP / scale) => FP8 */
|
||||||
|
|
||||||
// TODO(Hai): vectorized to add
|
// TODO(Hai): vectorized to add
|
||||||
|
|
||||||
// half -> fp8
|
// half -> fp8
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, const float scale)
|
__inline__ __device__ uint8_t
|
||||||
{
|
scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, const float scale) {
|
||||||
__half_raw tmp;
|
__half_raw tmp;
|
||||||
tmp.x = a;
|
tmp.x = a;
|
||||||
|
|
||||||
hip_fp8 f8{static_cast<float>(tmp.data)/scale};
|
hip_fp8 f8{static_cast<float>(tmp.data) / scale};
|
||||||
return f8.data;
|
return f8.data;
|
||||||
}
|
}
|
||||||
|
|
||||||
// bf16 -> fp8
|
// bf16 -> fp8
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a, const float scale)
|
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
|
||||||
{
|
const __nv_bfloat16& a, const float scale) {
|
||||||
hip_fp8 res{__bfloat162float(a)/scale};
|
hip_fp8 res{__bfloat162float(a) / scale};
|
||||||
return res.data;
|
return res.data;
|
||||||
}
|
}
|
||||||
|
|
||||||
// float -> fp8
|
// float -> fp8
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, float>(const float& a, const float scale)
|
__inline__ __device__ uint8_t
|
||||||
{
|
scaled_vec_conversion<uint8_t, float>(const float& a, const float scale) {
|
||||||
hip_fp8 f8(a/scale);
|
hip_fp8 f8(a / scale);
|
||||||
return f8.data;
|
return f8.data;
|
||||||
}
|
}
|
||||||
|
|
||||||
// fp8x4 -> float4
|
// fp8x4 -> float4
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ float4 scaled_vec_conversion<float4, uint32_t>(const uint32_t& a, const float scale)
|
__inline__ __device__ float4
|
||||||
{
|
scaled_vec_conversion<float4, uint32_t>(const uint32_t& a, const float scale) {
|
||||||
Float4_ tmp = scaled_vec_conversion<Float4_, uint32_t>(a, scale);
|
Float4_ tmp = scaled_vec_conversion<Float4_, uint32_t>(a, scale);
|
||||||
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
#endif // ENABLE_FP8
|
#endif // ENABLE_FP8
|
||||||
|
|
||||||
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
||||||
__inline__ __device__ Tout convert(const Tin &x) {
|
__inline__ __device__ Tout convert(const Tin& x) {
|
||||||
#ifdef ENABLE_FP8
|
#ifdef ENABLE_FP8
|
||||||
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
|
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
|
||||||
return vec_conversion<Tout, Tin>(x);
|
return vec_conversion<Tout, Tin>(x);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
assert(false);
|
assert(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
||||||
__inline__ __device__ Tout scaled_convert(const Tin &x, const float scale) {
|
__inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
|
||||||
#ifdef ENABLE_FP8
|
#ifdef ENABLE_FP8
|
||||||
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
|
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
|
||||||
return scaled_vec_conversion<Tout, Tin>(x, scale);
|
return scaled_vec_conversion<Tout, Tin>(x, scale);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
assert(false);
|
assert(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
// The following macro is used to dispatch the conversion function based on the
|
// The following macro is used to dispatch the conversion function based on
|
||||||
// data type of the key and value cache. The FN is a macro that calls a function
|
// the data type of the key and value cache. The FN is a macro that calls a
|
||||||
// with template<typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>.
|
// function with template<typename scalar_t, typename cache_t,
|
||||||
#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
|
// Fp8KVCacheDataType kv_dt>.
|
||||||
if (KV_DTYPE == "auto") { \
|
#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
|
||||||
if (SRC_DTYPE == at::ScalarType::Float) { \
|
if (KV_DTYPE == "auto") { \
|
||||||
FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \
|
|
||||||
} else if (SRC_DTYPE == at::ScalarType::Half) { \
|
|
||||||
FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \
|
|
||||||
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
|
||||||
FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \
|
|
||||||
} else { \
|
|
||||||
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
|
|
||||||
} \
|
|
||||||
} else { \
|
|
||||||
if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \
|
|
||||||
if (SRC_DTYPE == at::ScalarType::Float) { \
|
if (SRC_DTYPE == at::ScalarType::Float) { \
|
||||||
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \
|
||||||
} else if (SRC_DTYPE == at::ScalarType::Half) { \
|
} else if (SRC_DTYPE == at::ScalarType::Half) { \
|
||||||
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \
|
||||||
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
||||||
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \
|
||||||
} else { \
|
} else { \
|
||||||
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
|
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
|
||||||
} \
|
} \
|
||||||
} else { \
|
} else { \
|
||||||
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
|
if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \
|
||||||
} \
|
if (SRC_DTYPE == at::ScalarType::Float) { \
|
||||||
}
|
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
||||||
|
} else if (SRC_DTYPE == at::ScalarType::Half) { \
|
||||||
|
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
||||||
|
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
||||||
|
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
||||||
|
} else { \
|
||||||
|
TORCH_CHECK(false, \
|
||||||
|
"Unsupported input type of kv cache: ", SRC_DTYPE); \
|
||||||
|
} \
|
||||||
|
} else { \
|
||||||
|
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
|
||||||
|
} \
|
||||||
|
}
|
||||||
|
|
||||||
} // fp8
|
} // namespace fp8
|
||||||
#endif // USE_ROCM
|
#endif // USE_ROCM
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
@ -10,17 +10,20 @@
|
|||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
|
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
|
||||||
float old;
|
float old;
|
||||||
old = (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) :
|
old = (value >= 0)
|
||||||
__uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value)));
|
? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
|
||||||
|
: __uint_as_float(
|
||||||
|
atomicMin((unsigned int*)addr, __float_as_uint(value)));
|
||||||
|
|
||||||
return old;
|
return old;
|
||||||
}
|
}
|
||||||
|
|
||||||
#define FP8_E4M3_MAX std::numeric_limits<c10::Float8_e4m3fn>::max()
|
#define FP8_E4M3_MAX std::numeric_limits<c10::Float8_e4m3fn>::max()
|
||||||
|
|
||||||
template<typename scalar_t>
|
template <typename scalar_t>
|
||||||
__device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(const scalar_t val, const float scale) {
|
__device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(
|
||||||
|
const scalar_t val, const float scale) {
|
||||||
float x = static_cast<float>(val) / scale;
|
float x = static_cast<float>(val) / scale;
|
||||||
float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
|
float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
|
||||||
return static_cast<c10::Float8_e4m3fn>(r);
|
return static_cast<c10::Float8_e4m3fn>(r);
|
||||||
@ -32,11 +35,10 @@ __device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(const scalar
|
|||||||
// So to get the right answer, *scale needs to be initialized to
|
// So to get the right answer, *scale needs to be initialized to
|
||||||
// a value <= 0.0 and we need to wait for all thread blocks to
|
// a value <= 0.0 and we need to wait for all thread blocks to
|
||||||
// finish before consuming *scale.
|
// finish before consuming *scale.
|
||||||
template<typename scalar_t>
|
template <typename scalar_t>
|
||||||
__global__ void segmented_max_reduction(
|
__global__ void segmented_max_reduction(float* __restrict__ scale,
|
||||||
float* __restrict__ scale,
|
const scalar_t* __restrict__ input,
|
||||||
const scalar_t* __restrict__ input,
|
int64_t num_elems) {
|
||||||
int64_t num_elems) {
|
|
||||||
__shared__ float cache[1024];
|
__shared__ float cache[1024];
|
||||||
int i = blockDim.x * blockIdx.x + threadIdx.x;
|
int i = blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
|
|
||||||
@ -56,7 +58,7 @@ __global__ void segmented_max_reduction(
|
|||||||
int ib = blockDim.x / 2;
|
int ib = blockDim.x / 2;
|
||||||
while (ib != 0) {
|
while (ib != 0) {
|
||||||
if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) {
|
if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) {
|
||||||
cache[threadIdx.x] = cache[threadIdx.x + ib];
|
cache[threadIdx.x] = cache[threadIdx.x + ib];
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
ib /= 2;
|
ib /= 2;
|
||||||
@ -64,16 +66,16 @@ __global__ void segmented_max_reduction(
|
|||||||
// Finally, since cache[0] contains the maximum for this thread block,
|
// Finally, since cache[0] contains the maximum for this thread block,
|
||||||
// atomically write the max to the target location
|
// atomically write the max to the target location
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
atomicMaxFloat(scale, cache[0] / std::numeric_limits<c10::Float8_e4m3fn>::max());
|
atomicMaxFloat(scale,
|
||||||
|
cache[0] / std::numeric_limits<c10::Float8_e4m3fn>::max());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename scalar_t>
|
template <typename scalar_t>
|
||||||
__global__ void scaled_fp8_quant_kernel(
|
__global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out,
|
||||||
c10::Float8_e4m3fn* __restrict__ out,
|
const scalar_t* __restrict__ input,
|
||||||
const scalar_t* __restrict__ input,
|
const float* __restrict__ scale,
|
||||||
const float* __restrict__ scale,
|
int64_t num_elems) {
|
||||||
int64_t num_elems) {
|
|
||||||
int i = blockDim.x * blockIdx.x + threadIdx.x;
|
int i = blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
while (i < num_elems) {
|
while (i < num_elems) {
|
||||||
out[i] = scaled_fp8_conversion(input[i], *scale);
|
out[i] = scaled_fp8_conversion(input[i], *scale);
|
||||||
@ -81,12 +83,11 @@ __global__ void scaled_fp8_quant_kernel(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
void static_scaled_fp8_quant(
|
void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
||||||
torch::Tensor& out, // [..., d]
|
torch::Tensor& input, // [..., d]
|
||||||
torch::Tensor& input, // [..., d]
|
torch::Tensor& scale) // [1]
|
||||||
torch::Tensor& scale) // [1]
|
|
||||||
{
|
{
|
||||||
int64_t num_tokens = input.numel() / input.size(-1);
|
int64_t num_tokens = input.numel() / input.size(-1);
|
||||||
int64_t num_elems = input.numel();
|
int64_t num_elems = input.numel();
|
||||||
@ -95,21 +96,16 @@ void static_scaled_fp8_quant(
|
|||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(
|
VLLM_DISPATCH_FLOATING_TYPES(
|
||||||
input.scalar_type(),
|
input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
|
||||||
"scaled_fp8_quant_kernel",
|
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||||
[&] {
|
out.data_ptr<c10::Float8_e4m3fn>(), input.data_ptr<scalar_t>(),
|
||||||
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
scale.data_ptr<float>(), num_elems);
|
||||||
out.data_ptr<c10::Float8_e4m3fn>(),
|
|
||||||
input.data_ptr<scalar_t>(),
|
|
||||||
scale.data_ptr<float>(),
|
|
||||||
num_elems);
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void dynamic_scaled_fp8_quant(
|
void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
||||||
torch::Tensor& out, // [..., d]
|
torch::Tensor& input, // [..., d]
|
||||||
torch::Tensor& input, // [..., d]
|
torch::Tensor& scale) // [1]
|
||||||
torch::Tensor& scale) // [1]
|
|
||||||
{
|
{
|
||||||
int64_t num_tokens = input.numel() / input.size(-1);
|
int64_t num_tokens = input.numel() / input.size(-1);
|
||||||
int64_t num_elems = input.numel();
|
int64_t num_elems = input.numel();
|
||||||
@ -118,18 +114,11 @@ void dynamic_scaled_fp8_quant(
|
|||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(
|
VLLM_DISPATCH_FLOATING_TYPES(
|
||||||
input.scalar_type(),
|
input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
|
||||||
"scaled_fp8_quant_kernel",
|
vllm::segmented_max_reduction<scalar_t><<<grid, block, 0, stream>>>(
|
||||||
[&] {
|
scale.data_ptr<float>(), input.data_ptr<scalar_t>(), num_elems);
|
||||||
vllm::segmented_max_reduction<scalar_t><<<grid, block, 0, stream>>>(
|
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||||
scale.data_ptr<float>(),
|
out.data_ptr<c10::Float8_e4m3fn>(), input.data_ptr<scalar_t>(),
|
||||||
input.data_ptr<scalar_t>(),
|
scale.data_ptr<float>(), num_elems);
|
||||||
num_elems);
|
|
||||||
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
|
||||||
out.data_ptr<c10::Float8_e4m3fn>(),
|
|
||||||
input.data_ptr<scalar_t>(),
|
|
||||||
scale.data_ptr<float>(),
|
|
||||||
num_elems);
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -10,9 +10,9 @@ namespace vllm {
|
|||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
|
|
||||||
namespace fp8 {
|
namespace fp8 {
|
||||||
#ifdef ENABLE_FP8
|
#ifdef ENABLE_FP8
|
||||||
|
|
||||||
#if 0 // Disable the following code to reduce the binary size.
|
#if 0 // Disable the following code to reduce the binary size.
|
||||||
template <typename Tout, typename Tin>
|
template <typename Tout, typename Tin>
|
||||||
__inline__ __device__ Tout
|
__inline__ __device__ Tout
|
||||||
vec_conversion(const Tin &x, const __nv_fp8_interpretation_t fp8_type) {
|
vec_conversion(const Tin &x, const __nv_fp8_interpretation_t fp8_type) {
|
||||||
@ -177,13 +177,13 @@ __inline__ __device__ uint8_t vec_conversion<uint8_t, uint16_t>(
|
|||||||
template <>
|
template <>
|
||||||
__inline__ __device__ uint8_t vec_conversion<uint8_t, __nv_bfloat16>(
|
__inline__ __device__ uint8_t vec_conversion<uint8_t, __nv_bfloat16>(
|
||||||
const __nv_bfloat16 &a, const __nv_fp8_interpretation_t fp8_type) {
|
const __nv_bfloat16 &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||||
assert(false);
|
assert(false);
|
||||||
#else
|
#else
|
||||||
__nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8(
|
__nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8(
|
||||||
__nv_bfloat16_raw(a), __NV_SATFINITE, fp8_type);
|
__nv_bfloat16_raw(a), __NV_SATFINITE, fp8_type);
|
||||||
return (uint8_t)res;
|
return (uint8_t)res;
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
// float -> fp8
|
// float -> fp8
|
||||||
@ -276,7 +276,7 @@ __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(
|
|||||||
from_float(b, a);
|
from_float(b, a);
|
||||||
return b;
|
return b;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/* Scaled and vectorized conversions, for data exchange between high and low
|
/* Scaled and vectorized conversions, for data exchange between high and low
|
||||||
precision domains Convention of the scale in API, e.g: FP8_data =
|
precision domains Convention of the scale in API, e.g: FP8_data =
|
||||||
@ -286,14 +286,14 @@ __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(
|
|||||||
|
|
||||||
template <typename Tout, typename Tin>
|
template <typename Tout, typename Tin>
|
||||||
__inline__ __device__ Tout scaled_vec_conversion(
|
__inline__ __device__ Tout scaled_vec_conversion(
|
||||||
const Tin &x, const float scale, const __nv_fp8_interpretation_t fp8_type) {
|
const Tin& x, const float scale, const __nv_fp8_interpretation_t fp8_type) {
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
// fp8 -> half
|
// fp8 -> half
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, uint8_t>(
|
__inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, uint8_t>(
|
||||||
const uint8_t &a, const float scale,
|
const uint8_t& a, const float scale,
|
||||||
const __nv_fp8_interpretation_t fp8_type) {
|
const __nv_fp8_interpretation_t fp8_type) {
|
||||||
__half_raw tmp = __nv_cvt_fp8_to_halfraw(a, fp8_type);
|
__half_raw tmp = __nv_cvt_fp8_to_halfraw(a, fp8_type);
|
||||||
return float_to_half(half_to_float(tmp.x) * scale);
|
return float_to_half(half_to_float(tmp.x) * scale);
|
||||||
@ -302,7 +302,7 @@ __inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, uint8_t>(
|
|||||||
// fp8x2 -> half2
|
// fp8x2 -> half2
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>(
|
__inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>(
|
||||||
const uint16_t &a, const float scale,
|
const uint16_t& a, const float scale,
|
||||||
const __nv_fp8_interpretation_t fp8_type) {
|
const __nv_fp8_interpretation_t fp8_type) {
|
||||||
union {
|
union {
|
||||||
uint16_t u16[2];
|
uint16_t u16[2];
|
||||||
@ -317,7 +317,7 @@ __inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>(
|
|||||||
// fp8x4 -> half2x2
|
// fp8x4 -> half2x2
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ uint2 scaled_vec_conversion<uint2, uint32_t>(
|
__inline__ __device__ uint2 scaled_vec_conversion<uint2, uint32_t>(
|
||||||
const uint32_t &a, const float scale,
|
const uint32_t& a, const float scale,
|
||||||
const __nv_fp8_interpretation_t fp8_type) {
|
const __nv_fp8_interpretation_t fp8_type) {
|
||||||
union {
|
union {
|
||||||
uint2 u32x2;
|
uint2 u32x2;
|
||||||
@ -333,7 +333,7 @@ __inline__ __device__ uint2 scaled_vec_conversion<uint2, uint32_t>(
|
|||||||
// fp8x8 -> half2x4
|
// fp8x8 -> half2x4
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ uint4
|
__inline__ __device__ uint4
|
||||||
scaled_vec_conversion<uint4, uint2>(const uint2 &a, const float scale,
|
scaled_vec_conversion<uint4, uint2>(const uint2& a, const float scale,
|
||||||
const __nv_fp8_interpretation_t fp8_type) {
|
const __nv_fp8_interpretation_t fp8_type) {
|
||||||
union {
|
union {
|
||||||
uint4 u64x2;
|
uint4 u64x2;
|
||||||
@ -348,7 +348,7 @@ scaled_vec_conversion<uint4, uint2>(const uint2 &a, const float scale,
|
|||||||
template <>
|
template <>
|
||||||
__inline__ __device__ __nv_bfloat16
|
__inline__ __device__ __nv_bfloat16
|
||||||
scaled_vec_conversion<__nv_bfloat16, uint8_t>(
|
scaled_vec_conversion<__nv_bfloat16, uint8_t>(
|
||||||
const uint8_t &a, const float scale,
|
const uint8_t& a, const float scale,
|
||||||
const __nv_fp8_interpretation_t fp8_type) {
|
const __nv_fp8_interpretation_t fp8_type) {
|
||||||
// Note there is no direct convert function from fp8 to bf16.
|
// Note there is no direct convert function from fp8 to bf16.
|
||||||
// fp8 -> half
|
// fp8 -> half
|
||||||
@ -362,7 +362,7 @@ scaled_vec_conversion<__nv_bfloat16, uint8_t>(
|
|||||||
template <>
|
template <>
|
||||||
__inline__ __device__ __nv_bfloat162
|
__inline__ __device__ __nv_bfloat162
|
||||||
scaled_vec_conversion<__nv_bfloat162, uint16_t>(
|
scaled_vec_conversion<__nv_bfloat162, uint16_t>(
|
||||||
const uint16_t &a, const float scale,
|
const uint16_t& a, const float scale,
|
||||||
const __nv_fp8_interpretation_t fp8_type) {
|
const __nv_fp8_interpretation_t fp8_type) {
|
||||||
__nv_bfloat162 res;
|
__nv_bfloat162 res;
|
||||||
res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale,
|
res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale,
|
||||||
@ -375,7 +375,7 @@ scaled_vec_conversion<__nv_bfloat162, uint16_t>(
|
|||||||
// fp8x4 -> bf16_4_t
|
// fp8x4 -> bf16_4_t
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(
|
__inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(
|
||||||
const uint32_t &a, const float scale,
|
const uint32_t& a, const float scale,
|
||||||
const __nv_fp8_interpretation_t fp8_type) {
|
const __nv_fp8_interpretation_t fp8_type) {
|
||||||
bf16_4_t res;
|
bf16_4_t res;
|
||||||
res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale,
|
res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale,
|
||||||
@ -388,7 +388,7 @@ __inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(
|
|||||||
// fp8x8 -> bf16_8_t
|
// fp8x8 -> bf16_8_t
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ bf16_8_t scaled_vec_conversion<bf16_8_t, uint2>(
|
__inline__ __device__ bf16_8_t scaled_vec_conversion<bf16_8_t, uint2>(
|
||||||
const uint2 &a, const float scale,
|
const uint2& a, const float scale,
|
||||||
const __nv_fp8_interpretation_t fp8_type) {
|
const __nv_fp8_interpretation_t fp8_type) {
|
||||||
bf16_4_t tmp1, tmp2;
|
bf16_4_t tmp1, tmp2;
|
||||||
tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale, fp8_type);
|
tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale, fp8_type);
|
||||||
@ -404,9 +404,8 @@ __inline__ __device__ bf16_8_t scaled_vec_conversion<bf16_8_t, uint2>(
|
|||||||
// fp8 -> float
|
// fp8 -> float
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
|
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
|
||||||
const uint8_t &a, const float scale,
|
const uint8_t& a, const float scale,
|
||||||
const __nv_fp8_interpretation_t fp8_type) {
|
const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
|
||||||
// fp8 -> half
|
// fp8 -> half
|
||||||
__half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
|
__half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
|
||||||
uint16_t tmp = res.x;
|
uint16_t tmp = res.x;
|
||||||
@ -418,7 +417,7 @@ __inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
|
|||||||
// fp8x2 -> float2
|
// fp8x2 -> float2
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ float2 scaled_vec_conversion<float2, uint16_t>(
|
__inline__ __device__ float2 scaled_vec_conversion<float2, uint16_t>(
|
||||||
const uint16_t &a, const float scale,
|
const uint16_t& a, const float scale,
|
||||||
const __nv_fp8_interpretation_t fp8_type) {
|
const __nv_fp8_interpretation_t fp8_type) {
|
||||||
// fp8x2 -> half2
|
// fp8x2 -> half2
|
||||||
uint32_t tmp = scaled_vec_conversion<uint32_t, uint16_t>(a, scale, fp8_type);
|
uint32_t tmp = scaled_vec_conversion<uint32_t, uint16_t>(a, scale, fp8_type);
|
||||||
@ -429,7 +428,7 @@ __inline__ __device__ float2 scaled_vec_conversion<float2, uint16_t>(
|
|||||||
// fp8x4 -> float4
|
// fp8x4 -> float4
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ Float4_ scaled_vec_conversion<Float4_, uint32_t>(
|
__inline__ __device__ Float4_ scaled_vec_conversion<Float4_, uint32_t>(
|
||||||
const uint32_t &a, const float scale,
|
const uint32_t& a, const float scale,
|
||||||
const __nv_fp8_interpretation_t fp8_type) {
|
const __nv_fp8_interpretation_t fp8_type) {
|
||||||
Float4_ res;
|
Float4_ res;
|
||||||
res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale, fp8_type);
|
res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale, fp8_type);
|
||||||
@ -441,7 +440,7 @@ __inline__ __device__ Float4_ scaled_vec_conversion<Float4_, uint32_t>(
|
|||||||
// fp8x8 -> float8
|
// fp8x8 -> float8
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ Float8_ scaled_vec_conversion<Float8_, uint2>(
|
__inline__ __device__ Float8_ scaled_vec_conversion<Float8_, uint2>(
|
||||||
const uint2 &a, const float scale,
|
const uint2& a, const float scale,
|
||||||
const __nv_fp8_interpretation_t fp8_type) {
|
const __nv_fp8_interpretation_t fp8_type) {
|
||||||
Float4_ tmp1, tmp2;
|
Float4_ tmp1, tmp2;
|
||||||
tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale, fp8_type);
|
tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale, fp8_type);
|
||||||
@ -457,7 +456,7 @@ __inline__ __device__ Float8_ scaled_vec_conversion<Float8_, uint2>(
|
|||||||
// half -> fp8
|
// half -> fp8
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, uint16_t>(
|
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, uint16_t>(
|
||||||
const uint16_t &a, const float scale,
|
const uint16_t& a, const float scale,
|
||||||
const __nv_fp8_interpretation_t fp8_type) {
|
const __nv_fp8_interpretation_t fp8_type) {
|
||||||
__nv_fp8_storage_t res =
|
__nv_fp8_storage_t res =
|
||||||
__nv_cvt_float_to_fp8(half_to_float(a) / scale, __NV_SATFINITE, fp8_type);
|
__nv_cvt_float_to_fp8(half_to_float(a) / scale, __NV_SATFINITE, fp8_type);
|
||||||
@ -467,21 +466,21 @@ __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, uint16_t>(
|
|||||||
// bf16 -> fp8
|
// bf16 -> fp8
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
|
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
|
||||||
const __nv_bfloat16 &a, const float scale,
|
const __nv_bfloat16& a, const float scale,
|
||||||
const __nv_fp8_interpretation_t fp8_type) {
|
const __nv_fp8_interpretation_t fp8_type) {
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||||
assert(false);
|
assert(false);
|
||||||
#else
|
#else
|
||||||
__nv_fp8_storage_t res = __nv_cvt_float_to_fp8(__bfloat162float(a) / scale,
|
__nv_fp8_storage_t res = __nv_cvt_float_to_fp8(__bfloat162float(a) / scale,
|
||||||
__NV_SATFINITE, fp8_type);
|
__NV_SATFINITE, fp8_type);
|
||||||
return (uint8_t)res;
|
return (uint8_t)res;
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
// float -> fp8
|
// float -> fp8
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, float>(
|
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, float>(
|
||||||
const float &a, const float scale,
|
const float& a, const float scale,
|
||||||
const __nv_fp8_interpretation_t fp8_type) {
|
const __nv_fp8_interpretation_t fp8_type) {
|
||||||
__nv_fp8_storage_t res =
|
__nv_fp8_storage_t res =
|
||||||
__nv_cvt_float_to_fp8(a / scale, __NV_SATFINITE, fp8_type);
|
__nv_cvt_float_to_fp8(a / scale, __NV_SATFINITE, fp8_type);
|
||||||
@ -491,78 +490,81 @@ __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, float>(
|
|||||||
// fp8x4 -> float4
|
// fp8x4 -> float4
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ float4 scaled_vec_conversion<float4, uint32_t>(
|
__inline__ __device__ float4 scaled_vec_conversion<float4, uint32_t>(
|
||||||
const uint32_t &a, const float scale,
|
const uint32_t& a, const float scale,
|
||||||
const __nv_fp8_interpretation_t fp8_type) {
|
const __nv_fp8_interpretation_t fp8_type) {
|
||||||
Float4_ tmp = scaled_vec_conversion<Float4_, uint32_t>(a, scale, fp8_type);
|
Float4_ tmp = scaled_vec_conversion<Float4_, uint32_t>(a, scale, fp8_type);
|
||||||
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
#endif // ENABLE_FP8
|
#endif // ENABLE_FP8
|
||||||
|
|
||||||
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
||||||
__inline__ __device__ Tout convert(const Tin &x) {
|
__inline__ __device__ Tout convert(const Tin& x) {
|
||||||
#if 0 // Disable the following code to reduce the binary size.
|
#if 0 // Disable the following code to reduce the binary size.
|
||||||
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
|
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
|
||||||
return vec_conversion<Tout, Tin>(x, __NV_E4M3);
|
return vec_conversion<Tout, Tin>(x, __NV_E4M3);
|
||||||
} else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) {
|
} else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) {
|
||||||
return vec_conversion<Tout, Tin>(x, __NV_E5M2);
|
return vec_conversion<Tout, Tin>(x, __NV_E5M2);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
assert(false);
|
assert(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
||||||
__inline__ __device__ Tout scaled_convert(const Tin &x, const float scale) {
|
__inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
|
||||||
#ifdef ENABLE_FP8
|
#ifdef ENABLE_FP8
|
||||||
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
|
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
|
||||||
return scaled_vec_conversion<Tout, Tin>(x, scale, __NV_E4M3);
|
return scaled_vec_conversion<Tout, Tin>(x, scale, __NV_E4M3);
|
||||||
} else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) {
|
} else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) {
|
||||||
return scaled_vec_conversion<Tout, Tin>(x, scale, __NV_E5M2);
|
return scaled_vec_conversion<Tout, Tin>(x, scale, __NV_E5M2);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
assert(false);
|
assert(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
// The following macro is used to dispatch the conversion function based on the
|
// The following macro is used to dispatch the conversion function based on
|
||||||
// data type of the key and value cache. The FN is a macro that calls a function
|
// the data type of the key and value cache. The FN is a macro that calls a
|
||||||
// with template<typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>.
|
// function with template<typename scalar_t, typename cache_t,
|
||||||
#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
|
// Fp8KVCacheDataType kv_dt>.
|
||||||
if (KV_DTYPE == "auto") { \
|
#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
|
||||||
if (SRC_DTYPE == at::ScalarType::Float) { \
|
if (KV_DTYPE == "auto") { \
|
||||||
FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \
|
|
||||||
} else if (SRC_DTYPE == at::ScalarType::Half) { \
|
|
||||||
FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \
|
|
||||||
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
|
||||||
FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \
|
|
||||||
} else { \
|
|
||||||
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
|
|
||||||
} \
|
|
||||||
} else { \
|
|
||||||
if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \
|
|
||||||
if (SRC_DTYPE == at::ScalarType::Float) { \
|
if (SRC_DTYPE == at::ScalarType::Float) { \
|
||||||
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \
|
||||||
} else if (SRC_DTYPE == at::ScalarType::Half) { \
|
} else if (SRC_DTYPE == at::ScalarType::Half) { \
|
||||||
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \
|
||||||
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
||||||
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \
|
||||||
} else { \
|
|
||||||
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
|
|
||||||
} \
|
|
||||||
} else if (KV_DTYPE == "fp8_e5m2") { \
|
|
||||||
if (SRC_DTYPE == at::ScalarType::Float) { \
|
|
||||||
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
|
|
||||||
} else if (SRC_DTYPE == at::ScalarType::Half) { \
|
|
||||||
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
|
|
||||||
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
|
||||||
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
|
|
||||||
} else { \
|
} else { \
|
||||||
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
|
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
|
||||||
} \
|
} \
|
||||||
} else { \
|
} else { \
|
||||||
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
|
if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \
|
||||||
} \
|
if (SRC_DTYPE == at::ScalarType::Float) { \
|
||||||
}
|
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
||||||
|
} else if (SRC_DTYPE == at::ScalarType::Half) { \
|
||||||
|
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
||||||
|
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
||||||
|
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
||||||
|
} else { \
|
||||||
|
TORCH_CHECK(false, \
|
||||||
|
"Unsupported input type of kv cache: ", SRC_DTYPE); \
|
||||||
|
} \
|
||||||
|
} else if (KV_DTYPE == "fp8_e5m2") { \
|
||||||
|
if (SRC_DTYPE == at::ScalarType::Float) { \
|
||||||
|
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
|
||||||
|
} else if (SRC_DTYPE == at::ScalarType::Half) { \
|
||||||
|
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
|
||||||
|
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
||||||
|
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
|
||||||
|
} else { \
|
||||||
|
TORCH_CHECK(false, \
|
||||||
|
"Unsupported input type of kv cache: ", SRC_DTYPE); \
|
||||||
|
} \
|
||||||
|
} else { \
|
||||||
|
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
|
||||||
|
} \
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace fp8
|
} // namespace fp8
|
||||||
#endif // not USE_ROCM
|
#endif // not USE_ROCM
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
@ -9,54 +9,54 @@ namespace vllm {
|
|||||||
namespace gptq {
|
namespace gptq {
|
||||||
// atomicAdd for half types, to support CC < 7.x
|
// atomicAdd for half types, to support CC < 7.x
|
||||||
|
|
||||||
__device__ __forceinline__ void atomicAdd_half(half* address, half val)
|
__device__ __forceinline__ void atomicAdd_half(half* address, half val) {
|
||||||
{
|
unsigned int* address_as_ui =
|
||||||
unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
|
(unsigned int*)((char*)address - ((size_t)address & 2));
|
||||||
unsigned int old = *address_as_ui;
|
unsigned int old = *address_as_ui;
|
||||||
unsigned int assumed;
|
unsigned int assumed;
|
||||||
|
|
||||||
do
|
do {
|
||||||
{
|
assumed = old;
|
||||||
assumed = old;
|
__half_raw hsum;
|
||||||
__half_raw hsum;
|
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
|
||||||
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
|
half tmpres = __hadd(hsum, val);
|
||||||
half tmpres = __hadd(hsum, val);
|
hsum = __half_raw(tmpres);
|
||||||
hsum = __half_raw(tmpres);
|
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16)
|
||||||
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
|
: (old & 0xffff0000) | hsum.x;
|
||||||
old = atomicCAS(address_as_ui, assumed, old);
|
old = atomicCAS(address_as_ui, assumed, old);
|
||||||
}
|
} while (assumed != old);
|
||||||
while (assumed != old);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// atomicAdd for half2 types
|
// atomicAdd for half2 types
|
||||||
|
|
||||||
__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
|
__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) {
|
||||||
{
|
unsigned int* address_as_ui = (unsigned int*)address;
|
||||||
unsigned int* address_as_ui = (unsigned int*)address;
|
unsigned int old = *address_as_ui;
|
||||||
unsigned int old = *address_as_ui;
|
unsigned int assumed;
|
||||||
unsigned int assumed;
|
do {
|
||||||
do
|
assumed = old;
|
||||||
{
|
half2 old_val = *((half2*)&old);
|
||||||
assumed = old;
|
half2 new_val = __hadd2(old_val, val);
|
||||||
half2 old_val = *((half2*)&old);
|
old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val));
|
||||||
half2 new_val = __hadd2(old_val, val);
|
} while (assumed != old);
|
||||||
old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val));
|
|
||||||
}
|
|
||||||
while (assumed != old);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
|
|
||||||
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
|
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
|
||||||
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
|
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
|
||||||
|
|
||||||
__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
|
__device__ __forceinline__ void atomicAdd(half* address, half val) {
|
||||||
|
atomicAdd_half(address, val);
|
||||||
|
}
|
||||||
|
|
||||||
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
|
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
|
||||||
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
|
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) {
|
||||||
#endif
|
atomicAdd_half2(address, val);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
} // namespace gptq
|
} // namespace gptq
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
/*
|
/*
|
||||||
Adapted from https://github.com/turboderp/exllamav2 and https://github.com/turboderp/exllama
|
Adapted from https://github.com/turboderp/exllamav2 and
|
||||||
|
https://github.com/turboderp/exllama
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef _matrix_view_cuh
|
#ifndef _matrix_view_cuh
|
||||||
@ -13,260 +14,280 @@ Adapted from https://github.com/turboderp/exllamav2 and https://github.com/turbo
|
|||||||
namespace vllm {
|
namespace vllm {
|
||||||
namespace gptq {
|
namespace gptq {
|
||||||
|
|
||||||
class MatrixView_half
|
class MatrixView_half {
|
||||||
{
|
public:
|
||||||
public:
|
const half* data;
|
||||||
const half* data;
|
const int height;
|
||||||
const int height;
|
const int width;
|
||||||
const int width;
|
|
||||||
|
|
||||||
__device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width)
|
__device__ __forceinline__ MatrixView_half(const half* data, const int height,
|
||||||
: data(data), height(height), width(width)
|
const int width)
|
||||||
{ }
|
: data(data), height(height), width(width) {}
|
||||||
|
|
||||||
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
|
__device__ __forceinline__ half item(int row, int column) const {
|
||||||
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
|
return data[row * width + column];
|
||||||
__device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); }
|
}
|
||||||
__device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; }
|
__device__ __forceinline__ half2 item_half2(int row, int column) const {
|
||||||
|
return ((half2*)data)[(row * width + column) / 2];
|
||||||
|
}
|
||||||
|
__device__ __forceinline__ half2 item_half2half2(int row, int column) const {
|
||||||
|
return __half2half2(data[row * width + column]);
|
||||||
|
}
|
||||||
|
__device__ __forceinline__ const half* item_ptr(int row, int column) const {
|
||||||
|
return &data[row * width + column];
|
||||||
|
}
|
||||||
|
|
||||||
__device__ __forceinline__ void item4(half (&items)[4], int row, int column) const
|
__device__ __forceinline__ void item4(half (&items)[4], int row,
|
||||||
{
|
int column) const {
|
||||||
half2* ptr = (half2*) item_ptr(row, column);
|
half2* ptr = (half2*)item_ptr(row, column);
|
||||||
half2 i01 = ptr[0];
|
half2 i01 = ptr[0];
|
||||||
half2 i23 = ptr[1];
|
half2 i23 = ptr[1];
|
||||||
items[0] = __low2half(i01);
|
items[0] = __low2half(i01);
|
||||||
items[1] = __high2half(i01);
|
items[1] = __high2half(i01);
|
||||||
items[2] = __low2half(i23);
|
items[2] = __low2half(i23);
|
||||||
items[3] = __high2half(i23);
|
items[3] = __high2half(i23);
|
||||||
}
|
}
|
||||||
__device__ __forceinline__ void item4_f(float (&items)[4], int row, int column) const
|
__device__ __forceinline__ void item4_f(float (&items)[4], int row,
|
||||||
{
|
int column) const {
|
||||||
half2* ptr = (half2*)item_ptr(row, column);
|
half2* ptr = (half2*)item_ptr(row, column);
|
||||||
half2 i01 = ptr[0];
|
half2 i01 = ptr[0];
|
||||||
half2 i23 = ptr[1];
|
half2 i23 = ptr[1];
|
||||||
items[0] = __half2float(__low2half(i01));
|
items[0] = __half2float(__low2half(i01));
|
||||||
items[1] = __half2float(__high2half(i01));
|
items[1] = __half2float(__high2half(i01));
|
||||||
items[2] = __half2float(__low2half(i23));
|
items[2] = __half2float(__low2half(i23));
|
||||||
items[3] = __half2float(__high2half(i23));
|
items[3] = __half2float(__high2half(i23));
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, int column) const
|
__device__ __forceinline__ void item4_h2(half2 (&items)[4], int row,
|
||||||
{
|
int column) const {
|
||||||
half2* ptr = (half2*)item_ptr(row, column);
|
half2* ptr = (half2*)item_ptr(row, column);
|
||||||
half2 i01 = ptr[0];
|
half2 i01 = ptr[0];
|
||||||
half2 i23 = ptr[1];
|
half2 i23 = ptr[1];
|
||||||
items[0] = __half2half2(__low2half(i01));
|
items[0] = __half2half2(__low2half(i01));
|
||||||
items[1] = __half2half2(__high2half(i01));
|
items[1] = __half2half2(__high2half(i01));
|
||||||
items[2] = __half2half2(__low2half(i23));
|
items[2] = __half2half2(__low2half(i23));
|
||||||
items[3] = __half2half2(__high2half(i23));
|
items[3] = __half2half2(__high2half(i23));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class MatrixView_half_rw
|
class MatrixView_half_rw {
|
||||||
{
|
public:
|
||||||
public:
|
half* data;
|
||||||
half* data;
|
const int height;
|
||||||
const int height;
|
const int width;
|
||||||
const int width;
|
|
||||||
|
|
||||||
__device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width)
|
__device__ __forceinline__ MatrixView_half_rw(half* data, const int height,
|
||||||
: data(data), height(height), width(width)
|
const int width)
|
||||||
{ }
|
: data(data), height(height), width(width) {}
|
||||||
|
|
||||||
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
|
__device__ __forceinline__ half item(int row, int column) const {
|
||||||
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
|
return data[row * width + column];
|
||||||
__device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; }
|
}
|
||||||
__device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; }
|
__device__ __forceinline__ half2 item_half2(int row, int column) const {
|
||||||
__device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; }
|
return ((half2*)data)[(row * width + column) / 2];
|
||||||
|
}
|
||||||
|
__device__ __forceinline__ half* item_ptr(int row, int column) {
|
||||||
|
return &data[row * width + column];
|
||||||
|
}
|
||||||
|
__device__ __forceinline__ void set(int row, int column, half value) {
|
||||||
|
data[row * width + column] = value;
|
||||||
|
}
|
||||||
|
__device__ __forceinline__ void set_half2(int row, int column, half2 value) {
|
||||||
|
((half2*)data)[(row * width + column) / 2] = value;
|
||||||
|
}
|
||||||
|
|
||||||
__device__ __forceinline__ void set4(int row, int column, half v0, half v1, half v2, half v3)
|
__device__ __forceinline__ void set4(int row, int column, half v0, half v1,
|
||||||
{
|
half v2, half v3) {
|
||||||
half2 v01 = __halves2half2(v0, v1);
|
half2 v01 = __halves2half2(v0, v1);
|
||||||
half2 v23 = __halves2half2(v2, v3);
|
half2 v23 = __halves2half2(v2, v3);
|
||||||
half2* ptr = (half2*) item_ptr(row, column);
|
half2* ptr = (half2*)item_ptr(row, column);
|
||||||
ptr[0] = v01;
|
ptr[0] = v01;
|
||||||
ptr[1] = v23;
|
ptr[1] = v23;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class MatrixView_q4_row
|
class MatrixView_q4_row {
|
||||||
{
|
public:
|
||||||
public:
|
const uint32_t* data;
|
||||||
const uint32_t* data;
|
const int height;
|
||||||
const int height;
|
const int width;
|
||||||
const int width;
|
|
||||||
|
|
||||||
__device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width)
|
__device__ __forceinline__ MatrixView_q4_row(const uint32_t* data,
|
||||||
: data(data), height(height), width(width)
|
const int height,
|
||||||
{ }
|
const int width)
|
||||||
|
: data(data), height(height), width(width) {}
|
||||||
|
|
||||||
__device__ __forceinline__ int item(int row, int column) const
|
__device__ __forceinline__ int item(int row, int column) const {
|
||||||
{
|
int shift = (column & 0x07) * 4;
|
||||||
int shift = (column & 0x07) * 4;
|
return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
|
||||||
return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
|
}
|
||||||
}
|
|
||||||
|
|
||||||
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const
|
__device__ __forceinline__ void item2(int (&items)[2], int row,
|
||||||
{
|
int column) const {
|
||||||
int shift = (column & 0x07) * 4;
|
int shift = (column & 0x07) * 4;
|
||||||
uint32_t d = data[row * width / 8 + column / 8] >> shift;
|
uint32_t d = data[row * width / 8 + column / 8] >> shift;
|
||||||
items[0] = d & 0x0f;
|
items[0] = d & 0x0f;
|
||||||
items[1] = (d >> 4) & 0x0f;
|
items[1] = (d >> 4) & 0x0f;
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
|
__device__ __forceinline__ void item4(int (&items)[4], int row,
|
||||||
{
|
int column) const {
|
||||||
int shift = (column & 0x07) * 4;
|
int shift = (column & 0x07) * 4;
|
||||||
uint32_t d = data[row * width / 8 + column / 8] >> shift;
|
uint32_t d = data[row * width / 8 + column / 8] >> shift;
|
||||||
items[0] = d & 0x0f;
|
items[0] = d & 0x0f;
|
||||||
items[1] = (d >> 4) & 0x0f;
|
items[1] = (d >> 4) & 0x0f;
|
||||||
items[2] = (d >> 8) & 0x0f;
|
items[2] = (d >> 8) & 0x0f;
|
||||||
items[3] = (d >> 12) & 0x0f;
|
items[3] = (d >> 12) & 0x0f;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class MatrixView_q4_column
|
class MatrixView_q4_column {
|
||||||
{
|
public:
|
||||||
public:
|
const uint32_t* data;
|
||||||
const uint32_t* data;
|
const int height;
|
||||||
const int height;
|
const int width;
|
||||||
const int width;
|
|
||||||
|
|
||||||
__device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width)
|
__device__ __forceinline__ MatrixView_q4_column(const uint32_t* data,
|
||||||
: data(data), height(height), width(width)
|
const int height,
|
||||||
{ }
|
const int width)
|
||||||
|
: data(data), height(height), width(width) {}
|
||||||
|
|
||||||
__device__ __forceinline__ int item(int row, int column) const
|
__device__ __forceinline__ int item(int row, int column) const {
|
||||||
{
|
int shift = (row & 0x07) * 4;
|
||||||
int shift = (row & 0x07) * 4;
|
return (data[row / 8 * width + column] >> shift) & 0x0f;
|
||||||
return (data[row / 8 * width + column] >> shift) & 0x0f;
|
}
|
||||||
}
|
|
||||||
|
|
||||||
__device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; }
|
__device__ __forceinline__ uint32_t item_uint32_t(int row, int column) {
|
||||||
__device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; }
|
return data[row / 8 * width + column];
|
||||||
|
}
|
||||||
|
__device__ __forceinline__ const uint32_t* item_uint32_ptr(int row,
|
||||||
|
int column) {
|
||||||
|
return &data[row / 8 * width + column];
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class MatrixView_q2_row
|
class MatrixView_q2_row {
|
||||||
{
|
public:
|
||||||
public:
|
const uint32_t* data;
|
||||||
const uint32_t* data;
|
const int height;
|
||||||
const int height;
|
const int width;
|
||||||
const int width;
|
|
||||||
|
|
||||||
__device__ __forceinline__ MatrixView_q2_row(const uint32_t* data, const int height, const int width)
|
__device__ __forceinline__ MatrixView_q2_row(const uint32_t* data,
|
||||||
: data(data), height(height), width(width)
|
const int height,
|
||||||
{ }
|
const int width)
|
||||||
|
: data(data), height(height), width(width) {}
|
||||||
|
|
||||||
__device__ __forceinline__ int item(int row, int column) const
|
__device__ __forceinline__ int item(int row, int column) const {
|
||||||
{
|
int shift = (column & 0x0f) * 2;
|
||||||
int shift = (column & 0x0f) * 2;
|
return (data[row * width / 16 + column / 16] >> shift) & 0x03;
|
||||||
return (data[row * width / 16 + column / 16] >> shift) & 0x03;
|
}
|
||||||
}
|
|
||||||
|
|
||||||
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const
|
__device__ __forceinline__ void item2(int (&items)[2], int row,
|
||||||
{
|
int column) const {
|
||||||
int shift = (column & 0x0f) * 2;
|
int shift = (column & 0x0f) * 2;
|
||||||
uint32_t d = data[row * width / 16 + column / 16] >> shift;
|
uint32_t d = data[row * width / 16 + column / 16] >> shift;
|
||||||
items[0] = d & 0x03;
|
items[0] = d & 0x03;
|
||||||
items[1] = (d >> 2) & 0x03;
|
items[1] = (d >> 2) & 0x03;
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
|
__device__ __forceinline__ void item4(int (&items)[4], int row,
|
||||||
{
|
int column) const {
|
||||||
int shift = (column & 0x0f) * 2;
|
int shift = (column & 0x0f) * 2;
|
||||||
uint32_t d = data[row * width / 16 + column / 16] >> shift;
|
uint32_t d = data[row * width / 16 + column / 16] >> shift;
|
||||||
items[0] = d & 0x03;
|
items[0] = d & 0x03;
|
||||||
items[1] = (d >> 2) & 0x03;
|
items[1] = (d >> 2) & 0x03;
|
||||||
items[2] = (d >> 4) & 0x03;
|
items[2] = (d >> 4) & 0x03;
|
||||||
items[3] = (d >> 6) & 0x03;
|
items[3] = (d >> 6) & 0x03;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class MatrixView_q3_row
|
class MatrixView_q3_row {
|
||||||
{
|
public:
|
||||||
public:
|
const uint32_t* data;
|
||||||
const uint32_t* data;
|
const int height;
|
||||||
const int height;
|
const int width;
|
||||||
const int width;
|
|
||||||
|
|
||||||
__device__ __forceinline__ MatrixView_q3_row(const uint32_t* data, const int height, const int width)
|
__device__ __forceinline__ MatrixView_q3_row(const uint32_t* data,
|
||||||
: data(data), height(height), width(width)
|
const int height,
|
||||||
{ }
|
const int width)
|
||||||
|
: data(data), height(height), width(width) {}
|
||||||
|
|
||||||
__device__ __forceinline__ int item(int row, int column) const
|
__device__ __forceinline__ int item(int row, int column) const {
|
||||||
{
|
int z_w = column * 3 / 32;
|
||||||
int z_w = column * 3 / 32;
|
int z_mod = column & 0x1f;
|
||||||
int z_mod = column & 0x1f;
|
|
||||||
|
|
||||||
if (z_mod == 10) {
|
if (z_mod == 10) {
|
||||||
return (data[row * width * 3 / 32 + z_w] >> 30) | ((data[row * width * 3 / 32 + (z_w + 1)] << 2) & 0x4);
|
return (data[row * width * 3 / 32 + z_w] >> 30) |
|
||||||
} else if (z_mod == 21) {
|
((data[row * width * 3 / 32 + (z_w + 1)] << 2) & 0x4);
|
||||||
return (data[row * width * 3 / 32 + z_w] >> 31) | ((data[row * width * 3 / 32 + (z_w + 1)] << 1) & 0x6);
|
} else if (z_mod == 21) {
|
||||||
} else if (z_mod < 10) {
|
return (data[row * width * 3 / 32 + z_w] >> 31) |
|
||||||
return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3)) & 0x07;
|
((data[row * width * 3 / 32 + (z_w + 1)] << 1) & 0x6);
|
||||||
} else if (z_mod < 21) {
|
} else if (z_mod < 10) {
|
||||||
return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 32)) & 0x07;
|
return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3)) & 0x07;
|
||||||
} else {
|
} else if (z_mod < 21) {
|
||||||
return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 64)) & 0x07;
|
return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 32)) & 0x07;
|
||||||
}
|
} else {
|
||||||
|
return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 64)) & 0x07;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
|
__device__ __forceinline__ void item4(int (&items)[4], int row,
|
||||||
{
|
int column) const {
|
||||||
int shift = (column & 0x1f);
|
int shift = (column & 0x1f);
|
||||||
uint32_t d;
|
uint32_t d;
|
||||||
if (shift <= 4) {
|
if (shift <= 4) {
|
||||||
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3);
|
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3);
|
||||||
} else if (shift == 8) {
|
} else if (shift == 8) {
|
||||||
d = (data[row * width / 32 * 3 + column * 3 / 32] >> 24) | ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0x0f) << 8);
|
d = (data[row * width / 32 * 3 + column * 3 / 32] >> 24) |
|
||||||
} else if (shift <= 16) {
|
((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0x0f) << 8);
|
||||||
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 32);
|
} else if (shift <= 16) {
|
||||||
} else if (shift == 20) {
|
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 32);
|
||||||
d = (data[row * width / 32 * 3 + column * 3 / 32] >> 28) | ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0xff) << 4);
|
} else if (shift == 20) {
|
||||||
} else {
|
d = (data[row * width / 32 * 3 + column * 3 / 32] >> 28) |
|
||||||
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 64);
|
((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0xff) << 4);
|
||||||
}
|
} else {
|
||||||
items[0] = d & 0x07;
|
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 64);
|
||||||
items[1] = (d >> 3) & 0x07;
|
|
||||||
items[2] = (d >> 6) & 0x07;
|
|
||||||
items[3] = (d >> 9) & 0x07;
|
|
||||||
}
|
}
|
||||||
|
items[0] = d & 0x07;
|
||||||
|
items[1] = (d >> 3) & 0x07;
|
||||||
|
items[2] = (d >> 6) & 0x07;
|
||||||
|
items[3] = (d >> 9) & 0x07;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class MatrixView_q8_row
|
class MatrixView_q8_row {
|
||||||
{
|
public:
|
||||||
public:
|
const uint32_t* data;
|
||||||
const uint32_t* data;
|
const int height;
|
||||||
const int height;
|
const int width;
|
||||||
const int width;
|
|
||||||
|
|
||||||
__device__ __forceinline__ MatrixView_q8_row(const uint32_t* data, const int height, const int width)
|
__device__ __forceinline__ MatrixView_q8_row(const uint32_t* data,
|
||||||
: data(data), height(height), width(width)
|
const int height,
|
||||||
{ }
|
const int width)
|
||||||
|
: data(data), height(height), width(width) {}
|
||||||
|
|
||||||
__device__ __forceinline__ int item(int row, int column) const
|
__device__ __forceinline__ int item(int row, int column) const {
|
||||||
{
|
int shift = (column & 0x03) * 8;
|
||||||
int shift = (column & 0x03) * 8;
|
return (data[row * width / 4 + column / 4] >> shift) & 0xff;
|
||||||
return (data[row * width / 4 + column / 4] >> shift) & 0xff;
|
}
|
||||||
}
|
|
||||||
|
|
||||||
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const
|
__device__ __forceinline__ void item2(int (&items)[2], int row,
|
||||||
{
|
int column) const {
|
||||||
int shift = (column & 0x03) * 8;
|
int shift = (column & 0x03) * 8;
|
||||||
uint32_t d = data[row * width / 4 + column / 4] >> shift;
|
uint32_t d = data[row * width / 4 + column / 4] >> shift;
|
||||||
items[0] = d & 0xff;
|
items[0] = d & 0xff;
|
||||||
items[1] = (d >> 8) & 0xff;
|
items[1] = (d >> 8) & 0xff;
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
|
__device__ __forceinline__ void item4(int (&items)[4], int row,
|
||||||
{
|
int column) const {
|
||||||
int shift = (column & 0x03) * 2;
|
int shift = (column & 0x03) * 2;
|
||||||
uint32_t d = data[row * width / 4 + column / 4] >> shift;
|
uint32_t d = data[row * width / 4 + column / 4] >> shift;
|
||||||
items[0] = d & 0xff;
|
items[0] = d & 0xff;
|
||||||
items[1] = (d >> 8) & 0xff;
|
items[1] = (d >> 8) & 0xff;
|
||||||
items[2] = (d >> 16) & 0xff;
|
items[2] = (d >> 16) & 0xff;
|
||||||
items[3] = (d >> 24) & 0xff;
|
items[3] = (d >> 24) & 0xff;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace gptq
|
} // namespace gptq
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -14,71 +14,60 @@ namespace gptq {
|
|||||||
//
|
//
|
||||||
// ffddbb99 77553311 eeccaa88 66442200
|
// ffddbb99 77553311 eeccaa88 66442200
|
||||||
|
|
||||||
__forceinline__ __device__ void shuffle_2bit_16
|
__forceinline__ __device__ void shuffle_2bit_16(uint32_t* q, int stride) {
|
||||||
(
|
uint32_t qa = q[0];
|
||||||
uint32_t* q,
|
uint32_t qb = 0;
|
||||||
int stride
|
|
||||||
)
|
|
||||||
{
|
|
||||||
uint32_t qa = q[0];
|
|
||||||
uint32_t qb = 0;
|
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < 8; i++)
|
for (int i = 0; i < 8; i++) {
|
||||||
{
|
uint32_t qa0 = qa & 0x03;
|
||||||
uint32_t qa0 = qa & 0x03;
|
uint32_t qa1 = (qa & 0x0c) >> 2;
|
||||||
uint32_t qa1 = (qa & 0x0c) >> 2;
|
qa >>= 4;
|
||||||
qa >>= 4;
|
qb |= (qa1 << (i * 2 + 16));
|
||||||
qb |= (qa1 << (i * 2 + 16));
|
qb |= (qa0 << (i * 2));
|
||||||
qb |= (qa0 << (i * 2));
|
}
|
||||||
}
|
q[0] = qb;
|
||||||
q[0] = qb;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
__forceinline__ __device__ void dequant_2bit_16
|
__forceinline__ __device__ void dequant_2bit_16(const uint32_t q_0,
|
||||||
(
|
half2 (&dq)[8], int stride,
|
||||||
const uint32_t q_0,
|
const uint32_t zero) {
|
||||||
half2 (&dq)[8],
|
const uint32_t c0 = 0x64006400;
|
||||||
int stride,
|
const half y4_ = __float2half_rn(1.0f / 4.0f);
|
||||||
const uint32_t zero
|
const half y16_ = __float2half_rn(1.0f / 16.0f);
|
||||||
)
|
const half y64_ = __float2half_rn(1.0f / 64.0f);
|
||||||
{
|
const half2 y4 = __halves2half2(y4_, y4_);
|
||||||
const uint32_t c0 = 0x64006400;
|
const half2 y16 = __halves2half2(y16_, y16_);
|
||||||
const half y4_ = __float2half_rn(1.0f / 4.0f);
|
const half2 y64 = __halves2half2(y64_, y64_);
|
||||||
const half y16_ = __float2half_rn(1.0f / 16.0f);
|
|
||||||
const half y64_ = __float2half_rn(1.0f / 64.0f);
|
|
||||||
const half2 y4 = __halves2half2(y4_, y4_);
|
|
||||||
const half2 y16 = __halves2half2(y16_, y16_);
|
|
||||||
const half2 y64 = __halves2half2(y64_, y64_);
|
|
||||||
|
|
||||||
const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero);
|
const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero);
|
||||||
const half z4_ = __hsub(__int2half_rn(-256), __int2half_rn(zero));
|
const half z4_ = __hsub(__int2half_rn(-256), __int2half_rn(zero));
|
||||||
const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero));
|
const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero));
|
||||||
const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero));
|
const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero));
|
||||||
const half2 z1 = __half2half2(z1_.as_half);
|
const half2 z1 = __half2half2(z1_.as_half);
|
||||||
const half2 z4 = __half2half2(z4_);
|
const half2 z4 = __half2half2(z4_);
|
||||||
const half2 z16 = __half2half2(z16_);
|
const half2 z16 = __half2half2(z16_);
|
||||||
const half2 z64 = __half2half2(z64_);
|
const half2 z64 = __half2half2(z64_);
|
||||||
|
|
||||||
uint32_t qa = q_0;
|
uint32_t qa = q_0;
|
||||||
half2_uint32 q0((qa & 0x00030003) | c0); // half2(q[ 0], q[ 1]) + 1024
|
half2_uint32 q0((qa & 0x00030003) | c0); // half2(q[ 0], q[ 1]) + 1024
|
||||||
half2_uint32 q1((qa & 0x000c000c) | c0); // half2(q[ 2], q[ 3]) * 4 + 1024
|
half2_uint32 q1((qa & 0x000c000c) | c0); // half2(q[ 2], q[ 3]) * 4 + 1024
|
||||||
half2_uint32 q2((qa & 0x00300030) | c0); // half2(q[ 4], q[ 5]) * 16 + 1024
|
half2_uint32 q2((qa & 0x00300030) | c0); // half2(q[ 4], q[ 5]) * 16 + 1024
|
||||||
half2_uint32 q3((qa & 0x00c000c0) | c0); // half2(q[ 6], q[ 7]) * 64 + 1024
|
half2_uint32 q3((qa & 0x00c000c0) | c0); // half2(q[ 6], q[ 7]) * 64 + 1024
|
||||||
qa >>= 8;
|
qa >>= 8;
|
||||||
half2_uint32 q4((qa & 0x00030003) | c0); // half2(q[ 8], q[ 8]) + 1024
|
half2_uint32 q4((qa & 0x00030003) | c0); // half2(q[ 8], q[ 8]) + 1024
|
||||||
half2_uint32 q5((qa & 0x000c000c) | c0); // half2(q[10], q[11]) * 4 + 1024
|
half2_uint32 q5((qa & 0x000c000c) | c0); // half2(q[10], q[11]) * 4 + 1024
|
||||||
half2_uint32 q6((qa & 0x00300030) | c0); // half2(q[12], q[13]) * 16 + 1024
|
half2_uint32 q6((qa & 0x00300030) | c0); // half2(q[12], q[13]) * 16 + 1024
|
||||||
half2_uint32 q7((qa & 0x00c000c0) | c0); // half2(q[14], q[15]) * 64 + 1024
|
half2_uint32 q7((qa & 0x00c000c0) | c0); // half2(q[14], q[15]) * 64 + 1024
|
||||||
|
|
||||||
dq[0] = __hadd2(q0.as_half2, z1);
|
dq[0] = __hadd2(q0.as_half2, z1);
|
||||||
dq[1] = __hfma2(q1.as_half2, y4, z4);
|
dq[1] = __hfma2(q1.as_half2, y4, z4);
|
||||||
dq[2] = __hfma2(q2.as_half2, y16, z16);
|
dq[2] = __hfma2(q2.as_half2, y16, z16);
|
||||||
dq[3] = __hfma2(q3.as_half2, y64, z64);
|
dq[3] = __hfma2(q3.as_half2, y64, z64);
|
||||||
dq[4] = __hadd2(q4.as_half2, z1);
|
dq[4] = __hadd2(q4.as_half2, z1);
|
||||||
dq[5] = __hfma2(q5.as_half2, y4, z4);
|
dq[5] = __hfma2(q5.as_half2, y4, z4);
|
||||||
dq[6] = __hfma2(q6.as_half2, y16, z16);
|
dq[6] = __hfma2(q6.as_half2, y16, z16);
|
||||||
dq[7] = __hfma2(q7.as_half2, y64, z64);
|
dq[7] = __hfma2(q7.as_half2, y64, z64);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gptq
|
} // namespace gptq
|
||||||
|
@ -11,128 +11,136 @@ namespace gptq {
|
|||||||
// vjjjhhhf ffdddbbb uiiiggge eecccaaa
|
// vjjjhhhf ffdddbbb uiiiggge eecccaaa
|
||||||
// vtttrrrp ppnnnlll usssqqqo oommmkkk
|
// vtttrrrp ppnnnlll usssqqqo oommmkkk
|
||||||
|
|
||||||
__forceinline__ __device__ void shuffle_3bit_32
|
__forceinline__ __device__ void shuffle_3bit_32(uint32_t* q, int stride) {
|
||||||
(
|
uint32_t qa = q[0 * stride];
|
||||||
uint32_t* q,
|
uint32_t qb = q[1 * stride];
|
||||||
int stride
|
uint32_t qc = q[2 * stride];
|
||||||
)
|
|
||||||
{
|
|
||||||
uint32_t qa = q[0 * stride];
|
|
||||||
uint32_t qb = q[1 * stride];
|
|
||||||
uint32_t qc = q[2 * stride];
|
|
||||||
|
|
||||||
// qa: aa999888 77766655 54443332 22111000
|
// qa: aa999888 77766655 54443332 22111000
|
||||||
// qb: lkkkjjji iihhhggg fffeeedd dcccbbba
|
// qb: lkkkjjji iihhhggg fffeeedd dcccbbba
|
||||||
// qc: vvvuuutt tsssrrrq qqpppooo nnnmmmll
|
// qc: vvvuuutt tsssrrrq qqpppooo nnnmmmll
|
||||||
|
|
||||||
uint32_t qd = qc >> 26;
|
uint32_t qd = qc >> 26;
|
||||||
qc <<= 4;
|
qc <<= 4;
|
||||||
qc |= qb >> 28;
|
qc |= qb >> 28;
|
||||||
qb <<= 2;
|
qb <<= 2;
|
||||||
qb |= qa >> 30;
|
qb |= qa >> 30;
|
||||||
|
|
||||||
// qa: ..999888 77766655 54443332 22111000
|
// qa: ..999888 77766655 54443332 22111000
|
||||||
// qb: ..jjjiii hhhgggff feeedddc ccbbbaaa
|
// qb: ..jjjiii hhhgggff feeedddc ccbbbaaa
|
||||||
// qc: ..tttsss rrrqqqpp pooonnnm mmlllkkk
|
// qc: ..tttsss rrrqqqpp pooonnnm mmlllkkk
|
||||||
// qd: vvvuuu
|
// qd: vvvuuu
|
||||||
|
|
||||||
uint32_t za = 0;
|
uint32_t za = 0;
|
||||||
uint32_t zb = 0;
|
uint32_t zb = 0;
|
||||||
uint32_t zc = 0;
|
uint32_t zc = 0;
|
||||||
|
|
||||||
for (int i = 0; i < 5; i++) { uint32_t t0 = qa & 0x07; uint32_t t1 = (qa & 0x38) >> 3; qa >>= 6; za |= (t0 << (i * 3)); za |= (t1 << (i * 3 + 16)); }
|
for (int i = 0; i < 5; i++) {
|
||||||
for (int i = 0; i < 5; i++) { uint32_t t0 = qb & 0x07; uint32_t t1 = (qb & 0x38) >> 3; qb >>= 6; zb |= (t0 << (i * 3)); zb |= (t1 << (i * 3 + 16)); }
|
uint32_t t0 = qa & 0x07;
|
||||||
for (int i = 0; i < 5; i++) { uint32_t t0 = qc & 0x07; uint32_t t1 = (qc & 0x38) >> 3; qc >>= 6; zc |= (t0 << (i * 3)); zc |= (t1 << (i * 3 + 16)); }
|
uint32_t t1 = (qa & 0x38) >> 3;
|
||||||
|
qa >>= 6;
|
||||||
|
za |= (t0 << (i * 3));
|
||||||
|
za |= (t1 << (i * 3 + 16));
|
||||||
|
}
|
||||||
|
for (int i = 0; i < 5; i++) {
|
||||||
|
uint32_t t0 = qb & 0x07;
|
||||||
|
uint32_t t1 = (qb & 0x38) >> 3;
|
||||||
|
qb >>= 6;
|
||||||
|
zb |= (t0 << (i * 3));
|
||||||
|
zb |= (t1 << (i * 3 + 16));
|
||||||
|
}
|
||||||
|
for (int i = 0; i < 5; i++) {
|
||||||
|
uint32_t t0 = qc & 0x07;
|
||||||
|
uint32_t t1 = (qc & 0x38) >> 3;
|
||||||
|
qc >>= 6;
|
||||||
|
zc |= (t0 << (i * 3));
|
||||||
|
zc |= (t1 << (i * 3 + 16));
|
||||||
|
}
|
||||||
|
|
||||||
// za: 9997775 55333111 8886664 44222000
|
// za: 9997775 55333111 8886664 44222000
|
||||||
// zb: jjjhhhf ffdddbbb iiiggge eecccaaa
|
// zb: jjjhhhf ffdddbbb iiiggge eecccaaa
|
||||||
// zc: tttrrrp ppnnnlll sssqqqo oommmkkk
|
// zc: tttrrrp ppnnnlll sssqqqo oommmkkk
|
||||||
// qd: vvvuuu
|
// qd: vvvuuu
|
||||||
|
|
||||||
za |= ((qd & 0x01) >> 0) << 15;
|
za |= ((qd & 0x01) >> 0) << 15;
|
||||||
zb |= ((qd & 0x02) >> 1) << 15;
|
zb |= ((qd & 0x02) >> 1) << 15;
|
||||||
zc |= ((qd & 0x04) >> 2) << 15;
|
zc |= ((qd & 0x04) >> 2) << 15;
|
||||||
za |= ((qd & 0x08) >> 3) << 31;
|
za |= ((qd & 0x08) >> 3) << 31;
|
||||||
zb |= ((qd & 0x10) >> 4) << 31;
|
zb |= ((qd & 0x10) >> 4) << 31;
|
||||||
zc |= ((qd & 0x20) >> 5) << 31;
|
zc |= ((qd & 0x20) >> 5) << 31;
|
||||||
|
|
||||||
// za: v9997775 55333111 u8886664 44222000 (u, v lsb)
|
// za: v9997775 55333111 u8886664 44222000 (u, v lsb)
|
||||||
// zb: vjjjhhhf ffdddbbb uiiiggge eecccaaa
|
// zb: vjjjhhhf ffdddbbb uiiiggge eecccaaa
|
||||||
// zc: vtttrrrp ppnnnlll usssqqqo oommmkkk
|
// zc: vtttrrrp ppnnnlll usssqqqo oommmkkk
|
||||||
|
|
||||||
q[0 * stride] = za;
|
q[0 * stride] = za;
|
||||||
q[1 * stride] = zb;
|
q[1 * stride] = zb;
|
||||||
q[2 * stride] = zc;
|
q[2 * stride] = zc;
|
||||||
}
|
}
|
||||||
|
|
||||||
__forceinline__ __device__ void dequant_3bit_32
|
__forceinline__ __device__ void dequant_3bit_32(const uint32_t q_0,
|
||||||
(
|
const uint32_t q_1,
|
||||||
const uint32_t q_0,
|
const uint32_t q_2,
|
||||||
const uint32_t q_1,
|
half2 (&dq)[16], int stride,
|
||||||
const uint32_t q_2,
|
const uint32_t zero) {
|
||||||
half2 (&dq)[16],
|
const uint32_t c0 = 0x64006400;
|
||||||
int stride,
|
const half y8_ = __float2half_rn(1.0f / 8.0f);
|
||||||
const uint32_t zero
|
const half y64_ = __float2half_rn(1.0f / 64.0f);
|
||||||
)
|
const half2 y8 = __halves2half2(y8_, y8_);
|
||||||
{
|
const half2 y64 = __halves2half2(y64_, y64_);
|
||||||
const uint32_t c0 = 0x64006400;
|
const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero);
|
||||||
const half y8_ = __float2half_rn(1.0f / 8.0f);
|
const half z8_ = __hsub(__int2half_rn(-128), __int2half_rn(zero));
|
||||||
const half y64_ = __float2half_rn(1.0f / 64.0f);
|
const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero));
|
||||||
const half2 y8 = __halves2half2(y8_, y8_);
|
const half2 z1 = __halves2half2(z1_.as_half, z1_.as_half);
|
||||||
const half2 y64 = __halves2half2(y64_, y64_);
|
const half2 z8 = __halves2half2(z8_, z8_);
|
||||||
const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero);
|
const half2 z64 = __halves2half2(z64_, z64_);
|
||||||
const half z8_ = __hsub(__int2half_rn(-128), __int2half_rn(zero));
|
|
||||||
const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero));
|
|
||||||
const half2 z1 = __halves2half2(z1_.as_half, z1_.as_half);
|
|
||||||
const half2 z8 = __halves2half2(z8_, z8_);
|
|
||||||
const half2 z64 = __halves2half2(z64_, z64_);
|
|
||||||
|
|
||||||
uint32_t qa = q_0;
|
uint32_t qa = q_0;
|
||||||
uint32_t qb = q_1;
|
uint32_t qb = q_1;
|
||||||
uint32_t qc = q_2;
|
uint32_t qc = q_2;
|
||||||
|
|
||||||
half2_uint32 q0((qa & 0x00070007) | c0); // half2(q[ 0], q[ 1]) + 1024
|
half2_uint32 q0((qa & 0x00070007) | c0); // half2(q[ 0], q[ 1]) + 1024
|
||||||
half2_uint32 q1((qa & 0x00380038) | c0); // half2(q[ 2], q[ 3]) * 8 + 1024
|
half2_uint32 q1((qa & 0x00380038) | c0); // half2(q[ 2], q[ 3]) * 8 + 1024
|
||||||
qa >>= 6;
|
qa >>= 6;
|
||||||
half2_uint32 q2((qa & 0x00070007) | c0); // half2(q[ 4], q[ 5]) + 1024
|
half2_uint32 q2((qa & 0x00070007) | c0); // half2(q[ 4], q[ 5]) + 1024
|
||||||
half2_uint32 q3((qa & 0x00380038) | c0); // half2(q[ 6], q[ 7]) * 8 + 1024
|
half2_uint32 q3((qa & 0x00380038) | c0); // half2(q[ 6], q[ 7]) * 8 + 1024
|
||||||
half2_uint32 q4((qa & 0x01c001c0) | c0); // half2(q[ 8], q[ 9]) * 64 + 1024
|
half2_uint32 q4((qa & 0x01c001c0) | c0); // half2(q[ 8], q[ 9]) * 64 + 1024
|
||||||
qa >>= 9;
|
qa >>= 9;
|
||||||
qa &= 0x00010001;
|
qa &= 0x00010001;
|
||||||
half2_uint32 q5((qb & 0x00070007) | c0); // half2(q[10], q[11]) + 1024
|
half2_uint32 q5((qb & 0x00070007) | c0); // half2(q[10], q[11]) + 1024
|
||||||
half2_uint32 q6((qb & 0x00380038) | c0); // half2(q[12], q[13]) * 8 + 1024
|
half2_uint32 q6((qb & 0x00380038) | c0); // half2(q[12], q[13]) * 8 + 1024
|
||||||
qb >>= 6;
|
qb >>= 6;
|
||||||
half2_uint32 q7((qb & 0x00070007) | c0); // half2(q[14], q[15]) + 1024
|
half2_uint32 q7((qb & 0x00070007) | c0); // half2(q[14], q[15]) + 1024
|
||||||
half2_uint32 q8((qb & 0x00380038) | c0); // half2(q[16], q[17]) * 8 + 1024
|
half2_uint32 q8((qb & 0x00380038) | c0); // half2(q[16], q[17]) * 8 + 1024
|
||||||
half2_uint32 q9((qb & 0x01c001c0) | c0); // half2(q[18], q[19]) * 64 + 1024
|
half2_uint32 q9((qb & 0x01c001c0) | c0); // half2(q[18], q[19]) * 64 + 1024
|
||||||
qb >>= 8;
|
qb >>= 8;
|
||||||
qb &= 0x00020002;
|
qb &= 0x00020002;
|
||||||
half2_uint32 q10((qc & 0x00070007) | c0); // half2(q[20], q[21]) + 1024
|
half2_uint32 q10((qc & 0x00070007) | c0); // half2(q[20], q[21]) + 1024
|
||||||
half2_uint32 q11((qc & 0x00380038) | c0); // half2(q[22], q[23]) * 8 + 1024
|
half2_uint32 q11((qc & 0x00380038) | c0); // half2(q[22], q[23]) * 8 + 1024
|
||||||
qc >>= 6;
|
qc >>= 6;
|
||||||
half2_uint32 q12((qc & 0x00070007) | c0); // half2(q[24], q[25]) + 1024
|
half2_uint32 q12((qc & 0x00070007) | c0); // half2(q[24], q[25]) + 1024
|
||||||
half2_uint32 q13((qc & 0x00380038) | c0); // half2(q[26], q[27]) * 8 + 1024
|
half2_uint32 q13((qc & 0x00380038) | c0); // half2(q[26], q[27]) * 8 + 1024
|
||||||
half2_uint32 q14((qc & 0x01c001c0) | c0); // half2(q[28], q[29]) * 64 + 1024
|
half2_uint32 q14((qc & 0x01c001c0) | c0); // half2(q[28], q[29]) * 64 + 1024
|
||||||
qc >>= 7;
|
qc >>= 7;
|
||||||
qc &= 0x00040004;
|
qc &= 0x00040004;
|
||||||
half2_uint32 q15((qa | qb | qc) | c0);
|
half2_uint32 q15((qa | qb | qc) | c0);
|
||||||
|
|
||||||
dq[ 0] = __hadd2( q0.as_half2, z1);
|
dq[0] = __hadd2(q0.as_half2, z1);
|
||||||
dq[ 1] = __hfma2( q1.as_half2, y8, z8);
|
dq[1] = __hfma2(q1.as_half2, y8, z8);
|
||||||
dq[ 2] = __hadd2( q2.as_half2, z1);
|
dq[2] = __hadd2(q2.as_half2, z1);
|
||||||
dq[ 3] = __hfma2( q3.as_half2, y8, z8);
|
dq[3] = __hfma2(q3.as_half2, y8, z8);
|
||||||
dq[ 4] = __hfma2( q4.as_half2, y64, z64);
|
dq[4] = __hfma2(q4.as_half2, y64, z64);
|
||||||
dq[ 5] = __hadd2( q5.as_half2, z1);
|
dq[5] = __hadd2(q5.as_half2, z1);
|
||||||
dq[ 6] = __hfma2( q6.as_half2, y8, z8);
|
dq[6] = __hfma2(q6.as_half2, y8, z8);
|
||||||
dq[ 7] = __hadd2( q7.as_half2, z1);
|
dq[7] = __hadd2(q7.as_half2, z1);
|
||||||
dq[ 8] = __hfma2( q8.as_half2, y8, z8);
|
dq[8] = __hfma2(q8.as_half2, y8, z8);
|
||||||
dq[ 9] = __hfma2( q9.as_half2, y64, z64);
|
dq[9] = __hfma2(q9.as_half2, y64, z64);
|
||||||
dq[10] = __hadd2(q10.as_half2, z1);
|
dq[10] = __hadd2(q10.as_half2, z1);
|
||||||
dq[11] = __hfma2(q11.as_half2, y8, z8);
|
dq[11] = __hfma2(q11.as_half2, y8, z8);
|
||||||
dq[12] = __hadd2(q12.as_half2, z1);
|
dq[12] = __hadd2(q12.as_half2, z1);
|
||||||
dq[13] = __hfma2(q13.as_half2, y8, z8);
|
dq[13] = __hfma2(q13.as_half2, y8, z8);
|
||||||
dq[14] = __hfma2(q14.as_half2, y64, z64);
|
dq[14] = __hfma2(q14.as_half2, y64, z64);
|
||||||
dq[15] = __hadd2(q15.as_half2, z1);
|
dq[15] = __hadd2(q15.as_half2, z1);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gptq
|
} // namespace gptq
|
||||||
|
@ -13,133 +13,112 @@ namespace gptq {
|
|||||||
//
|
//
|
||||||
// 77775555 33331111 66664444 22220000
|
// 77775555 33331111 66664444 22220000
|
||||||
|
|
||||||
__forceinline__ __device__ void shuffle_4bit_8
|
__forceinline__ __device__ void shuffle_4bit_8(uint32_t* q, int stride) {
|
||||||
(
|
uint32_t qa = q[0];
|
||||||
uint32_t* q,
|
uint32_t qb = 0;
|
||||||
int stride
|
|
||||||
)
|
|
||||||
{
|
|
||||||
uint32_t qa = q[0];
|
|
||||||
uint32_t qb = 0;
|
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < 4; i++)
|
for (int i = 0; i < 4; i++) {
|
||||||
{
|
uint32_t qa0 = qa & 0x0f;
|
||||||
uint32_t qa0 = qa & 0x0f;
|
uint32_t qa1 = (qa & 0xf0) >> 4;
|
||||||
uint32_t qa1 = (qa & 0xf0) >> 4;
|
|
||||||
qa >>= 8;
|
|
||||||
qb |= (qa1 << (i * 4 + 16));
|
|
||||||
qb |= (qa0 << (i * 4));
|
|
||||||
}
|
|
||||||
q[0] = qb;
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __device__ void dequant_4bit_8
|
|
||||||
(
|
|
||||||
const uint32_t q_0,
|
|
||||||
half2 (&dq)[4],
|
|
||||||
int stride,
|
|
||||||
const uint32_t zero
|
|
||||||
)
|
|
||||||
{
|
|
||||||
const uint32_t c0 = 0x64006400;
|
|
||||||
const half y16_ = __float2half_rn(1.0f / 16.0f);
|
|
||||||
const half2 y16 = __halves2half2(y16_, y16_);
|
|
||||||
const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero);
|
|
||||||
const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero));
|
|
||||||
const half2 z1 = __half2half2(z1_.as_half);
|
|
||||||
const half2 z16 = __half2half2(z16_);
|
|
||||||
|
|
||||||
uint32_t qa = q_0;
|
|
||||||
half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1]) + 1024
|
|
||||||
half2_uint32 q1((qa & 0x00f000f0) | c0); // half2(q[ 2], q[ 3]) * 16 + 1024
|
|
||||||
qa >>= 8;
|
qa >>= 8;
|
||||||
half2_uint32 q2((qa & 0x000f000f) | c0); // half2(q[ 4], q[ 5]) + 1024
|
qb |= (qa1 << (i * 4 + 16));
|
||||||
half2_uint32 q3((qa & 0x00f000f0) | c0); // half2(q[ 6], q[ 7]) * 16 + 1024
|
qb |= (qa0 << (i * 4));
|
||||||
|
}
|
||||||
dq[0] = __hadd2(q0.as_half2, z1);
|
q[0] = qb;
|
||||||
dq[1] = __hfma2(q1.as_half2, y16, z16);
|
|
||||||
dq[2] = __hadd2(q2.as_half2, z1);
|
|
||||||
dq[3] = __hfma2(q3.as_half2, y16, z16);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale
|
__forceinline__ __device__ void dequant_4bit_8(const uint32_t q_0,
|
||||||
(
|
half2 (&dq)[4], int stride,
|
||||||
const uint32_t zero,
|
const uint32_t zero) {
|
||||||
const half scale,
|
const uint32_t c0 = 0x64006400;
|
||||||
half2 (&z1z16)[2],
|
const half y16_ = __float2half_rn(1.0f / 16.0f);
|
||||||
half2 (&y1y16)[2]
|
const half2 y16 = __halves2half2(y16_, y16_);
|
||||||
)
|
const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero);
|
||||||
{
|
const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero));
|
||||||
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
|
const half2 z1 = __half2half2(z1_.as_half);
|
||||||
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
|
const half2 z16 = __half2half2(z16_);
|
||||||
|
|
||||||
half2 scale2 = __half2half2(scale);
|
uint32_t qa = q_0;
|
||||||
|
half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1]) + 1024
|
||||||
|
half2_uint32 q1((qa & 0x00f000f0) | c0); // half2(q[ 2], q[ 3]) * 16 + 1024
|
||||||
|
qa >>= 8;
|
||||||
|
half2_uint32 q2((qa & 0x000f000f) | c0); // half2(q[ 4], q[ 5]) + 1024
|
||||||
|
half2_uint32 q3((qa & 0x00f000f0) | c0); // half2(q[ 6], q[ 7]) * 16 + 1024
|
||||||
|
|
||||||
z1z16[0] = __hmul2(scale2, __half2half2(z1.as_half));
|
dq[0] = __hadd2(q0.as_half2, z1);
|
||||||
z1z16[1] = __hmul2(scale2, __half2half2(z16));
|
dq[1] = __hfma2(q1.as_half2, y16, z16);
|
||||||
|
dq[2] = __hadd2(q2.as_half2, z1);
|
||||||
const half y1 = __float2half_rn(1.0f);
|
dq[3] = __hfma2(q3.as_half2, y16, z16);
|
||||||
const half y16 = __float2half_rn(1.0f / 16.0f);
|
|
||||||
|
|
||||||
y1y16[0] = __hmul2(scale2, __half2half2(y1));
|
|
||||||
y1y16[1] = __hmul2(scale2, __half2half2(y16));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
__forceinline__ __device__ void dequant_4bit_8_prep_zero
|
__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale(
|
||||||
(
|
const uint32_t zero, const half scale, half2 (&z1z16)[2],
|
||||||
const uint32_t zero,
|
half2 (&y1y16)[2]) {
|
||||||
half2(&z1z16)[2],
|
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
|
||||||
half2(&y1y16)[2]
|
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
|
||||||
)
|
|
||||||
{
|
|
||||||
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
|
|
||||||
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
|
|
||||||
|
|
||||||
z1z16[0] = __half2half2(z1.as_half);
|
half2 scale2 = __half2half2(scale);
|
||||||
z1z16[1] = __half2half2(z16);
|
|
||||||
|
|
||||||
const half y1 = __float2half_rn(1.0f);
|
z1z16[0] = __hmul2(scale2, __half2half2(z1.as_half));
|
||||||
const half y16 = __float2half_rn(1.0f / 16.0f);
|
z1z16[1] = __hmul2(scale2, __half2half2(z16));
|
||||||
|
|
||||||
y1y16[0] = __half2half2(y1);
|
const half y1 = __float2half_rn(1.0f);
|
||||||
y1y16[1] = __half2half2(y16);
|
const half y16 = __float2half_rn(1.0f / 16.0f);
|
||||||
|
|
||||||
|
y1y16[0] = __hmul2(scale2, __half2half2(y1));
|
||||||
|
y1y16[1] = __hmul2(scale2, __half2half2(y16));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__forceinline__ __device__ void dequant_4bit_8_prep_zero(const uint32_t zero,
|
||||||
|
half2 (&z1z16)[2],
|
||||||
|
half2 (&y1y16)[2]) {
|
||||||
|
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
|
||||||
|
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
|
||||||
|
|
||||||
__forceinline__ __device__ void dequant_4bit_8_gptq
|
z1z16[0] = __half2half2(z1.as_half);
|
||||||
(
|
z1z16[1] = __half2half2(z16);
|
||||||
const uint32_t q_0,
|
|
||||||
half2 (&dq)[4],
|
|
||||||
half2 (&z1z16)[2],
|
|
||||||
half2 (&y1y16)[2],
|
|
||||||
int stride,
|
|
||||||
bool scaled
|
|
||||||
)
|
|
||||||
{
|
|
||||||
const uint32_t c0 = 0x64006400;
|
|
||||||
|
|
||||||
uint32_t qa = q_0;
|
const half y1 = __float2half_rn(1.0f);
|
||||||
half2_uint32 q0((qa & 0x000f000f) | c0); // half2( q[0] + 1024, q[1] + 1024 )
|
const half y16 = __float2half_rn(1.0f / 16.0f);
|
||||||
half2_uint32 q1((qa & 0x00f000f0) | c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 )
|
|
||||||
qa >>= 8;
|
|
||||||
half2_uint32 q2((qa & 0x000f000f) | c0); // half2( q[4] + 1024, q[5] + 1024 )
|
|
||||||
half2_uint32 q3((qa & 0x00f000f0) | c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 )
|
|
||||||
|
|
||||||
if (scaled)
|
y1y16[0] = __half2half2(y1);
|
||||||
{
|
y1y16[1] = __half2half2(y16);
|
||||||
dq[0] = __hfma2(q0.as_half2, y1y16[0], z1z16[0]); // half2( q[0] * s - z * s, q[1] * s - z * s)
|
}
|
||||||
dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] * s - z * s, q[3] * s - z * s)
|
|
||||||
dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]);
|
__forceinline__ __device__ void dequant_4bit_8_gptq(const uint32_t q_0,
|
||||||
dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]);
|
half2 (&dq)[4],
|
||||||
}
|
half2 (&z1z16)[2],
|
||||||
else
|
half2 (&y1y16)[2],
|
||||||
{
|
int stride, bool scaled) {
|
||||||
dq[0] = __hadd2(q0.as_half2, z1z16[0]); // half2( q[0] - z, q[1] - z )
|
const uint32_t c0 = 0x64006400;
|
||||||
dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] - z, q[3] - z )
|
|
||||||
dq[2] = __hadd2(q2.as_half2, z1z16[0]); // half2( q[4] - z, q[5] - z )
|
uint32_t qa = q_0;
|
||||||
dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); // half2( q[6] - z, q[7] - z )
|
half2_uint32 q0((qa & 0x000f000f) |
|
||||||
}
|
c0); // half2( q[0] + 1024, q[1] + 1024 )
|
||||||
|
half2_uint32 q1((qa & 0x00f000f0) |
|
||||||
|
c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 )
|
||||||
|
qa >>= 8;
|
||||||
|
half2_uint32 q2((qa & 0x000f000f) |
|
||||||
|
c0); // half2( q[4] + 1024, q[5] + 1024 )
|
||||||
|
half2_uint32 q3((qa & 0x00f000f0) |
|
||||||
|
c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 )
|
||||||
|
|
||||||
|
if (scaled) {
|
||||||
|
dq[0] = __hfma2(q0.as_half2, y1y16[0],
|
||||||
|
z1z16[0]); // half2( q[0] * s - z * s, q[1] * s - z * s)
|
||||||
|
dq[1] = __hfma2(q1.as_half2, y1y16[1],
|
||||||
|
z1z16[1]); // half2( q[2] * s - z * s, q[3] * s - z * s)
|
||||||
|
dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]);
|
||||||
|
dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]);
|
||||||
|
} else {
|
||||||
|
dq[0] = __hadd2(q0.as_half2, z1z16[0]); // half2( q[0] - z, q[1] - z )
|
||||||
|
dq[1] = __hfma2(q1.as_half2, y1y16[1],
|
||||||
|
z1z16[1]); // half2( q[2] - z, q[3] - z )
|
||||||
|
dq[2] = __hadd2(q2.as_half2, z1z16[0]); // half2( q[4] - z, q[5] - z )
|
||||||
|
dq[3] = __hfma2(q3.as_half2, y1y16[1],
|
||||||
|
z1z16[1]); // half2( q[6] - z, q[7] - z )
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} // namespace gptq
|
} // namespace gptq
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
@ -10,28 +10,18 @@ Copied from https://github.com/turboderp/exllamav2
|
|||||||
namespace vllm {
|
namespace vllm {
|
||||||
namespace gptq {
|
namespace gptq {
|
||||||
|
|
||||||
__forceinline__ __device__ void shuffle_8bit_4
|
__forceinline__ __device__ void shuffle_8bit_4(uint32_t* q, int stride) {}
|
||||||
(
|
|
||||||
uint32_t* q,
|
|
||||||
int stride
|
|
||||||
)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __device__ void dequant_8bit_8
|
__forceinline__ __device__ void dequant_8bit_8(const uint32_t q_0,
|
||||||
(
|
const uint32_t q_1,
|
||||||
const uint32_t q_0,
|
half2 (&dq)[4], int stride,
|
||||||
const uint32_t q_1,
|
const uint32_t zero) {
|
||||||
half2 (&dq)[4],
|
half dqh[8];
|
||||||
int stride,
|
for (int i = 0; i < 4; i++) dqh[i] = dq_ns(exb(q_0, i * 8, 0xff), zero);
|
||||||
const uint32_t zero
|
for (int i = 0; i < 4; i++) dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), zero);
|
||||||
)
|
|
||||||
{
|
|
||||||
half dqh[8];
|
|
||||||
for (int i = 0; i < 4; i++) dqh[i ] = dq_ns(exb(q_0, i * 8, 0xff), zero);
|
|
||||||
for (int i = 0; i < 4; i++) dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), zero);
|
|
||||||
|
|
||||||
for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
|
for (int i = 0; i < 4; i++)
|
||||||
|
dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gptq
|
} // namespace gptq
|
||||||
|
@ -8,51 +8,47 @@ Copied from https://github.com/turboderp/exllamav2
|
|||||||
namespace vllm {
|
namespace vllm {
|
||||||
namespace gptq {
|
namespace gptq {
|
||||||
|
|
||||||
union half2_uint32
|
union half2_uint32 {
|
||||||
{
|
uint32_t as_uint32;
|
||||||
uint32_t as_uint32;
|
half2 as_half2;
|
||||||
half2 as_half2;
|
__device__ half2_uint32(uint32_t val) : as_uint32(val) {}
|
||||||
__device__ half2_uint32(uint32_t val) : as_uint32(val) {}
|
__device__ half2_uint32(half2 val) : as_half2(val) {}
|
||||||
__device__ half2_uint32(half2 val) : as_half2(val) {}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
union half_uint16
|
union half_uint16 {
|
||||||
{
|
uint16_t as_uint16;
|
||||||
uint16_t as_uint16;
|
half as_half;
|
||||||
half as_half;
|
__device__ half_uint16(uint16_t val) : as_uint16(val) {}
|
||||||
__device__ half_uint16(uint16_t val) : as_uint16(val) {}
|
__device__ half_uint16(half val) : as_half(val) {}
|
||||||
__device__ half_uint16(half val) : as_half(val) {}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Max_scale premultiplied by 1/256
|
// Max_scale premultiplied by 1/256
|
||||||
|
|
||||||
__forceinline__ __device__ half dq_scale(const int qs, const half max_scale)
|
__forceinline__ __device__ half dq_scale(const int qs, const half max_scale) {
|
||||||
{
|
int qs_i = qs + 1;
|
||||||
int qs_i = qs + 1;
|
half qs_h = __int2half_rn(qs_i * qs_i);
|
||||||
half qs_h = __int2half_rn(qs_i * qs_i);
|
qs_h = __hmul(qs_h, max_scale);
|
||||||
qs_h = __hmul(qs_h, max_scale);
|
return qs_h;
|
||||||
return qs_h;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
__forceinline__ __device__ half dq(const int q, const int qzero, const half scale)
|
__forceinline__ __device__ half dq(const int q, const int qzero,
|
||||||
{
|
const half scale) {
|
||||||
return __hmul(__int2half_rn(q - qzero), scale);
|
return __hmul(__int2half_rn(q - qzero), scale);
|
||||||
}
|
}
|
||||||
|
|
||||||
__forceinline__ __device__ half dq_ns(const int q, const int qzero)
|
__forceinline__ __device__ half dq_ns(const int q, const int qzero) {
|
||||||
{
|
// return __hsub(__int2half_rn(q), __int2half_rn(qzero));
|
||||||
//return __hsub(__int2half_rn(q), __int2half_rn(qzero));
|
return __int2half_rn(q - qzero);
|
||||||
return __int2half_rn(q - qzero);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
__forceinline__ __device__ int exb(const uint32_t q, const int shift, const int mask)
|
__forceinline__ __device__ int exb(const uint32_t q, const int shift,
|
||||||
{
|
const int mask) {
|
||||||
return (int)((q >> shift) & mask);
|
return (int)((q >> shift) & mask);
|
||||||
}
|
}
|
||||||
|
|
||||||
__forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, const int shift, const int mask)
|
__forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0,
|
||||||
{
|
const int shift, const int mask) {
|
||||||
return (int)(__funnelshift_rc(q0, q1, shift) & mask);
|
return (int)(__funnelshift_rc(q0, q1, shift) & mask);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gptq
|
} // namespace gptq
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -11,22 +11,23 @@
|
|||||||
|
|
||||||
namespace gptq_marlin {
|
namespace gptq_marlin {
|
||||||
|
|
||||||
// 8 warps are a good choice since every SM has 4 schedulers and having more than 1 warp per
|
// 8 warps are a good choice since every SM has 4 schedulers and having more
|
||||||
// schedule allows some more latency hiding. At the same time, we want relatively few warps to have
|
// than 1 warp per schedule allows some more latency hiding. At the same time,
|
||||||
// many registers per warp and small tiles.
|
// we want relatively few warps to have many registers per warp and small tiles.
|
||||||
static constexpr int default_threads = 256;
|
static constexpr int default_threads = 256;
|
||||||
|
|
||||||
static constexpr int pipe_stages = 4; // 4 pipeline stages fit into shared memory
|
static constexpr int pipe_stages =
|
||||||
|
4; // 4 pipeline stages fit into shared memory
|
||||||
|
|
||||||
static constexpr int min_thread_n = 64;
|
static constexpr int min_thread_n = 64;
|
||||||
static constexpr int min_thread_k = 64;
|
static constexpr int min_thread_k = 64;
|
||||||
|
|
||||||
static constexpr int tile_size = 16;
|
static constexpr int tile_size = 16;
|
||||||
static constexpr int max_par = 16;
|
static constexpr int max_par = 16;
|
||||||
|
|
||||||
template <typename T, int n>
|
template <typename T, int n>
|
||||||
struct Vec {
|
struct Vec {
|
||||||
T elems[n];
|
T elems[n];
|
||||||
__device__ T& operator[](int i) { return elems[i]; }
|
__device__ T& operator[](int i) { return elems[i]; }
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -35,30 +36,35 @@ using I4 = Vec<int, 4>;
|
|||||||
constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
|
constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
|
||||||
|
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||||
// No support for async
|
// No support for async
|
||||||
#else
|
#else
|
||||||
|
|
||||||
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) {
|
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
|
||||||
|
bool pred = true) {
|
||||||
const int BYTES = 16;
|
const int BYTES = 16;
|
||||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||||
asm volatile("{\n"
|
asm volatile(
|
||||||
" .reg .pred p;\n"
|
"{\n"
|
||||||
" setp.ne.b32 p, %0, 0;\n"
|
" .reg .pred p;\n"
|
||||||
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
|
" setp.ne.b32 p, %0, 0;\n"
|
||||||
"}\n" ::"r"((int)pred),
|
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
|
||||||
"r"(smem), "l"(glob_ptr), "n"(BYTES));
|
"}\n" ::"r"((int)pred),
|
||||||
|
"r"(smem), "l"(glob_ptr), "n"(BYTES));
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
|
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
|
||||||
const int BYTES = 16;
|
const int BYTES = 16;
|
||||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||||
asm volatile("{\n"
|
asm volatile(
|
||||||
" cp.async.cg.shared.global [%0], [%1], %2;\n"
|
"{\n"
|
||||||
"}\n" ::"r"(smem),
|
" cp.async.cg.shared.global [%0], [%1], %2;\n"
|
||||||
"l"(glob_ptr), "n"(BYTES));
|
"}\n" ::"r"(smem),
|
||||||
|
"l"(glob_ptr), "n"(BYTES));
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ inline void cp_async_fence() { asm volatile("cp.async.commit_group;\n" ::); }
|
__device__ inline void cp_async_fence() {
|
||||||
|
asm volatile("cp.async.commit_group;\n" ::);
|
||||||
|
}
|
||||||
|
|
||||||
template <int n>
|
template <int n>
|
||||||
__device__ inline void cp_async_wait() {
|
__device__ inline void cp_async_wait() {
|
||||||
@ -67,4 +73,4 @@ __device__ inline void cp_async_wait() {
|
|||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
} // namespace gptq_marlin
|
} // namespace gptq_marlin
|
||||||
|
@ -5,58 +5,73 @@
|
|||||||
#include <cuda_fp16.h>
|
#include <cuda_fp16.h>
|
||||||
#include <cuda_bf16.h>
|
#include <cuda_bf16.h>
|
||||||
|
|
||||||
|
|
||||||
namespace gptq_marlin {
|
namespace gptq_marlin {
|
||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
class ScalarType {
|
class ScalarType {};
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
class ScalarType<half> {
|
class ScalarType<half> {
|
||||||
public:
|
public:
|
||||||
using scalar_t = half;
|
using scalar_t = half;
|
||||||
using scalar_t2 = half2;
|
using scalar_t2 = half2;
|
||||||
|
|
||||||
// Matrix fragments for tensor core instructions; their precise layout is
|
// Matrix fragments for tensor core instructions; their precise layout is
|
||||||
// documented here:
|
// documented here:
|
||||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
|
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
|
||||||
using FragA = Vec<half2, 4>;
|
using FragA = Vec<half2, 4>;
|
||||||
using FragB = Vec<half2, 2>;
|
using FragB = Vec<half2, 2>;
|
||||||
using FragC = Vec<float, 4>;
|
using FragC = Vec<float, 4>;
|
||||||
using FragS = Vec<half2, 1>;
|
using FragS = Vec<half2, 1>;
|
||||||
|
|
||||||
static __device__ float inline num2float(const half x) { return __half2float(x); }
|
static __device__ float inline num2float(const half x) {
|
||||||
|
return __half2float(x);
|
||||||
|
}
|
||||||
|
|
||||||
static __device__ half2 inline num2num2(const half x) { return __half2half2(x); }
|
static __device__ half2 inline num2num2(const half x) {
|
||||||
|
return __half2half2(x);
|
||||||
|
}
|
||||||
|
|
||||||
static __device__ half2 inline nums2num2(const half x1, const half x2) { return __halves2half2(x1, x2); }
|
static __device__ half2 inline nums2num2(const half x1, const half x2) {
|
||||||
|
return __halves2half2(x1, x2);
|
||||||
|
}
|
||||||
|
|
||||||
static __host__ __device__ half inline float2num(const float x) { return __float2half(x); }
|
static __host__ __device__ half inline float2num(const float x) {
|
||||||
|
return __float2half(x);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
class ScalarType<nv_bfloat16> {
|
class ScalarType<nv_bfloat16> {
|
||||||
public:
|
public:
|
||||||
using scalar_t = nv_bfloat16;
|
using scalar_t = nv_bfloat16;
|
||||||
using scalar_t2 = nv_bfloat162;
|
using scalar_t2 = nv_bfloat162;
|
||||||
|
|
||||||
using FragA = Vec<nv_bfloat162, 4>;
|
using FragA = Vec<nv_bfloat162, 4>;
|
||||||
using FragB = Vec<nv_bfloat162, 2>;
|
using FragB = Vec<nv_bfloat162, 2>;
|
||||||
using FragC = Vec<float, 4>;
|
using FragC = Vec<float, 4>;
|
||||||
using FragS = Vec<nv_bfloat162, 1>;
|
using FragS = Vec<nv_bfloat162, 1>;
|
||||||
|
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||||
static __device__ float inline num2float(const nv_bfloat16 x) { return __bfloat162float(x); }
|
static __device__ float inline num2float(const nv_bfloat16 x) {
|
||||||
|
return __bfloat162float(x);
|
||||||
|
}
|
||||||
|
|
||||||
static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) { return __bfloat162bfloat162(x); }
|
static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) {
|
||||||
|
return __bfloat162bfloat162(x);
|
||||||
|
}
|
||||||
|
|
||||||
static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1, const nv_bfloat16 x2) { return __halves2bfloat162(x1, x2); }
|
static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1,
|
||||||
|
const nv_bfloat16 x2) {
|
||||||
|
return __halves2bfloat162(x1, x2);
|
||||||
|
}
|
||||||
|
|
||||||
static __host__ __device__ nv_bfloat16 inline float2num(const float x) { return __float2bfloat16(x); }
|
static __host__ __device__ nv_bfloat16 inline float2num(const float x) {
|
||||||
|
return __float2bfloat16(x);
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
} // namespace gptq_marlin
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
@ -12,14 +12,14 @@ static constexpr int tile_n_size = tile_k_size * 4;
|
|||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||||
|
|
||||||
template <int const num_threads, int const num_bits, bool const has_perm>
|
template <int const num_threads, int const num_bits, bool const has_perm>
|
||||||
__global__ void
|
__global__ void marlin_repack_kernel(
|
||||||
marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
|
uint32_t const* __restrict__ b_q_weight_ptr,
|
||||||
uint32_t const *__restrict__ perm_ptr,
|
uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,
|
||||||
uint32_t *__restrict__ out_ptr, int size_k, int size_n) {}
|
int size_k, int size_n) {}
|
||||||
|
|
||||||
} // namespace gptq_marlin
|
} // namespace gptq_marlin
|
||||||
|
|
||||||
torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
|
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
||||||
int64_t size_k, int64_t size_n,
|
int64_t size_k, int64_t size_n,
|
||||||
int64_t num_bits) {
|
int64_t num_bits) {
|
||||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||||
@ -30,10 +30,10 @@ torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
|
|||||||
#else
|
#else
|
||||||
|
|
||||||
template <int const num_threads, int const num_bits, bool const has_perm>
|
template <int const num_threads, int const num_bits, bool const has_perm>
|
||||||
__global__ void
|
__global__ void marlin_repack_kernel(
|
||||||
marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
|
uint32_t const* __restrict__ b_q_weight_ptr,
|
||||||
uint32_t const *__restrict__ perm_ptr,
|
uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,
|
||||||
uint32_t *__restrict__ out_ptr, int size_k, int size_n) {
|
int size_k, int size_n) {
|
||||||
constexpr int pack_factor = 32 / num_bits;
|
constexpr int pack_factor = 32 / num_bits;
|
||||||
|
|
||||||
int k_tiles = size_k / tile_k_size;
|
int k_tiles = size_k / tile_k_size;
|
||||||
@ -61,8 +61,8 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
|
|||||||
|
|
||||||
constexpr int perm_size = tile_k_size / 4;
|
constexpr int perm_size = tile_k_size / 4;
|
||||||
|
|
||||||
int4 *sh_perm_ptr = sh;
|
int4* sh_perm_ptr = sh;
|
||||||
int4 *sh_pipe_ptr = sh_perm_ptr;
|
int4* sh_pipe_ptr = sh_perm_ptr;
|
||||||
if constexpr (has_perm) {
|
if constexpr (has_perm) {
|
||||||
sh_pipe_ptr += perm_size;
|
sh_pipe_ptr += perm_size;
|
||||||
}
|
}
|
||||||
@ -76,7 +76,7 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
|
|||||||
auto load_perm_to_shared = [&](int k_tile_id) {
|
auto load_perm_to_shared = [&](int k_tile_id) {
|
||||||
int first_k_int4 = (k_tile_id * tile_k_size) / 4;
|
int first_k_int4 = (k_tile_id * tile_k_size) / 4;
|
||||||
|
|
||||||
int4 const *perm_int4_ptr = reinterpret_cast<int4 const *>(perm_ptr);
|
int4 const* perm_int4_ptr = reinterpret_cast<int4 const*>(perm_ptr);
|
||||||
|
|
||||||
if (threadIdx.x < perm_size) {
|
if (threadIdx.x < perm_size) {
|
||||||
sh_perm_ptr[threadIdx.x] = perm_int4_ptr[first_k_int4 + threadIdx.x];
|
sh_perm_ptr[threadIdx.x] = perm_int4_ptr[first_k_int4 + threadIdx.x];
|
||||||
@ -92,22 +92,22 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
|
|||||||
|
|
||||||
int first_n = n_tile_id * tile_n_size;
|
int first_n = n_tile_id * tile_n_size;
|
||||||
|
|
||||||
int4 *sh_ptr = sh_pipe_ptr + stage_size * pipe;
|
int4* sh_ptr = sh_pipe_ptr + stage_size * pipe;
|
||||||
|
|
||||||
if constexpr (has_perm) {
|
if constexpr (has_perm) {
|
||||||
if (threadIdx.x < stage_size) {
|
if (threadIdx.x < stage_size) {
|
||||||
int k_id = threadIdx.x / stage_n_threads;
|
int k_id = threadIdx.x / stage_n_threads;
|
||||||
int n_id = threadIdx.x % stage_n_threads;
|
int n_id = threadIdx.x % stage_n_threads;
|
||||||
|
|
||||||
uint32_t const *sh_perm_int_ptr =
|
uint32_t const* sh_perm_int_ptr =
|
||||||
reinterpret_cast<uint32_t const *>(sh_perm_ptr);
|
reinterpret_cast<uint32_t const*>(sh_perm_ptr);
|
||||||
|
|
||||||
int src_k = sh_perm_int_ptr[k_id];
|
int src_k = sh_perm_int_ptr[k_id];
|
||||||
int src_k_packed = src_k / pack_factor;
|
int src_k_packed = src_k / pack_factor;
|
||||||
|
|
||||||
cp_async4(
|
cp_async4(
|
||||||
&sh_ptr[k_id * stage_n_threads + n_id],
|
&sh_ptr[k_id * stage_n_threads + n_id],
|
||||||
reinterpret_cast<int4 const *>(&(
|
reinterpret_cast<int4 const*>(&(
|
||||||
b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)])));
|
b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)])));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -120,7 +120,7 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
|
|||||||
int first_k_packed = first_k / pack_factor;
|
int first_k_packed = first_k / pack_factor;
|
||||||
|
|
||||||
cp_async4(&sh_ptr[k_id * stage_n_threads + n_id],
|
cp_async4(&sh_ptr[k_id * stage_n_threads + n_id],
|
||||||
reinterpret_cast<int4 const *>(
|
reinterpret_cast<int4 const*>(
|
||||||
&(b_q_weight_ptr[(first_k_packed + k_id) * size_n +
|
&(b_q_weight_ptr[(first_k_packed + k_id) * size_n +
|
||||||
first_n + (n_id * 4)])));
|
first_n + (n_id * 4)])));
|
||||||
}
|
}
|
||||||
@ -151,10 +151,10 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
|
|||||||
constexpr int sh_stride = 64;
|
constexpr int sh_stride = 64;
|
||||||
constexpr uint32_t mask = (1 << num_bits) - 1;
|
constexpr uint32_t mask = (1 << num_bits) - 1;
|
||||||
|
|
||||||
int4 *sh_stage_ptr = sh_pipe_ptr + stage_size * pipe;
|
int4* sh_stage_ptr = sh_pipe_ptr + stage_size * pipe;
|
||||||
uint32_t *sh_stage_int_ptr = reinterpret_cast<uint32_t *>(sh_stage_ptr);
|
uint32_t* sh_stage_int_ptr = reinterpret_cast<uint32_t*>(sh_stage_ptr);
|
||||||
|
|
||||||
uint32_t *sh_perm_int_ptr = reinterpret_cast<uint32_t *>(sh_perm_ptr);
|
uint32_t* sh_perm_int_ptr = reinterpret_cast<uint32_t*>(sh_perm_ptr);
|
||||||
|
|
||||||
uint32_t vals[8];
|
uint32_t vals[8];
|
||||||
|
|
||||||
@ -176,17 +176,16 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
|
|||||||
}
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
|
|
||||||
uint32_t b1_vals[tile_ints];
|
uint32_t b1_vals[tile_ints];
|
||||||
uint32_t b2_vals[tile_ints];
|
uint32_t b2_vals[tile_ints];
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < tile_ints; i++) {
|
for (int i = 0; i < tile_ints; i++) {
|
||||||
b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i];
|
b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i];
|
||||||
b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i];
|
b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i];
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < 4; i++) {
|
for (int i = 0; i < 4; i++) {
|
||||||
int cur_elem = tc_row + tc_offsets[i];
|
int cur_elem = tc_row + tc_offsets[i];
|
||||||
int cur_int = cur_elem / pack_factor;
|
int cur_int = cur_elem / pack_factor;
|
||||||
@ -206,7 +205,7 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
|
|||||||
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
|
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
|
||||||
|
|
||||||
uint32_t res = 0;
|
uint32_t res = 0;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < 8; i++) {
|
for (int i = 0; i < 8; i++) {
|
||||||
res |= vals[pack_idx[i]] << (i * 4);
|
res |= vals[pack_idx[i]] << (i * 4);
|
||||||
}
|
}
|
||||||
@ -218,7 +217,7 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
|
|||||||
|
|
||||||
uint32_t res1 = 0;
|
uint32_t res1 = 0;
|
||||||
uint32_t res2 = 0;
|
uint32_t res2 = 0;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < 4; i++) {
|
for (int i = 0; i < 4; i++) {
|
||||||
res1 |= vals[pack_idx[i]] << (i * 8);
|
res1 |= vals[pack_idx[i]] << (i * 8);
|
||||||
res2 |= vals[4 + pack_idx[i]] << (i * 8);
|
res2 |= vals[4 + pack_idx[i]] << (i * 8);
|
||||||
@ -230,14 +229,14 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
|
|||||||
};
|
};
|
||||||
|
|
||||||
auto start_pipes = [&](int k_tile_id, int n_tile_id) {
|
auto start_pipes = [&](int k_tile_id, int n_tile_id) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int pipe = 0; pipe < repack_stages - 1; pipe++) {
|
for (int pipe = 0; pipe < repack_stages - 1; pipe++) {
|
||||||
fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe);
|
fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe);
|
||||||
}
|
}
|
||||||
|
|
||||||
wait_for_stage();
|
wait_for_stage();
|
||||||
};
|
};
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) {
|
for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) {
|
||||||
int n_tile_id = 0;
|
int n_tile_id = 0;
|
||||||
|
|
||||||
@ -248,7 +247,7 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
|
|||||||
start_pipes(k_tile_id, n_tile_id);
|
start_pipes(k_tile_id, n_tile_id);
|
||||||
|
|
||||||
while (n_tile_id < n_tiles) {
|
while (n_tile_id < n_tiles) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int pipe = 0; pipe < repack_stages; pipe++) {
|
for (int pipe = 0; pipe < repack_stages; pipe++) {
|
||||||
fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id,
|
fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id,
|
||||||
n_tile_id + pipe + repack_stages - 1);
|
n_tile_id + pipe + repack_stages - 1);
|
||||||
@ -260,21 +259,21 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gptq_marlin
|
} // namespace gptq_marlin
|
||||||
|
|
||||||
#define CALL_IF(NUM_BITS, HAS_PERM) \
|
#define CALL_IF(NUM_BITS, HAS_PERM) \
|
||||||
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
|
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
|
||||||
cudaFuncSetAttribute( \
|
cudaFuncSetAttribute( \
|
||||||
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, \
|
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, \
|
||||||
NUM_BITS, HAS_PERM>, \
|
NUM_BITS, HAS_PERM>, \
|
||||||
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
||||||
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, NUM_BITS, \
|
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, NUM_BITS, \
|
||||||
HAS_PERM> \
|
HAS_PERM> \
|
||||||
<<<blocks, gptq_marlin::repack_threads, max_shared_mem, stream>>>( \
|
<<<blocks, gptq_marlin::repack_threads, max_shared_mem, stream>>>( \
|
||||||
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
|
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
|
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
||||||
int64_t size_k, int64_t size_n,
|
int64_t size_k, int64_t size_n,
|
||||||
int64_t num_bits) {
|
int64_t num_bits) {
|
||||||
// Verify compatibility with marlin tile of 16x64
|
// Verify compatibility with marlin tile of 16x64
|
||||||
@ -318,11 +317,10 @@ torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
|
|||||||
bool has_perm = perm.size(0) != 0;
|
bool has_perm = perm.size(0) != 0;
|
||||||
|
|
||||||
// Get ptrs
|
// Get ptrs
|
||||||
uint32_t const *b_q_weight_ptr =
|
uint32_t const* b_q_weight_ptr =
|
||||||
reinterpret_cast<uint32_t const *>(b_q_weight.data_ptr());
|
reinterpret_cast<uint32_t const*>(b_q_weight.data_ptr());
|
||||||
uint32_t const *perm_ptr =
|
uint32_t const* perm_ptr = reinterpret_cast<uint32_t const*>(perm.data_ptr());
|
||||||
reinterpret_cast<uint32_t const *>(perm.data_ptr());
|
uint32_t* out_ptr = reinterpret_cast<uint32_t*>(out.data_ptr());
|
||||||
uint32_t *out_ptr = reinterpret_cast<uint32_t *>(out.data_ptr());
|
|
||||||
|
|
||||||
// Get dev info
|
// Get dev info
|
||||||
int dev = b_q_weight.get_device();
|
int dev = b_q_weight.get_device();
|
||||||
|
@ -25,7 +25,10 @@
|
|||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
|
||||||
template <typename T> inline std::string str(T x) { return std::to_string(x); }
|
template <typename T>
|
||||||
|
inline std::string str(T x) {
|
||||||
|
return std::to_string(x);
|
||||||
|
}
|
||||||
|
|
||||||
namespace marlin {
|
namespace marlin {
|
||||||
|
|
||||||
@ -38,9 +41,10 @@ constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; }
|
|||||||
// corresponding index accesses must be compile-time constants, which is why we
|
// corresponding index accesses must be compile-time constants, which is why we
|
||||||
// extensively use `#pragma unroll` throughout the kernel code to guarantee
|
// extensively use `#pragma unroll` throughout the kernel code to guarantee
|
||||||
// this.
|
// this.
|
||||||
template <typename T, int n> struct Vec {
|
template <typename T, int n>
|
||||||
|
struct Vec {
|
||||||
T elems[n];
|
T elems[n];
|
||||||
__device__ T &operator[](int i) { return elems[i]; }
|
__device__ T& operator[](int i) { return elems[i]; }
|
||||||
};
|
};
|
||||||
|
|
||||||
using I4 = Vec<int, 4>;
|
using I4 = Vec<int, 4>;
|
||||||
@ -51,29 +55,32 @@ using I4 = Vec<int, 4>;
|
|||||||
using FragA = Vec<half2, 4>;
|
using FragA = Vec<half2, 4>;
|
||||||
using FragB = Vec<half2, 2>;
|
using FragB = Vec<half2, 2>;
|
||||||
using FragC = Vec<float, 4>;
|
using FragC = Vec<float, 4>;
|
||||||
using FragS = Vec<half2, 1>; // quantization scales
|
using FragS = Vec<half2, 1>; // quantization scales
|
||||||
|
|
||||||
// Predicated asynchronous global->shared copy; used for inputs A where we apply
|
// Predicated asynchronous global->shared copy; used for inputs A where we apply
|
||||||
// predication to handle batchsizes that are not multiples of 16.
|
// predication to handle batchsizes that are not multiples of 16.
|
||||||
__device__ inline void cp_async4_pred(void *smem_ptr, const void *glob_ptr,
|
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
|
||||||
bool pred = true) {
|
bool pred = true) {
|
||||||
const int BYTES = 16;
|
const int BYTES = 16;
|
||||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||||
asm volatile("{\n"
|
asm volatile(
|
||||||
" .reg .pred p;\n"
|
"{\n"
|
||||||
" setp.ne.b32 p, %0, 0;\n"
|
" .reg .pred p;\n"
|
||||||
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
|
" setp.ne.b32 p, %0, 0;\n"
|
||||||
"}\n" ::"r"((int)pred),
|
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
|
||||||
"r"(smem), "l"(glob_ptr), "n"(BYTES));
|
"}\n" ::"r"((int)pred),
|
||||||
|
"r"(smem), "l"(glob_ptr), "n"(BYTES));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Asynchronous global->shared copy
|
// Asynchronous global->shared copy
|
||||||
__device__ inline void cp_async4(void *smem_ptr, const void *glob_ptr) {
|
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
|
||||||
const int BYTES = 16;
|
const int BYTES = 16;
|
||||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||||
asm volatile("{\n"
|
asm volatile(
|
||||||
" cp.async.cg.shared.global [%0], [%1], %2;\n"
|
"{\n"
|
||||||
"}\n" :: "r"(smem), "l"(glob_ptr), "n"(BYTES));
|
" cp.async.cg.shared.global [%0], [%1], %2;\n"
|
||||||
|
"}\n" ::"r"(smem),
|
||||||
|
"l"(glob_ptr), "n"(BYTES));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Async copy fence.
|
// Async copy fence.
|
||||||
@ -82,28 +89,30 @@ __device__ inline void cp_async_fence() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Wait until at most `n` async copy stages are still pending.
|
// Wait until at most `n` async copy stages are still pending.
|
||||||
template <int n> __device__ inline void cp_async_wait() {
|
template <int n>
|
||||||
|
__device__ inline void cp_async_wait() {
|
||||||
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
|
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
|
||||||
}
|
}
|
||||||
|
|
||||||
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
|
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
|
||||||
// output/accumulation.
|
// output/accumulation.
|
||||||
__device__ inline void mma(const FragA &a_frag, const FragB &frag_b,
|
__device__ inline void mma(const FragA& a_frag, const FragB& frag_b,
|
||||||
FragC &frag_c) {
|
FragC& frag_c) {
|
||||||
const uint32_t *a = reinterpret_cast<const uint32_t *>(&a_frag);
|
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
|
||||||
const uint32_t *b = reinterpret_cast<const uint32_t *>(&frag_b);
|
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
|
||||||
float *c = reinterpret_cast<float *>(&frag_c);
|
float* c = reinterpret_cast<float*>(&frag_c);
|
||||||
asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
|
asm volatile(
|
||||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
|
||||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]),
|
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||||
"r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||||
|
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
|
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
|
||||||
// memory, directly in tensor core layout.
|
// memory, directly in tensor core layout.
|
||||||
__device__ inline void ldsm4(FragA &frag_a, const void *smem_ptr) {
|
__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) {
|
||||||
uint32_t *a = reinterpret_cast<uint32_t *>(&frag_a);
|
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
|
||||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||||
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
|
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
|
||||||
: "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
|
: "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
|
||||||
@ -113,7 +122,8 @@ __device__ inline void ldsm4(FragA &frag_a, const void *smem_ptr) {
|
|||||||
// Lookup-table based 3-input logical operation; explicitly used for
|
// Lookup-table based 3-input logical operation; explicitly used for
|
||||||
// dequantization as the compiler does not seem to automatically recognize it in
|
// dequantization as the compiler does not seem to automatically recognize it in
|
||||||
// all cases.
|
// all cases.
|
||||||
template <int lut> __device__ inline int lop3(int a, int b, int c) {
|
template <int lut>
|
||||||
|
__device__ inline int lop3(int a, int b, int c) {
|
||||||
int res;
|
int res;
|
||||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||||
: "=r"(res)
|
: "=r"(res)
|
||||||
@ -138,24 +148,24 @@ __device__ inline FragB dequant(int q) {
|
|||||||
const int MUL = 0x2c002c00;
|
const int MUL = 0x2c002c00;
|
||||||
const int ADD = 0xd480d480;
|
const int ADD = 0xd480d480;
|
||||||
FragB frag_b;
|
FragB frag_b;
|
||||||
frag_b[0] = __hsub2(*reinterpret_cast<half2 *>(&lo),
|
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
||||||
*reinterpret_cast<const half2 *>(&SUB));
|
*reinterpret_cast<const half2*>(&SUB));
|
||||||
frag_b[1] = __hfma2(*reinterpret_cast<half2 *>(&hi),
|
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
|
||||||
*reinterpret_cast<const half2 *>(&MUL),
|
*reinterpret_cast<const half2*>(&MUL),
|
||||||
*reinterpret_cast<const half2 *>(&ADD));
|
*reinterpret_cast<const half2*>(&ADD));
|
||||||
return frag_b;
|
return frag_b;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Multiply dequantized values by the corresponding quantization scale; used
|
// Multiply dequantized values by the corresponding quantization scale; used
|
||||||
// only for grouped quantization.
|
// only for grouped quantization.
|
||||||
__device__ inline void scale(FragB &frag_b, FragS &frag_s, int i) {
|
__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) {
|
||||||
half2 s = __half2half2(reinterpret_cast<__half *>(&frag_s)[i]);
|
half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]);
|
||||||
frag_b[0] = __hmul2(frag_b[0], s);
|
frag_b[0] = __hmul2(frag_b[0], s);
|
||||||
frag_b[1] = __hmul2(frag_b[1], s);
|
frag_b[1] = __hmul2(frag_b[1], s);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wait until barrier reaches `count`, then lock for current threadblock.
|
// Wait until barrier reaches `count`, then lock for current threadblock.
|
||||||
__device__ inline void barrier_acquire(int *lock, int count) {
|
__device__ inline void barrier_acquire(int* lock, int count) {
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
int state = -1;
|
int state = -1;
|
||||||
do
|
do
|
||||||
@ -170,7 +180,7 @@ __device__ inline void barrier_acquire(int *lock, int count) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Release barrier and increment visitation count.
|
// Release barrier and increment visitation count.
|
||||||
__device__ inline void barrier_release(int *lock, bool reset = false) {
|
__device__ inline void barrier_release(int* lock, bool reset = false) {
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
if (reset) {
|
if (reset) {
|
||||||
@ -187,26 +197,27 @@ __device__ inline void barrier_release(int *lock, bool reset = false) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <const int threads, // number of threads in a threadblock
|
template <const int threads, // number of threads in a threadblock
|
||||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||||
// dimension (batchsize) of the threadblock
|
// dimension (batchsize) of the
|
||||||
const int thread_n_blocks, // same for n dimension (output)
|
// threadblock
|
||||||
const int thread_k_blocks, // same for k dimension (reduction)
|
const int thread_n_blocks, // same for n dimension (output)
|
||||||
const int stages, // number of stages for the async global->shared
|
const int thread_k_blocks, // same for k dimension (reduction)
|
||||||
// fetch pipeline
|
const int stages, // number of stages for the async global->shared
|
||||||
const int group_blocks = -1 // number of consecutive 16x16 blocks with
|
// fetch pipeline
|
||||||
// a separate quantization scale
|
const int group_blocks = -1 // number of consecutive 16x16 blocks
|
||||||
|
// with a separate quantization scale
|
||||||
>
|
>
|
||||||
__global__ void
|
__global__ void Marlin(
|
||||||
Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
const int4* __restrict__ A, // fp16 input matrix of shape mxk
|
||||||
const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn
|
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
|
||||||
int4 *__restrict__ C, // fp16 output buffer of shape mxn
|
int4* __restrict__ C, // fp16 output buffer of shape mxn
|
||||||
const int4
|
const int4* __restrict__ s, // fp16 quantization scales of shape
|
||||||
*__restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn
|
// (k/groupsize)xn
|
||||||
int prob_m, // batch dimension m
|
int prob_m, // batch dimension m
|
||||||
int prob_n, // output dimension n
|
int prob_n, // output dimension n
|
||||||
int prob_k, // reduction dimension k
|
int prob_k, // reduction dimension k
|
||||||
int *locks // extra global storage for barrier synchronization
|
int* locks // extra global storage for barrier synchronization
|
||||||
) {
|
) {
|
||||||
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
|
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
|
||||||
// same size, which might involve multiple column "slices" (of width 16 *
|
// same size, which might involve multiple column "slices" (of width 16 *
|
||||||
@ -241,11 +252,11 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|||||||
int slice_row = (iters * blockIdx.x) % k_tiles;
|
int slice_row = (iters * blockIdx.x) % k_tiles;
|
||||||
int slice_col_par = (iters * blockIdx.x) / k_tiles;
|
int slice_col_par = (iters * blockIdx.x) / k_tiles;
|
||||||
int slice_col = slice_col_par;
|
int slice_col = slice_col_par;
|
||||||
int slice_iters; // number of threadblock tiles in the current slice
|
int slice_iters; // number of threadblock tiles in the current slice
|
||||||
int slice_count =
|
int slice_count =
|
||||||
0; // total number of active threadblocks in the current slice
|
0; // total number of active threadblocks in the current slice
|
||||||
int slice_idx; // index of threadblock in current slice; numbered bottom to
|
int slice_idx; // index of threadblock in current slice; numbered bottom to
|
||||||
// top
|
// top
|
||||||
|
|
||||||
// We can easily implement parallel problem execution by just remapping
|
// We can easily implement parallel problem execution by just remapping
|
||||||
// indices and advancing global pointers
|
// indices and advancing global pointers
|
||||||
@ -261,27 +272,22 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|||||||
auto init_slice = [&]() {
|
auto init_slice = [&]() {
|
||||||
slice_iters =
|
slice_iters =
|
||||||
iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);
|
iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);
|
||||||
if (slice_iters < 0 || slice_col_par >= n_tiles * parallel)
|
if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0;
|
||||||
slice_iters = 0;
|
if (slice_iters == 0) return;
|
||||||
if (slice_iters == 0)
|
if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row;
|
||||||
return;
|
|
||||||
if (slice_row + slice_iters > k_tiles)
|
|
||||||
slice_iters = k_tiles - slice_row;
|
|
||||||
slice_count = 1;
|
slice_count = 1;
|
||||||
slice_idx = 0;
|
slice_idx = 0;
|
||||||
int col_first = iters * ceildiv(k_tiles * slice_col_par, iters);
|
int col_first = iters * ceildiv(k_tiles * slice_col_par, iters);
|
||||||
if (col_first <= k_tiles * (slice_col_par + 1)) {
|
if (col_first <= k_tiles * (slice_col_par + 1)) {
|
||||||
int col_off = col_first - k_tiles * slice_col_par;
|
int col_off = col_first - k_tiles * slice_col_par;
|
||||||
slice_count = ceildiv(k_tiles - col_off, iters);
|
slice_count = ceildiv(k_tiles - col_off, iters);
|
||||||
if (col_off > 0)
|
if (col_off > 0) slice_count++;
|
||||||
slice_count++;
|
|
||||||
int delta_first = iters * blockIdx.x - col_first;
|
int delta_first = iters * blockIdx.x - col_first;
|
||||||
if (delta_first < 0 || (col_off == 0 && delta_first == 0))
|
if (delta_first < 0 || (col_off == 0 && delta_first == 0))
|
||||||
slice_idx = slice_count - 1;
|
slice_idx = slice_count - 1;
|
||||||
else {
|
else {
|
||||||
slice_idx = slice_count - 1 - delta_first / iters;
|
slice_idx = slice_count - 1 - delta_first / iters;
|
||||||
if (col_off > 0)
|
if (col_off > 0) slice_idx--;
|
||||||
slice_idx--;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (slice_col == n_tiles) {
|
if (slice_col == n_tiles) {
|
||||||
@ -293,29 +299,30 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|||||||
};
|
};
|
||||||
init_slice();
|
init_slice();
|
||||||
|
|
||||||
int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory
|
int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory
|
||||||
// We typically use `constexpr` to indicate that this value is a compile-time
|
// We typically use `constexpr` to indicate that this value is a compile-time
|
||||||
// constant
|
// constant
|
||||||
constexpr int a_sh_stride =
|
constexpr int a_sh_stride =
|
||||||
16 * thread_k_blocks / 8; // stride of an A matrix tile in shared memory
|
16 * thread_k_blocks / 8; // stride of an A matrix tile in shared memory
|
||||||
constexpr int a_gl_rd_delta_o =
|
constexpr int a_gl_rd_delta_o =
|
||||||
16 * thread_k_blocks /
|
16 * thread_k_blocks /
|
||||||
8; // delta between subsequent A tiles in global memory
|
8; // delta between subsequent A tiles in global memory
|
||||||
int a_gl_rd_delta_i =
|
int a_gl_rd_delta_i =
|
||||||
a_gl_stride *
|
a_gl_stride *
|
||||||
(threads / a_gl_rd_delta_o); // between subsequent accesses within a tile
|
(threads / a_gl_rd_delta_o); // between subsequent accesses within a tile
|
||||||
constexpr int a_sh_wr_delta =
|
constexpr int a_sh_wr_delta =
|
||||||
a_sh_stride * (threads / a_gl_rd_delta_o); // between shared memory writes
|
a_sh_stride *
|
||||||
|
(threads / a_gl_rd_delta_o); // between shared memory writes
|
||||||
constexpr int a_sh_rd_delta_o =
|
constexpr int a_sh_rd_delta_o =
|
||||||
2 * ((threads / 32) /
|
2 * ((threads / 32) /
|
||||||
(thread_n_blocks / 4)); // between shared memory tile reads
|
(thread_n_blocks / 4)); // between shared memory tile reads
|
||||||
constexpr int a_sh_rd_delta_i =
|
constexpr int a_sh_rd_delta_i =
|
||||||
a_sh_stride * 16; // within a shared memory tile
|
a_sh_stride * 16; // within a shared memory tile
|
||||||
constexpr int a_sh_stage =
|
constexpr int a_sh_stage =
|
||||||
a_sh_stride * (16 * thread_m_blocks); // overall size of a tile
|
a_sh_stride * (16 * thread_m_blocks); // overall size of a tile
|
||||||
constexpr int a_sh_wr_iters =
|
constexpr int a_sh_wr_iters =
|
||||||
ceildiv(a_sh_stage,
|
ceildiv(a_sh_stage,
|
||||||
a_sh_wr_delta); // number of shared write iterations for a tile
|
a_sh_wr_delta); // number of shared write iterations for a tile
|
||||||
|
|
||||||
int b_gl_stride = 16 * prob_n / 32;
|
int b_gl_stride = 16 * prob_n / 32;
|
||||||
constexpr int b_sh_stride = 32 * thread_n_blocks / 4;
|
constexpr int b_sh_stride = 32 * thread_n_blocks / 4;
|
||||||
@ -368,7 +375,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|||||||
// needed if there are more threads than required for a certain tilesize or
|
// needed if there are more threads than required for a certain tilesize or
|
||||||
// when the batchsize is not a multiple of 16.
|
// when the batchsize is not a multiple of 16.
|
||||||
bool a_sh_wr_pred[a_sh_wr_iters];
|
bool a_sh_wr_pred[a_sh_wr_iters];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < a_sh_wr_iters; i++)
|
for (int i = 0; i < a_sh_wr_iters; i++)
|
||||||
a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m;
|
a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m;
|
||||||
bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
|
bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
|
||||||
@ -387,13 +394,13 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|||||||
// loop unrolls, all shared memory accesses are static, we simply precompute
|
// loop unrolls, all shared memory accesses are static, we simply precompute
|
||||||
// both transformed reads and writes.
|
// both transformed reads and writes.
|
||||||
int a_sh_wr_trans[a_sh_wr_iters];
|
int a_sh_wr_trans[a_sh_wr_iters];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < a_sh_wr_iters; i++)
|
for (int i = 0; i < a_sh_wr_iters; i++)
|
||||||
a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr);
|
a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr);
|
||||||
int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks];
|
int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < b_sh_wr_iters; i++) {
|
for (int i = 0; i < b_sh_wr_iters; i++) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < thread_m_blocks; j++)
|
for (int j = 0; j < thread_m_blocks; j++)
|
||||||
a_sh_rd_trans[i][j] =
|
a_sh_rd_trans[i][j] =
|
||||||
transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd);
|
transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd);
|
||||||
@ -403,16 +410,16 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|||||||
// runtime; we break dependencies between subsequent accesses with a tile by
|
// runtime; we break dependencies between subsequent accesses with a tile by
|
||||||
// maintining multiple pointers (we have enough registers), a tiny
|
// maintining multiple pointers (we have enough registers), a tiny
|
||||||
// optimization.
|
// optimization.
|
||||||
const int4 *B_ptr[b_sh_wr_iters];
|
const int4* B_ptr[b_sh_wr_iters];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < b_sh_wr_iters; i++)
|
for (int i = 0; i < b_sh_wr_iters; i++)
|
||||||
B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;
|
B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;
|
||||||
|
|
||||||
extern __shared__ int4 sh[];
|
extern __shared__ int4 sh[];
|
||||||
// Shared memory storage for global fetch pipelines.
|
// Shared memory storage for global fetch pipelines.
|
||||||
int4 *sh_a = sh;
|
int4* sh_a = sh;
|
||||||
int4 *sh_b = sh_a + (stages * a_sh_stage);
|
int4* sh_b = sh_a + (stages * a_sh_stage);
|
||||||
int4 *sh_s = sh_b + (stages * b_sh_stage);
|
int4* sh_s = sh_b + (stages * b_sh_stage);
|
||||||
// Register storage for double buffer of shared memory reads.
|
// Register storage for double buffer of shared memory reads.
|
||||||
FragA frag_a[2][thread_m_blocks];
|
FragA frag_a[2][thread_m_blocks];
|
||||||
I4 frag_b_quant[2];
|
I4 frag_b_quant[2];
|
||||||
@ -421,34 +428,33 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|||||||
|
|
||||||
// Zero accumulators.
|
// Zero accumulators.
|
||||||
auto zero_accums = [&]() {
|
auto zero_accums = [&]() {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++)
|
for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++)
|
||||||
reinterpret_cast<float *>(frag_c)[i] = 0;
|
reinterpret_cast<float*>(frag_c)[i] = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Asynchronously fetch the next A, B and s tile from global to the next
|
// Asynchronously fetch the next A, B and s tile from global to the next
|
||||||
// shared memory pipeline location.
|
// shared memory pipeline location.
|
||||||
auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) {
|
auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) {
|
||||||
if (pred) {
|
if (pred) {
|
||||||
int4 *sh_a_stage = sh_a + a_sh_stage * pipe;
|
int4* sh_a_stage = sh_a + a_sh_stage * pipe;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < a_sh_wr_iters; i++) {
|
for (int i = 0; i < a_sh_wr_iters; i++) {
|
||||||
cp_async4_pred(
|
cp_async4_pred(
|
||||||
&sh_a_stage[a_sh_wr_trans[i]],
|
&sh_a_stage[a_sh_wr_trans[i]],
|
||||||
&A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off],
|
&A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off],
|
||||||
a_sh_wr_pred[i]);
|
a_sh_wr_pred[i]);
|
||||||
}
|
}
|
||||||
int4 *sh_b_stage = sh_b + b_sh_stage * pipe;
|
int4* sh_b_stage = sh_b + b_sh_stage * pipe;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < b_sh_wr_iters; i++) {
|
for (int i = 0; i < b_sh_wr_iters; i++) {
|
||||||
cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]);
|
cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]);
|
||||||
B_ptr[i] += b_gl_rd_delta_o;
|
B_ptr[i] += b_gl_rd_delta_o;
|
||||||
}
|
}
|
||||||
// Only fetch scales if this tile starts a new group
|
// Only fetch scales if this tile starts a new group
|
||||||
if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) {
|
if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) {
|
||||||
int4 *sh_s_stage = sh_s + s_sh_stage * pipe;
|
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
|
||||||
if (s_sh_wr_pred)
|
if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
|
||||||
cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
|
|
||||||
s_gl_rd += s_gl_rd_delta;
|
s_gl_rd += s_gl_rd_delta;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -475,37 +481,35 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|||||||
// theoretically better attempts have lead to bad instruction ordering by
|
// theoretically better attempts have lead to bad instruction ordering by
|
||||||
// the compiler and correspondingly a noticeable drop in performance.
|
// the compiler and correspondingly a noticeable drop in performance.
|
||||||
if (group_blocks != -1) {
|
if (group_blocks != -1) {
|
||||||
int4 *sh_s_stage =
|
int4* sh_s_stage =
|
||||||
sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
|
sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
|
||||||
(pipe / (group_blocks / thread_k_blocks)));
|
(pipe / (group_blocks / thread_k_blocks)));
|
||||||
reinterpret_cast<int4 *>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
|
reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
|
||||||
}
|
}
|
||||||
int4 *sh_a_stage = sh_a + a_sh_stage * pipe;
|
int4* sh_a_stage = sh_a + a_sh_stage * pipe;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < thread_m_blocks; i++)
|
for (int i = 0; i < thread_m_blocks; i++)
|
||||||
ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
|
ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
|
||||||
int4 *sh_b_stage = sh_b + b_sh_stage * pipe;
|
int4* sh_b_stage = sh_b + b_sh_stage * pipe;
|
||||||
frag_b_quant[k % 2] = *reinterpret_cast<I4 *>(
|
frag_b_quant[k % 2] = *reinterpret_cast<I4*>(
|
||||||
&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]);
|
&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]);
|
||||||
};
|
};
|
||||||
|
|
||||||
// Execute the actual tensor core matmul of a sub-tile.
|
// Execute the actual tensor core matmul of a sub-tile.
|
||||||
auto matmul = [&](int k) {
|
auto matmul = [&](int k) {
|
||||||
// We have the m dimension as the inner loop in order to encourage overlapping
|
// We have the m dimension as the inner loop in order to encourage overlapping
|
||||||
// dequantization and matmul operations.
|
// dequantization and matmul operations.
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < 4; j++) {
|
for (int j = 0; j < 4; j++) {
|
||||||
int b_quant = frag_b_quant[k % 2][j];
|
int b_quant = frag_b_quant[k % 2][j];
|
||||||
int b_quant_shift = b_quant >> 8;
|
int b_quant_shift = b_quant >> 8;
|
||||||
FragB frag_b0 = dequant(b_quant);
|
FragB frag_b0 = dequant(b_quant);
|
||||||
// If there are no groups, we can just scale the final output once and can
|
// If there are no groups, we can just scale the final output once and can
|
||||||
// avoid doing so for each weight.
|
// avoid doing so for each weight.
|
||||||
if (group_blocks != -1)
|
if (group_blocks != -1) scale(frag_b0, frag_s[k % 2][j], 0);
|
||||||
scale(frag_b0, frag_s[k % 2][j], 0);
|
|
||||||
FragB frag_b1 = dequant(b_quant_shift);
|
FragB frag_b1 = dequant(b_quant_shift);
|
||||||
if (group_blocks != -1)
|
if (group_blocks != -1) scale(frag_b1, frag_s[k % 2][j], 1);
|
||||||
scale(frag_b1, frag_s[k % 2][j], 1);
|
#pragma unroll
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < thread_m_blocks; i++) {
|
for (int i = 0; i < thread_m_blocks; i++) {
|
||||||
mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]);
|
mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]);
|
||||||
mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]);
|
mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]);
|
||||||
@ -530,38 +534,38 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|||||||
// unnecessary read or write iterations, e.g., for two warps we write only
|
// unnecessary read or write iterations, e.g., for two warps we write only
|
||||||
// once by warp 1 and read only once by warp 0.
|
// once by warp 1 and read only once by warp 0.
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int m_block = 0; m_block < thread_m_blocks; m_block++) {
|
for (int m_block = 0; m_block < thread_m_blocks; m_block++) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = red_off; i > 0; i /= 2) {
|
for (int i = red_off; i > 0; i /= 2) {
|
||||||
if (i <= red_idx && red_idx < 2 * i) {
|
if (i <= red_idx && red_idx < 2 * i) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < 4 * 2; j++) {
|
for (int j = 0; j < 4 * 2; j++) {
|
||||||
int red_sh_wr =
|
int red_sh_wr =
|
||||||
red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
|
red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
|
||||||
if (i < red_off) {
|
if (i < red_off) {
|
||||||
float *c_rd = reinterpret_cast<float *>(
|
float* c_rd =
|
||||||
&sh[red_sh_delta * j + red_sh_rd]);
|
reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]);
|
||||||
float *c_wr = reinterpret_cast<float *>(&sh[red_sh_wr]);
|
float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int k = 0; k < 4; k++)
|
for (int k = 0; k < 4; k++)
|
||||||
reinterpret_cast<FragC *>(frag_c)[4 * 2 * m_block + j][k] +=
|
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] +=
|
||||||
c_rd[k] + c_wr[k];
|
c_rd[k] + c_wr[k];
|
||||||
}
|
}
|
||||||
sh[red_sh_wr] =
|
sh[red_sh_wr] =
|
||||||
reinterpret_cast<int4 *>(&frag_c)[4 * 2 * m_block + j];
|
reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
}
|
}
|
||||||
if (red_idx == 0) {
|
if (red_idx == 0) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < 4 * 2; i++) {
|
for (int i = 0; i < 4 * 2; i++) {
|
||||||
float *c_rd =
|
float* c_rd =
|
||||||
reinterpret_cast<float *>(&sh[red_sh_delta * i + red_sh_rd]);
|
reinterpret_cast<float*>(&sh[red_sh_delta * i + red_sh_rd]);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < 4; j++)
|
for (int j = 0; j < 4; j++)
|
||||||
reinterpret_cast<FragC *>(frag_c)[4 * 2 * m_block + i][j] +=
|
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] +=
|
||||||
c_rd[j];
|
c_rd[j];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -571,9 +575,9 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Since multiple threadblocks may process parts of the same column slice, we
|
// Since multiple threadblocks may process parts of the same column slice, we
|
||||||
// finally have to globally reduce over the results. As the striped partitioning
|
// finally have to globally reduce over the results. As the striped
|
||||||
// minimizes the number of such reductions and our outputs are usually rather
|
// partitioning minimizes the number of such reductions and our outputs are
|
||||||
// small, we perform this reduction serially in L2 cache.
|
// usually rather small, we perform this reduction serially in L2 cache.
|
||||||
auto global_reduce = [&](bool first = false, bool last = false) {
|
auto global_reduce = [&](bool first = false, bool last = false) {
|
||||||
// We are very careful here to reduce directly in the output buffer to
|
// We are very careful here to reduce directly in the output buffer to
|
||||||
// maximize L2 cache utilization in this step. To do this, we write out
|
// maximize L2 cache utilization in this step. To do this, we write out
|
||||||
@ -592,39 +596,39 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|||||||
int row = (threadIdx.x % 32) / 4;
|
int row = (threadIdx.x % 32) / 4;
|
||||||
|
|
||||||
if (!first) {
|
if (!first) {
|
||||||
// Interestingly, doing direct global accesses here really seems to mess up the
|
// Interestingly, doing direct global accesses here really seems to mess up
|
||||||
// compiler and lead to slowdowns, hence we also use async-copies even though
|
// the compiler and lead to slowdowns, hence we also use async-copies even
|
||||||
// these fetches are not actually asynchronous.
|
// though these fetches are not actually asynchronous.
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < thread_m_blocks * 4; i++) {
|
for (int i = 0; i < thread_m_blocks * 4; i++) {
|
||||||
cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i],
|
cp_async4_pred(
|
||||||
&C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
|
&sh[c_sh_wr + c_sh_wr_delta * i],
|
||||||
c_gl_wr_delta_i * (i % 2)],
|
&C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
|
||||||
i < (thread_m_blocks - 1) * 4 ||
|
c_gl_wr_delta_i * (i % 2)],
|
||||||
8 * (i / 2) + row < prob_m);
|
i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m);
|
||||||
}
|
}
|
||||||
cp_async_fence();
|
cp_async_fence();
|
||||||
cp_async_wait<0>();
|
cp_async_wait<0>();
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < thread_m_blocks * 4; i++) {
|
for (int i = 0; i < thread_m_blocks * 4; i++) {
|
||||||
if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) {
|
if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) {
|
||||||
if (!first) {
|
if (!first) {
|
||||||
int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta];
|
int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < 2 * 4; j++) {
|
for (int j = 0; j < 2 * 4; j++) {
|
||||||
reinterpret_cast<float *>(
|
reinterpret_cast<float*>(
|
||||||
&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] +=
|
&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] +=
|
||||||
__half2float(reinterpret_cast<__half *>(&c_red)[j]);
|
__half2float(reinterpret_cast<__half*>(&c_red)[j]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!last) {
|
if (!last) {
|
||||||
int4 c;
|
int4 c;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < 2 * 4; j++) {
|
for (int j = 0; j < 2 * 4; j++) {
|
||||||
reinterpret_cast<__half *>(&c)[j] =
|
reinterpret_cast<__half*>(&c)[j] =
|
||||||
__float2half(reinterpret_cast<float *>(
|
__float2half(reinterpret_cast<float*>(
|
||||||
&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]);
|
&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]);
|
||||||
}
|
}
|
||||||
C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] =
|
C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] =
|
||||||
@ -658,17 +662,17 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|||||||
|
|
||||||
// We first reorder in shared memory to guarantee the most efficient final
|
// We first reorder in shared memory to guarantee the most efficient final
|
||||||
// global write patterns
|
// global write patterns
|
||||||
auto write = [&](int idx, float c0, float c1, FragS &s) {
|
auto write = [&](int idx, float c0, float c1, FragS& s) {
|
||||||
half2 res = __halves2half2(__float2half(c0), __float2half(c1));
|
half2 res = __halves2half2(__float2half(c0), __float2half(c1));
|
||||||
if (group_blocks ==
|
if (group_blocks ==
|
||||||
-1) // for per-column quantization we finally apply the scale here
|
-1) // for per-column quantization we finally apply the scale here
|
||||||
res = __hmul2(res, s[0]);
|
res = __hmul2(res, s[0]);
|
||||||
((half2 *)sh)[idx] = res;
|
((half2*)sh)[idx] = res;
|
||||||
};
|
};
|
||||||
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < thread_m_blocks; i++) {
|
for (int i = 0; i < thread_m_blocks; i++) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < 4; j++) {
|
for (int j = 0; j < 4; j++) {
|
||||||
int wr = c_sh_wr + 8 * j;
|
int wr = c_sh_wr + 8 * j;
|
||||||
write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0],
|
write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0],
|
||||||
@ -685,7 +689,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0;
|
for (int i = 0;
|
||||||
i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks));
|
i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks));
|
||||||
i++) {
|
i++) {
|
||||||
@ -699,9 +703,8 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|||||||
|
|
||||||
// Start global fetch and register load pipelines.
|
// Start global fetch and register load pipelines.
|
||||||
auto start_pipes = [&]() {
|
auto start_pipes = [&]() {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < stages - 1; i++)
|
for (int i = 0; i < stages - 1; i++) fetch_to_shared(i, i, i < slice_iters);
|
||||||
fetch_to_shared(i, i, i < slice_iters);
|
|
||||||
zero_accums();
|
zero_accums();
|
||||||
wait_for_stage();
|
wait_for_stage();
|
||||||
fetch_to_registers(0, 0);
|
fetch_to_registers(0, 0);
|
||||||
@ -711,12 +714,12 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|||||||
|
|
||||||
// Main loop.
|
// Main loop.
|
||||||
while (slice_iters) {
|
while (slice_iters) {
|
||||||
// We unroll over both the global fetch and the register load pipeline to ensure
|
// We unroll over both the global fetch and the register load pipeline to
|
||||||
// all shared memory accesses are static. Note that both pipelines have even
|
// ensure all shared memory accesses are static. Note that both pipelines have
|
||||||
// length meaning that the next iteration will always start at index 0.
|
// even length meaning that the next iteration will always start at index 0.
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int pipe = 0; pipe < stages;) {
|
for (int pipe = 0; pipe < stages;) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int k = 0; k < b_sh_wr_iters; k++) {
|
for (int k = 0; k < b_sh_wr_iters; k++) {
|
||||||
fetch_to_registers(k + 1, pipe % stages);
|
fetch_to_registers(k + 1, pipe % stages);
|
||||||
if (k == b_sh_wr_iters - 2) {
|
if (k == b_sh_wr_iters - 2) {
|
||||||
@ -728,8 +731,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|||||||
matmul(k);
|
matmul(k);
|
||||||
}
|
}
|
||||||
slice_iters--;
|
slice_iters--;
|
||||||
if (slice_iters == 0)
|
if (slice_iters == 0) break;
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
a_gl_rd += a_gl_rd_delta_o * stages;
|
a_gl_rd += a_gl_rd_delta_o * stages;
|
||||||
|
|
||||||
@ -742,8 +744,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|||||||
// For per-column scales, we only fetch them here in the final step before
|
// For per-column scales, we only fetch them here in the final step before
|
||||||
// write-out
|
// write-out
|
||||||
if (group_blocks == -1 && last) {
|
if (group_blocks == -1 && last) {
|
||||||
if (s_sh_wr_pred)
|
if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]);
|
||||||
cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]);
|
|
||||||
cp_async_fence();
|
cp_async_fence();
|
||||||
}
|
}
|
||||||
thread_block_reduce();
|
thread_block_reduce();
|
||||||
@ -751,17 +752,17 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|||||||
cp_async_wait<0>();
|
cp_async_wait<0>();
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
||||||
reinterpret_cast<int4 *>(&frag_s)[0] = sh_s[s_sh_rd + 0];
|
reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];
|
||||||
reinterpret_cast<int4 *>(&frag_s)[1] = sh_s[s_sh_rd + 4];
|
reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (slice_count > 1) { // only globally reduce if there is more than one
|
if (slice_count > 1) { // only globally reduce if there is more than one
|
||||||
// block in a slice
|
// block in a slice
|
||||||
barrier_acquire(&locks[slice_col], slice_idx);
|
barrier_acquire(&locks[slice_col], slice_idx);
|
||||||
global_reduce(slice_idx == 0, last);
|
global_reduce(slice_idx == 0, last);
|
||||||
barrier_release(&locks[slice_col], last);
|
barrier_release(&locks[slice_col], last);
|
||||||
}
|
}
|
||||||
if (last) // only the last block in a slice actually writes the result
|
if (last) // only the last block in a slice actually writes the result
|
||||||
write_result();
|
write_result();
|
||||||
slice_row = 0;
|
slice_row = 0;
|
||||||
slice_col_par++;
|
slice_col_par++;
|
||||||
@ -770,13 +771,12 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|||||||
if (slice_iters) {
|
if (slice_iters) {
|
||||||
a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
|
a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
|
||||||
(threadIdx.x % a_gl_rd_delta_o);
|
(threadIdx.x % a_gl_rd_delta_o);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < b_sh_wr_iters; i++)
|
for (int i = 0; i < b_sh_wr_iters; i++)
|
||||||
B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
|
B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
|
||||||
if (slice_col == 0) {
|
if (slice_col == 0) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < b_sh_wr_iters; i++)
|
for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
|
||||||
B_ptr[i] -= b_gl_stride;
|
|
||||||
}
|
}
|
||||||
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
|
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
|
||||||
start_pipes();
|
start_pipes();
|
||||||
@ -787,26 +787,27 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|||||||
|
|
||||||
#else
|
#else
|
||||||
|
|
||||||
template <const int threads, // number of threads in a threadblock
|
template <const int threads, // number of threads in a threadblock
|
||||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||||
// dimension (batchsize) of the threadblock
|
// dimension (batchsize) of the
|
||||||
const int thread_n_blocks, // same for n dimension (output)
|
// threadblock
|
||||||
const int thread_k_blocks, // same for k dimension (reduction)
|
const int thread_n_blocks, // same for n dimension (output)
|
||||||
const int stages, // number of stages for the async global->shared
|
const int thread_k_blocks, // same for k dimension (reduction)
|
||||||
// fetch pipeline
|
const int stages, // number of stages for the async global->shared
|
||||||
const int group_blocks = -1 // number of consecutive 16x16 blocks with
|
// fetch pipeline
|
||||||
// a separate quantization scale
|
const int group_blocks = -1 // number of consecutive 16x16 blocks
|
||||||
|
// with a separate quantization scale
|
||||||
>
|
>
|
||||||
__global__ void
|
__global__ void Marlin(
|
||||||
Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
const int4* __restrict__ A, // fp16 input matrix of shape mxk
|
||||||
const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn
|
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
|
||||||
int4 *__restrict__ C, // fp16 output buffer of shape mxn
|
int4* __restrict__ C, // fp16 output buffer of shape mxn
|
||||||
const int4
|
const int4* __restrict__ s, // fp16 quantization scales of shape
|
||||||
*__restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn
|
// (k/groupsize)xn
|
||||||
int prob_m, // batch dimension m
|
int prob_m, // batch dimension m
|
||||||
int prob_n, // output dimension n
|
int prob_n, // output dimension n
|
||||||
int prob_k, // reduction dimension k
|
int prob_k, // reduction dimension k
|
||||||
int *locks // extra global storage for barrier synchronization
|
int* locks // extra global storage for barrier synchronization
|
||||||
) {
|
) {
|
||||||
// Marlin is not implemented yet for SM < 8.0
|
// Marlin is not implemented yet for SM < 8.0
|
||||||
assert(false);
|
assert(false);
|
||||||
@ -819,10 +820,10 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|||||||
// than 1 warp per schedule allows some more latency hiding. At the same time,
|
// than 1 warp per schedule allows some more latency hiding. At the same time,
|
||||||
// we want relatively few warps to have many registers per warp and small tiles.
|
// we want relatively few warps to have many registers per warp and small tiles.
|
||||||
const int USER_THREADS =
|
const int USER_THREADS =
|
||||||
256; // Note: This is only used with user-provided thread_k/n
|
256; // Note: This is only used with user-provided thread_k/n
|
||||||
const int STAGES = 4; // 4 pipeline stages fit into shared memory
|
const int STAGES = 4; // 4 pipeline stages fit into shared memory
|
||||||
const int SHARED_MEM =
|
const int SHARED_MEM =
|
||||||
96 * 1024; // max shared memory on compute capability 8.6 (< 8.0)
|
96 * 1024; // max shared memory on compute capability 8.6 (< 8.0)
|
||||||
|
|
||||||
static constexpr int min_thread_n = 64;
|
static constexpr int min_thread_n = 64;
|
||||||
static constexpr int min_thread_k = 64;
|
static constexpr int min_thread_k = 64;
|
||||||
@ -831,7 +832,7 @@ static constexpr int tile_size = 16;
|
|||||||
static constexpr int max_par = 16;
|
static constexpr int max_par = 16;
|
||||||
|
|
||||||
static constexpr int pack_factor_4bit =
|
static constexpr int pack_factor_4bit =
|
||||||
8; // We have 8 4-bit vals inside a 32 bit
|
8; // We have 8 4-bit vals inside a 32 bit
|
||||||
|
|
||||||
#define __CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
|
#define __CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
|
||||||
GROUP_BLOCKS, NUM_THREADS) \
|
GROUP_BLOCKS, NUM_THREADS) \
|
||||||
@ -858,23 +859,23 @@ thread_config_t small_batch_thread_configs[] = {
|
|||||||
// Ordered by priority
|
// Ordered by priority
|
||||||
|
|
||||||
// thread_k, thread_n, num_threads
|
// thread_k, thread_n, num_threads
|
||||||
{128, 128, 256}, // Default
|
{128, 128, 256}, // Default
|
||||||
{128, 64, 128}, // Reduce N 2X, same K
|
{128, 64, 128}, // Reduce N 2X, same K
|
||||||
{64, 256, 256}, // Reduce K 2X, increase N 2X
|
{64, 256, 256}, // Reduce K 2X, increase N 2X
|
||||||
{64, 128, 128}, // Reduce K 2X, same N
|
{64, 128, 128}, // Reduce K 2X, same N
|
||||||
};
|
};
|
||||||
|
|
||||||
thread_config_t large_batch_thread_configs[] = {
|
thread_config_t large_batch_thread_configs[] = {
|
||||||
// Ordered by priority
|
// Ordered by priority
|
||||||
|
|
||||||
// thread_k, thread_n, num_threads
|
// thread_k, thread_n, num_threads
|
||||||
{64, 256, 256}, // Default
|
{64, 256, 256}, // Default
|
||||||
{128, 128, 256}, // Reduce N 2X, increase K 2X
|
{128, 128, 256}, // Reduce N 2X, increase K 2X
|
||||||
{64, 128, 128}, // Reduce N 2X, same K
|
{64, 128, 128}, // Reduce N 2X, same K
|
||||||
{128, 64, 128}, // Reduce N 4X, increase K 2X
|
{128, 64, 128}, // Reduce N 4X, increase K 2X
|
||||||
};
|
};
|
||||||
|
|
||||||
bool is_valid_config(thread_config_t const &th_config, int prob_m, int prob_n,
|
bool is_valid_config(thread_config_t const& th_config, int prob_m, int prob_n,
|
||||||
int prob_k) {
|
int prob_k) {
|
||||||
// Sanity
|
// Sanity
|
||||||
if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
|
if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
|
||||||
@ -907,7 +908,6 @@ bool is_valid_config(thread_config_t const &th_config, int prob_m, int prob_n,
|
|||||||
}
|
}
|
||||||
|
|
||||||
thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) {
|
thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) {
|
||||||
|
|
||||||
if (prob_m <= 16) {
|
if (prob_m <= 16) {
|
||||||
for (auto th_config : small_batch_thread_configs) {
|
for (auto th_config : small_batch_thread_configs) {
|
||||||
if (is_valid_config(th_config, prob_m, prob_n, prob_k)) {
|
if (is_valid_config(th_config, prob_m, prob_n, prob_k)) {
|
||||||
@ -926,20 +926,20 @@ thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) {
|
|||||||
return thread_config_t{-1, -1, -1};
|
return thread_config_t{-1, -1, -1};
|
||||||
}
|
}
|
||||||
|
|
||||||
#define CALL_IF(N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
#define CALL_IF(N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||||
__CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
|
__CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
|
||||||
__CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
|
__CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
|
||||||
__CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
|
__CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
|
||||||
__CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
|
__CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
|
||||||
__CALL_IF(2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
|
__CALL_IF(2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
|
||||||
__CALL_IF(2, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
|
__CALL_IF(2, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
|
||||||
__CALL_IF(3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
|
__CALL_IF(3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
|
||||||
__CALL_IF(3, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
|
__CALL_IF(3, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
|
||||||
__CALL_IF(4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
|
__CALL_IF(4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
|
||||||
__CALL_IF(4, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS)
|
__CALL_IF(4, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS)
|
||||||
|
|
||||||
void marlin_cuda(const void *A, const void *B, void *C, void *s, int prob_m,
|
void marlin_cuda(const void* A, const void* B, void* C, void* s, int prob_m,
|
||||||
int prob_n, int prob_k, void *workspace, int groupsize = -1,
|
int prob_n, int prob_k, void* workspace, int groupsize = -1,
|
||||||
int dev = 0, cudaStream_t stream = 0, int thread_k = -1,
|
int dev = 0, cudaStream_t stream = 0, int thread_k = -1,
|
||||||
int thread_n = -1, int sms = -1, int max_par = 16) {
|
int thread_n = -1, int sms = -1, int max_par = 16) {
|
||||||
int tot_m = prob_m;
|
int tot_m = prob_m;
|
||||||
@ -996,12 +996,12 @@ void marlin_cuda(const void *A, const void *B, void *C, void *s, int prob_m,
|
|||||||
" is not divisible by group_blocks = ", group_blocks);
|
" is not divisible by group_blocks = ", group_blocks);
|
||||||
}
|
}
|
||||||
|
|
||||||
const int4 *A_ptr = (const int4 *)A;
|
const int4* A_ptr = (const int4*)A;
|
||||||
const int4 *B_ptr = (const int4 *)B;
|
const int4* B_ptr = (const int4*)B;
|
||||||
int4 *C_ptr = (int4 *)C;
|
int4* C_ptr = (int4*)C;
|
||||||
const int4 *s_ptr = (const int4 *)s;
|
const int4* s_ptr = (const int4*)s;
|
||||||
|
|
||||||
int *locks = (int *)workspace;
|
int* locks = (int*)workspace;
|
||||||
|
|
||||||
for (int i = 0; i < tot_m_blocks; i += 4) {
|
for (int i = 0; i < tot_m_blocks; i += 4) {
|
||||||
int thread_m_blocks = tot_m_blocks - i;
|
int thread_m_blocks = tot_m_blocks - i;
|
||||||
@ -1011,8 +1011,7 @@ void marlin_cuda(const void *A, const void *B, void *C, void *s, int prob_m,
|
|||||||
// Note that parallel > 1 currently only works for inputs without any
|
// Note that parallel > 1 currently only works for inputs without any
|
||||||
// padding
|
// padding
|
||||||
par = (16 * thread_m_blocks - pad) / 64;
|
par = (16 * thread_m_blocks - pad) / 64;
|
||||||
if (par > max_par)
|
if (par > max_par) par = max_par;
|
||||||
par = max_par;
|
|
||||||
prob_m = 64 * par;
|
prob_m = 64 * par;
|
||||||
i += 4 * (par - 1);
|
i += 4 * (par - 1);
|
||||||
thread_m_blocks = 4;
|
thread_m_blocks = 4;
|
||||||
@ -1041,12 +1040,11 @@ void marlin_cuda(const void *A, const void *B, void *C, void *s, int prob_m,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace marlin
|
} // namespace marlin
|
||||||
|
|
||||||
torch::Tensor marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
|
torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||||
torch::Tensor &b_scales, torch::Tensor &workspace,
|
torch::Tensor& b_scales, torch::Tensor& workspace,
|
||||||
int64_t size_m, int64_t size_n, int64_t size_k) {
|
int64_t size_m, int64_t size_n, int64_t size_k) {
|
||||||
|
|
||||||
// Verify M
|
// Verify M
|
||||||
TORCH_CHECK(size_m == a.size(0),
|
TORCH_CHECK(size_m == a.size(0),
|
||||||
"Shape mismatch: a.size(0) = " + str(a.size(0)) +
|
"Shape mismatch: a.size(0) = " + str(a.size(0)) +
|
||||||
@ -1074,9 +1072,9 @@ torch::Tensor marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
|
|||||||
|
|
||||||
int actual_size_n =
|
int actual_size_n =
|
||||||
(b_q_weight.size(1) / marlin::tile_size) * marlin::pack_factor_4bit;
|
(b_q_weight.size(1) / marlin::tile_size) * marlin::pack_factor_4bit;
|
||||||
TORCH_CHECK(size_n == actual_size_n,
|
TORCH_CHECK(
|
||||||
"size_n = " + str(size_n) +
|
size_n == actual_size_n,
|
||||||
", actual_size_n = " + str(actual_size_n));
|
"size_n = " + str(size_n) + ", actual_size_n = " + str(actual_size_n));
|
||||||
|
|
||||||
// Verify A device and strides
|
// Verify A device and strides
|
||||||
TORCH_CHECK(a.device().is_cuda(), "A is not on GPU");
|
TORCH_CHECK(a.device().is_cuda(), "A is not on GPU");
|
||||||
|
@ -26,12 +26,14 @@ constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; }
|
|||||||
// corresponding index accesses must be compile-time constants, which is why we
|
// corresponding index accesses must be compile-time constants, which is why we
|
||||||
// extensively use `#pragma unroll` throughout the kernel code to guarantee
|
// extensively use `#pragma unroll` throughout the kernel code to guarantee
|
||||||
// this.
|
// this.
|
||||||
template <typename T, int n> struct Vec {
|
template <typename T, int n>
|
||||||
|
struct Vec {
|
||||||
T elems[n];
|
T elems[n];
|
||||||
__device__ T &operator[](int i) { return elems[i]; }
|
__device__ T& operator[](int i) { return elems[i]; }
|
||||||
};
|
};
|
||||||
|
|
||||||
template <int M_, int N_, int K_> struct ShapeBase {
|
template <int M_, int N_, int K_>
|
||||||
|
struct ShapeBase {
|
||||||
static constexpr int M = M_, N = N_, K = K_;
|
static constexpr int M = M_, N = N_, K = K_;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -44,6 +46,6 @@ using FragA = Vec<half2, 4>;
|
|||||||
using FragB = Vec<half2, 2>;
|
using FragB = Vec<half2, 2>;
|
||||||
using FragM = Vec<uint, 1>;
|
using FragM = Vec<uint, 1>;
|
||||||
using FragC = Vec<float, 4>;
|
using FragC = Vec<float, 4>;
|
||||||
using FragS = Vec<half2, 1>; // quantization scales
|
using FragS = Vec<half2, 1>; // quantization scales
|
||||||
|
|
||||||
} // namespace marlin_24
|
} // namespace marlin_24
|
||||||
|
@ -21,41 +21,44 @@
|
|||||||
namespace marlin_24 {
|
namespace marlin_24 {
|
||||||
// Predicated asynchronous global->shared copy; used for inputs A where we apply
|
// Predicated asynchronous global->shared copy; used for inputs A where we apply
|
||||||
// predication to handle batchsizes that are not multiples of 16.
|
// predication to handle batchsizes that are not multiples of 16.
|
||||||
__device__ inline void cp_async4_pred_zfill(void *smem_ptr,
|
__device__ inline void cp_async4_pred_zfill(void* smem_ptr,
|
||||||
const void *glob_ptr,
|
const void* glob_ptr,
|
||||||
bool pred = true,
|
bool pred = true,
|
||||||
const bool zfill = false) {
|
const bool zfill = false) {
|
||||||
const int BYTES = 16;
|
const int BYTES = 16;
|
||||||
int src_in_bytes = (zfill ? 0 : BYTES);
|
int src_in_bytes = (zfill ? 0 : BYTES);
|
||||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||||
asm volatile("{\n"
|
asm volatile(
|
||||||
" .reg .pred p;\n"
|
"{\n"
|
||||||
" setp.ne.b32 p, %0, 0;\n"
|
" .reg .pred p;\n"
|
||||||
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
|
" setp.ne.b32 p, %0, 0;\n"
|
||||||
"}\n" ::"r"((int)pred),
|
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
|
||||||
"r"(smem), "l"(glob_ptr), "n"(BYTES), "r"(src_in_bytes));
|
"}\n" ::"r"((int)pred),
|
||||||
|
"r"(smem), "l"(glob_ptr), "n"(BYTES), "r"(src_in_bytes));
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ inline void cp_async4_pred(void *smem_ptr, const void *glob_ptr,
|
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
|
||||||
bool pred = true) {
|
bool pred = true) {
|
||||||
const int BYTES = 16;
|
const int BYTES = 16;
|
||||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||||
asm volatile("{\n"
|
asm volatile(
|
||||||
" .reg .pred p;\n"
|
"{\n"
|
||||||
" setp.ne.b32 p, %0, 0;\n"
|
" .reg .pred p;\n"
|
||||||
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
|
" setp.ne.b32 p, %0, 0;\n"
|
||||||
"}\n" ::"r"((int)pred),
|
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
|
||||||
"r"(smem), "l"(glob_ptr), "n"(BYTES));
|
"}\n" ::"r"((int)pred),
|
||||||
|
"r"(smem), "l"(glob_ptr), "n"(BYTES));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Asynchronous global->shared copy
|
// Asynchronous global->shared copy
|
||||||
__device__ inline void cp_async4(void *smem_ptr, const void *glob_ptr) {
|
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
|
||||||
const int BYTES = 16;
|
const int BYTES = 16;
|
||||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||||
asm volatile("{\n"
|
asm volatile(
|
||||||
" cp.async.cg.shared.global [%0], [%1], %2;\n"
|
"{\n"
|
||||||
"}\n" ::"r"(smem),
|
" cp.async.cg.shared.global [%0], [%1], %2;\n"
|
||||||
"l"(glob_ptr), "n"(BYTES));
|
"}\n" ::"r"(smem),
|
||||||
|
"l"(glob_ptr), "n"(BYTES));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Async copy fence.
|
// Async copy fence.
|
||||||
@ -64,22 +67,23 @@ __device__ inline void cp_async_fence() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Wait until at most `n` async copy stages are still pending.
|
// Wait until at most `n` async copy stages are still pending.
|
||||||
template <int n> __device__ inline void cp_async_wait() {
|
template <int n>
|
||||||
|
__device__ inline void cp_async_wait() {
|
||||||
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
|
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
|
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
|
||||||
// memory, directly in tensor core layout.
|
// memory, directly in tensor core layout.
|
||||||
__device__ inline void ldsm4(FragA &frag_a, const void *smem_ptr) {
|
__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) {
|
||||||
uint32_t *a = reinterpret_cast<uint32_t *>(&frag_a);
|
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
|
||||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||||
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
|
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
|
||||||
: "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
|
: "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
|
||||||
: "r"(smem));
|
: "r"(smem));
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ inline void ldsm4_m(FragM &frag_m, const void *smem_ptr) {
|
__device__ inline void ldsm4_m(FragM& frag_m, const void* smem_ptr) {
|
||||||
uint32_t *a = reinterpret_cast<uint32_t *>(&frag_m);
|
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_m);
|
||||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||||
asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n"
|
asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n"
|
||||||
: "=r"(a[0]), "=r"(a[1])
|
: "=r"(a[0]), "=r"(a[1])
|
||||||
@ -88,8 +92,8 @@ __device__ inline void ldsm4_m(FragM &frag_m, const void *smem_ptr) {
|
|||||||
|
|
||||||
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
|
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
|
||||||
// memory, directly in tensor core layout.
|
// memory, directly in tensor core layout.
|
||||||
__device__ inline void ldsm4_t(FragA &frag_a, const void *smem_ptr) {
|
__device__ inline void ldsm4_t(FragA& frag_a, const void* smem_ptr) {
|
||||||
uint32_t *a = reinterpret_cast<uint32_t *>(&frag_a);
|
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
|
||||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||||
asm volatile(
|
asm volatile(
|
||||||
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0,%1,%2,%3}, [%4];\n"
|
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0,%1,%2,%3}, [%4];\n"
|
||||||
@ -98,7 +102,7 @@ __device__ inline void ldsm4_t(FragA &frag_a, const void *smem_ptr) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Wait until barrier reaches `count`, then lock for current threadblock.
|
// Wait until barrier reaches `count`, then lock for current threadblock.
|
||||||
__device__ inline void barrier_acquire(int *lock, int count) {
|
__device__ inline void barrier_acquire(int* lock, int count) {
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
int state = -1;
|
int state = -1;
|
||||||
do
|
do
|
||||||
@ -113,7 +117,7 @@ __device__ inline void barrier_acquire(int *lock, int count) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Release barrier and increment visitation count.
|
// Release barrier and increment visitation count.
|
||||||
__device__ inline void barrier_release(int *lock, bool reset = false) {
|
__device__ inline void barrier_release(int* lock, bool reset = false) {
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
if (reset) {
|
if (reset) {
|
||||||
@ -129,4 +133,4 @@ __device__ inline void barrier_release(int *lock, bool reset = false) {
|
|||||||
: "l"(lock), "r"(val));
|
: "l"(lock), "r"(val));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} // namespace marlin_24
|
} // namespace marlin_24
|
||||||
|
@ -22,51 +22,56 @@ namespace marlin_24 {
|
|||||||
|
|
||||||
// m16n8k32 sparse tensor core mma instruction with fp16 inputs and fp32
|
// m16n8k32 sparse tensor core mma instruction with fp16 inputs and fp32
|
||||||
// output/accumulation.
|
// output/accumulation.
|
||||||
__device__ inline void mma_sp(const FragB &a_frag0, const FragB &a_frag1,
|
__device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1,
|
||||||
const FragA &frag_b, FragC &frag_c, FragM &frag_m,
|
const FragA& frag_b, FragC& frag_c, FragM& frag_m,
|
||||||
const int psel) {
|
const int psel) {
|
||||||
const uint32_t *a0 = reinterpret_cast<const uint32_t *>(&a_frag0);
|
const uint32_t* a0 = reinterpret_cast<const uint32_t*>(&a_frag0);
|
||||||
const uint32_t *a1 = reinterpret_cast<const uint32_t *>(&a_frag1);
|
const uint32_t* a1 = reinterpret_cast<const uint32_t*>(&a_frag1);
|
||||||
const uint32_t *b = reinterpret_cast<const uint32_t *>(&frag_b);
|
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
|
||||||
const uint32_t *e = reinterpret_cast<const uint32_t *>(&frag_m);
|
const uint32_t* e = reinterpret_cast<const uint32_t*>(&frag_m);
|
||||||
float *c = reinterpret_cast<float *>(&frag_c);
|
float* c = reinterpret_cast<float*>(&frag_c);
|
||||||
if (psel == 0) {
|
if (psel == 0) {
|
||||||
asm volatile("mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
|
asm volatile(
|
||||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
|
"mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
|
||||||
"{%12,%13,%14,%15}, %16, 0x0;\n"
|
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
|
||||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
"{%12,%13,%14,%15}, %16, 0x0;\n"
|
||||||
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]),
|
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||||
"r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]),
|
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]), "r"(b[2]),
|
||||||
"f"(c[2]), "f"(c[3]), "r"(e[0]));
|
"r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]),
|
||||||
asm volatile("mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
|
"r"(e[0]));
|
||||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
|
asm volatile(
|
||||||
"{%12,%13,%14,%15}, %16, 0x0;\n"
|
"mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
|
||||||
: "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7])
|
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
|
||||||
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]),
|
"{%12,%13,%14,%15}, %16, 0x0;\n"
|
||||||
"r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]),
|
: "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7])
|
||||||
"f"(c[6]), "f"(c[7]), "r"(e[0]));
|
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]), "r"(b[3]),
|
||||||
|
"r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]), "f"(c[6]), "f"(c[7]),
|
||||||
|
"r"(e[0]));
|
||||||
} else {
|
} else {
|
||||||
asm volatile("mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
|
asm volatile(
|
||||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
|
"mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
|
||||||
"{%12,%13,%14,%15}, %16, 0x1;\n"
|
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
|
||||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
"{%12,%13,%14,%15}, %16, 0x1;\n"
|
||||||
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]),
|
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||||
"r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]),
|
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]), "r"(b[2]),
|
||||||
"f"(c[2]), "f"(c[3]), "r"(e[0]));
|
"r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]),
|
||||||
asm volatile("mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
|
"r"(e[0]));
|
||||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
|
asm volatile(
|
||||||
"{%12,%13,%14,%15}, %16, 0x1;\n"
|
"mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
|
||||||
: "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7])
|
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
|
||||||
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]),
|
"{%12,%13,%14,%15}, %16, 0x1;\n"
|
||||||
"r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]),
|
: "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7])
|
||||||
"f"(c[6]), "f"(c[7]), "r"(e[0]));
|
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]), "r"(b[3]),
|
||||||
|
"r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]), "f"(c[6]), "f"(c[7]),
|
||||||
|
"r"(e[0]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Lookup-table based 3-input logical operation; explicitly used for
|
// Lookup-table based 3-input logical operation; explicitly used for
|
||||||
// dequantization as the compiler does not seem to automatically recognize it in
|
// dequantization as the compiler does not seem to automatically recognize it in
|
||||||
// all cases.
|
// all cases.
|
||||||
template <int lut> __device__ inline int lop3(int a, int b, int c) {
|
template <int lut>
|
||||||
|
__device__ inline int lop3(int a, int b, int c) {
|
||||||
int res;
|
int res;
|
||||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||||
: "=r"(res)
|
: "=r"(res)
|
||||||
@ -120,11 +125,11 @@ __device__ inline FragB dequant_4bit(int q) {
|
|||||||
const int ADD = 0xd480d480;
|
const int ADD = 0xd480d480;
|
||||||
|
|
||||||
FragB frag_b;
|
FragB frag_b;
|
||||||
frag_b[0] = __hsub2(*reinterpret_cast<half2 *>(&lo),
|
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
||||||
*reinterpret_cast<const half2 *>(&SUB));
|
*reinterpret_cast<const half2*>(&SUB));
|
||||||
frag_b[1] = __hfma2(*reinterpret_cast<half2 *>(&hi),
|
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
|
||||||
*reinterpret_cast<const half2 *>(&MUL),
|
*reinterpret_cast<const half2*>(&MUL),
|
||||||
*reinterpret_cast<const half2 *>(&ADD));
|
*reinterpret_cast<const half2*>(&ADD));
|
||||||
return frag_b;
|
return frag_b;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -143,24 +148,24 @@ __device__ inline FragB dequant_8bit(int q) {
|
|||||||
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
|
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
|
||||||
|
|
||||||
FragB frag_b;
|
FragB frag_b;
|
||||||
frag_b[0] = __hsub2(*reinterpret_cast<half2 *>(&lo),
|
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
||||||
*reinterpret_cast<const half2 *>(&I8s_TO_F16s_MAGIC_NUM));
|
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||||
frag_b[1] = __hsub2(*reinterpret_cast<half2 *>(&hi),
|
frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
|
||||||
*reinterpret_cast<const half2 *>(&I8s_TO_F16s_MAGIC_NUM));
|
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||||
return frag_b;
|
return frag_b;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Multiply dequantized values by the corresponding quantization scale; used
|
// Multiply dequantized values by the corresponding quantization scale; used
|
||||||
// only for grouped quantization.
|
// only for grouped quantization.
|
||||||
__device__ inline void scale(FragB &frag_b, FragS &frag_s, int i) {
|
__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) {
|
||||||
half2 s = __half2half2(reinterpret_cast<__half *>(&frag_s)[i]);
|
half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]);
|
||||||
frag_b[0] = __hmul2(frag_b[0], s);
|
frag_b[0] = __hmul2(frag_b[0], s);
|
||||||
frag_b[1] = __hmul2(frag_b[1], s);
|
frag_b[1] = __hmul2(frag_b[1], s);
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ inline void scale_floats(float *c0, float *c1, float *c2, float *c3,
|
__device__ inline void scale_floats(float* c0, float* c1, float* c2, float* c3,
|
||||||
FragS &s0, float *c4, float *c5, float *c6,
|
FragS& s0, float* c4, float* c5, float* c6,
|
||||||
float *c7, FragS &s1) {
|
float* c7, FragS& s1) {
|
||||||
*c0 = __fmul_rn(*c0, __half2float(s0[0].x));
|
*c0 = __fmul_rn(*c0, __half2float(s0[0].x));
|
||||||
*c1 = __fmul_rn(*c1, __half2float(s0[0].y));
|
*c1 = __fmul_rn(*c1, __half2float(s0[0].y));
|
||||||
*c2 = __fmul_rn(*c2, __half2float(s0[1].x));
|
*c2 = __fmul_rn(*c2, __half2float(s0[1].x));
|
||||||
@ -172,4 +177,4 @@ __device__ inline void scale_floats(float *c0, float *c1, float *c2, float *c3,
|
|||||||
*c7 = __fmul_rn(*c7, __half2float(s1[1].y));
|
*c7 = __fmul_rn(*c7, __half2float(s1[1].y));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace marlin_24
|
} // namespace marlin_24
|
||||||
|
@ -32,12 +32,15 @@
|
|||||||
|
|
||||||
#else
|
#else
|
||||||
|
|
||||||
#include "common/mem.h"
|
#include "common/mem.h"
|
||||||
#include "common/mma.h"
|
#include "common/mma.h"
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
template <typename T> inline std::string str(T x) { return std::to_string(x); }
|
template <typename T>
|
||||||
|
inline std::string str(T x) {
|
||||||
|
return std::to_string(x);
|
||||||
|
}
|
||||||
|
|
||||||
namespace marlin_24 {
|
namespace marlin_24 {
|
||||||
|
|
||||||
@ -45,7 +48,7 @@ namespace marlin_24 {
|
|||||||
// than 1 warp per schedule allows some more latency hiding. At the same time,
|
// than 1 warp per schedule allows some more latency hiding. At the same time,
|
||||||
// we want relatively few warps to have many registers per warp and small tiles.
|
// we want relatively few warps to have many registers per warp and small tiles.
|
||||||
static constexpr int THREADS = 256;
|
static constexpr int THREADS = 256;
|
||||||
static constexpr int STAGES = 4; // 4 pipeline stages fit into shared memory
|
static constexpr int STAGES = 4; // 4 pipeline stages fit into shared memory
|
||||||
|
|
||||||
static constexpr int min_thread_n = 128;
|
static constexpr int min_thread_n = 128;
|
||||||
|
|
||||||
@ -54,35 +57,36 @@ static constexpr int max_par = 16;
|
|||||||
|
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||||
|
|
||||||
template <const int num_bits, // weight bits
|
template <const int num_bits, // weight bits
|
||||||
const int threads, // number of threads in a threadblock
|
const int threads, // number of threads in a threadblock
|
||||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||||
// dimension (batchsize) of the threadblock
|
// dimension (batchsize) of the
|
||||||
const int thread_n_blocks, // same for n dimension (output)
|
// threadblock
|
||||||
const int thread_k_blocks, // same for k dimension (reduction)
|
const int thread_n_blocks, // same for n dimension (output)
|
||||||
const int stages, // number of stages for the async global->shared
|
const int thread_k_blocks, // same for k dimension (reduction)
|
||||||
// fetch pipeline
|
const int stages, // number of stages for the async global->shared
|
||||||
const int group_blocks = -1 // number of consecutive 16x16 blocks with
|
// fetch pipeline
|
||||||
// a separate quantization scale
|
const int group_blocks = -1 // number of consecutive 16x16 blocks
|
||||||
|
// with a separate quantization scale
|
||||||
>
|
>
|
||||||
__global__ void Marlin_24(
|
__global__ void Marlin_24(
|
||||||
const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
const int4* __restrict__ A, // fp16 input matrix of shape mxk
|
||||||
const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn
|
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
|
||||||
const int4
|
const int4* __restrict__ meta, // 2bit metadata information about 2:4
|
||||||
*__restrict__ meta, // 2bit metadata information about 2:4 format on B
|
// format on B
|
||||||
int4 *__restrict__ C, // fp16 output buffer of shape mxn
|
int4* __restrict__ C, // fp16 output buffer of shape mxn
|
||||||
const int4
|
const int4* __restrict__ s, // fp16 quantization scales of shape
|
||||||
*__restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn
|
// (k/groupsize)xn
|
||||||
int prob_m, // batch dimension m
|
int prob_m, // batch dimension m
|
||||||
int prob_n, // output dimension n
|
int prob_n, // output dimension n
|
||||||
int prob_k, // reduction dimension k
|
int prob_k, // reduction dimension k
|
||||||
int *locks // extra global storage for barrier synchronization
|
int* locks // extra global storage for barrier synchronization
|
||||||
) {}
|
) {}
|
||||||
|
|
||||||
torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
|
torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||||
torch::Tensor &b_meta,
|
torch::Tensor& b_meta,
|
||||||
torch::Tensor &b_scales,
|
torch::Tensor& b_scales,
|
||||||
torch::Tensor &workspace, int64_t num_bits,
|
torch::Tensor& workspace, int64_t num_bits,
|
||||||
int64_t size_m, int64_t size_n,
|
int64_t size_m, int64_t size_n,
|
||||||
int64_t size_k) {
|
int64_t size_k) {
|
||||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||||
@ -92,29 +96,30 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
|
|||||||
|
|
||||||
#else
|
#else
|
||||||
|
|
||||||
template <const int num_bits, // weight bits
|
template <const int num_bits, // weight bits
|
||||||
const int threads, // number of threads in a threadblock
|
const int threads, // number of threads in a threadblock
|
||||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||||
// dimension (batchsize) of the threadblock
|
// dimension (batchsize) of the
|
||||||
const int thread_n_blocks, // same for n dimension (output)
|
// threadblock
|
||||||
const int thread_k_blocks, // same for k dimension (reduction)
|
const int thread_n_blocks, // same for n dimension (output)
|
||||||
const int stages, // number of stages for the async global->shared
|
const int thread_k_blocks, // same for k dimension (reduction)
|
||||||
// fetch pipeline
|
const int stages, // number of stages for the async global->shared
|
||||||
const int group_blocks = -1 // number of consecutive 16x16 blocks with
|
// fetch pipeline
|
||||||
// a separate quantization scale
|
const int group_blocks = -1 // number of consecutive 16x16 blocks
|
||||||
|
// with a separate quantization scale
|
||||||
>
|
>
|
||||||
__global__ void Marlin_24(
|
__global__ void Marlin_24(
|
||||||
const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
const int4* __restrict__ A, // fp16 input matrix of shape mxk
|
||||||
const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn
|
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
|
||||||
const int4
|
const int4* __restrict__ meta, // 2bit metadata information about 2:4
|
||||||
*__restrict__ meta, // 2bit metadata information about 2:4 format on B
|
// format on B
|
||||||
int4 *__restrict__ C, // fp16 output buffer of shape mxn
|
int4* __restrict__ C, // fp16 output buffer of shape mxn
|
||||||
const int4
|
const int4* __restrict__ s, // fp16 quantization scales of shape
|
||||||
*__restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn
|
// (k/groupsize)xn
|
||||||
int prob_m, // batch dimension m
|
int prob_m, // batch dimension m
|
||||||
int prob_n, // output dimension n
|
int prob_n, // output dimension n
|
||||||
int prob_k, // reduction dimension k
|
int prob_k, // reduction dimension k
|
||||||
int *locks // extra global storage for barrier synchronization
|
int* locks // extra global storage for barrier synchronization
|
||||||
) {
|
) {
|
||||||
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
|
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
|
||||||
// same size, which might involve multiple column "slices" (of width 16 *
|
// same size, which might involve multiple column "slices" (of width 16 *
|
||||||
@ -174,27 +179,22 @@ __global__ void Marlin_24(
|
|||||||
auto init_slice = [&]() {
|
auto init_slice = [&]() {
|
||||||
slice_iters =
|
slice_iters =
|
||||||
iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);
|
iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);
|
||||||
if (slice_iters < 0 || slice_col_par >= n_tiles * parallel)
|
if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0;
|
||||||
slice_iters = 0;
|
if (slice_iters == 0) return;
|
||||||
if (slice_iters == 0)
|
if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row;
|
||||||
return;
|
|
||||||
if (slice_row + slice_iters > k_tiles)
|
|
||||||
slice_iters = k_tiles - slice_row;
|
|
||||||
slice_count = 1;
|
slice_count = 1;
|
||||||
slice_idx = 0;
|
slice_idx = 0;
|
||||||
int col_first = iters * ceildiv(k_tiles * slice_col_par, iters);
|
int col_first = iters * ceildiv(k_tiles * slice_col_par, iters);
|
||||||
if (col_first <= k_tiles * (slice_col_par + 1)) {
|
if (col_first <= k_tiles * (slice_col_par + 1)) {
|
||||||
int col_off = col_first - k_tiles * slice_col_par;
|
int col_off = col_first - k_tiles * slice_col_par;
|
||||||
slice_count = ceildiv(k_tiles - col_off, iters);
|
slice_count = ceildiv(k_tiles - col_off, iters);
|
||||||
if (col_off > 0)
|
if (col_off > 0) slice_count++;
|
||||||
slice_count++;
|
|
||||||
int delta_first = iters * blockIdx.x - col_first;
|
int delta_first = iters * blockIdx.x - col_first;
|
||||||
if (delta_first < 0 || (col_off == 0 && delta_first == 0))
|
if (delta_first < 0 || (col_off == 0 && delta_first == 0))
|
||||||
slice_idx = slice_count - 1;
|
slice_idx = slice_count - 1;
|
||||||
else {
|
else {
|
||||||
slice_idx = slice_count - 1 - delta_first / iters;
|
slice_idx = slice_count - 1 - delta_first / iters;
|
||||||
if (col_off > 0)
|
if (col_off > 0) slice_idx--;
|
||||||
slice_idx--;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (slice_col == n_tiles) {
|
if (slice_col == n_tiles) {
|
||||||
@ -207,7 +207,7 @@ __global__ void Marlin_24(
|
|||||||
init_slice();
|
init_slice();
|
||||||
|
|
||||||
// RLC: 8 is vec_size -> 128-bit instructions, 8 fp16 elements
|
// RLC: 8 is vec_size -> 128-bit instructions, 8 fp16 elements
|
||||||
int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory
|
int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory
|
||||||
|
|
||||||
// stride of an A matrix tile in shared memory
|
// stride of an A matrix tile in shared memory
|
||||||
constexpr int a_sh_stride = 32 * thread_k_blocks / 8;
|
constexpr int a_sh_stride = 32 * thread_k_blocks / 8;
|
||||||
@ -239,9 +239,9 @@ __global__ void Marlin_24(
|
|||||||
constexpr int b_sh_stage = b_sh_stride * thread_k_blocks;
|
constexpr int b_sh_stage = b_sh_stride * thread_k_blocks;
|
||||||
constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
|
constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
|
||||||
|
|
||||||
int m_gl_stride = 2 * prob_n / 8; // (16*2*4 / 8) = 16
|
int m_gl_stride = 2 * prob_n / 8; // (16*2*4 / 8) = 16
|
||||||
constexpr int m_sh_stride =
|
constexpr int m_sh_stride =
|
||||||
(16 * thread_n_blocks) / 4; // #warps n-dim * threads/warp
|
(16 * thread_n_blocks) / 4; // #warps n-dim * threads/warp
|
||||||
int m_gl_rd_delta_o = m_gl_stride * thread_k_blocks;
|
int m_gl_rd_delta_o = m_gl_stride * thread_k_blocks;
|
||||||
int m_gl_rd_delta_i = m_gl_stride * (threads / m_sh_stride);
|
int m_gl_rd_delta_i = m_gl_stride * (threads / m_sh_stride);
|
||||||
constexpr int m_sh_wr_delta = threads / 2;
|
constexpr int m_sh_wr_delta = threads / 2;
|
||||||
@ -305,7 +305,7 @@ __global__ void Marlin_24(
|
|||||||
// needed if there are more threads than required for a certain tilesize or
|
// needed if there are more threads than required for a certain tilesize or
|
||||||
// when the batchsize is not a multiple of 16.
|
// when the batchsize is not a multiple of 16.
|
||||||
bool a_sh_wr_pred[a_sh_wr_iters];
|
bool a_sh_wr_pred[a_sh_wr_iters];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < a_sh_wr_iters; i++) {
|
for (int i = 0; i < a_sh_wr_iters; i++) {
|
||||||
a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m;
|
a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m;
|
||||||
}
|
}
|
||||||
@ -325,13 +325,13 @@ __global__ void Marlin_24(
|
|||||||
// loop unrolls, all shared memory accesses are static, we simply precompute
|
// loop unrolls, all shared memory accesses are static, we simply precompute
|
||||||
// both transformed reads and writes.
|
// both transformed reads and writes.
|
||||||
int a_sh_wr_trans[a_sh_wr_iters];
|
int a_sh_wr_trans[a_sh_wr_iters];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < a_sh_wr_iters; i++)
|
for (int i = 0; i < a_sh_wr_iters; i++)
|
||||||
a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr);
|
a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr);
|
||||||
int a_sh_rd_trans[2][b_sh_wr_iters][thread_m_blocks];
|
int a_sh_rd_trans[2][b_sh_wr_iters][thread_m_blocks];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < b_sh_wr_iters; i++) {
|
for (int i = 0; i < b_sh_wr_iters; i++) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < thread_m_blocks; j++) {
|
for (int j = 0; j < thread_m_blocks; j++) {
|
||||||
a_sh_rd_trans[0][i][j] =
|
a_sh_rd_trans[0][i][j] =
|
||||||
transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd);
|
transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd);
|
||||||
@ -344,23 +344,23 @@ __global__ void Marlin_24(
|
|||||||
// runtime; we break dependencies between subsequent accesses with a tile by
|
// runtime; we break dependencies between subsequent accesses with a tile by
|
||||||
// maintining multiple pointers (we have enough registers), a tiny
|
// maintining multiple pointers (we have enough registers), a tiny
|
||||||
// optimization.
|
// optimization.
|
||||||
const int4 *B_ptr[b_sh_wr_iters];
|
const int4* B_ptr[b_sh_wr_iters];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < b_sh_wr_iters; i++)
|
for (int i = 0; i < b_sh_wr_iters; i++)
|
||||||
B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;
|
B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;
|
||||||
|
|
||||||
bool m_sh_wr_pred = threadIdx.x < m_sh_wr_delta;
|
bool m_sh_wr_pred = threadIdx.x < m_sh_wr_delta;
|
||||||
const int4 *meta_ptr[m_sh_iters];
|
const int4* meta_ptr[m_sh_iters];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < m_sh_iters; i++)
|
for (int i = 0; i < m_sh_iters; i++)
|
||||||
meta_ptr[i] = meta + m_gl_rd_delta_i * i + m_gl_rd;
|
meta_ptr[i] = meta + m_gl_rd_delta_i * i + m_gl_rd;
|
||||||
|
|
||||||
extern __shared__ int4 sh[];
|
extern __shared__ int4 sh[];
|
||||||
// Shared memory storage for global fetch pipelines.
|
// Shared memory storage for global fetch pipelines.
|
||||||
int4 *sh_a = sh;
|
int4* sh_a = sh;
|
||||||
int4 *sh_b = sh_a + (stages * a_sh_stage);
|
int4* sh_b = sh_a + (stages * a_sh_stage);
|
||||||
int4 *sh_s = sh_b + (stages * b_sh_stage);
|
int4* sh_s = sh_b + (stages * b_sh_stage);
|
||||||
int4 *sh_m = sh_s + (stages * s_sh_stage);
|
int4* sh_m = sh_s + (stages * s_sh_stage);
|
||||||
// Register storage for double buffer of shared memory reads.
|
// Register storage for double buffer of shared memory reads.
|
||||||
FragA frag_a[2][thread_m_blocks][2];
|
FragA frag_a[2][thread_m_blocks][2];
|
||||||
I4 frag_b_quant[2][b_thread_vecs];
|
I4 frag_b_quant[2][b_thread_vecs];
|
||||||
@ -370,46 +370,43 @@ __global__ void Marlin_24(
|
|||||||
|
|
||||||
// Zero accumulators.
|
// Zero accumulators.
|
||||||
auto zero_accums = [&]() {
|
auto zero_accums = [&]() {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++)
|
for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++)
|
||||||
reinterpret_cast<float *>(frag_c)[i] = 0;
|
reinterpret_cast<float*>(frag_c)[i] = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Asynchronously fetch the next A, B and s tile from global to the next
|
// Asynchronously fetch the next A, B and s tile from global to the next
|
||||||
// shared memory pipeline location.
|
// shared memory pipeline location.
|
||||||
auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) {
|
auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) {
|
||||||
if (pred) {
|
if (pred) {
|
||||||
int4 *sh_a_stage = sh_a + a_sh_stage * pipe;
|
int4* sh_a_stage = sh_a + a_sh_stage * pipe;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < a_sh_wr_iters; i++) {
|
for (int i = 0; i < a_sh_wr_iters; i++) {
|
||||||
cp_async4_pred(
|
cp_async4_pred(
|
||||||
&sh_a_stage[a_sh_wr_trans[i]],
|
&sh_a_stage[a_sh_wr_trans[i]],
|
||||||
&A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off],
|
&A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off],
|
||||||
a_sh_wr_pred[i]);
|
a_sh_wr_pred[i]);
|
||||||
}
|
}
|
||||||
int4 *sh_b_stage = sh_b + b_sh_stage * pipe;
|
int4* sh_b_stage = sh_b + b_sh_stage * pipe;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < b_sh_wr_iters; i++) {
|
for (int i = 0; i < b_sh_wr_iters; i++) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < b_thread_vecs; j++) {
|
for (int j = 0; j < b_thread_vecs; j++) {
|
||||||
cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j],
|
cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j);
|
||||||
B_ptr[i] + j);
|
|
||||||
}
|
}
|
||||||
B_ptr[i] += b_gl_rd_delta_o;
|
B_ptr[i] += b_gl_rd_delta_o;
|
||||||
}
|
}
|
||||||
int4 *sh_meta_stage = sh_m + m_sh_stage * pipe;
|
int4* sh_meta_stage = sh_m + m_sh_stage * pipe;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < m_sh_iters; i++) {
|
for (int i = 0; i < m_sh_iters; i++) {
|
||||||
if (m_sh_wr_pred)
|
if (m_sh_wr_pred)
|
||||||
cp_async4(&sh_meta_stage[m_sh_wr_delta * i + m_sh_wr],
|
cp_async4(&sh_meta_stage[m_sh_wr_delta * i + m_sh_wr], meta_ptr[i]);
|
||||||
meta_ptr[i]);
|
|
||||||
meta_ptr[i] += m_gl_rd_delta_o;
|
meta_ptr[i] += m_gl_rd_delta_o;
|
||||||
}
|
}
|
||||||
// Only fetch scales if this tile starts a new group
|
// Only fetch scales if this tile starts a new group
|
||||||
if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) {
|
if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) {
|
||||||
int4 *sh_s_stage = sh_s + s_sh_stage * pipe;
|
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
|
||||||
if (s_sh_wr_pred)
|
if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
|
||||||
cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
|
|
||||||
s_gl_rd += s_gl_rd_delta;
|
s_gl_rd += s_gl_rd_delta;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -436,13 +433,13 @@ __global__ void Marlin_24(
|
|||||||
// theoretically better attempts have lead to bad instruction ordering by
|
// theoretically better attempts have lead to bad instruction ordering by
|
||||||
// the compiler and correspondingly a noticeable drop in performance.
|
// the compiler and correspondingly a noticeable drop in performance.
|
||||||
if (group_blocks != -1) {
|
if (group_blocks != -1) {
|
||||||
int4 *sh_s_stage =
|
int4* sh_s_stage =
|
||||||
sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
|
sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
|
||||||
(pipe / (group_blocks / thread_k_blocks)));
|
(pipe / (group_blocks / thread_k_blocks)));
|
||||||
reinterpret_cast<int4 *>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
|
reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
|
||||||
}
|
}
|
||||||
int4 *sh_a_stage = sh_a + a_sh_stage * pipe;
|
int4* sh_a_stage = sh_a + a_sh_stage * pipe;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < thread_m_blocks; i++) {
|
for (int i = 0; i < thread_m_blocks; i++) {
|
||||||
ldsm4(frag_a[k % 2][i][0],
|
ldsm4(frag_a[k % 2][i][0],
|
||||||
&sh_a_stage[a_sh_rd_trans[0][k % b_sh_wr_iters][i]]);
|
&sh_a_stage[a_sh_rd_trans[0][k % b_sh_wr_iters][i]]);
|
||||||
@ -450,24 +447,24 @@ __global__ void Marlin_24(
|
|||||||
&sh_a_stage[a_sh_rd_trans[1][k % b_sh_wr_iters][i]]);
|
&sh_a_stage[a_sh_rd_trans[1][k % b_sh_wr_iters][i]]);
|
||||||
}
|
}
|
||||||
|
|
||||||
int4 *sh_b_stage = sh_b + b_sh_stage * pipe;
|
int4* sh_b_stage = sh_b + b_sh_stage * pipe;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < b_thread_vecs; i++) {
|
for (int i = 0; i < b_thread_vecs; i++) {
|
||||||
frag_b_quant[k % 2][i] = *reinterpret_cast<I4 *>(
|
frag_b_quant[k % 2][i] = *reinterpret_cast<I4*>(
|
||||||
&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]);
|
&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load meta with ldsm4
|
// Load meta with ldsm4
|
||||||
int4 *sh_m_stage = sh_m + m_sh_stage * pipe;
|
int4* sh_m_stage = sh_m + m_sh_stage * pipe;
|
||||||
ldsm4_m(frag_m[k % 2][0],
|
ldsm4_m(frag_m[k % 2][0],
|
||||||
&sh_m_stage[m_sh_rd_delta * (k % m_sh_iters) + m_sh_rd]);
|
&sh_m_stage[m_sh_rd_delta * (k % m_sh_iters) + m_sh_rd]);
|
||||||
};
|
};
|
||||||
|
|
||||||
// Execute the actual tensor core matmul of a sub-tile.
|
// Execute the actual tensor core matmul of a sub-tile.
|
||||||
auto matmul = [&](int k) {
|
auto matmul = [&](int k) {
|
||||||
// We have the m dimension as the inner loop in order to encourage overlapping
|
// We have the m dimension as the inner loop in order to encourage overlapping
|
||||||
// dequantization and matmul operations.
|
// dequantization and matmul operations.
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < 4; j++) {
|
for (int j = 0; j < 4; j++) {
|
||||||
FragB frag_b0;
|
FragB frag_b0;
|
||||||
FragB frag_b1;
|
FragB frag_b1;
|
||||||
@ -480,7 +477,7 @@ __global__ void Marlin_24(
|
|||||||
frag_b1 = dequant_4bit(b_quant_shift);
|
frag_b1 = dequant_4bit(b_quant_shift);
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
int *frag_b_quant_ptr = reinterpret_cast<int *>(frag_b_quant[k % 2]);
|
int* frag_b_quant_ptr = reinterpret_cast<int*>(frag_b_quant[k % 2]);
|
||||||
int b_quant_0 = frag_b_quant_ptr[j * 2 + 0];
|
int b_quant_0 = frag_b_quant_ptr[j * 2 + 0];
|
||||||
int b_quant_1 = frag_b_quant_ptr[j * 2 + 1];
|
int b_quant_1 = frag_b_quant_ptr[j * 2 + 1];
|
||||||
|
|
||||||
@ -497,7 +494,7 @@ __global__ void Marlin_24(
|
|||||||
scale(frag_b1, frag_s[k % 2][j], 1);
|
scale(frag_b1, frag_s[k % 2][j], 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < thread_m_blocks; i++) {
|
for (int i = 0; i < thread_m_blocks; i++) {
|
||||||
mma_sp(frag_b0, frag_b1, frag_a[k % 2][i][0], frag_c[i][j][0],
|
mma_sp(frag_b0, frag_b1, frag_a[k % 2][i][0], frag_c[i][j][0],
|
||||||
frag_m[k % 2][j / 2], j % 2);
|
frag_m[k % 2][j / 2], j % 2);
|
||||||
@ -518,41 +515,41 @@ __global__ void Marlin_24(
|
|||||||
int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) +
|
int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) +
|
||||||
(threadIdx.x % b_sh_stride_threads);
|
(threadIdx.x % b_sh_stride_threads);
|
||||||
|
|
||||||
// Parallel logarithmic shared memory reduction. We make sure to avoid any
|
// Parallel logarithmic shared memory reduction. We make sure to avoid any
|
||||||
// unnecessary read or write iterations, e.g., for two warps we write only once
|
// unnecessary read or write iterations, e.g., for two warps we write only
|
||||||
// by warp 1 and read only once by warp 0.
|
// once by warp 1 and read only once by warp 0.
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int m_block = 0; m_block < thread_m_blocks; m_block++) {
|
for (int m_block = 0; m_block < thread_m_blocks; m_block++) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = red_off; i > 0; i /= 2) {
|
for (int i = red_off; i > 0; i /= 2) {
|
||||||
if (i <= red_idx && red_idx < 2 * i) {
|
if (i <= red_idx && red_idx < 2 * i) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < 4 * 2; j++) {
|
for (int j = 0; j < 4 * 2; j++) {
|
||||||
int red_sh_wr =
|
int red_sh_wr =
|
||||||
red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
|
red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
|
||||||
if (i < red_off) {
|
if (i < red_off) {
|
||||||
float *c_rd = reinterpret_cast<float *>(
|
float* c_rd =
|
||||||
&sh[red_sh_delta * j + red_sh_rd]);
|
reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]);
|
||||||
float *c_wr = reinterpret_cast<float *>(&sh[red_sh_wr]);
|
float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int k = 0; k < 4; k++)
|
for (int k = 0; k < 4; k++)
|
||||||
reinterpret_cast<FragC *>(frag_c)[4 * 2 * m_block + j][k] +=
|
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] +=
|
||||||
c_rd[k] + c_wr[k];
|
c_rd[k] + c_wr[k];
|
||||||
}
|
}
|
||||||
sh[red_sh_wr] =
|
sh[red_sh_wr] =
|
||||||
reinterpret_cast<int4 *>(&frag_c)[4 * 2 * m_block + j];
|
reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
}
|
}
|
||||||
if (red_idx == 0) {
|
if (red_idx == 0) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < 4 * 2; i++) {
|
for (int i = 0; i < 4 * 2; i++) {
|
||||||
float *c_rd =
|
float* c_rd =
|
||||||
reinterpret_cast<float *>(&sh[red_sh_delta * i + red_sh_rd]);
|
reinterpret_cast<float*>(&sh[red_sh_delta * i + red_sh_rd]);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < 4; j++)
|
for (int j = 0; j < 4; j++)
|
||||||
reinterpret_cast<FragC *>(frag_c)[4 * 2 * m_block + i][j] +=
|
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] +=
|
||||||
c_rd[j];
|
c_rd[j];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -562,9 +559,9 @@ __global__ void Marlin_24(
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Since multiple threadblocks may process parts of the same column slice, we
|
// Since multiple threadblocks may process parts of the same column slice, we
|
||||||
// finally have to globally reduce over the results. As the striped partitioning
|
// finally have to globally reduce over the results. As the striped
|
||||||
// minimizes the number of such reductions and our outputs are usually rather
|
// partitioning minimizes the number of such reductions and our outputs are
|
||||||
// small, we perform this reduction serially in L2 cache.
|
// usually rather small, we perform this reduction serially in L2 cache.
|
||||||
auto global_reduce = [&](bool first = false, bool last = false) {
|
auto global_reduce = [&](bool first = false, bool last = false) {
|
||||||
// We are very careful here to reduce directly in the output buffer to
|
// We are very careful here to reduce directly in the output buffer to
|
||||||
// maximize L2 cache utilization in this step. To do this, we write out
|
// maximize L2 cache utilization in this step. To do this, we write out
|
||||||
@ -574,7 +571,7 @@ __global__ void Marlin_24(
|
|||||||
int c_gl_stride = prob_n / 8;
|
int c_gl_stride = prob_n / 8;
|
||||||
int c_gl_wr_delta_o = 2 * 4 * c_gl_stride;
|
int c_gl_wr_delta_o = 2 * 4 * c_gl_stride;
|
||||||
int c_gl_wr_delta_i =
|
int c_gl_wr_delta_i =
|
||||||
c_gl_stride; // 8 threads (e.g., 0,4,8,12,16,20,24,28)
|
c_gl_stride; // 8 threads (e.g., 0,4,8,12,16,20,24,28)
|
||||||
int c_gl_wr = 2 * c_gl_stride * (threadIdx.x % 4) +
|
int c_gl_wr = 2 * c_gl_stride * (threadIdx.x % 4) +
|
||||||
8 * (threadIdx.x / 32) + (threadIdx.x % 32) / 4;
|
8 * (threadIdx.x / 32) + (threadIdx.x % 32) / 4;
|
||||||
c_gl_wr += (2 * thread_n_blocks) * slice_col;
|
c_gl_wr += (2 * thread_n_blocks) * slice_col;
|
||||||
@ -584,10 +581,10 @@ __global__ void Marlin_24(
|
|||||||
int col = 2 * ((threadIdx.x % 32) % 4);
|
int col = 2 * ((threadIdx.x % 32) % 4);
|
||||||
|
|
||||||
if (!first) {
|
if (!first) {
|
||||||
// Interestingly, doing direct global accesses here really seems to mess up the
|
// Interestingly, doing direct global accesses here really seems to mess up
|
||||||
// compiler and lead to slowdowns, hence we also use async-copies even though
|
// the compiler and lead to slowdowns, hence we also use async-copies even
|
||||||
// these fetches are not actually asynchronous.
|
// though these fetches are not actually asynchronous.
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < thread_m_blocks * 4; i++) {
|
for (int i = 0; i < thread_m_blocks * 4; i++) {
|
||||||
cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i],
|
cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i],
|
||||||
&C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
|
&C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
|
||||||
@ -599,32 +596,32 @@ __global__ void Marlin_24(
|
|||||||
cp_async_wait<0>();
|
cp_async_wait<0>();
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < thread_m_blocks * 4; i++) {
|
for (int i = 0; i < thread_m_blocks * 4; i++) {
|
||||||
if (i < (thread_m_blocks - 1) * 4 ||
|
if (i < (thread_m_blocks - 1) * 4 ||
|
||||||
8 * (i / 2) + col + (i % 2) < prob_m) {
|
8 * (i / 2) + col + (i % 2) < prob_m) {
|
||||||
if (!first) {
|
if (!first) {
|
||||||
int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta];
|
int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j2 = 0; j2 < 2; j2++) {
|
for (int j2 = 0; j2 < 2; j2++) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j1 = 0; j1 < 4; j1++) {
|
for (int j1 = 0; j1 < 4; j1++) {
|
||||||
reinterpret_cast<float *>(
|
reinterpret_cast<float*>(
|
||||||
&frag_c)[4 * 2 * 4 * (i / 4) + 8 * j1 + 2 * j2 +
|
&frag_c)[4 * 2 * 4 * (i / 4) + 8 * j1 + 2 * j2 +
|
||||||
4 * ((i % 4) / 2) + i % 2] +=
|
4 * ((i % 4) / 2) + i % 2] +=
|
||||||
__half2float(
|
__half2float(
|
||||||
reinterpret_cast<__half *>(&c_red)[(j2 * 4 + j1)]);
|
reinterpret_cast<__half*>(&c_red)[(j2 * 4 + j1)]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!last) {
|
if (!last) {
|
||||||
int4 c;
|
int4 c;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j2 = 0; j2 < 2; j2++) {
|
for (int j2 = 0; j2 < 2; j2++) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j1 = 0; j1 < 4; j1++) {
|
for (int j1 = 0; j1 < 4; j1++) {
|
||||||
reinterpret_cast<__half *>(&c)[(j2 * 4 + j1)] =
|
reinterpret_cast<__half*>(&c)[(j2 * 4 + j1)] =
|
||||||
__float2half(reinterpret_cast<float *>(
|
__float2half(reinterpret_cast<float*>(
|
||||||
&frag_c)[4 * 2 * 4 * (i / 4) + 8 * j1 + 2 * j2 +
|
&frag_c)[4 * 2 * 4 * (i / 4) + 8 * j1 + 2 * j2 +
|
||||||
4 * ((i % 4) / 2) + i % 2]);
|
4 * ((i % 4) / 2) + i % 2]);
|
||||||
}
|
}
|
||||||
@ -643,9 +640,9 @@ __global__ void Marlin_24(
|
|||||||
auto write_result = [&]() {
|
auto write_result = [&]() {
|
||||||
int c_gl_stride = prob_n / 8;
|
int c_gl_stride = prob_n / 8;
|
||||||
|
|
||||||
constexpr int c_sh_stride = 2 * thread_n_blocks; // RLC:
|
constexpr int c_sh_stride = 2 * thread_n_blocks; // RLC:
|
||||||
constexpr int c_sh_stride_2 = 2 * c_sh_stride + 2; // RLC:
|
constexpr int c_sh_stride_2 = 2 * c_sh_stride + 2; // RLC:
|
||||||
constexpr int c_sh_stride_3 = 2 * (2 * thread_n_blocks) + 2; // RLC:
|
constexpr int c_sh_stride_3 = 2 * (2 * thread_n_blocks) + 2; // RLC:
|
||||||
|
|
||||||
int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));
|
int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));
|
||||||
|
|
||||||
@ -654,22 +651,22 @@ __global__ void Marlin_24(
|
|||||||
c_gl_wr += (2 * thread_n_blocks) * slice_col;
|
c_gl_wr += (2 * thread_n_blocks) * slice_col;
|
||||||
|
|
||||||
int c_sh_wr = c_sh_stride_2 * ((threadIdx.x % 32) % 4) +
|
int c_sh_wr = c_sh_stride_2 * ((threadIdx.x % 32) % 4) +
|
||||||
((threadIdx.x % 32) / 4); // RLC:
|
((threadIdx.x % 32) / 4); // RLC:
|
||||||
c_sh_wr += 8 * (threadIdx.x / 32); // 128/4(half4)
|
c_sh_wr += 8 * (threadIdx.x / 32); // 128/4(half4)
|
||||||
|
|
||||||
constexpr int c_sh_rd_delta =
|
constexpr int c_sh_rd_delta =
|
||||||
c_sh_stride_3 * (threads / (2 * 2 * thread_n_blocks)); // RLC:
|
c_sh_stride_3 * (threads / (2 * 2 * thread_n_blocks)); // RLC:
|
||||||
int c_sh_rd = c_sh_stride_3 * (threadIdx.x / (2 * 2 * thread_n_blocks)) +
|
int c_sh_rd = c_sh_stride_3 * (threadIdx.x / (2 * 2 * thread_n_blocks)) +
|
||||||
(threadIdx.x % (2 * 2 * thread_n_blocks));
|
(threadIdx.x % (2 * 2 * thread_n_blocks));
|
||||||
|
|
||||||
int c_gl_wr_end = c_gl_stride * prob_m;
|
int c_gl_wr_end = c_gl_stride * prob_m;
|
||||||
|
|
||||||
auto write = [&](int idx, float c0, float c1, float c2, float c3, FragS &s0,
|
auto write = [&](int idx, float c0, float c1, float c2, float c3, FragS& s0,
|
||||||
float c4, float c5, float c6, float c7, FragS &s1) {
|
float c4, float c5, float c6, float c7, FragS& s1) {
|
||||||
uint2 res[2];
|
uint2 res[2];
|
||||||
res[0] = to_half4(c0, c1, c2, c3);
|
res[0] = to_half4(c0, c1, c2, c3);
|
||||||
res[1] = to_half4(c4, c5, c6, c7);
|
res[1] = to_half4(c4, c5, c6, c7);
|
||||||
half2 *tmp = (half2 *)&res;
|
half2* tmp = (half2*)&res;
|
||||||
// for per-column quantization we finally apply the scale here
|
// for per-column quantization we finally apply the scale here
|
||||||
if constexpr (group_blocks == -1 && num_bits == 4) {
|
if constexpr (group_blocks == -1 && num_bits == 4) {
|
||||||
tmp[0] = __hmul2(tmp[0], s0[0]);
|
tmp[0] = __hmul2(tmp[0], s0[0]);
|
||||||
@ -677,12 +674,12 @@ __global__ void Marlin_24(
|
|||||||
tmp[2] = __hmul2(tmp[2], s1[0]);
|
tmp[2] = __hmul2(tmp[2], s1[0]);
|
||||||
tmp[3] = __hmul2(tmp[3], s1[1]);
|
tmp[3] = __hmul2(tmp[3], s1[1]);
|
||||||
}
|
}
|
||||||
((int4 *)sh)[idx] = *((int4 *)&res[0]);
|
((int4*)sh)[idx] = *((int4*)&res[0]);
|
||||||
};
|
};
|
||||||
|
|
||||||
// RLC: only warp 0 and 1 baseline example
|
// RLC: only warp 0 and 1 baseline example
|
||||||
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < thread_m_blocks; i++) {
|
for (int i = 0; i < thread_m_blocks; i++) {
|
||||||
int wr = c_sh_wr;
|
int wr = c_sh_wr;
|
||||||
write(wr, frag_c[i][0][0][0], frag_c[i][1][0][0], frag_c[i][2][0][0],
|
write(wr, frag_c[i][0][0][0], frag_c[i][1][0][0], frag_c[i][2][0][0],
|
||||||
@ -707,7 +704,7 @@ __global__ void Marlin_24(
|
|||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0;
|
for (int i = 0;
|
||||||
i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks));
|
i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks));
|
||||||
i++) {
|
i++) {
|
||||||
@ -721,9 +718,8 @@ __global__ void Marlin_24(
|
|||||||
|
|
||||||
// Start global fetch and register load pipelines.
|
// Start global fetch and register load pipelines.
|
||||||
auto start_pipes = [&]() {
|
auto start_pipes = [&]() {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < stages - 1; i++)
|
for (int i = 0; i < stages - 1; i++) fetch_to_shared(i, i, i < slice_iters);
|
||||||
fetch_to_shared(i, i, i < slice_iters);
|
|
||||||
zero_accums();
|
zero_accums();
|
||||||
wait_for_stage();
|
wait_for_stage();
|
||||||
fetch_to_registers(0, 0);
|
fetch_to_registers(0, 0);
|
||||||
@ -733,10 +729,10 @@ __global__ void Marlin_24(
|
|||||||
|
|
||||||
// Main loop.
|
// Main loop.
|
||||||
while (slice_iters) {
|
while (slice_iters) {
|
||||||
// We unroll over both the global fetch and the register load pipeline to ensure
|
// We unroll over both the global fetch and the register load pipeline to
|
||||||
// all shared memory accesses are static. Note that both pipelines have even
|
// ensure all shared memory accesses are static. Note that both pipelines have
|
||||||
// length meaning that the next iteration will always start at index 0.
|
// even length meaning that the next iteration will always start at index 0.
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int pipe = 0; pipe < stages;) {
|
for (int pipe = 0; pipe < stages;) {
|
||||||
fetch_to_shared((pipe + stages - 1) % stages, pipe,
|
fetch_to_shared((pipe + stages - 1) % stages, pipe,
|
||||||
slice_iters >= stages);
|
slice_iters >= stages);
|
||||||
@ -747,8 +743,7 @@ __global__ void Marlin_24(
|
|||||||
|
|
||||||
pipe++;
|
pipe++;
|
||||||
slice_iters--;
|
slice_iters--;
|
||||||
if (slice_iters == 0)
|
if (slice_iters == 0) break;
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
a_gl_rd += a_gl_rd_delta_o * stages;
|
a_gl_rd += a_gl_rd_delta_o * stages;
|
||||||
|
|
||||||
@ -762,13 +757,11 @@ __global__ void Marlin_24(
|
|||||||
// write-out
|
// write-out
|
||||||
if constexpr (group_blocks == -1) {
|
if constexpr (group_blocks == -1) {
|
||||||
if constexpr (num_bits == 8) {
|
if constexpr (num_bits == 8) {
|
||||||
if (s_sh_wr_pred)
|
if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]);
|
||||||
cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]);
|
|
||||||
cp_async_fence();
|
cp_async_fence();
|
||||||
} else {
|
} else {
|
||||||
if (last) {
|
if (last) {
|
||||||
if (s_sh_wr_pred)
|
if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]);
|
||||||
cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]);
|
|
||||||
cp_async_fence();
|
cp_async_fence();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -780,14 +773,14 @@ __global__ void Marlin_24(
|
|||||||
cp_async_wait<0>();
|
cp_async_wait<0>();
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
||||||
*(float4 *)(frag_s) = *(float4 *)(&sh_s[s_sh_rd]);
|
*(float4*)(frag_s) = *(float4*)(&sh_s[s_sh_rd]);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (last) {
|
if (last) {
|
||||||
cp_async_wait<0>();
|
cp_async_wait<0>();
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
||||||
*(float4 *)(frag_s) = *(float4 *)(&sh_s[s_sh_rd]);
|
*(float4*)(frag_s) = *(float4*)(&sh_s[s_sh_rd]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -798,7 +791,7 @@ __global__ void Marlin_24(
|
|||||||
// overflow in fp16)
|
// overflow in fp16)
|
||||||
if constexpr (group_blocks == -1 && num_bits == 8) {
|
if constexpr (group_blocks == -1 && num_bits == 8) {
|
||||||
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < thread_m_blocks; i++) {
|
for (int i = 0; i < thread_m_blocks; i++) {
|
||||||
scale_floats(&frag_c[i][0][0][0], &frag_c[i][1][0][0],
|
scale_floats(&frag_c[i][0][0][0], &frag_c[i][1][0][0],
|
||||||
&frag_c[i][2][0][0], &frag_c[i][3][0][0], frag_s[0][0],
|
&frag_c[i][2][0][0], &frag_c[i][3][0][0], frag_s[0][0],
|
||||||
@ -827,13 +820,13 @@ __global__ void Marlin_24(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (slice_count > 1) { // only globally reduce if there is more than one
|
if (slice_count > 1) { // only globally reduce if there is more than one
|
||||||
// block in a slice
|
// block in a slice
|
||||||
barrier_acquire(&locks[slice_col], slice_idx);
|
barrier_acquire(&locks[slice_col], slice_idx);
|
||||||
global_reduce(slice_idx == 0, last);
|
global_reduce(slice_idx == 0, last);
|
||||||
barrier_release(&locks[slice_col], last);
|
barrier_release(&locks[slice_col], last);
|
||||||
}
|
}
|
||||||
if (last) // only the last block in a slice actually writes the result
|
if (last) // only the last block in a slice actually writes the result
|
||||||
write_result();
|
write_result();
|
||||||
|
|
||||||
slice_row = 0;
|
slice_row = 0;
|
||||||
@ -843,19 +836,17 @@ __global__ void Marlin_24(
|
|||||||
if (slice_iters) {
|
if (slice_iters) {
|
||||||
a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
|
a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
|
||||||
(threadIdx.x % a_gl_rd_delta_o);
|
(threadIdx.x % a_gl_rd_delta_o);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < b_sh_wr_iters; i++)
|
for (int i = 0; i < b_sh_wr_iters; i++)
|
||||||
B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
|
B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < m_sh_iters; i++)
|
for (int i = 0; i < m_sh_iters; i++)
|
||||||
meta_ptr[i] += (m_sh_stride)-m_gl_rd_delta_o * k_tiles;
|
meta_ptr[i] += (m_sh_stride)-m_gl_rd_delta_o * k_tiles;
|
||||||
if (slice_col == 0) {
|
if (slice_col == 0) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < b_sh_wr_iters; i++)
|
for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
|
||||||
B_ptr[i] -= b_gl_stride;
|
#pragma unroll
|
||||||
#pragma unroll
|
for (int i = 0; i < m_sh_iters; i++) meta_ptr[i] -= m_gl_stride;
|
||||||
for (int i = 0; i < m_sh_iters; i++)
|
|
||||||
meta_ptr[i] -= m_gl_stride;
|
|
||||||
}
|
}
|
||||||
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
|
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
|
||||||
start_pipes();
|
start_pipes();
|
||||||
@ -866,26 +857,26 @@ __global__ void Marlin_24(
|
|||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#define CALL_IF_2_4(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
|
#define CALL_IF_2_4(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
|
||||||
THREAD_K_BLOCKS, GROUP_BLOCKS) \
|
THREAD_K_BLOCKS, GROUP_BLOCKS) \
|
||||||
else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \
|
else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \
|
||||||
thread_n_blocks == THREAD_N_BLOCKS && \
|
thread_n_blocks == THREAD_N_BLOCKS && \
|
||||||
thread_k_blocks == THREAD_K_BLOCKS && \
|
thread_k_blocks == THREAD_K_BLOCKS && \
|
||||||
group_blocks == GROUP_BLOCKS) { \
|
group_blocks == GROUP_BLOCKS) { \
|
||||||
cudaFuncSetAttribute( \
|
cudaFuncSetAttribute( \
|
||||||
Marlin_24<NUM_BITS, THREADS, THREAD_N_BLOCKS, THREAD_M_BLOCKS, \
|
Marlin_24<NUM_BITS, THREADS, THREAD_N_BLOCKS, THREAD_M_BLOCKS, \
|
||||||
THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS>, \
|
THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS>, \
|
||||||
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
||||||
Marlin_24<NUM_BITS, THREADS, THREAD_N_BLOCKS, THREAD_M_BLOCKS, \
|
Marlin_24<NUM_BITS, THREADS, THREAD_N_BLOCKS, THREAD_M_BLOCKS, \
|
||||||
THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS> \
|
THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS> \
|
||||||
<<<blocks, THREADS, max_shared_mem, stream>>>(A_ptr, B_ptr, meta_ptr, \
|
<<<blocks, THREADS, max_shared_mem, stream>>>(A_ptr, B_ptr, meta_ptr, \
|
||||||
C_ptr, s_ptr, prob_n, \
|
C_ptr, s_ptr, prob_n, \
|
||||||
prob_m, prob_k, locks); \
|
prob_m, prob_k, locks); \
|
||||||
}
|
}
|
||||||
|
|
||||||
void marlin_cuda_2_4(const void *A, const void *B, const void *meta, void *C,
|
void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C,
|
||||||
void *s, int prob_m, int prob_n, int prob_k,
|
void* s, int prob_m, int prob_n, int prob_k,
|
||||||
void *workspace, int num_bits, int groupsize = -1,
|
void* workspace, int num_bits, int groupsize = -1,
|
||||||
int dev = 0, cudaStream_t stream = 0, int thread_k = -1,
|
int dev = 0, cudaStream_t stream = 0, int thread_k = -1,
|
||||||
int thread_m = -1, int sms = -1, int max_par = 16) {
|
int thread_m = -1, int sms = -1, int max_par = 16) {
|
||||||
int tot_n = prob_n;
|
int tot_n = prob_n;
|
||||||
@ -904,8 +895,8 @@ void marlin_cuda_2_4(const void *A, const void *B, const void *meta, void *C,
|
|||||||
|
|
||||||
if (thread_k == -1 || thread_m == -1) {
|
if (thread_k == -1 || thread_m == -1) {
|
||||||
if (prob_n <= 16) {
|
if (prob_n <= 16) {
|
||||||
// For small batchizes, better partitioningif is slightly more important than
|
// For small batchizes, better partitioningif is slightly more important
|
||||||
// better compute utilization
|
// than better compute utilization
|
||||||
thread_k = 128;
|
thread_k = 128;
|
||||||
thread_m = 128;
|
thread_m = 128;
|
||||||
} else {
|
} else {
|
||||||
@ -914,7 +905,7 @@ void marlin_cuda_2_4(const void *A, const void *B, const void *meta, void *C,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int thread_k_blocks = thread_k / 32; // 2:4 version with m16n8k32 instruction
|
int thread_k_blocks = thread_k / 32; // 2:4 version with m16n8k32 instruction
|
||||||
int thread_m_blocks = thread_m / 16;
|
int thread_m_blocks = thread_m / 16;
|
||||||
int group_blocks = (groupsize == -1) ? -1 : groupsize / 16;
|
int group_blocks = (groupsize == -1) ? -1 : groupsize / 16;
|
||||||
int blocks = sms;
|
int blocks = sms;
|
||||||
@ -931,13 +922,13 @@ void marlin_cuda_2_4(const void *A, const void *B, const void *meta, void *C,
|
|||||||
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
|
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
|
||||||
", ", prob_n, ", ", prob_k, "]");
|
", ", prob_n, ", ", prob_k, "]");
|
||||||
|
|
||||||
const int4 *A_ptr = (const int4 *)A;
|
const int4* A_ptr = (const int4*)A;
|
||||||
const int4 *B_ptr = (const int4 *)B;
|
const int4* B_ptr = (const int4*)B;
|
||||||
const int4 *meta_ptr = (const int4 *)meta;
|
const int4* meta_ptr = (const int4*)meta;
|
||||||
int4 *C_ptr = (int4 *)C;
|
int4* C_ptr = (int4*)C;
|
||||||
const int4 *s_ptr = (const int4 *)s;
|
const int4* s_ptr = (const int4*)s;
|
||||||
|
|
||||||
int *locks = (int *)workspace;
|
int* locks = (int*)workspace;
|
||||||
for (int i = 0; i < tot_n_blocks; i += 4) {
|
for (int i = 0; i < tot_n_blocks; i += 4) {
|
||||||
int thread_n_blocks = tot_n_blocks - i;
|
int thread_n_blocks = tot_n_blocks - i;
|
||||||
prob_n = tot_n - 16 * i;
|
prob_n = tot_n - 16 * i;
|
||||||
@ -946,8 +937,7 @@ void marlin_cuda_2_4(const void *A, const void *B, const void *meta, void *C,
|
|||||||
// Note that parallel > 1 currently only works for inputs without any
|
// Note that parallel > 1 currently only works for inputs without any
|
||||||
// padding
|
// padding
|
||||||
par = (16 * thread_n_blocks - pad) / 64;
|
par = (16 * thread_n_blocks - pad) / 64;
|
||||||
if (par > max_par)
|
if (par > max_par) par = max_par;
|
||||||
par = max_par;
|
|
||||||
prob_n = 64 * par;
|
prob_n = 64 * par;
|
||||||
i += 4 * (par - 1);
|
i += 4 * (par - 1);
|
||||||
thread_n_blocks = 4;
|
thread_n_blocks = 4;
|
||||||
@ -959,13 +949,13 @@ void marlin_cuda_2_4(const void *A, const void *B, const void *meta, void *C,
|
|||||||
|
|
||||||
// the false is start of the CALL_IF macros
|
// the false is start of the CALL_IF macros
|
||||||
if (false) {
|
if (false) {
|
||||||
} // BMxBNxBK, group
|
} // BMxBNxBK, group
|
||||||
// 4-bit
|
// 4-bit
|
||||||
CALL_IF_2_4(4, 8, 1, 4, -1) // e.g., 16x128x128
|
CALL_IF_2_4(4, 8, 1, 4, -1) // e.g., 16x128x128
|
||||||
CALL_IF_2_4(4, 8, 1, 4, 4) // e.g., 16x128x128, 64
|
CALL_IF_2_4(4, 8, 1, 4, 4) // e.g., 16x128x128, 64
|
||||||
CALL_IF_2_4(4, 16, 1, 2, -1) // e.g., 16x256x64
|
CALL_IF_2_4(4, 16, 1, 2, -1) // e.g., 16x256x64
|
||||||
CALL_IF_2_4(4, 16, 1, 2, 4) // e.g., 16x256x64, 64
|
CALL_IF_2_4(4, 16, 1, 2, 4) // e.g., 16x256x64, 64
|
||||||
CALL_IF_2_4(4, 16, 2, 2, -1) // e.g.. 32x256x64
|
CALL_IF_2_4(4, 16, 2, 2, -1) // e.g.. 32x256x64
|
||||||
CALL_IF_2_4(4, 16, 2, 2, 4)
|
CALL_IF_2_4(4, 16, 2, 2, 4)
|
||||||
CALL_IF_2_4(4, 16, 3, 2, -1)
|
CALL_IF_2_4(4, 16, 3, 2, -1)
|
||||||
CALL_IF_2_4(4, 16, 3, 2, 4)
|
CALL_IF_2_4(4, 16, 3, 2, 4)
|
||||||
@ -973,11 +963,11 @@ void marlin_cuda_2_4(const void *A, const void *B, const void *meta, void *C,
|
|||||||
CALL_IF_2_4(4, 16, 4, 2, 4)
|
CALL_IF_2_4(4, 16, 4, 2, 4)
|
||||||
|
|
||||||
// 8-bit
|
// 8-bit
|
||||||
CALL_IF_2_4(8, 8, 1, 4, -1) // e.g., 16x128x128
|
CALL_IF_2_4(8, 8, 1, 4, -1) // e.g., 16x128x128
|
||||||
CALL_IF_2_4(8, 8, 1, 4, 4) // e.g., 16x128x128, 64
|
CALL_IF_2_4(8, 8, 1, 4, 4) // e.g., 16x128x128, 64
|
||||||
CALL_IF_2_4(8, 16, 1, 2, -1) // e.g., 16x256x64
|
CALL_IF_2_4(8, 16, 1, 2, -1) // e.g., 16x256x64
|
||||||
CALL_IF_2_4(8, 16, 1, 2, 4) // e.g., 16x256x64, 64
|
CALL_IF_2_4(8, 16, 1, 2, 4) // e.g., 16x256x64, 64
|
||||||
CALL_IF_2_4(8, 16, 2, 2, -1) // e.g.. 32x256x64
|
CALL_IF_2_4(8, 16, 2, 2, -1) // e.g.. 32x256x64
|
||||||
CALL_IF_2_4(8, 16, 2, 2, 4)
|
CALL_IF_2_4(8, 16, 2, 2, 4)
|
||||||
CALL_IF_2_4(8, 16, 3, 2, -1)
|
CALL_IF_2_4(8, 16, 3, 2, -1)
|
||||||
CALL_IF_2_4(8, 16, 3, 2, 4)
|
CALL_IF_2_4(8, 16, 3, 2, 4)
|
||||||
@ -997,12 +987,12 @@ void marlin_cuda_2_4(const void *A, const void *B, const void *meta, void *C,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace marlin_24
|
} // namespace marlin_24
|
||||||
|
|
||||||
torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
|
torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||||
torch::Tensor &b_meta,
|
torch::Tensor& b_meta,
|
||||||
torch::Tensor &b_scales,
|
torch::Tensor& b_scales,
|
||||||
torch::Tensor &workspace, int64_t num_bits,
|
torch::Tensor& workspace, int64_t num_bits,
|
||||||
int64_t size_m, int64_t size_n,
|
int64_t size_m, int64_t size_n,
|
||||||
int64_t size_k) {
|
int64_t size_k) {
|
||||||
// Verify num_bits
|
// Verify num_bits
|
||||||
@ -1037,9 +1027,9 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
|
|||||||
" is not divisible by tile_size = " + str(marlin_24::tile_size));
|
" is not divisible by tile_size = " + str(marlin_24::tile_size));
|
||||||
|
|
||||||
int actual_size_n = (b_q_weight.size(1) / marlin_24::tile_size) * pack_factor;
|
int actual_size_n = (b_q_weight.size(1) / marlin_24::tile_size) * pack_factor;
|
||||||
TORCH_CHECK(size_n == actual_size_n,
|
TORCH_CHECK(
|
||||||
"size_n = " + str(size_n) +
|
size_n == actual_size_n,
|
||||||
", actual_size_n = " + str(actual_size_n));
|
"size_n = " + str(size_n) + ", actual_size_n = " + str(actual_size_n));
|
||||||
|
|
||||||
// Verify meta
|
// Verify meta
|
||||||
TORCH_CHECK(b_meta.size(0) == size_k / 8 / 2 / 2,
|
TORCH_CHECK(b_meta.size(0) == size_k / 8 / 2 / 2,
|
||||||
@ -1081,7 +1071,7 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
|
|||||||
", is not divisible by b_scales.size(0) = " +
|
", is not divisible by b_scales.size(0) = " +
|
||||||
str(b_scales.size(0)));
|
str(b_scales.size(0)));
|
||||||
groupsize = size_k / b_scales.size(0);
|
groupsize = size_k / b_scales.size(0);
|
||||||
groupsize /= 2; // Because of 24
|
groupsize /= 2; // Because of 24
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify groupsize
|
// Verify groupsize
|
||||||
|
@ -22,27 +22,23 @@ __device__ inline unsigned int as_unsigned(int i) {
|
|||||||
// 4-bit matvec kernel (LUT-based)
|
// 4-bit matvec kernel (LUT-based)
|
||||||
__global__ void NUQ4MatMulKernel(
|
__global__ void NUQ4MatMulKernel(
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
const half2* __restrict__ vec,
|
const half2* __restrict__ vec,
|
||||||
#else
|
#else
|
||||||
const __half2* __restrict__ vec,
|
const __half2* __restrict__ vec,
|
||||||
#endif
|
#endif
|
||||||
const int* __restrict__ mat,
|
const int* __restrict__ mat,
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
half2* __restrict__ mul,
|
half2* __restrict__ mul,
|
||||||
#else
|
#else
|
||||||
float2* __restrict__ mul,
|
float2* __restrict__ mul,
|
||||||
#endif
|
#endif
|
||||||
const __half* __restrict__ lookup_table,
|
const __half* __restrict__ lookup_table, int height, int width, int batch,
|
||||||
int height,
|
int vec_height) {
|
||||||
int width,
|
|
||||||
int batch,
|
|
||||||
int vec_height
|
|
||||||
) {
|
|
||||||
|
|
||||||
const int blockwidth2 = BLOCKWIDTH / 2;
|
const int blockwidth2 = BLOCKWIDTH / 2;
|
||||||
|
|
||||||
int row = BLOCKHEIGHT4 * blockIdx.x;
|
int row = BLOCKHEIGHT4 * blockIdx.x;
|
||||||
int col = BLOCKWIDTH * blockIdx.y + threadIdx.x;
|
int col = BLOCKWIDTH * blockIdx.y + threadIdx.x;
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
__shared__ half2 blockvec[blockwidth2];
|
__shared__ half2 blockvec[blockwidth2];
|
||||||
@ -73,14 +69,16 @@ __global__ void NUQ4MatMulKernel(
|
|||||||
unsigned int tmp1;
|
unsigned int tmp1;
|
||||||
unsigned int lut_index1, lut_index2;
|
unsigned int lut_index1, lut_index2;
|
||||||
|
|
||||||
for (int b = 0; b < batch; ++b){
|
for (int b = 0; b < batch; ++b) {
|
||||||
i = width * row + col;
|
i = width * row + col;
|
||||||
res = __int2half_rd(0);
|
res = __int2half_rd(0);
|
||||||
k = 0;
|
k = 0;
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
if (threadIdx.x < blockwidth2)
|
if (threadIdx.x < blockwidth2)
|
||||||
blockvec[threadIdx.x] = vec[b * vec_height / 2 + (row / BLOCKHEIGHT4) * blockwidth2 + threadIdx.x];
|
blockvec[threadIdx.x] =
|
||||||
|
vec[b * vec_height / 2 + (row / BLOCKHEIGHT4) * blockwidth2 +
|
||||||
|
threadIdx.x];
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
while (k < blockwidth2) {
|
while (k < blockwidth2) {
|
||||||
@ -143,7 +141,8 @@ __global__ void NUQ4MatMulKernel(
|
|||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
res = __hadd(__hadd(res2.x, res2.y), res);
|
res = __hadd(__hadd(res2.x, res2.y), res);
|
||||||
#else
|
#else
|
||||||
res = __hadd(__hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)), res);
|
res = __hadd(__hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)),
|
||||||
|
res);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
i += width;
|
i += width;
|
||||||
@ -179,46 +178,38 @@ __global__ void NUQ4MatMulKernel(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace squeezellm
|
} // namespace squeezellm
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
// 4-bit matvec kernel (LUT-based)
|
// 4-bit matvec kernel (LUT-based)
|
||||||
void squeezellm_gemm(
|
void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
||||||
torch::Tensor vec,
|
torch::Tensor lookup_table) {
|
||||||
torch::Tensor mat,
|
|
||||||
torch::Tensor mul,
|
|
||||||
torch::Tensor lookup_table
|
|
||||||
) {
|
|
||||||
int height = mat.size(0);
|
int height = mat.size(0);
|
||||||
int width = mat.size(1);
|
int width = mat.size(1);
|
||||||
|
|
||||||
int batch = vec.size(0);
|
int batch = vec.size(0);
|
||||||
int vec_height = vec.size(1);
|
int vec_height = vec.size(1);
|
||||||
|
|
||||||
dim3 blocks(
|
dim3 blocks((height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
|
||||||
(height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
|
(width + BLOCKWIDTH - 1) / BLOCKWIDTH);
|
||||||
(width + BLOCKWIDTH - 1) / BLOCKWIDTH
|
|
||||||
);
|
|
||||||
dim3 threads(BLOCKWIDTH);
|
dim3 threads(BLOCKWIDTH);
|
||||||
|
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
vllm::squeezellm::NUQ4MatMulKernel<<<blocks, threads, 0, stream>>>(
|
vllm::squeezellm::NUQ4MatMulKernel<<<blocks, threads, 0, stream>>>(
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
(half2*) vec.data<at::Half>(),
|
(half2*)vec.data<at::Half>(),
|
||||||
#else
|
#else
|
||||||
(__half2*) vec.data_ptr<at::Half>(),
|
(__half2*)vec.data_ptr<at::Half>(),
|
||||||
#endif
|
#endif
|
||||||
mat.data_ptr<int>(),
|
mat.data_ptr<int>(),
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
(half2*) mul.data<at::Half>(),
|
(half2*)mul.data<at::Half>(), (__half*)lookup_table.data<at::Half>(),
|
||||||
(__half*) lookup_table.data<at::Half>(),
|
|
||||||
#else
|
#else
|
||||||
(float2*) mul.data_ptr<float>(),
|
(float2*)mul.data_ptr<float>(),
|
||||||
(__half*) lookup_table.data_ptr<at::Half>(),
|
(__half*)lookup_table.data_ptr<at::Half>(),
|
||||||
#endif
|
#endif
|
||||||
height, width, batch, vec_height
|
height, width, batch, vec_height);
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#undef BLOCKWIDTH
|
#undef BLOCKWIDTH
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
/*
|
/*
|
||||||
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/reduce_kernel_utils.cuh
|
* Adapted from
|
||||||
|
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/reduce_kernel_utils.cuh
|
||||||
* Copyright (c) 2023, The vLLM team.
|
* Copyright (c) 2023, The vLLM team.
|
||||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||||
*
|
*
|
||||||
@ -20,12 +21,12 @@
|
|||||||
#include "cuda_compat.h"
|
#include "cuda_compat.h"
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
template<typename T, int numLanes = WARP_SIZE>
|
template <typename T, int numLanes = WARP_SIZE>
|
||||||
__inline__ __device__ T warpReduceSum(T val) {
|
__inline__ __device__ T warpReduceSum(T val) {
|
||||||
static_assert(numLanes > 0 && (numLanes & (numLanes - 1)) == 0,
|
static_assert(numLanes > 0 && (numLanes & (numLanes - 1)) == 0,
|
||||||
"numLanes is not a positive power of 2!");
|
"numLanes is not a positive power of 2!");
|
||||||
static_assert(numLanes <= WARP_SIZE);
|
static_assert(numLanes <= WARP_SIZE);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mask = numLanes >> 1; mask > 0; mask >>= 1)
|
for (int mask = numLanes >> 1; mask > 0; mask >>= 1)
|
||||||
val += VLLM_SHFL_XOR_SYNC(val, mask);
|
val += VLLM_SHFL_XOR_SYNC(val, mask);
|
||||||
return val;
|
return val;
|
||||||
@ -38,22 +39,23 @@ static constexpr int _nextPow2(unsigned int num) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/* Calculate the sum of all elements in a block */
|
/* Calculate the sum of all elements in a block */
|
||||||
template<typename T, int maxBlockSize = 1024>
|
template <typename T, int maxBlockSize = 1024>
|
||||||
__inline__ __device__ T blockReduceSum(T val) {
|
__inline__ __device__ T blockReduceSum(T val) {
|
||||||
static_assert(maxBlockSize <= 1024);
|
static_assert(maxBlockSize <= 1024);
|
||||||
if constexpr (maxBlockSize > WARP_SIZE) {
|
if constexpr (maxBlockSize > WARP_SIZE) {
|
||||||
val = warpReduceSum<T>(val);
|
val = warpReduceSum<T>(val);
|
||||||
// Calculates max number of lanes that need to participate in the last warpReduce
|
// Calculates max number of lanes that need to participate in the last
|
||||||
|
// warpReduce
|
||||||
constexpr int maxActiveLanes = (maxBlockSize + WARP_SIZE - 1) / WARP_SIZE;
|
constexpr int maxActiveLanes = (maxBlockSize + WARP_SIZE - 1) / WARP_SIZE;
|
||||||
static __shared__ T shared[maxActiveLanes];
|
static __shared__ T shared[maxActiveLanes];
|
||||||
int lane = threadIdx.x % WARP_SIZE;
|
int lane = threadIdx.x % WARP_SIZE;
|
||||||
int wid = threadIdx.x / WARP_SIZE;
|
int wid = threadIdx.x / WARP_SIZE;
|
||||||
if (lane == 0)
|
if (lane == 0) shared[wid] = val;
|
||||||
shared[wid] = val;
|
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
val = (threadIdx.x < blockDim.x / float(WARP_SIZE)) ? shared[lane] : (T)(0.0f);
|
val = (threadIdx.x < blockDim.x / float(WARP_SIZE)) ? shared[lane]
|
||||||
|
: (T)(0.0f);
|
||||||
val = warpReduceSum<T, _nextPow2(maxActiveLanes)>(val);
|
val = warpReduceSum<T, _nextPow2(maxActiveLanes)>(val);
|
||||||
} else {
|
} else {
|
||||||
// A single warpReduce is equal to blockReduce
|
// A single warpReduce is equal to blockReduce
|
||||||
@ -62,4 +64,4 @@ __inline__ __device__ T blockReduceSum(T val) {
|
|||||||
return val;
|
return val;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
57
format.sh
57
format.sh
@ -26,6 +26,7 @@ RUFF_VERSION=$(ruff --version | awk '{print $2}')
|
|||||||
MYPY_VERSION=$(mypy --version | awk '{print $2}')
|
MYPY_VERSION=$(mypy --version | awk '{print $2}')
|
||||||
CODESPELL_VERSION=$(codespell --version)
|
CODESPELL_VERSION=$(codespell --version)
|
||||||
ISORT_VERSION=$(isort --vn)
|
ISORT_VERSION=$(isort --vn)
|
||||||
|
CLANGFORMAT_VERSION=$(clang-format --version | awk '{print $3}')
|
||||||
|
|
||||||
# # params: tool name, tool version, required version
|
# # params: tool name, tool version, required version
|
||||||
tool_version_check() {
|
tool_version_check() {
|
||||||
@ -40,6 +41,7 @@ tool_version_check "ruff" $RUFF_VERSION "$(grep "ruff==" requirements-dev.txt |
|
|||||||
tool_version_check "mypy" "$MYPY_VERSION" "$(grep mypy requirements-dev.txt | cut -d'=' -f3)"
|
tool_version_check "mypy" "$MYPY_VERSION" "$(grep mypy requirements-dev.txt | cut -d'=' -f3)"
|
||||||
tool_version_check "isort" "$ISORT_VERSION" "$(grep isort requirements-dev.txt | cut -d'=' -f3)"
|
tool_version_check "isort" "$ISORT_VERSION" "$(grep isort requirements-dev.txt | cut -d'=' -f3)"
|
||||||
tool_version_check "codespell" "$CODESPELL_VERSION" "$(grep codespell requirements-dev.txt | cut -d'=' -f3)"
|
tool_version_check "codespell" "$CODESPELL_VERSION" "$(grep codespell requirements-dev.txt | cut -d'=' -f3)"
|
||||||
|
tool_version_check "clang-format" "$CLANGFORMAT_VERSION" "$(grep clang-format requirements-dev.txt | cut -d'=' -f3)"
|
||||||
|
|
||||||
YAPF_FLAGS=(
|
YAPF_FLAGS=(
|
||||||
'--recursive'
|
'--recursive'
|
||||||
@ -179,7 +181,6 @@ lint_changed() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Run Ruff
|
# Run Ruff
|
||||||
echo 'vLLM ruff:'
|
|
||||||
### This flag lints individual files. --files *must* be the first command line
|
### This flag lints individual files. --files *must* be the first command line
|
||||||
### arg to use this option.
|
### arg to use this option.
|
||||||
if [[ "$1" == '--files' ]]; then
|
if [[ "$1" == '--files' ]]; then
|
||||||
@ -192,6 +193,7 @@ else
|
|||||||
# Format only the files that changed in last commit.
|
# Format only the files that changed in last commit.
|
||||||
lint_changed
|
lint_changed
|
||||||
fi
|
fi
|
||||||
|
echo 'vLLM ruff: Done'
|
||||||
|
|
||||||
# check spelling of specified files
|
# check spelling of specified files
|
||||||
isort_check() {
|
isort_check() {
|
||||||
@ -233,6 +235,59 @@ else
|
|||||||
fi
|
fi
|
||||||
echo 'vLLM isort: Done'
|
echo 'vLLM isort: Done'
|
||||||
|
|
||||||
|
# Clang-format section
|
||||||
|
# Exclude some files for formatting because they are vendored
|
||||||
|
# NOTE: Keep up to date with .github/workflows/clang-format.yml
|
||||||
|
CLANG_FORMAT_EXCLUDES=(
|
||||||
|
'csrc/moe/topk_softmax_kernels.cu'
|
||||||
|
'csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu'
|
||||||
|
'csrc/punica/bgmv/bgmv_config.h'
|
||||||
|
'csrc/punica/bgmv/bgmv_impl.cuh'
|
||||||
|
'csrc/punica/bgmv/vec_dtypes.cuh'
|
||||||
|
'csrc/punica/punica_ops.cu'
|
||||||
|
'csrc/punica/type_convert.h'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Format specified files with clang-format
|
||||||
|
clang_format() {
|
||||||
|
clang-format -i "$@"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Format files that differ from main branch with clang-format.
|
||||||
|
clang_format_changed() {
|
||||||
|
# The `if` guard ensures that the list of filenames is not empty, which
|
||||||
|
# could cause clang-format to receive 0 positional arguments, making it hang
|
||||||
|
# waiting for STDIN.
|
||||||
|
#
|
||||||
|
# `diff-filter=ACM` and $MERGEBASE is to ensure we only format files that
|
||||||
|
# exist on both branches.
|
||||||
|
MERGEBASE="$(git merge-base origin/main HEAD)"
|
||||||
|
|
||||||
|
# Get the list of changed files, excluding the specified ones
|
||||||
|
changed_files=$(git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.h' '*.cpp' '*.cu' '*.cuh' | grep -vFf <(printf "%s\n" "${CLANG_FORMAT_EXCLUDES[@]}"))
|
||||||
|
if [ -n "$changed_files" ]; then
|
||||||
|
echo "$changed_files" | xargs -P 5 clang-format -i
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
# Format all files with clang-format
|
||||||
|
clang_format_all() {
|
||||||
|
find csrc/ \( -name '*.h' -o -name '*.cpp' -o -name '*.cu' -o -name '*.cuh' \) -print \
|
||||||
|
| grep -vFf <(printf "%s\n" "${CLANG_FORMAT_EXCLUDES[@]}") \
|
||||||
|
| xargs clang-format -i
|
||||||
|
}
|
||||||
|
|
||||||
|
# Run clang-format
|
||||||
|
if [[ "$1" == '--files' ]]; then
|
||||||
|
clang_format "${@:2}"
|
||||||
|
elif [[ "$1" == '--all' ]]; then
|
||||||
|
clang_format_all
|
||||||
|
else
|
||||||
|
clang_format_changed
|
||||||
|
fi
|
||||||
|
echo 'vLLM clang-format: Done'
|
||||||
|
|
||||||
|
|
||||||
if ! git diff --quiet &>/dev/null; then
|
if ! git diff --quiet &>/dev/null; then
|
||||||
echo 'Reformatted files. Please review and stage the changes.'
|
echo 'Reformatted files. Please review and stage the changes.'
|
||||||
echo 'Changes not staged for commit:'
|
echo 'Changes not staged for commit:'
|
||||||
|
@ -5,6 +5,7 @@ tomli==2.0.1
|
|||||||
ruff==0.1.5
|
ruff==0.1.5
|
||||||
codespell==2.2.6
|
codespell==2.2.6
|
||||||
isort==5.13.2
|
isort==5.13.2
|
||||||
|
clang-format==18.1.5
|
||||||
|
|
||||||
# type checking
|
# type checking
|
||||||
mypy==1.9.0
|
mypy==1.9.0
|
||||||
|
Loading…
x
Reference in New Issue
Block a user