[AMD] Add support for GGUF quantization on ROCm (#10254)

This commit is contained in:
kliuae 2024-11-23 13:14:49 +08:00 committed by GitHub
parent 02a43f82a9
commit 7c25fe45a6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 234 additions and 211 deletions

View File

@ -85,7 +85,6 @@ if [[ $commands == *" kernels "* ]]; then
--ignore=kernels/test_encoder_decoder_attn.py \
--ignore=kernels/test_flash_attn.py \
--ignore=kernels/test_flashinfer.py \
--ignore=kernels/test_gguf.py \
--ignore=kernels/test_int8_quant.py \
--ignore=kernels/test_machete_gemm.py \
--ignore=kernels/test_mamba_ssm.py \

View File

@ -196,6 +196,7 @@ set(VLLM_EXT_SRC
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
"csrc/quantization/fp8/common.cu"
"csrc/quantization/gguf/gguf_kernel.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/prepare_inputs/advance_step.cu"
"csrc/torch_bindings.cpp")
@ -237,7 +238,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/mamba/causal_conv1d/causal_conv1d.cu"
"csrc/quantization/aqlm/gemm_kernels.cu"
"csrc/quantization/awq/gemm_kernels.cu"
"csrc/quantization/gguf/gguf_kernel.cu"
"csrc/custom_all_reduce.cu"
"csrc/permute_cols.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu")

View File

@ -128,6 +128,7 @@ torch::Tensor awq_dequantize(torch::Tensor _kernel,
int64_t thx, int64_t thy);
torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm);
#endif
torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m,
int64_t n);
@ -138,6 +139,7 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X,
torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type,
int64_t row);
#ifndef USE_ROCM
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,

View File

@ -1,7 +1,7 @@
// copied from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-common.h
#define QK_K 256
#define K_QUANTS_PER_ITERATION 2
#define WARP_SIZE 32
#define WARP_SIZE_GGUF 32
#define K_SCALE_SIZE 12
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
#define CUDA_QUANTIZE_BLOCK_SIZE 256
@ -1112,4 +1112,19 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
#endif
return c;
}
static __device__ __forceinline__ uint32_t __vcmpeq4(const uint32_t a, const uint32_t b) {
uint32_t neq = a^b;
return !(neq & 0xff000000) * 0xff000000 |
!(neq & 0x00ff0000) * 0x00ff0000 |
!(neq & 0x0000ff00) * 0x0000ff00 |
!(neq & 0x000000ff) * 0x000000ff;
}
static __device__ __forceinline__ uint32_t __vsub4(const uint32_t a, const uint32_t b) {
return (static_cast<uint8_t>(((a & 0xff000000) >> 24) - ((b & 0xff000000) >> 24)) << 24) +
(static_cast<uint8_t>(((a & 0x00ff0000) >> 16) - ((b & 0x00ff0000) >> 16)) << 16) +
(static_cast<uint8_t>(((a & 0x0000ff00) >> 8) - ((b & 0x0000ff00) >> 8)) << 8) +
(static_cast<uint8_t>(((a & 0x000000ff) >> 0) - ((b & 0x000000ff) >> 0)) << 0);
}
#endif // defined(USE_ROCM)

View File

@ -4,6 +4,8 @@
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include "cuda_compat.h"
#include "ggml-common.h"
#include "vecdotq.cuh"
#include "dequantize.cuh"
@ -32,8 +34,8 @@ static __global__ void quantize_q8_1(const half* __restrict__ x,
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, mask, 32));
sum += __shfl_xor_sync(0xffffffff, sum, mask, 32);
amax = fmaxf(amax, VLLM_SHFL_XOR_SYNC_WIDTH(amax, mask, 32));
sum += VLLM_SHFL_XOR_SYNC_WIDTH(sum, mask, 32);
}
const float d = amax / 127;

View File

@ -10,7 +10,7 @@ static __device__ __forceinline__ void mul_mat_q(
const int blocks_per_row_x = ncols_x / qk;
const int blocks_per_col_y = nrows_y / QK8_1;
const int blocks_per_warp = WARP_SIZE / qi;
const int blocks_per_warp = WARP_SIZE_GGUF / qi;
const int & ncols_dst = ncols_y;
@ -27,10 +27,10 @@ static __device__ __forceinline__ void mul_mat_q(
allocate_tiles(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc);
__shared__ int tile_y_qs[mmq_x * WARP_SIZE];
__shared__ half2 tile_y_ds[mmq_x * WARP_SIZE/QI8_1];
__shared__ int tile_y_qs[mmq_x * WARP_SIZE_GGUF];
__shared__ half2 tile_y_ds[mmq_x * WARP_SIZE_GGUF/QI8_1];
float sum[mmq_y/WARP_SIZE][mmq_x/nwarps] = {{0.0f}};
float sum[mmq_y/WARP_SIZE_GGUF][mmq_x/nwarps] = {{0.0f}};
for (int ib0 = 0; ib0 < blocks_per_row_x; ib0 += blocks_per_warp) {
@ -39,26 +39,26 @@ static __device__ __forceinline__ void mul_mat_q(
#pragma unroll
for (int ir = 0; ir < qr; ++ir) {
const int kqs = ir*WARP_SIZE + threadIdx.x;
const int kqs = ir*WARP_SIZE_GGUF + threadIdx.x;
const int kbxd = kqs / QI8_1;
#pragma unroll
for (int i = 0; i < mmq_x; i += nwarps) {
const int col_y_eff = min(col_y_0 + threadIdx.y + i, ncols_y-1); // to prevent out-of-bounds memory accesses
const block_q8_1 * by0 = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + kbxd];
const int index_y = (threadIdx.y + i) * WARP_SIZE + kqs % WARP_SIZE;
const int index_y = (threadIdx.y + i) * WARP_SIZE_GGUF + kqs % WARP_SIZE_GGUF;
tile_y_qs[index_y] = get_int_from_int8_aligned(by0->qs, threadIdx.x % QI8_1);
}
#pragma unroll
for (int ids0 = 0; ids0 < mmq_x; ids0 += nwarps * QI8_1) {
const int ids = (ids0 + threadIdx.y * QI8_1 + threadIdx.x / (WARP_SIZE/QI8_1)) % mmq_x;
const int kby = threadIdx.x % (WARP_SIZE/QI8_1);
const int ids = (ids0 + threadIdx.y * QI8_1 + threadIdx.x / (WARP_SIZE_GGUF/QI8_1)) % mmq_x;
const int kby = threadIdx.x % (WARP_SIZE_GGUF/QI8_1);
const int col_y_eff = min(col_y_0 + ids, ncols_y-1);
// if the sum is not needed it's faster to transform the scale to f32 ahead of time
const half2 * dsi_src = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + ir*(WARP_SIZE/QI8_1) + kby].ds;
half2 * dsi_dst = &tile_y_ds[ids * (WARP_SIZE/QI8_1) + kby];
const half2 * dsi_src = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + ir*(WARP_SIZE_GGUF/QI8_1) + kby].ds;
half2 * dsi_dst = &tile_y_ds[ids * (WARP_SIZE_GGUF/QI8_1) + kby];
if (need_sum) {
*dsi_dst = *dsi_src;
} else {
@ -70,12 +70,12 @@ static __device__ __forceinline__ void mul_mat_q(
__syncthreads();
// #pragma unroll // unrolling this loop causes too much register pressure
for (int k = ir*WARP_SIZE/qr; k < (ir+1)*WARP_SIZE/qr; k += vdr) {
for (int k = ir*WARP_SIZE_GGUF/qr; k < (ir+1)*WARP_SIZE_GGUF/qr; k += vdr) {
#pragma unroll
for (int j = 0; j < mmq_x; j += nwarps) {
#pragma unroll
for (int i = 0; i < mmq_y; i += WARP_SIZE) {
sum[i/WARP_SIZE][j/nwarps] += vec_dot(
for (int i = 0; i < mmq_y; i += WARP_SIZE_GGUF) {
sum[i/WARP_SIZE_GGUF][j/nwarps] += vec_dot(
tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, tile_y_qs, tile_y_ds,
threadIdx.x + i, threadIdx.y + j, k);
}
@ -93,12 +93,12 @@ static __device__ __forceinline__ void mul_mat_q(
}
#pragma unroll
for (int i = 0; i < mmq_y; i += WARP_SIZE) {
for (int i = 0; i < mmq_y; i += WARP_SIZE_GGUF) {
const int row_dst = row_dst_0 + threadIdx.x + i;
if (row_dst >= nrows_dst) {
continue;
}
dst[col_dst*nrows_dst + row_dst] = __float2half(sum[i/WARP_SIZE][j/nwarps]);
dst[col_dst*nrows_dst + row_dst] = __float2half(sum[i/WARP_SIZE_GGUF][j/nwarps]);
}
}
}
@ -115,7 +115,7 @@ static __device__ __forceinline__ void mul_mat_q(
template <bool need_check> static __global__ void
#if defined(USE_ROCM)
__launch_bounds__(WARP_SIZE*NWARPS_Q4_0, 2)
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q4_0, 2)
#endif
mul_mat_q4_0(
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
@ -140,7 +140,7 @@ static void ggml_mul_mat_q4_0_q8_1_cuda(
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
const dim3 block_nums(block_num_x, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, nwarps, 1);
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);
if (nrows_x % mmq_y == 0) {
const bool need_check = false;
@ -165,7 +165,7 @@ static void ggml_mul_mat_q4_0_q8_1_cuda(
template <bool need_check> static __global__ void
#if defined(USE_ROCM)
__launch_bounds__(WARP_SIZE*NWARPS_Q4_1, 2)
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q4_1, 2)
#endif
mul_mat_q4_1(
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
@ -190,7 +190,7 @@ static void ggml_mul_mat_q4_1_q8_1_cuda(
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
const dim3 block_nums(block_num_x, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, nwarps, 1);
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);
if (nrows_x % mmq_y == 0) {
const bool need_check = false;
@ -215,7 +215,7 @@ static void ggml_mul_mat_q4_1_q8_1_cuda(
template <bool need_check> static __global__ void
#if defined(USE_ROCM)
__launch_bounds__(WARP_SIZE*NWARPS_Q5_0, 2)
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q5_0, 2)
#endif
mul_mat_q5_0(
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
@ -240,7 +240,7 @@ static void ggml_mul_mat_q5_0_q8_1_cuda(
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
const dim3 block_nums(block_num_x, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, nwarps, 1);
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);
if (nrows_x % mmq_y == 0) {
const bool need_check = false;
@ -265,7 +265,7 @@ static void ggml_mul_mat_q5_0_q8_1_cuda(
template <bool need_check> static __global__ void
#if defined(USE_ROCM)
__launch_bounds__(WARP_SIZE*NWARPS_Q5_1, 2)
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q5_1, 2)
#endif
mul_mat_q5_1(
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
@ -289,7 +289,7 @@ static void ggml_mul_mat_q5_1_q8_1_cuda(
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
const dim3 block_nums(block_num_x, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, nwarps, 1);
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);
if (nrows_x % mmq_y == 0) {
const bool need_check = false;
@ -314,7 +314,7 @@ static void ggml_mul_mat_q5_1_q8_1_cuda(
template <bool need_check> static __global__ void
#if defined(USE_ROCM)
__launch_bounds__(WARP_SIZE*NWARPS_Q8_0, 2)
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q8_0, 2)
#endif
mul_mat_q8_0(
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
@ -338,7 +338,7 @@ static void ggml_mul_mat_q8_0_q8_1_cuda(
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
const dim3 block_nums(block_num_x, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, nwarps, 1);
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);
if (nrows_x % mmq_y == 0) {
const bool need_check = false;
@ -363,7 +363,7 @@ static void ggml_mul_mat_q8_0_q8_1_cuda(
template <bool need_check> static __global__ void
#if defined(USE_ROCM)
__launch_bounds__(WARP_SIZE*NWARPS_Q2_K, 2)
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q2_K, 2)
#endif
mul_mat_q2_K(
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
@ -387,7 +387,7 @@ static void ggml_mul_mat_q2_K_q8_1_cuda(
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
const dim3 block_nums(block_num_x, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, nwarps, 1);
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);
if (nrows_x % mmq_y == 0) {
const bool need_check = false;
@ -412,7 +412,7 @@ static void ggml_mul_mat_q2_K_q8_1_cuda(
template <bool need_check> static __global__ void
#if defined(USE_ROCM)
__launch_bounds__(WARP_SIZE*NWARPS_Q3_K, 2)
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q3_K, 2)
#endif
mul_mat_q3_K(
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
@ -438,7 +438,7 @@ static void ggml_mul_mat_q3_K_q8_1_cuda(
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
const dim3 block_nums(block_num_x, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, nwarps, 1);
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);
if (nrows_x % mmq_y == 0) {
const bool need_check = false;
@ -463,7 +463,7 @@ static void ggml_mul_mat_q3_K_q8_1_cuda(
template <bool need_check> static __global__ void
#if defined(USE_ROCM)
__launch_bounds__(WARP_SIZE*NWARPS_Q4_K, 2)
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q4_K, 2)
#endif
mul_mat_q4_K(
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
@ -487,7 +487,7 @@ static void ggml_mul_mat_q4_K_q8_1_cuda(
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
const dim3 block_nums(block_num_x, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, nwarps, 1);
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);
if (nrows_x % mmq_y == 0) {
const bool need_check = false;
@ -512,7 +512,7 @@ static void ggml_mul_mat_q4_K_q8_1_cuda(
template <bool need_check> static __global__ void
#if defined(USE_ROCM)
__launch_bounds__(WARP_SIZE*NWARPS_Q5_K, 2)
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q5_K, 2)
#endif
mul_mat_q5_K(
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
@ -537,7 +537,7 @@ static void ggml_mul_mat_q5_K_q8_1_cuda(
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
const dim3 block_nums(block_num_x, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, nwarps, 1);
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);
if (nrows_x % mmq_y == 0) {
const bool need_check = false;
@ -562,7 +562,7 @@ static void ggml_mul_mat_q5_K_q8_1_cuda(
template <bool need_check> static __global__ void
#if defined(USE_ROCM)
__launch_bounds__(WARP_SIZE*NWARPS_Q6_K, 2)
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q6_K, 2)
#endif
mul_mat_q6_K(
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
@ -586,7 +586,7 @@ static void ggml_mul_mat_q6_K_q8_1_cuda(
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
const dim3 block_nums(block_num_x, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, nwarps, 1);
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);
if (nrows_x % mmq_y == 0) {
const bool need_check = false;

View File

@ -28,8 +28,8 @@ static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void *
// sum up partial sums and write back result
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
for (int mask = WARP_SIZE/2; mask > 0; mask >>= 1) {
tmp += VLLM_SHFL_XOR_SYNC(tmp, mask);
}
if (threadIdx.x == 0) {

View File

@ -43,7 +43,7 @@ static __device__ __forceinline__ int get_int_from_uint8_aligned(const uint8_t *
template <int vdr> static __device__ __forceinline__ float vec_dot_q4_0_q8_1_impl(
const int * v, const int * u, const float & d4, const half2 & ds8) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
int sumi = 0;
#pragma unroll
@ -68,7 +68,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q4_0_q8_1_imp
template <int vdr> static __device__ __forceinline__ float vec_dot_q4_1_q8_1_impl(
const int * v, const int * u, const half2 & dm4, const half2 & ds8) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
int sumi = 0;
#pragma unroll
@ -95,7 +95,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q4_1_q8_1_imp
template <int vdr> static __device__ __forceinline__ float vec_dot_q5_0_q8_1_impl(
const int * vl, const int * vh, const int * u, const float & d5, const half2 & ds8) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
int sumi = 0;
#pragma unroll
@ -128,7 +128,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q5_0_q8_1_imp
template <int vdr> static __device__ __forceinline__ float vec_dot_q5_1_q8_1_impl(
const int * vl, const int * vh, const int * u, const half2 & dm5, const half2 & ds8) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
int sumi = 0;
#pragma unroll
@ -162,7 +162,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q5_1_q8_1_imp
template <int vdr> static __device__ __forceinline__ float vec_dot_q8_0_q8_1_impl(
const int * v, const int * u, const float & d8_0, const float & d8_1) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
int sumi = 0;
#pragma unroll
@ -176,7 +176,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_0_q8_1_imp
template <int vdr> static __device__ __forceinline__ float vec_dot_q8_1_q8_1_impl(
const int * v, const int * u, const half2 & dm8, const half2 & ds8) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
int sumi = 0;
@ -202,7 +202,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_1_q8_1_imp
static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq(
const int & v, const int * __restrict__ u, const uint8_t * __restrict__ scales,
const half2 & dm2, const float * __restrict__ d8) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
float sumf_d = 0.0f;
float sumf_m = 0.0f;
@ -230,7 +230,7 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq(
static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq(
const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ scales,
const half2 & dm2, const float & d8) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
int sumi_d = 0;
int sumi_m = 0;
@ -267,7 +267,7 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq(
static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmvq(
const int & vl, const int & vh, const int * __restrict__ u, const uint8_t * __restrict__ scales,
const int & scale_offset, const float & d3, const float * __restrict__ d8) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
float sumf = 0.0f;
@ -301,7 +301,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmvq(
static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq(
const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ scales,
const float & d3, const float & d8) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
int sumi = 0;
#pragma unroll
@ -326,7 +326,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq(
static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq(
const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
const uint8_t * __restrict__ m, const half2 & dm4, const float * __restrict__ d8) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
float sumf_d = 0.0f;
float sumf_m = 0.0f;
@ -351,7 +351,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq(
static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq(
const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
float sumf_d = 0.0f;
float sumf_m = 0.0f;
@ -382,7 +382,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq(
static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_vmmq(
const int * __restrict__ vl, const int * __restrict__ vh, const int * __restrict__ u, const uint8_t * __restrict__ sc,
const uint8_t * __restrict__ m, const half2 & dm5, const float * __restrict__ d8) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
float sumf_d = 0.0f;
float sumf_m = 0.0f;
@ -413,7 +413,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_vmmq(
static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq(
const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
float sumf_d = 0.0f;
float sumf_m = 0.0f;
@ -445,7 +445,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq(
static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmvq(
const int & vl, const int & vh, const int * __restrict__ u, const int8_t * __restrict__ scales,
const float & d, const float * __restrict__ d8) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
float sumf = 0.0f;
#pragma unroll
@ -465,7 +465,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmvq(
static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq(
const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ sc,
const float & d6, const float * __restrict__ d8) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
float sumf_d = 0.0f;
#pragma unroll
@ -507,8 +507,8 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1(
}
template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
__shared__ int tile_x_qs[mmq_y * (WARP_SIZE) + mmq_y];
__shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI4_0) + mmq_y/QI4_0];
__shared__ int tile_x_qs[mmq_y * (WARP_SIZE_GGUF) + mmq_y];
__shared__ float tile_x_d[mmq_y * (WARP_SIZE_GGUF/QI4_0) + mmq_y/QI4_0];
*x_ql = tile_x_qs;
*x_dm = (half2 *) tile_x_d;
}
@ -529,11 +529,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
i = min(i, i_max);
}
const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbx;
x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx);
// x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbx] = bxi->d;
x_ql[i * (WARP_SIZE_GGUF + 1) + k] = get_int_from_uint8(bxi->qs, kqsx);
// x_dmf[i * (WARP_SIZE_GGUF/QI4_0) + i / QI4_0 + kbx] = bxi->d;
}
const int blocks_per_tile_x_row = WARP_SIZE / QI4_0;
const int blocks_per_tile_x_row = WARP_SIZE_GGUF / QI4_0;
const int kbxd = k % blocks_per_tile_x_row;
#pragma unroll
@ -543,7 +543,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
i = min(i, i_max);
}
const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbxd;
x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbxd] = __half2float(bxi->d);
x_dmf[i * (WARP_SIZE_GGUF/QI4_0) + i / QI4_0 + kbxd] = __half2float(bxi->d);
}
}
@ -559,13 +559,13 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1_mul_mat(
#pragma unroll
for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) {
u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE];
u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_0) % WARP_SIZE];
u[2*l+0] = y_qs[j * WARP_SIZE_GGUF + (kyqs + l) % WARP_SIZE_GGUF];
u[2*l+1] = y_qs[j * WARP_SIZE_GGUF + (kyqs + l + QI4_0) % WARP_SIZE_GGUF];
}
return vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
(&x_ql[i * (WARP_SIZE + 1) + k], u, x_dmf[i * (WARP_SIZE/QI4_0) + i/QI4_0 + k/QI4_0],
y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
(&x_ql[i * (WARP_SIZE_GGUF + 1) + k], u, x_dmf[i * (WARP_SIZE_GGUF/QI4_0) + i/QI4_0 + k/QI4_0],
y_ds[j * (WARP_SIZE_GGUF/QI8_1) + (2*k/QI8_1) % (WARP_SIZE_GGUF/QI8_1)]);
}
static __device__ __forceinline__ float vec_dot_q4_1_q8_1(
@ -587,8 +587,8 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1(
}
template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
__shared__ int tile_x_qs[mmq_y * (WARP_SIZE) + + mmq_y];
__shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI4_1) + mmq_y/QI4_1];
__shared__ int tile_x_qs[mmq_y * (WARP_SIZE_GGUF) + + mmq_y];
__shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE_GGUF/QI4_1) + mmq_y/QI4_1];
*x_ql = tile_x_qs;
*x_dm = tile_x_dm;
}
@ -608,10 +608,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
i = min(i, i_max);
}
const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbx;
x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
x_ql[i * (WARP_SIZE_GGUF + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
}
const int blocks_per_tile_x_row = WARP_SIZE / QI4_1;
const int blocks_per_tile_x_row = WARP_SIZE_GGUF / QI4_1;
const int kbxd = k % blocks_per_tile_x_row;
#pragma unroll
@ -621,7 +621,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
i = min(i, i_max);
}
const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbxd;
x_dm[i * (WARP_SIZE/QI4_1) + i / QI4_1 + kbxd] = bxi->dm;
x_dm[i * (WARP_SIZE_GGUF/QI4_1) + i / QI4_1 + kbxd] = bxi->dm;
}
}
@ -634,13 +634,13 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1_mul_mat(
#pragma unroll
for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) {
u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE];
u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_1) % WARP_SIZE];
u[2*l+0] = y_qs[j * WARP_SIZE_GGUF + (kyqs + l) % WARP_SIZE_GGUF];
u[2*l+1] = y_qs[j * WARP_SIZE_GGUF + (kyqs + l + QI4_1) % WARP_SIZE_GGUF];
}
return vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
(&x_ql[i * (WARP_SIZE + 1) + k], u, x_dm[i * (WARP_SIZE/QI4_1) + i/QI4_1 + k/QI4_1],
y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
(&x_ql[i * (WARP_SIZE_GGUF + 1) + k], u, x_dm[i * (WARP_SIZE_GGUF/QI4_1) + i/QI4_1 + k/QI4_1],
y_ds[j * (WARP_SIZE_GGUF/QI8_1) + (2*k/QI8_1) % (WARP_SIZE_GGUF/QI8_1)]);
}
static __device__ __forceinline__ float vec_dot_q5_0_q8_1(
@ -664,8 +664,8 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1(
}
template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q5_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
__shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y];
__shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI5_0) + mmq_y/QI5_0];
__shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE_GGUF) + mmq_y];
__shared__ float tile_x_d[mmq_y * (WARP_SIZE_GGUF/QI5_0) + mmq_y/QI5_0];
*x_ql = tile_x_ql;
*x_dm = (half2 *) tile_x_d;
@ -697,7 +697,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
qs0 |= (qh << 25) & 0x10000000; // 3 -> 28
qs0 = __vsubss4(qs0, 0x10101010); // subtract 16
x_ql[i * (2*WARP_SIZE + 1) + 2*k+0] = qs0;
x_ql[i * (2*WARP_SIZE_GGUF + 1) + 2*k+0] = qs0;
int qs1 = (ql >> 4) & 0x0F0F0F0F;
qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4
@ -706,10 +706,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
qs1 = __vsubss4(qs1, 0x10101010); // subtract 16
x_ql[i * (2*WARP_SIZE + 1) + 2*k+1] = qs1;
x_ql[i * (2*WARP_SIZE_GGUF + 1) + 2*k+1] = qs1;
}
const int blocks_per_tile_x_row = WARP_SIZE / QI5_0;
const int blocks_per_tile_x_row = WARP_SIZE_GGUF / QI5_0;
const int kbxd = k % blocks_per_tile_x_row;
float * x_dmf = (float *) x_dm;
@ -722,7 +722,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
}
const block_q5_0 * bxi = bx0 + i*blocks_per_row + kbxd;
x_dmf[i * (WARP_SIZE/QI5_0) + i / QI5_0 + kbxd] = __half2float(bxi->d);
x_dmf[i * (WARP_SIZE_GGUF/QI5_0) + i / QI5_0 + kbxd] = __half2float(bxi->d);
}
}
@ -730,7 +730,7 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1_mul_mat(
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
const int index_bx = i * (WARP_SIZE/QI5_0) + i/QI5_0 + k/QI5_0;
const int index_bx = i * (WARP_SIZE_GGUF/QI5_0) + i/QI5_0 + k/QI5_0;
const float * x_dmf = (const float *) x_dm;
const float * y_df = (const float *) y_ds;
@ -738,12 +738,12 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1_mul_mat(
#pragma unroll
for (int l = 0; l < VDR_Q5_0_Q8_1_MMQ; ++l) {
u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE];
u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_0) % WARP_SIZE];
u[2*l+0] = y_qs[j * WARP_SIZE_GGUF + (kyqs + l) % WARP_SIZE_GGUF];
u[2*l+1] = y_qs[j * WARP_SIZE_GGUF + (kyqs + l + QI5_0) % WARP_SIZE_GGUF];
}
return vec_dot_q8_0_q8_1_impl<QR5_0*VDR_Q5_0_Q8_1_MMQ>
(&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dmf[index_bx], y_df[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
(&x_ql[i * (2*WARP_SIZE_GGUF + 1) + 2 * k], u, x_dmf[index_bx], y_df[j * (WARP_SIZE_GGUF/QI8_1) + (2*k/QI8_1) % (WARP_SIZE_GGUF/QI8_1)]);
}
static __device__ __forceinline__ float vec_dot_q5_1_q8_1(
@ -767,8 +767,8 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1(
}
template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q5_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
__shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y];
__shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI5_1) + mmq_y/QI5_1];
__shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE_GGUF) + mmq_y];
__shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE_GGUF/QI5_1) + mmq_y/QI5_1];
*x_ql = tile_x_ql;
*x_dm = tile_x_dm;
@ -801,7 +801,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
qs0 |= (qh << 18) & 0x00100000; // 2 -> 20
qs0 |= (qh << 25) & 0x10000000; // 3 -> 28
x_ql[i * (2*WARP_SIZE + 1) + 2*k+0] = qs0;
x_ql[i * (2*WARP_SIZE_GGUF + 1) + 2*k+0] = qs0;
int qs1 = (ql >> 4) & 0x0F0F0F0F;
qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4
@ -809,10 +809,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
x_ql[i * (2*WARP_SIZE + 1) + 2*k+1] = qs1;
x_ql[i * (2*WARP_SIZE_GGUF + 1) + 2*k+1] = qs1;
}
const int blocks_per_tile_x_row = WARP_SIZE / QI5_1;
const int blocks_per_tile_x_row = WARP_SIZE_GGUF / QI5_1;
const int kbxd = k % blocks_per_tile_x_row;
#pragma unroll
@ -825,7 +825,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbxd;
x_dm[i * (WARP_SIZE/QI5_1) + i / QI5_1 + kbxd] = bxi->dm;
x_dm[i * (WARP_SIZE_GGUF/QI5_1) + i / QI5_1 + kbxd] = bxi->dm;
}
}
@ -833,18 +833,18 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1_mul_mat(
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
const int index_bx = i * (WARP_SIZE/QI5_1) + + i/QI5_1 + k/QI5_1;
const int index_bx = i * (WARP_SIZE_GGUF/QI5_1) + + i/QI5_1 + k/QI5_1;
int u[2*VDR_Q5_1_Q8_1_MMQ];
#pragma unroll
for (int l = 0; l < VDR_Q5_1_Q8_1_MMQ; ++l) {
u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE];
u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_1) % WARP_SIZE];
u[2*l+0] = y_qs[j * WARP_SIZE_GGUF + (kyqs + l) % WARP_SIZE_GGUF];
u[2*l+1] = y_qs[j * WARP_SIZE_GGUF + (kyqs + l + QI5_1) % WARP_SIZE_GGUF];
}
return vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
(&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dm[index_bx], y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
(&x_ql[i * (2*WARP_SIZE_GGUF + 1) + 2 * k], u, x_dm[index_bx], y_ds[j * (WARP_SIZE_GGUF/QI8_1) + (2*k/QI8_1) % (WARP_SIZE_GGUF/QI8_1)]);
}
static __device__ __forceinline__ float vec_dot_q8_0_q8_1(
@ -865,8 +865,8 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1(
}
template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q8_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
__shared__ int tile_x_qs[mmq_y * (WARP_SIZE) + mmq_y];
__shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI8_0) + mmq_y/QI8_0];
__shared__ int tile_x_qs[mmq_y * (WARP_SIZE_GGUF) + mmq_y];
__shared__ float tile_x_d[mmq_y * (WARP_SIZE_GGUF/QI8_0) + mmq_y/QI8_0];
*x_ql = tile_x_qs;
*x_dm = (half2 *) tile_x_d;
@ -889,10 +889,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
i = min(i, i_max);
}
const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbx;
x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_int8(bxi->qs, kqsx);
x_ql[i * (WARP_SIZE_GGUF + 1) + k] = get_int_from_int8(bxi->qs, kqsx);
}
const int blocks_per_tile_x_row = WARP_SIZE / QI8_0;
const int blocks_per_tile_x_row = WARP_SIZE_GGUF / QI8_0;
const int kbxd = k % blocks_per_tile_x_row;
#pragma unroll
@ -903,7 +903,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
i = min(i, i_max);
}
const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbxd;
x_dmf[i * (WARP_SIZE/QI8_0) + i / QI8_0 + kbxd] = __half2float(bxi->d);
x_dmf[i * (WARP_SIZE_GGUF/QI8_0) + i / QI8_0 + kbxd] = __half2float(bxi->d);
}
}
@ -914,8 +914,8 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1_mul_mat(
const float * y_df = (const float *) y_ds;
return vec_dot_q8_0_q8_1_impl<VDR_Q8_0_Q8_1_MMQ>
(&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[j * WARP_SIZE + k], x_dmf[i * (WARP_SIZE/QI8_0) + i/QI8_0 + k/QI8_0],
y_df[j * (WARP_SIZE/QI8_1) + k/QI8_1]);
(&x_ql[i * (WARP_SIZE_GGUF + 1) + k], &y_qs[j * WARP_SIZE_GGUF + k], x_dmf[i * (WARP_SIZE_GGUF/QI8_0) + i/QI8_0 + k/QI8_0],
y_df[j * (WARP_SIZE_GGUF/QI8_1) + k/QI8_1]);
}
static __device__ __forceinline__ float vec_dot_q2_K_q8_1(
@ -942,9 +942,9 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1(
}
template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q2_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
__shared__ int tile_x_ql[mmq_y * (WARP_SIZE) + mmq_y];
__shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI2_K) + mmq_y/QI2_K];
__shared__ int tile_x_sc[mmq_y * (WARP_SIZE/4) + mmq_y/4];
__shared__ int tile_x_ql[mmq_y * (WARP_SIZE_GGUF) + mmq_y];
__shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE_GGUF/QI2_K) + mmq_y/QI2_K];
__shared__ int tile_x_sc[mmq_y * (WARP_SIZE_GGUF/4) + mmq_y/4];
*x_ql = tile_x_ql;
*x_dm = tile_x_dm;
@ -967,10 +967,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
i = min(i, i_max);
}
const block_q2_K * bxi = bx0 + i*blocks_per_row + kbx;
x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
x_ql[i * (WARP_SIZE_GGUF + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
}
const int blocks_per_tile_x_row = WARP_SIZE / QI2_K;
const int blocks_per_tile_x_row = WARP_SIZE_GGUF / QI2_K;
const int kbxd = k % blocks_per_tile_x_row;
#pragma unroll
@ -981,18 +981,18 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
i = min(i, i_max);
}
const block_q2_K * bxi = bx0 + i*blocks_per_row + kbxd;
x_dm[i * (WARP_SIZE/QI2_K) + i / QI2_K + kbxd] = bxi->dm;
x_dm[i * (WARP_SIZE_GGUF/QI2_K) + i / QI2_K + kbxd] = bxi->dm;
}
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
int i = i0 + i_offset * 4 + k / (WARP_SIZE/4);
int i = i0 + i_offset * 4 + k / (WARP_SIZE_GGUF/4);
if (need_check) {
i = min(i, i_max);
}
const block_q2_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI2_K/4);
x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = get_int_from_uint8_aligned(bxi->scales, k % (QI2_K/4));
const block_q2_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE_GGUF/4)) / (QI2_K/4);
x_sc[i * (WARP_SIZE_GGUF/4) + i / 4 + k % (WARP_SIZE_GGUF/4)] = get_int_from_uint8_aligned(bxi->scales, k % (QI2_K/4));
}
}
@ -1005,7 +1005,7 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_mul_mat(
int v[QR2_K*VDR_Q2_K_Q8_1_MMQ];
const int kqsx = i * (WARP_SIZE + 1) + kbx*QI2_K + (QI2_K/2) * (ky/(2*QI2_K)) + ky % (QI2_K/2);
const int kqsx = i * (WARP_SIZE_GGUF + 1) + kbx*QI2_K + (QI2_K/2) * (ky/(2*QI2_K)) + ky % (QI2_K/2);
const int shift = 2 * ((ky % (2*QI2_K)) / (QI2_K/2));
#pragma unroll
@ -1013,10 +1013,10 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_mul_mat(
v[l] = (x_ql[kqsx + l] >> shift) & 0x03030303;
}
const uint8_t * scales = ((const uint8_t *) &x_sc[i * (WARP_SIZE/4) + i/4 + kbx*4]) + ky/4;
const uint8_t * scales = ((const uint8_t *) &x_sc[i * (WARP_SIZE_GGUF/4) + i/4 + kbx*4]) + ky/4;
const int index_y = j * WARP_SIZE + (QR2_K*k) % WARP_SIZE;
return vec_dot_q2_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dm[i * (WARP_SIZE/QI2_K) + i/QI2_K + kbx], y_df[index_y/QI8_1]);
const int index_y = j * WARP_SIZE_GGUF + (QR2_K*k) % WARP_SIZE_GGUF;
return vec_dot_q2_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dm[i * (WARP_SIZE_GGUF/QI2_K) + i/QI2_K + kbx], y_df[index_y/QI8_1]);
}
static __device__ __forceinline__ float vec_dot_q3_K_q8_1(
@ -1047,10 +1047,10 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1(
}
template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q3_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
__shared__ int tile_x_ql[mmq_y * (WARP_SIZE) + mmq_y];
__shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI3_K) + mmq_y/QI3_K];
__shared__ int tile_x_qh[mmq_y * (WARP_SIZE/2) + mmq_y/2];
__shared__ int tile_x_sc[mmq_y * (WARP_SIZE/4) + mmq_y/4];
__shared__ int tile_x_ql[mmq_y * (WARP_SIZE_GGUF) + mmq_y];
__shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE_GGUF/QI3_K) + mmq_y/QI3_K];
__shared__ int tile_x_qh[mmq_y * (WARP_SIZE_GGUF/2) + mmq_y/2];
__shared__ int tile_x_sc[mmq_y * (WARP_SIZE_GGUF/4) + mmq_y/4];
*x_ql = tile_x_ql;
*x_dm = tile_x_dm;
@ -1073,10 +1073,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
i = min(i, i_max);
}
const block_q3_K * bxi = bx0 + i*blocks_per_row + kbx;
x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx);
x_ql[i * (WARP_SIZE_GGUF + 1) + k] = get_int_from_uint8(bxi->qs, kqsx);
}
const int blocks_per_tile_x_row = WARP_SIZE / QI3_K;
const int blocks_per_tile_x_row = WARP_SIZE_GGUF / QI3_K;
const int kbxd = k % blocks_per_tile_x_row;
float * x_dmf = (float *) x_dm;
@ -1087,27 +1087,27 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
i = min(i, i_max);
}
const block_q3_K * bxi = bx0 + i*blocks_per_row + kbxd;
x_dmf[i * (WARP_SIZE/QI3_K) + i / QI3_K + kbxd] = __half2float(bxi->d);
x_dmf[i * (WARP_SIZE_GGUF/QI3_K) + i / QI3_K + kbxd] = __half2float(bxi->d);
}
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 2) {
int i = i0 + i_offset * 2 + k / (WARP_SIZE/2);
int i = i0 + i_offset * 2 + k / (WARP_SIZE_GGUF/2);
if (need_check) {
i = min(i, i_max);
}
const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/2)) / (QI3_K/2);
const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE_GGUF/2)) / (QI3_K/2);
// invert the mask with ~ so that a 0/1 results in 4/0 being subtracted
x_qh[i * (WARP_SIZE/2) + i / 2 + k % (WARP_SIZE/2)] = ~get_int_from_uint8(bxi->hmask, k % (QI3_K/2));
x_qh[i * (WARP_SIZE_GGUF/2) + i / 2 + k % (WARP_SIZE_GGUF/2)] = ~get_int_from_uint8(bxi->hmask, k % (QI3_K/2));
}
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
int i = i0 + i_offset * 4 + k / (WARP_SIZE/4);
int i = i0 + i_offset * 4 + k / (WARP_SIZE_GGUF/4);
if (need_check) {
i = min(i, i_max);
}
const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI3_K/4);
const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE_GGUF/4)) / (QI3_K/4);
const int ksc = k % (QI3_K/4);
@ -1121,7 +1121,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = sc;
x_sc[i * (WARP_SIZE_GGUF/4) + i / 4 + k % (WARP_SIZE_GGUF/4)] = sc;
}
}
@ -1134,24 +1134,24 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_mul_mat(
const float * x_dmf = (const float *) x_dm;
const float * y_df = (const float *) y_ds;
const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4;
const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE_GGUF/4) + i/4 + kbx*4)) + ky/4;
int v[QR3_K*VDR_Q3_K_Q8_1_MMQ];
#pragma unroll
for (int l = 0; l < QR3_K*VDR_Q3_K_Q8_1_MMQ; ++l) {
const int kqsx = i * (WARP_SIZE + 1) + kbx*QI3_K + (QI3_K/2) * (ky/(2*QI3_K)) + ky % (QI3_K/2);
const int kqsx = i * (WARP_SIZE_GGUF + 1) + kbx*QI3_K + (QI3_K/2) * (ky/(2*QI3_K)) + ky % (QI3_K/2);
const int shift = 2 * ((ky % 32) / 8);
const int vll = (x_ql[kqsx + l] >> shift) & 0x03030303;
const int vh = x_qh[i * (WARP_SIZE/2) + i/2 + kbx * (QI3_K/2) + (ky+l)%8] >> ((ky+l) / 8);
const int vh = x_qh[i * (WARP_SIZE_GGUF/2) + i/2 + kbx * (QI3_K/2) + (ky+l)%8] >> ((ky+l) / 8);
const int vlh = (vh << 2) & 0x04040404;
v[l] = __vsubss4(vll, vlh);
}
const int index_y = j * WARP_SIZE + (k*QR3_K) % WARP_SIZE;
return vec_dot_q3_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dmf[i * (WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[index_y/QI8_1]);
const int index_y = j * WARP_SIZE_GGUF + (k*QR3_K) % WARP_SIZE_GGUF;
return vec_dot_q3_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dmf[i * (WARP_SIZE_GGUF/QI3_K) + i/QI3_K + kbx], y_df[index_y/QI8_1]);
}
static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
@ -1200,9 +1200,9 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
}
template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
__shared__ int tile_x_ql[mmq_y * (WARP_SIZE) + mmq_y];
__shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI4_K) + mmq_y/QI4_K];
__shared__ int tile_x_sc[mmq_y * (WARP_SIZE/8) + mmq_y/8];
__shared__ int tile_x_ql[mmq_y * (WARP_SIZE_GGUF) + mmq_y];
__shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE_GGUF/QI4_K) + mmq_y/QI4_K];
__shared__ int tile_x_sc[mmq_y * (WARP_SIZE_GGUF/8) + mmq_y/8];
*x_ql = tile_x_ql;
*x_dm = tile_x_dm;
@ -1225,10 +1225,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
i = min(i, i_max);
}
const block_q4_K * bxi = bx0 + i*blocks_per_row + kbx;
x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
x_ql[i * (WARP_SIZE_GGUF + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
}
const int blocks_per_tile_x_row = WARP_SIZE / QI4_K; // == 1 if QK_K == 256
const int blocks_per_tile_x_row = WARP_SIZE_GGUF / QI4_K; // == 1 if QK_K == 256
const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256
#pragma unroll
@ -1238,27 +1238,27 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
i = min(i, i_max);
}
const block_q4_K * bxi = bx0 + i*blocks_per_row + kbxd;
x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = bxi->dm;
x_dm[i * (WARP_SIZE_GGUF/QI4_K) + i / QI4_K + kbxd] = bxi->dm;
}
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y;
int i = (i0 + i_offset * 8 + k / (WARP_SIZE_GGUF/8)) % mmq_y;
if (need_check) {
i = min(i, i_max);
}
const block_q4_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI4_K/8);
const block_q4_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE_GGUF/8)) / (QI4_K/8);
const int * scales = (const int *) bxi->scales;
const int ksc = k % (WARP_SIZE/8);
const int ksc = k % (WARP_SIZE_GGUF/8);
// scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits
scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits
x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8;
x_sc[i * (WARP_SIZE_GGUF/8) + i / 8 + ksc] = scales8;
}
}
@ -1267,11 +1267,11 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_mul_mat(
const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
(void)x_qh;
const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2*((k % 16) / 8);
const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE_GGUF/8) + i/8 + k/16]) + 2*((k % 16) / 8);
const int index_y = j * WARP_SIZE + (QR4_K*k) % WARP_SIZE;
return vec_dot_q4_K_q8_1_impl_mmq(&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[index_y], sc, sc+8,
x_dm[i * (WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[index_y/QI8_1]);
const int index_y = j * WARP_SIZE_GGUF + (QR4_K*k) % WARP_SIZE_GGUF;
return vec_dot_q4_K_q8_1_impl_mmq(&x_ql[i * (WARP_SIZE_GGUF + 1) + k], &y_qs[index_y], sc, sc+8,
x_dm[i * (WARP_SIZE_GGUF/QI4_K) + i/QI4_K], &y_ds[index_y/QI8_1]);
}
static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
@ -1321,9 +1321,9 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
}
template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q5_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
__shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y];
__shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI5_K) + mmq_y/QI5_K];
__shared__ int tile_x_sc[mmq_y * (WARP_SIZE/8) + mmq_y/8];
__shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE_GGUF) + mmq_y];
__shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE_GGUF/QI5_K) + mmq_y/QI5_K];
__shared__ int tile_x_sc[mmq_y * (WARP_SIZE_GGUF/8) + mmq_y/8];
*x_ql = tile_x_ql;
*x_dm = tile_x_dm;
@ -1360,11 +1360,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const int kq0 = ky - ky % (QI5_K/2) + k % (QI5_K/4) + 0;
const int kq1 = ky - ky % (QI5_K/2) + k % (QI5_K/4) + (QI5_K/4);
x_ql[i * (2*WARP_SIZE + 1) + kq0] = ql0 | qh0;
x_ql[i * (2*WARP_SIZE + 1) + kq1] = ql1 | qh1;
x_ql[i * (2*WARP_SIZE_GGUF + 1) + kq0] = ql0 | qh0;
x_ql[i * (2*WARP_SIZE_GGUF + 1) + kq1] = ql1 | qh1;
}
const int blocks_per_tile_x_row = WARP_SIZE / QI5_K; // == 1 if QK_K == 256
const int blocks_per_tile_x_row = WARP_SIZE_GGUF / QI5_K; // == 1 if QK_K == 256
const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256
#pragma unroll
@ -1376,40 +1376,40 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
}
const block_q5_K * bxi = bx0 + i*blocks_per_row + kbxd;
x_dm[i * (WARP_SIZE/QI5_K) + i / QI5_K + kbxd] = bxi->dm;
x_dm[i * (WARP_SIZE_GGUF/QI5_K) + i / QI5_K + kbxd] = bxi->dm;
}
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y;
int i = (i0 + i_offset * 8 + k / (WARP_SIZE_GGUF/8)) % mmq_y;
if (need_check) {
i = min(i, i_max);
}
const block_q5_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI5_K/8);
const block_q5_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE_GGUF/8)) / (QI5_K/8);
const int * scales = (const int *) bxi->scales;
const int ksc = k % (WARP_SIZE/8);
const int ksc = k % (WARP_SIZE_GGUF/8);
// scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits
scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits
x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8;
x_sc[i * (WARP_SIZE_GGUF/8) + i / 8 + ksc] = scales8;
}
}
static __device__ __forceinline__ float vec_dot_q5_K_q8_1_mul_mat(
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2 * ((k % 16) / 8);
const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE_GGUF/8) + i/8 + k/16]) + 2 * ((k % 16) / 8);
const int index_x = i * (QR5_K*WARP_SIZE + 1) + QR5_K*k;
const int index_y = j * WARP_SIZE + (QR5_K*k) % WARP_SIZE;
const int index_x = i * (QR5_K*WARP_SIZE_GGUF + 1) + QR5_K*k;
const int index_y = j * WARP_SIZE_GGUF + (QR5_K*k) % WARP_SIZE_GGUF;
return vec_dot_q5_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, sc+8,
x_dm[i * (WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[index_y/QI8_1]);
x_dm[i * (WARP_SIZE_GGUF/QI5_K) + i/QI5_K], &y_ds[index_y/QI8_1]);
}
static __device__ __forceinline__ float vec_dot_q6_K_q8_1(
@ -1439,9 +1439,9 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1(
}
template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q6_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
__shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y];
__shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI6_K) + mmq_y/QI6_K];
__shared__ int tile_x_sc[mmq_y * (WARP_SIZE/8) + mmq_y/8];
__shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE_GGUF) + mmq_y];
__shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE_GGUF/QI6_K) + mmq_y/QI6_K];
__shared__ int tile_x_sc[mmq_y * (WARP_SIZE_GGUF/8) + mmq_y/8];
*x_ql = tile_x_ql;
*x_dm = tile_x_dm;
@ -1478,11 +1478,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const int kq0 = ky - ky % QI6_K + k % (QI6_K/2) + 0;
const int kq1 = ky - ky % QI6_K + k % (QI6_K/2) + (QI6_K/2);
x_ql[i * (2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
x_ql[i * (2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
x_ql[i * (2*WARP_SIZE_GGUF + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
x_ql[i * (2*WARP_SIZE_GGUF + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
}
const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256
const int blocks_per_tile_x_row = WARP_SIZE_GGUF / QI6_K; // == 1 if QK_K == 256
const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256
float * x_dmf = (float *) x_dm;
@ -1496,20 +1496,20 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const block_q6_K * bxi = bx0 + i*blocks_per_row + kbxd;
x_dmf[i * (WARP_SIZE/QI6_K) + i / QI6_K + kbxd] = __half2float(bxi->d);
x_dmf[i * (WARP_SIZE_GGUF/QI6_K) + i / QI6_K + kbxd] = __half2float(bxi->d);
}
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y;
int i = (i0 + i_offset * 8 + k / (WARP_SIZE_GGUF/8)) % mmq_y;
if (need_check) {
i = min(i, i_max);
}
const block_q6_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / 4;
const block_q6_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE_GGUF/8)) / 4;
x_sc[i * (WARP_SIZE/8) + i / 8 + k % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, k % (QI6_K/8));
x_sc[i * (WARP_SIZE_GGUF/8) + i / 8 + k % (WARP_SIZE_GGUF/8)] = get_int_from_int8(bxi->scales, k % (QI6_K/8));
}
}
@ -1519,11 +1519,11 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_mul_mat(
const float * x_dmf = (const float *) x_dm;
const float * y_df = (const float *) y_ds;
const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/8]);
const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE_GGUF/8) + i/8 + k/8]);
const int index_x = i * (QR6_K*WARP_SIZE + 1) + QR6_K*k;
const int index_y = j * WARP_SIZE + (QR6_K*k) % WARP_SIZE;
return vec_dot_q6_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]);
const int index_x = i * (QR6_K*WARP_SIZE_GGUF + 1) + QR6_K*k;
const int index_y = j * WARP_SIZE_GGUF + (QR6_K*k) % WARP_SIZE_GGUF;
return vec_dot_q6_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE_GGUF/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]);
}
static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
@ -1582,7 +1582,7 @@ static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
static __device__ __forceinline__ float vec_dot_iq2_s_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
const block_iq2_s * bq2 = (const block_iq2_s *) vbq;
const int ib32 = iqs;
@ -1619,7 +1619,7 @@ static __device__ __forceinline__ float vec_dot_iq2_s_q8_1(
static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
const block_iq3_xxs * bq2 = (const block_iq3_xxs *) vbq;
const int ib32 = iqs;
@ -1646,7 +1646,7 @@ static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1(
static __device__ __forceinline__ float vec_dot_iq3_s_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
const block_iq3_s * bq2 = (const block_iq3_s *) vbq;
const int ib32 = iqs;
@ -1671,7 +1671,7 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1(
static __device__ __forceinline__ float vec_dot_iq1_s_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
const block_iq1_s * bq1 = (const block_iq1_s *) vbq;
const int qs_packed = get_int_b2(bq1->qs, iqs);
@ -1703,7 +1703,7 @@ static __device__ __forceinline__ float vec_dot_iq1_s_q8_1(
static __device__ __forceinline__ float vec_dot_iq1_m_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
const block_iq1_m * bq1 = (const block_iq1_m *) vbq;
@ -1763,7 +1763,7 @@ static __device__ __forceinline__ void get_int_from_table_16(const uint32_t & q4
static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
const block_iq4_nl * bq = (const block_iq4_nl *) vbq;
@ -1788,7 +1788,7 @@ static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1(
static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
const block_iq4_xs * bq4 = (const block_iq4_xs *) vbq;
const uint8_t * values = (const uint8_t *)kvalues_iq4nl;

View File

@ -258,6 +258,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"awq_marlin_repack(Tensor b_q_weight, SymInt size_k, "
"SymInt size_n, int num_bits) -> Tensor");
// conditionally compiled so impl registrations are in source file
#endif
// Dequantization for GGML.
ops.def("ggml_dequantize(Tensor W, int type, SymInt m, SymInt n) -> Tensor");
@ -274,6 +275,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"ggml_mul_mat_a8(Tensor W, Tensor X, int type, SymInt row) -> Tensor");
ops.impl("ggml_mul_mat_a8", torch::kCUDA, &ggml_mul_mat_a8);
#ifndef USE_ROCM
// fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
ops.def(
"fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "

View File

@ -344,31 +344,6 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
is_zp_float: bool = False) -> torch.Tensor:
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
@register_fake("_C::ggml_dequantize")
def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int,
m: torch.SymInt,
n: torch.SymInt) -> torch.Tensor:
return torch.empty((m, n), dtype=torch.float16, device=W.device)
@register_fake("_C::ggml_mul_mat_vec_a8")
def _ggml_mul_mat_vec_a8_fake(
W: torch.Tensor,
X: torch.Tensor,
quant_type: int,
row: torch.SymInt,
) -> torch.Tensor:
return torch.empty((1, row), dtype=torch.float16, device=W.device)
@register_fake("_C::ggml_mul_mat_a8")
def _ggml_mul_mat_a8_fake(
W: torch.Tensor,
X: torch.Tensor,
quant_type: int,
row: torch.SymInt,
) -> torch.Tensor:
batch = X.size(0)
return torch.empty((batch, row), dtype=torch.float16, device=W.device)
@register_fake("_C::marlin_qqq_gemm")
def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
s_tok: torch.Tensor, s_ch: torch.Tensor,
@ -468,6 +443,34 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
memory_format=torch.contiguous_format)
if hasattr(torch.ops._C, "ggml_dequantize"):
@register_fake("_C::ggml_dequantize")
def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int,
m: torch.SymInt,
n: torch.SymInt) -> torch.Tensor:
return torch.empty((m, n), dtype=torch.float16, device=W.device)
@register_fake("_C::ggml_mul_mat_vec_a8")
def _ggml_mul_mat_vec_a8_fake(
W: torch.Tensor,
X: torch.Tensor,
quant_type: int,
row: torch.SymInt,
) -> torch.Tensor:
return torch.empty((1, row), dtype=torch.float16, device=W.device)
@register_fake("_C::ggml_mul_mat_a8")
def _ggml_mul_mat_a8_fake(
W: torch.Tensor,
X: torch.Tensor,
quant_type: int,
row: torch.SymInt,
) -> torch.Tensor:
batch = X.size(0)
return torch.empty((batch, row), dtype=torch.float16, device=W.device)
# cutlass
def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
return torch.ops._C.cutlass_scaled_mm_supports_fp8(cuda_device_capability)

View File

@ -387,7 +387,7 @@ class ModelConfig:
supported_quantization = QUANTIZATION_METHODS
rocm_supported_quantization = [
"awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors",
"fbgemm_fp8"
"fbgemm_fp8", "gguf"
]
optimized_quantization_methods = [
"fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin",