[Kernel] fix types used in aqlm and ggml kernels to support dynamo (#7596)
This commit is contained in:
parent
7759ae958f
commit
37fd47e780
16
csrc/ops.h
16
csrc/ops.h
@ -63,12 +63,12 @@ void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size,
|
|||||||
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
|
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
|
||||||
const torch::Tensor& codebooks,
|
const torch::Tensor& codebooks,
|
||||||
const torch::Tensor& scales,
|
const torch::Tensor& scales,
|
||||||
const torch::Tensor& codebook_partition_sizes,
|
const std::vector<int64_t>& codebook_partition_sizes,
|
||||||
const std::optional<torch::Tensor>& bias);
|
const std::optional<torch::Tensor>& bias);
|
||||||
|
|
||||||
torch::Tensor aqlm_dequant(const torch::Tensor& codes,
|
torch::Tensor aqlm_dequant(
|
||||||
const torch::Tensor& codebooks,
|
const torch::Tensor& codes, const torch::Tensor& codebooks,
|
||||||
const torch::Tensor& codebook_partition_sizes);
|
const std::vector<int64_t>& codebook_partition_sizes);
|
||||||
|
|
||||||
torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
|
torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
|
||||||
torch::Tensor _scaling_factors, torch::Tensor _zeros,
|
torch::Tensor _scaling_factors, torch::Tensor _zeros,
|
||||||
@ -107,13 +107,13 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
|||||||
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
|
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
|
||||||
int64_t size_n, int64_t num_bits);
|
int64_t size_n, int64_t num_bits);
|
||||||
|
|
||||||
torch::Tensor ggml_dequantize(torch::Tensor W, int8_t type, int64_t m,
|
torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m,
|
||||||
int64_t n);
|
int64_t n);
|
||||||
|
|
||||||
torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X, int8_t type,
|
torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X,
|
||||||
int64_t row);
|
int64_t type, int64_t row);
|
||||||
|
|
||||||
torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int8_t type,
|
torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type,
|
||||||
int64_t row);
|
int64_t row);
|
||||||
|
|
||||||
torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||||
|
@ -496,14 +496,14 @@ torch::Tensor code2x8_matmat(const torch::Tensor& input,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Accumulate the partition sizes.
|
// Accumulate the partition sizes.
|
||||||
int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes) {
|
int4 accumulate_sizes(const std::vector<int64_t>& codebook_partition_sizes) {
|
||||||
int4 cumulative_sizes;
|
int4 cumulative_sizes;
|
||||||
auto cumulative_size = &cumulative_sizes.x;
|
auto cumulative_size = &cumulative_sizes.x;
|
||||||
int i = 0;
|
size_t i = 0;
|
||||||
int last = 0;
|
int last = 0;
|
||||||
assert(codebook_partition_sizes.size(0) <= 4);
|
assert(codebook_partition_sizes.size() <= 4);
|
||||||
for (; i < codebook_partition_sizes.size(0); ++i, ++cumulative_size) {
|
for (; i < codebook_partition_sizes.size(); ++i, ++cumulative_size) {
|
||||||
*cumulative_size = codebook_partition_sizes[i].item<int>() + last;
|
*cumulative_size = codebook_partition_sizes[i] + last;
|
||||||
last = *cumulative_size;
|
last = *cumulative_size;
|
||||||
}
|
}
|
||||||
// fill in the rest with unreachable.
|
// fill in the rest with unreachable.
|
||||||
@ -519,12 +519,12 @@ int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes) {
|
|||||||
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
|
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
|
||||||
const torch::Tensor& codebooks,
|
const torch::Tensor& codebooks,
|
||||||
const torch::Tensor& scales,
|
const torch::Tensor& scales,
|
||||||
const torch::Tensor& codebook_partition_sizes,
|
const std::vector<int64_t>& codebook_partition_sizes,
|
||||||
const std::optional<torch::Tensor>& bias) {
|
const std::optional<torch::Tensor>& bias) {
|
||||||
int4 cumulative_sizes =
|
int4 cumulative_sizes =
|
||||||
vllm::aqlm::accumulate_sizes(codebook_partition_sizes);
|
vllm::aqlm::accumulate_sizes(codebook_partition_sizes);
|
||||||
|
|
||||||
int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0);
|
int const nbooks = codebooks.size(0) / codebook_partition_sizes.size();
|
||||||
int const entries = codebooks.size(1);
|
int const entries = codebooks.size(1);
|
||||||
|
|
||||||
if (nbooks == 1 && entries == (1 << 16)) {
|
if (nbooks == 1 && entries == (1 << 16)) {
|
||||||
@ -541,13 +541,13 @@ torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
|
|||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::Tensor aqlm_dequant(const torch::Tensor& codes,
|
torch::Tensor aqlm_dequant(
|
||||||
const torch::Tensor& codebooks,
|
const torch::Tensor& codes, const torch::Tensor& codebooks,
|
||||||
const torch::Tensor& codebook_partition_sizes) {
|
const std::vector<int64_t>& codebook_partition_sizes) {
|
||||||
int4 cumulative_sizes =
|
int4 cumulative_sizes =
|
||||||
vllm::aqlm::accumulate_sizes(codebook_partition_sizes);
|
vllm::aqlm::accumulate_sizes(codebook_partition_sizes);
|
||||||
|
|
||||||
int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0);
|
int const nbooks = codebooks.size(0) / codebook_partition_sizes.size();
|
||||||
int const entries = codebooks.size(1);
|
int const entries = codebooks.size(1);
|
||||||
|
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(codes));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(codes));
|
||||||
@ -557,7 +557,8 @@ torch::Tensor aqlm_dequant(const torch::Tensor& codes,
|
|||||||
auto in_features = codes.size(1) * 8;
|
auto in_features = codes.size(1) * 8;
|
||||||
auto out_features = codes.size(0);
|
auto out_features = codes.size(0);
|
||||||
|
|
||||||
assert(out_features = codebook_partition_sizes.sum().item<int>());
|
assert(out_features == std::accumulate(codebook_partition_sizes.begin(),
|
||||||
|
codebook_partition_sizes.end(), 0));
|
||||||
|
|
||||||
auto weights = torch::empty({out_features, in_features},
|
auto weights = torch::empty({out_features, in_features},
|
||||||
torch::TensorOptions()
|
torch::TensorOptions()
|
||||||
|
@ -487,7 +487,7 @@ static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int k,
|
|||||||
dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y);
|
dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y);
|
||||||
}
|
}
|
||||||
|
|
||||||
static to_fp16_cuda_t ggml_get_to_fp16_cuda(int type) {
|
static to_fp16_cuda_t ggml_get_to_fp16_cuda(int64_t type) {
|
||||||
switch (type) {
|
switch (type) {
|
||||||
case 2:
|
case 2:
|
||||||
return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
|
return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
|
||||||
|
@ -60,7 +60,7 @@ static void quantize_row_q8_1_cuda(const half* x, void* vy, const int kx,
|
|||||||
}
|
}
|
||||||
|
|
||||||
torch::Tensor ggml_dequantize(torch::Tensor W, // quant weight
|
torch::Tensor ggml_dequantize(torch::Tensor W, // quant weight
|
||||||
int8_t type, int64_t m, int64_t n) {
|
int64_t type, int64_t m, int64_t n) {
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(W));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(W));
|
||||||
auto options =
|
auto options =
|
||||||
torch::TensorOptions().dtype(torch::kFloat16).device(W.device());
|
torch::TensorOptions().dtype(torch::kFloat16).device(W.device());
|
||||||
@ -73,7 +73,7 @@ torch::Tensor ggml_dequantize(torch::Tensor W, // quant weight
|
|||||||
|
|
||||||
torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, // quant weight
|
torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, // quant weight
|
||||||
torch::Tensor X, // input
|
torch::Tensor X, // input
|
||||||
int8_t type, int64_t row) {
|
int64_t type, int64_t row) {
|
||||||
int col = X.sizes()[1];
|
int col = X.sizes()[1];
|
||||||
const int padded = (col + 512 - 1) / 512 * 512;
|
const int padded = (col + 512 - 1) / 512 * 512;
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(X));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(X));
|
||||||
@ -172,7 +172,7 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, // quant weight
|
|||||||
|
|
||||||
torch::Tensor ggml_mul_mat_a8(torch::Tensor W, // quant weight
|
torch::Tensor ggml_mul_mat_a8(torch::Tensor W, // quant weight
|
||||||
torch::Tensor X, // input
|
torch::Tensor X, // input
|
||||||
int8_t type, int64_t row) {
|
int64_t type, int64_t row) {
|
||||||
int col = X.sizes()[1];
|
int col = X.sizes()[1];
|
||||||
int padded = (col + 512 - 1) / 512 * 512;
|
int padded = (col + 512 - 1) / 512 * 512;
|
||||||
int batch = X.sizes()[0];
|
int batch = X.sizes()[0];
|
||||||
@ -239,4 +239,4 @@ torch::Tensor ggml_mul_mat_a8(torch::Tensor W, // quant weight
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
return Y;
|
return Y;
|
||||||
}
|
}
|
||||||
|
@ -17,13 +17,7 @@ if not current_platform.is_tpu():
|
|||||||
logger.warning("Failed to import from vllm._C with %r", e)
|
logger.warning("Failed to import from vllm._C with %r", e)
|
||||||
|
|
||||||
with contextlib.suppress(ImportError):
|
with contextlib.suppress(ImportError):
|
||||||
# ruff: noqa: F401
|
import vllm._moe_C # noqa: F401
|
||||||
import vllm._moe_C
|
|
||||||
|
|
||||||
|
|
||||||
def is_custom_op_supported(op_name: str) -> bool:
|
|
||||||
op, overloads = torch._C._jit_get_operation(op_name)
|
|
||||||
return op is not None
|
|
||||||
|
|
||||||
|
|
||||||
def hint_on_error(fn):
|
def hint_on_error(fn):
|
||||||
@ -280,14 +274,14 @@ def cutlass_scaled_mm_azp(a: torch.Tensor,
|
|||||||
# aqlm
|
# aqlm
|
||||||
def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor,
|
def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor,
|
||||||
codebooks: torch.Tensor, scales: torch.Tensor,
|
codebooks: torch.Tensor, scales: torch.Tensor,
|
||||||
codebook_partition_sizes: torch.Tensor,
|
codebook_partition_sizes: List[int],
|
||||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||||
return torch.ops._C.aqlm_gemm(input, codes, codebooks, scales,
|
return torch.ops._C.aqlm_gemm(input, codes, codebooks, scales,
|
||||||
codebook_partition_sizes, bias)
|
codebook_partition_sizes, bias)
|
||||||
|
|
||||||
|
|
||||||
def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor,
|
def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor,
|
||||||
codebook_partition_sizes: torch.Tensor) -> torch.Tensor:
|
codebook_partition_sizes: List[int]) -> torch.Tensor:
|
||||||
return torch.ops._C.aqlm_dequant(codes, codebooks,
|
return torch.ops._C.aqlm_dequant(codes, codebooks,
|
||||||
codebook_partition_sizes)
|
codebook_partition_sizes)
|
||||||
|
|
||||||
@ -434,25 +428,17 @@ def marlin_qqq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
|
|||||||
|
|
||||||
|
|
||||||
# gguf
|
# gguf
|
||||||
def ggml_dequantize(W: torch.Tensor, quant_type: int, m: int, n: int):
|
def ggml_dequantize(W: torch.Tensor, quant_type: int, m: int,
|
||||||
|
n: int) -> torch.Tensor:
|
||||||
return torch.ops._C.ggml_dequantize(W, quant_type, m, n)
|
return torch.ops._C.ggml_dequantize(W, quant_type, m, n)
|
||||||
|
|
||||||
|
|
||||||
def ggml_mul_mat_vec(
|
|
||||||
W: torch.Tensor,
|
|
||||||
X: torch.Tensor,
|
|
||||||
quant_type: int,
|
|
||||||
row: int,
|
|
||||||
):
|
|
||||||
return torch.ops._C.ggml_mul_mat_vec(W, X, quant_type, row)
|
|
||||||
|
|
||||||
|
|
||||||
def ggml_mul_mat_vec_a8(
|
def ggml_mul_mat_vec_a8(
|
||||||
W: torch.Tensor,
|
W: torch.Tensor,
|
||||||
X: torch.Tensor,
|
X: torch.Tensor,
|
||||||
quant_type: int,
|
quant_type: int,
|
||||||
row: int,
|
row: int,
|
||||||
):
|
) -> torch.Tensor:
|
||||||
return torch.ops._C.ggml_mul_mat_vec_a8(W, X, quant_type, row)
|
return torch.ops._C.ggml_mul_mat_vec_a8(W, X, quant_type, row)
|
||||||
|
|
||||||
|
|
||||||
@ -461,7 +447,7 @@ def ggml_mul_mat_a8(
|
|||||||
X: torch.Tensor,
|
X: torch.Tensor,
|
||||||
quant_type: int,
|
quant_type: int,
|
||||||
row: int,
|
row: int,
|
||||||
):
|
) -> torch.Tensor:
|
||||||
return torch.ops._C.ggml_mul_mat_a8(W, X, quant_type, row)
|
return torch.ops._C.ggml_mul_mat_a8(W, X, quant_type, row)
|
||||||
|
|
||||||
|
|
||||||
|
@ -95,7 +95,7 @@ def generic_dequantize_gemm(
|
|||||||
codebooks: torch.
|
codebooks: torch.
|
||||||
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
|
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
|
||||||
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
|
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
|
||||||
output_partition_sizes: torch.IntTensor,
|
output_partition_sizes: List[int],
|
||||||
bias: Optional[torch.Tensor],
|
bias: Optional[torch.Tensor],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
output_shape = input.shape[:-1] + (scales.shape[0], )
|
output_shape = input.shape[:-1] + (scales.shape[0], )
|
||||||
@ -133,7 +133,7 @@ def optimized_dequantize_gemm(
|
|||||||
codebooks: torch.
|
codebooks: torch.
|
||||||
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
|
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
|
||||||
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
|
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
|
||||||
output_partition_sizes: torch.IntTensor,
|
output_partition_sizes: List[int],
|
||||||
bias: Optional[torch.Tensor],
|
bias: Optional[torch.Tensor],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
|
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
|
||||||
@ -288,10 +288,8 @@ class AQLMLinearMethod(LinearMethodBase):
|
|||||||
codebooks,
|
codebooks,
|
||||||
{
|
{
|
||||||
# metadata indicates fixed size concatenated along dim 0
|
# metadata indicates fixed size concatenated along dim 0
|
||||||
"is_metadata":
|
"is_metadata": True,
|
||||||
True,
|
"output_partition_sizes": output_partition_sizes
|
||||||
"output_partition_sizes":
|
|
||||||
torch.tensor(output_partition_sizes, device='cpu'),
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -334,7 +332,7 @@ class AQLMLinearMethod(LinearMethodBase):
|
|||||||
codes = layer.codes
|
codes = layer.codes
|
||||||
scales = layer.scales
|
scales = layer.scales
|
||||||
output_partition_sizes = getattr(codebooks, "output_partition_sizes",
|
output_partition_sizes = getattr(codebooks, "output_partition_sizes",
|
||||||
None)
|
[])
|
||||||
|
|
||||||
nbooks = codes.shape[2]
|
nbooks = codes.shape[2]
|
||||||
ingroups = codebooks.shape[3]
|
ingroups = codebooks.shape[3]
|
||||||
|
@ -212,6 +212,7 @@ class GPTQLinearMethod(LinearMethodBase):
|
|||||||
layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int)
|
layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int)
|
||||||
else:
|
else:
|
||||||
layer.g_idx.data = torch.empty((0, ),
|
layer.g_idx.data = torch.empty((0, ),
|
||||||
|
dtype=torch.int,
|
||||||
device=layer.g_idx.device)
|
device=layer.g_idx.device)
|
||||||
layer.exllama_state = ExllamaState.READY
|
layer.exllama_state = ExllamaState.READY
|
||||||
ops.gptq_shuffle(layer.qweight, layer.g_idx,
|
ops.gptq_shuffle(layer.qweight, layer.g_idx,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user