[Kernel] fix types used in aqlm and ggml kernels to support dynamo (#7596)
This commit is contained in:
parent
7759ae958f
commit
37fd47e780
16
csrc/ops.h
16
csrc/ops.h
@ -63,12 +63,12 @@ void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size,
|
||||
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
|
||||
const torch::Tensor& codebooks,
|
||||
const torch::Tensor& scales,
|
||||
const torch::Tensor& codebook_partition_sizes,
|
||||
const std::vector<int64_t>& codebook_partition_sizes,
|
||||
const std::optional<torch::Tensor>& bias);
|
||||
|
||||
torch::Tensor aqlm_dequant(const torch::Tensor& codes,
|
||||
const torch::Tensor& codebooks,
|
||||
const torch::Tensor& codebook_partition_sizes);
|
||||
torch::Tensor aqlm_dequant(
|
||||
const torch::Tensor& codes, const torch::Tensor& codebooks,
|
||||
const std::vector<int64_t>& codebook_partition_sizes);
|
||||
|
||||
torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
|
||||
torch::Tensor _scaling_factors, torch::Tensor _zeros,
|
||||
@ -107,13 +107,13 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
||||
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
|
||||
int64_t size_n, int64_t num_bits);
|
||||
|
||||
torch::Tensor ggml_dequantize(torch::Tensor W, int8_t type, int64_t m,
|
||||
torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m,
|
||||
int64_t n);
|
||||
|
||||
torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X, int8_t type,
|
||||
int64_t row);
|
||||
torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X,
|
||||
int64_t type, int64_t row);
|
||||
|
||||
torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int8_t type,
|
||||
torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type,
|
||||
int64_t row);
|
||||
|
||||
torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
|
@ -496,14 +496,14 @@ torch::Tensor code2x8_matmat(const torch::Tensor& input,
|
||||
}
|
||||
|
||||
// Accumulate the partition sizes.
|
||||
int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes) {
|
||||
int4 accumulate_sizes(const std::vector<int64_t>& codebook_partition_sizes) {
|
||||
int4 cumulative_sizes;
|
||||
auto cumulative_size = &cumulative_sizes.x;
|
||||
int i = 0;
|
||||
size_t i = 0;
|
||||
int last = 0;
|
||||
assert(codebook_partition_sizes.size(0) <= 4);
|
||||
for (; i < codebook_partition_sizes.size(0); ++i, ++cumulative_size) {
|
||||
*cumulative_size = codebook_partition_sizes[i].item<int>() + last;
|
||||
assert(codebook_partition_sizes.size() <= 4);
|
||||
for (; i < codebook_partition_sizes.size(); ++i, ++cumulative_size) {
|
||||
*cumulative_size = codebook_partition_sizes[i] + last;
|
||||
last = *cumulative_size;
|
||||
}
|
||||
// fill in the rest with unreachable.
|
||||
@ -519,12 +519,12 @@ int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes) {
|
||||
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
|
||||
const torch::Tensor& codebooks,
|
||||
const torch::Tensor& scales,
|
||||
const torch::Tensor& codebook_partition_sizes,
|
||||
const std::vector<int64_t>& codebook_partition_sizes,
|
||||
const std::optional<torch::Tensor>& bias) {
|
||||
int4 cumulative_sizes =
|
||||
vllm::aqlm::accumulate_sizes(codebook_partition_sizes);
|
||||
|
||||
int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0);
|
||||
int const nbooks = codebooks.size(0) / codebook_partition_sizes.size();
|
||||
int const entries = codebooks.size(1);
|
||||
|
||||
if (nbooks == 1 && entries == (1 << 16)) {
|
||||
@ -541,13 +541,13 @@ torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
|
||||
return {};
|
||||
}
|
||||
|
||||
torch::Tensor aqlm_dequant(const torch::Tensor& codes,
|
||||
const torch::Tensor& codebooks,
|
||||
const torch::Tensor& codebook_partition_sizes) {
|
||||
torch::Tensor aqlm_dequant(
|
||||
const torch::Tensor& codes, const torch::Tensor& codebooks,
|
||||
const std::vector<int64_t>& codebook_partition_sizes) {
|
||||
int4 cumulative_sizes =
|
||||
vllm::aqlm::accumulate_sizes(codebook_partition_sizes);
|
||||
|
||||
int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0);
|
||||
int const nbooks = codebooks.size(0) / codebook_partition_sizes.size();
|
||||
int const entries = codebooks.size(1);
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(codes));
|
||||
@ -557,7 +557,8 @@ torch::Tensor aqlm_dequant(const torch::Tensor& codes,
|
||||
auto in_features = codes.size(1) * 8;
|
||||
auto out_features = codes.size(0);
|
||||
|
||||
assert(out_features = codebook_partition_sizes.sum().item<int>());
|
||||
assert(out_features == std::accumulate(codebook_partition_sizes.begin(),
|
||||
codebook_partition_sizes.end(), 0));
|
||||
|
||||
auto weights = torch::empty({out_features, in_features},
|
||||
torch::TensorOptions()
|
||||
|
@ -487,7 +487,7 @@ static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int k,
|
||||
dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y);
|
||||
}
|
||||
|
||||
static to_fp16_cuda_t ggml_get_to_fp16_cuda(int type) {
|
||||
static to_fp16_cuda_t ggml_get_to_fp16_cuda(int64_t type) {
|
||||
switch (type) {
|
||||
case 2:
|
||||
return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
|
||||
|
@ -60,7 +60,7 @@ static void quantize_row_q8_1_cuda(const half* x, void* vy, const int kx,
|
||||
}
|
||||
|
||||
torch::Tensor ggml_dequantize(torch::Tensor W, // quant weight
|
||||
int8_t type, int64_t m, int64_t n) {
|
||||
int64_t type, int64_t m, int64_t n) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(W));
|
||||
auto options =
|
||||
torch::TensorOptions().dtype(torch::kFloat16).device(W.device());
|
||||
@ -73,7 +73,7 @@ torch::Tensor ggml_dequantize(torch::Tensor W, // quant weight
|
||||
|
||||
torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, // quant weight
|
||||
torch::Tensor X, // input
|
||||
int8_t type, int64_t row) {
|
||||
int64_t type, int64_t row) {
|
||||
int col = X.sizes()[1];
|
||||
const int padded = (col + 512 - 1) / 512 * 512;
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(X));
|
||||
@ -172,7 +172,7 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, // quant weight
|
||||
|
||||
torch::Tensor ggml_mul_mat_a8(torch::Tensor W, // quant weight
|
||||
torch::Tensor X, // input
|
||||
int8_t type, int64_t row) {
|
||||
int64_t type, int64_t row) {
|
||||
int col = X.sizes()[1];
|
||||
int padded = (col + 512 - 1) / 512 * 512;
|
||||
int batch = X.sizes()[0];
|
||||
@ -239,4 +239,4 @@ torch::Tensor ggml_mul_mat_a8(torch::Tensor W, // quant weight
|
||||
break;
|
||||
}
|
||||
return Y;
|
||||
}
|
||||
}
|
||||
|
@ -17,13 +17,7 @@ if not current_platform.is_tpu():
|
||||
logger.warning("Failed to import from vllm._C with %r", e)
|
||||
|
||||
with contextlib.suppress(ImportError):
|
||||
# ruff: noqa: F401
|
||||
import vllm._moe_C
|
||||
|
||||
|
||||
def is_custom_op_supported(op_name: str) -> bool:
|
||||
op, overloads = torch._C._jit_get_operation(op_name)
|
||||
return op is not None
|
||||
import vllm._moe_C # noqa: F401
|
||||
|
||||
|
||||
def hint_on_error(fn):
|
||||
@ -280,14 +274,14 @@ def cutlass_scaled_mm_azp(a: torch.Tensor,
|
||||
# aqlm
|
||||
def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor,
|
||||
codebooks: torch.Tensor, scales: torch.Tensor,
|
||||
codebook_partition_sizes: torch.Tensor,
|
||||
codebook_partition_sizes: List[int],
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
return torch.ops._C.aqlm_gemm(input, codes, codebooks, scales,
|
||||
codebook_partition_sizes, bias)
|
||||
|
||||
|
||||
def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor,
|
||||
codebook_partition_sizes: torch.Tensor) -> torch.Tensor:
|
||||
codebook_partition_sizes: List[int]) -> torch.Tensor:
|
||||
return torch.ops._C.aqlm_dequant(codes, codebooks,
|
||||
codebook_partition_sizes)
|
||||
|
||||
@ -434,25 +428,17 @@ def marlin_qqq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
|
||||
|
||||
|
||||
# gguf
|
||||
def ggml_dequantize(W: torch.Tensor, quant_type: int, m: int, n: int):
|
||||
def ggml_dequantize(W: torch.Tensor, quant_type: int, m: int,
|
||||
n: int) -> torch.Tensor:
|
||||
return torch.ops._C.ggml_dequantize(W, quant_type, m, n)
|
||||
|
||||
|
||||
def ggml_mul_mat_vec(
|
||||
W: torch.Tensor,
|
||||
X: torch.Tensor,
|
||||
quant_type: int,
|
||||
row: int,
|
||||
):
|
||||
return torch.ops._C.ggml_mul_mat_vec(W, X, quant_type, row)
|
||||
|
||||
|
||||
def ggml_mul_mat_vec_a8(
|
||||
W: torch.Tensor,
|
||||
X: torch.Tensor,
|
||||
quant_type: int,
|
||||
row: int,
|
||||
):
|
||||
) -> torch.Tensor:
|
||||
return torch.ops._C.ggml_mul_mat_vec_a8(W, X, quant_type, row)
|
||||
|
||||
|
||||
@ -461,7 +447,7 @@ def ggml_mul_mat_a8(
|
||||
X: torch.Tensor,
|
||||
quant_type: int,
|
||||
row: int,
|
||||
):
|
||||
) -> torch.Tensor:
|
||||
return torch.ops._C.ggml_mul_mat_a8(W, X, quant_type, row)
|
||||
|
||||
|
||||
|
@ -95,7 +95,7 @@ def generic_dequantize_gemm(
|
||||
codebooks: torch.
|
||||
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
|
||||
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
|
||||
output_partition_sizes: torch.IntTensor,
|
||||
output_partition_sizes: List[int],
|
||||
bias: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
output_shape = input.shape[:-1] + (scales.shape[0], )
|
||||
@ -133,7 +133,7 @@ def optimized_dequantize_gemm(
|
||||
codebooks: torch.
|
||||
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
|
||||
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
|
||||
output_partition_sizes: torch.IntTensor,
|
||||
output_partition_sizes: List[int],
|
||||
bias: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
|
||||
@ -288,10 +288,8 @@ class AQLMLinearMethod(LinearMethodBase):
|
||||
codebooks,
|
||||
{
|
||||
# metadata indicates fixed size concatenated along dim 0
|
||||
"is_metadata":
|
||||
True,
|
||||
"output_partition_sizes":
|
||||
torch.tensor(output_partition_sizes, device='cpu'),
|
||||
"is_metadata": True,
|
||||
"output_partition_sizes": output_partition_sizes
|
||||
},
|
||||
)
|
||||
|
||||
@ -334,7 +332,7 @@ class AQLMLinearMethod(LinearMethodBase):
|
||||
codes = layer.codes
|
||||
scales = layer.scales
|
||||
output_partition_sizes = getattr(codebooks, "output_partition_sizes",
|
||||
None)
|
||||
[])
|
||||
|
||||
nbooks = codes.shape[2]
|
||||
ingroups = codebooks.shape[3]
|
||||
|
@ -212,6 +212,7 @@ class GPTQLinearMethod(LinearMethodBase):
|
||||
layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int)
|
||||
else:
|
||||
layer.g_idx.data = torch.empty((0, ),
|
||||
dtype=torch.int,
|
||||
device=layer.g_idx.device)
|
||||
layer.exllama_state = ExllamaState.READY
|
||||
ops.gptq_shuffle(layer.qweight, layer.g_idx,
|
||||
|
Loading…
x
Reference in New Issue
Block a user