vllm/csrc/quantization/squeezellm/quant_cuda_kernel.cu

217 lines
5.4 KiB
Plaintext
Raw Normal View History

#include <torch/all.h>
#include <torch/python.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
// half-tensor
#include <c10/cuda/CUDAStream.h>
#include <ATen/cuda/CUDATensorMethods.cuh>
#include <c10/cuda/CUDAGuard.h>
#define BLOCKWIDTH 128
#define BLOCKHEIGHT4 16
namespace vllm {
namespace squeezellm {
__device__ inline unsigned int as_unsigned(int i) {
return *reinterpret_cast<unsigned int*>(&i);
}
// 4-bit matvec kernel (LUT-based)
__global__ void NUQ4MatMulKernel(
#ifndef USE_ROCM
const half2* __restrict__ vec,
#else
const __half2* __restrict__ vec,
#endif
const int* __restrict__ mat,
#ifndef USE_ROCM
half2* __restrict__ mul,
#else
float2* __restrict__ mul,
#endif
const __half* __restrict__ lookup_table, int height, int width, int batch,
int vec_height) {
const int blockwidth2 = BLOCKWIDTH / 2;
int row = BLOCKHEIGHT4 * blockIdx.x;
int col = BLOCKWIDTH * blockIdx.y + threadIdx.x;
#ifndef USE_ROCM
__shared__ half2 blockvec[blockwidth2];
#else
__shared__ __half2 blockvec[blockwidth2];
#endif
__shared__ __half deq2[16][BLOCKWIDTH];
int off = threadIdx.x;
int column_offset = col * 16;
for (int val = 0; val < 16; val += 1) {
int lut_index = column_offset + val;
deq2[val][off] = lookup_table[lut_index];
}
__half res;
#ifndef USE_ROCM
half2 res2;
half2 tmp2;
#else
__half2 res2;
__half2 tmp2;
#endif
int i;
int k;
unsigned int tmp1;
unsigned int lut_index1, lut_index2;
for (int b = 0; b < batch; ++b) {
i = width * row + col;
res = __int2half_rd(0);
k = 0;
__syncthreads();
if (threadIdx.x < blockwidth2)
blockvec[threadIdx.x] =
vec[b * vec_height / 2 + (row / BLOCKHEIGHT4) * blockwidth2 +
threadIdx.x];
__syncthreads();
while (k < blockwidth2) {
tmp1 = as_unsigned(mat[i]);
#ifndef USE_ROCM
res2 = {};
tmp2 = {};
#else
res2.x = __half_as_ushort(__float2half(0));
res2.y = __half_as_ushort(__float2half(0));
tmp2.x = __half_as_ushort(__float2half(0));
tmp2.y = __half_as_ushort(__float2half(0));
#endif
lut_index1 = tmp1 & 0xF;
lut_index2 = (tmp1 >> 4) & 0xF;
#ifndef USE_ROCM
tmp2.x = deq2[lut_index1][off];
tmp2.y = deq2[lut_index2][off];
#else
tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
#endif
res2 = __hfma2(tmp2, blockvec[k + 0], res2);
lut_index1 = (tmp1 >> 8) & 0xF;
lut_index2 = (tmp1 >> 12) & 0xF;
#ifndef USE_ROCM
tmp2.x = deq2[lut_index1][off];
tmp2.y = deq2[lut_index2][off];
#else
tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
#endif
res2 = __hfma2(tmp2, blockvec[k + 1], res2);
lut_index1 = (tmp1 >> 16) & 0xF;
lut_index2 = (tmp1 >> 20) & 0xF;
#ifndef USE_ROCM
tmp2.x = deq2[lut_index1][off];
tmp2.y = deq2[lut_index2][off];
#else
tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
#endif
res2 = __hfma2(tmp2, blockvec[k + 2], res2);
lut_index1 = (tmp1 >> 24) & 0xF;
lut_index2 = (tmp1 >> 28) & 0xF;
#ifndef USE_ROCM
tmp2.x = deq2[lut_index1][off];
tmp2.y = deq2[lut_index2][off];
#else
tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
#endif
res2 = __hfma2(tmp2, blockvec[k + 3], res2);
#ifndef USE_ROCM
res = __hadd(__hadd(res2.x, res2.y), res);
#else
res = __hadd(__hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)),
res);
#endif
i += width;
k += 4;
}
// col%2 -> only set one of the two values
#ifndef USE_ROCM
half2 res3 = {};
if (col % 2 == 0) {
res3.x = res;
} else {
res3.y = res;
}
#else
__half2 res3;
res3.x = __half_as_ushort(__float2half(0));
res3.y = __half_as_ushort(__float2half(0));
if (col % 2 == 0) {
res3.x = __half_as_ushort(res);
} else {
res3.y = __half_as_ushort(res);
}
#endif
#ifndef USE_ROCM
atomicAdd(&mul[b * width / 2 + col / 2], res3);
#else
int tmp_addr = b * width / 2 + col / 2;
atomicAdd(&(mul[tmp_addr].x), __half2float(__ushort_as_half(res3.x)));
atomicAdd(&(mul[tmp_addr].y), __half2float(__ushort_as_half(res3.y)));
#endif
}
}
} // namespace squeezellm
} // namespace vllm
// 4-bit matvec kernel (LUT-based)
void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor lookup_table) {
int height = mat.size(0);
int width = mat.size(1);
int batch = vec.size(0);
int vec_height = vec.size(1);
dim3 blocks((height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
(width + BLOCKWIDTH - 1) / BLOCKWIDTH);
dim3 threads(BLOCKWIDTH);
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
vllm::squeezellm::NUQ4MatMulKernel<<<blocks, threads, 0, stream>>>(
#ifndef USE_ROCM
(half2*)vec.data<at::Half>(),
#else
(__half2*)vec.data_ptr<at::Half>(),
#endif
mat.data_ptr<int>(),
#ifndef USE_ROCM
(half2*)mul.data<at::Half>(), (__half*)lookup_table.data<at::Half>(),
#else
(float2*)mul.data_ptr<float>(),
(__half*)lookup_table.data_ptr<at::Half>(),
#endif
height, width, batch, vec_height);
}
#undef BLOCKWIDTH
#undef BLOCKHEIGHT4