vllm/csrc/quantization/gptq/qdq_util.cuh

57 lines
1.4 KiB
Plaintext
Raw Normal View History

2023-12-15 19:04:22 +08:00
/*
Copied from https://github.com/turboderp/exllamav2
*/
#ifndef _qdq_util_cuh
#define _qdq_util_cuh
namespace vllm {
namespace gptq {
union half2_uint32 {
uint32_t as_uint32;
half2 as_half2;
__device__ half2_uint32(uint32_t val) : as_uint32(val) {}
__device__ half2_uint32(half2 val) : as_half2(val) {}
2023-12-15 19:04:22 +08:00
};
union half_uint16 {
uint16_t as_uint16;
half as_half;
__device__ half_uint16(uint16_t val) : as_uint16(val) {}
__device__ half_uint16(half val) : as_half(val) {}
2023-12-15 19:04:22 +08:00
};
// Max_scale premultiplied by 1/256
__forceinline__ __device__ half dq_scale(const int qs, const half max_scale) {
int qs_i = qs + 1;
half qs_h = __int2half_rn(qs_i * qs_i);
qs_h = __hmul(qs_h, max_scale);
return qs_h;
2023-12-15 19:04:22 +08:00
}
__forceinline__ __device__ half dq(const int q, const int qzero,
const half scale) {
return __hmul(__int2half_rn(q - qzero), scale);
2023-12-15 19:04:22 +08:00
}
__forceinline__ __device__ half dq_ns(const int q, const int qzero) {
// return __hsub(__int2half_rn(q), __int2half_rn(qzero));
return __int2half_rn(q - qzero);
2023-12-15 19:04:22 +08:00
}
__forceinline__ __device__ int exb(const uint32_t q, const int shift,
const int mask) {
return (int)((q >> shift) & mask);
2023-12-15 19:04:22 +08:00
}
__forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0,
const int shift, const int mask) {
return (int)(__funnelshift_rc(q0, q1, shift) & mask);
2023-12-15 19:04:22 +08:00
}
} // namespace gptq
} // namespace vllm
#endif