2023-09-16 00:03:37 -07:00
|
|
|
#include <torch/extension.h>
|
|
|
|
|
|
|
|
torch::Tensor awq_gemm(
|
|
|
|
torch::Tensor _in_feats,
|
|
|
|
torch::Tensor _kernel,
|
|
|
|
torch::Tensor _scaling_factors,
|
|
|
|
torch::Tensor _zeros,
|
|
|
|
int split_k_iters);
|
|
|
|
|
2023-10-22 03:14:59 -03:00
|
|
|
void squeezellm_gemm(
|
|
|
|
torch::Tensor vec,
|
|
|
|
torch::Tensor mat,
|
|
|
|
torch::Tensor mul,
|
|
|
|
torch::Tensor lookup_table);
|
|
|
|
|
2023-09-16 00:03:37 -07:00
|
|
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
2023-10-22 03:14:59 -03:00
|
|
|
m.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
|
|
|
|
m.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
|
2023-09-16 00:03:37 -07:00
|
|
|
}
|