[Kernel] fix types used in aqlm and ggml kernels to support dynamo (#7596)

This commit is contained in:
bnellnm 2024-08-16 17:00:11 -04:00 committed by GitHub
parent 7759ae958f
commit 37fd47e780
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 39 additions and 53 deletions

View File

@ -63,12 +63,12 @@ void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size,
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes, torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
const torch::Tensor& codebooks, const torch::Tensor& codebooks,
const torch::Tensor& scales, const torch::Tensor& scales,
const torch::Tensor& codebook_partition_sizes, const std::vector<int64_t>& codebook_partition_sizes,
const std::optional<torch::Tensor>& bias); const std::optional<torch::Tensor>& bias);
torch::Tensor aqlm_dequant(const torch::Tensor& codes, torch::Tensor aqlm_dequant(
const torch::Tensor& codebooks, const torch::Tensor& codes, const torch::Tensor& codebooks,
const torch::Tensor& codebook_partition_sizes); const std::vector<int64_t>& codebook_partition_sizes);
torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel, torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
torch::Tensor _scaling_factors, torch::Tensor _zeros, torch::Tensor _scaling_factors, torch::Tensor _zeros,
@ -107,13 +107,13 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
int64_t size_n, int64_t num_bits); int64_t size_n, int64_t num_bits);
torch::Tensor ggml_dequantize(torch::Tensor W, int8_t type, int64_t m, torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m,
int64_t n); int64_t n);
torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X, int8_t type, torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X,
int64_t row); int64_t type, int64_t row);
torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int8_t type, torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type,
int64_t row); int64_t row);
torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,

View File

@ -496,14 +496,14 @@ torch::Tensor code2x8_matmat(const torch::Tensor& input,
} }
// Accumulate the partition sizes. // Accumulate the partition sizes.
int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes) { int4 accumulate_sizes(const std::vector<int64_t>& codebook_partition_sizes) {
int4 cumulative_sizes; int4 cumulative_sizes;
auto cumulative_size = &cumulative_sizes.x; auto cumulative_size = &cumulative_sizes.x;
int i = 0; size_t i = 0;
int last = 0; int last = 0;
assert(codebook_partition_sizes.size(0) <= 4); assert(codebook_partition_sizes.size() <= 4);
for (; i < codebook_partition_sizes.size(0); ++i, ++cumulative_size) { for (; i < codebook_partition_sizes.size(); ++i, ++cumulative_size) {
*cumulative_size = codebook_partition_sizes[i].item<int>() + last; *cumulative_size = codebook_partition_sizes[i] + last;
last = *cumulative_size; last = *cumulative_size;
} }
// fill in the rest with unreachable. // fill in the rest with unreachable.
@ -519,12 +519,12 @@ int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes) {
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes, torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
const torch::Tensor& codebooks, const torch::Tensor& codebooks,
const torch::Tensor& scales, const torch::Tensor& scales,
const torch::Tensor& codebook_partition_sizes, const std::vector<int64_t>& codebook_partition_sizes,
const std::optional<torch::Tensor>& bias) { const std::optional<torch::Tensor>& bias) {
int4 cumulative_sizes = int4 cumulative_sizes =
vllm::aqlm::accumulate_sizes(codebook_partition_sizes); vllm::aqlm::accumulate_sizes(codebook_partition_sizes);
int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0); int const nbooks = codebooks.size(0) / codebook_partition_sizes.size();
int const entries = codebooks.size(1); int const entries = codebooks.size(1);
if (nbooks == 1 && entries == (1 << 16)) { if (nbooks == 1 && entries == (1 << 16)) {
@ -541,13 +541,13 @@ torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
return {}; return {};
} }
torch::Tensor aqlm_dequant(const torch::Tensor& codes, torch::Tensor aqlm_dequant(
const torch::Tensor& codebooks, const torch::Tensor& codes, const torch::Tensor& codebooks,
const torch::Tensor& codebook_partition_sizes) { const std::vector<int64_t>& codebook_partition_sizes) {
int4 cumulative_sizes = int4 cumulative_sizes =
vllm::aqlm::accumulate_sizes(codebook_partition_sizes); vllm::aqlm::accumulate_sizes(codebook_partition_sizes);
int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0); int const nbooks = codebooks.size(0) / codebook_partition_sizes.size();
int const entries = codebooks.size(1); int const entries = codebooks.size(1);
const at::cuda::OptionalCUDAGuard device_guard(device_of(codes)); const at::cuda::OptionalCUDAGuard device_guard(device_of(codes));
@ -557,7 +557,8 @@ torch::Tensor aqlm_dequant(const torch::Tensor& codes,
auto in_features = codes.size(1) * 8; auto in_features = codes.size(1) * 8;
auto out_features = codes.size(0); auto out_features = codes.size(0);
assert(out_features = codebook_partition_sizes.sum().item<int>()); assert(out_features == std::accumulate(codebook_partition_sizes.begin(),
codebook_partition_sizes.end(), 0));
auto weights = torch::empty({out_features, in_features}, auto weights = torch::empty({out_features, in_features},
torch::TensorOptions() torch::TensorOptions()

View File

@ -487,7 +487,7 @@ static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int k,
dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y); dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y);
} }
static to_fp16_cuda_t ggml_get_to_fp16_cuda(int type) { static to_fp16_cuda_t ggml_get_to_fp16_cuda(int64_t type) {
switch (type) { switch (type) {
case 2: case 2:
return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>; return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;

View File

@ -60,7 +60,7 @@ static void quantize_row_q8_1_cuda(const half* x, void* vy, const int kx,
} }
torch::Tensor ggml_dequantize(torch::Tensor W, // quant weight torch::Tensor ggml_dequantize(torch::Tensor W, // quant weight
int8_t type, int64_t m, int64_t n) { int64_t type, int64_t m, int64_t n) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(W)); const at::cuda::OptionalCUDAGuard device_guard(device_of(W));
auto options = auto options =
torch::TensorOptions().dtype(torch::kFloat16).device(W.device()); torch::TensorOptions().dtype(torch::kFloat16).device(W.device());
@ -73,7 +73,7 @@ torch::Tensor ggml_dequantize(torch::Tensor W, // quant weight
torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, // quant weight torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, // quant weight
torch::Tensor X, // input torch::Tensor X, // input
int8_t type, int64_t row) { int64_t type, int64_t row) {
int col = X.sizes()[1]; int col = X.sizes()[1];
const int padded = (col + 512 - 1) / 512 * 512; const int padded = (col + 512 - 1) / 512 * 512;
const at::cuda::OptionalCUDAGuard device_guard(device_of(X)); const at::cuda::OptionalCUDAGuard device_guard(device_of(X));
@ -172,7 +172,7 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, // quant weight
torch::Tensor ggml_mul_mat_a8(torch::Tensor W, // quant weight torch::Tensor ggml_mul_mat_a8(torch::Tensor W, // quant weight
torch::Tensor X, // input torch::Tensor X, // input
int8_t type, int64_t row) { int64_t type, int64_t row) {
int col = X.sizes()[1]; int col = X.sizes()[1];
int padded = (col + 512 - 1) / 512 * 512; int padded = (col + 512 - 1) / 512 * 512;
int batch = X.sizes()[0]; int batch = X.sizes()[0];
@ -239,4 +239,4 @@ torch::Tensor ggml_mul_mat_a8(torch::Tensor W, // quant weight
break; break;
} }
return Y; return Y;
} }

View File

@ -17,13 +17,7 @@ if not current_platform.is_tpu():
logger.warning("Failed to import from vllm._C with %r", e) logger.warning("Failed to import from vllm._C with %r", e)
with contextlib.suppress(ImportError): with contextlib.suppress(ImportError):
# ruff: noqa: F401 import vllm._moe_C # noqa: F401
import vllm._moe_C
def is_custom_op_supported(op_name: str) -> bool:
op, overloads = torch._C._jit_get_operation(op_name)
return op is not None
def hint_on_error(fn): def hint_on_error(fn):
@ -280,14 +274,14 @@ def cutlass_scaled_mm_azp(a: torch.Tensor,
# aqlm # aqlm
def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor, def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor,
codebooks: torch.Tensor, scales: torch.Tensor, codebooks: torch.Tensor, scales: torch.Tensor,
codebook_partition_sizes: torch.Tensor, codebook_partition_sizes: List[int],
bias: Optional[torch.Tensor]) -> torch.Tensor: bias: Optional[torch.Tensor]) -> torch.Tensor:
return torch.ops._C.aqlm_gemm(input, codes, codebooks, scales, return torch.ops._C.aqlm_gemm(input, codes, codebooks, scales,
codebook_partition_sizes, bias) codebook_partition_sizes, bias)
def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor, def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor,
codebook_partition_sizes: torch.Tensor) -> torch.Tensor: codebook_partition_sizes: List[int]) -> torch.Tensor:
return torch.ops._C.aqlm_dequant(codes, codebooks, return torch.ops._C.aqlm_dequant(codes, codebooks,
codebook_partition_sizes) codebook_partition_sizes)
@ -434,25 +428,17 @@ def marlin_qqq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
# gguf # gguf
def ggml_dequantize(W: torch.Tensor, quant_type: int, m: int, n: int): def ggml_dequantize(W: torch.Tensor, quant_type: int, m: int,
n: int) -> torch.Tensor:
return torch.ops._C.ggml_dequantize(W, quant_type, m, n) return torch.ops._C.ggml_dequantize(W, quant_type, m, n)
def ggml_mul_mat_vec(
W: torch.Tensor,
X: torch.Tensor,
quant_type: int,
row: int,
):
return torch.ops._C.ggml_mul_mat_vec(W, X, quant_type, row)
def ggml_mul_mat_vec_a8( def ggml_mul_mat_vec_a8(
W: torch.Tensor, W: torch.Tensor,
X: torch.Tensor, X: torch.Tensor,
quant_type: int, quant_type: int,
row: int, row: int,
): ) -> torch.Tensor:
return torch.ops._C.ggml_mul_mat_vec_a8(W, X, quant_type, row) return torch.ops._C.ggml_mul_mat_vec_a8(W, X, quant_type, row)
@ -461,7 +447,7 @@ def ggml_mul_mat_a8(
X: torch.Tensor, X: torch.Tensor,
quant_type: int, quant_type: int,
row: int, row: int,
): ) -> torch.Tensor:
return torch.ops._C.ggml_mul_mat_a8(W, X, quant_type, row) return torch.ops._C.ggml_mul_mat_a8(W, X, quant_type, row)

View File

@ -95,7 +95,7 @@ def generic_dequantize_gemm(
codebooks: torch. codebooks: torch.
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
scales: torch.Tensor, # [num_out_groups, 1, 1, 1] scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
output_partition_sizes: torch.IntTensor, output_partition_sizes: List[int],
bias: Optional[torch.Tensor], bias: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
output_shape = input.shape[:-1] + (scales.shape[0], ) output_shape = input.shape[:-1] + (scales.shape[0], )
@ -133,7 +133,7 @@ def optimized_dequantize_gemm(
codebooks: torch. codebooks: torch.
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
scales: torch.Tensor, # [num_out_groups, 1, 1, 1] scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
output_partition_sizes: torch.IntTensor, output_partition_sizes: List[int],
bias: Optional[torch.Tensor], bias: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes) weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
@ -288,10 +288,8 @@ class AQLMLinearMethod(LinearMethodBase):
codebooks, codebooks,
{ {
# metadata indicates fixed size concatenated along dim 0 # metadata indicates fixed size concatenated along dim 0
"is_metadata": "is_metadata": True,
True, "output_partition_sizes": output_partition_sizes
"output_partition_sizes":
torch.tensor(output_partition_sizes, device='cpu'),
}, },
) )
@ -334,7 +332,7 @@ class AQLMLinearMethod(LinearMethodBase):
codes = layer.codes codes = layer.codes
scales = layer.scales scales = layer.scales
output_partition_sizes = getattr(codebooks, "output_partition_sizes", output_partition_sizes = getattr(codebooks, "output_partition_sizes",
None) [])
nbooks = codes.shape[2] nbooks = codes.shape[2]
ingroups = codebooks.shape[3] ingroups = codebooks.shape[3]

View File

@ -212,6 +212,7 @@ class GPTQLinearMethod(LinearMethodBase):
layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int) layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int)
else: else:
layer.g_idx.data = torch.empty((0, ), layer.g_idx.data = torch.empty((0, ),
dtype=torch.int,
device=layer.g_idx.device) device=layer.g_idx.device)
layer.exllama_state = ExllamaState.READY layer.exllama_state = ExllamaState.READY
ops.gptq_shuffle(layer.qweight, layer.g_idx, ops.gptq_shuffle(layer.qweight, layer.g_idx,