[Core]refactor aqlm quant ops (#4351)
This commit is contained in:
parent
bd7a8eef25
commit
f4bc4de1b1
@ -6,7 +6,7 @@ from typing import Optional
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm._C import ops
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.aqlm import (
|
||||
dequantize_weight, generic_dequantize_gemm, get_int_dtype,
|
||||
optimized_dequantize_gemm)
|
||||
|
@ -153,6 +153,20 @@ def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
|
||||
size_n, size_k)
|
||||
|
||||
|
||||
# aqlm
|
||||
def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor,
|
||||
codebooks: torch.Tensor, scales: torch.Tensor,
|
||||
codebook_partition_sizes: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
return vllm_ops.aqlm_gemm(input, codes, codebooks, scales,
|
||||
codebook_partition_sizes, bias)
|
||||
|
||||
|
||||
def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor,
|
||||
codebook_partition_sizes: torch.Tensor) -> torch.Tensor:
|
||||
return vllm_ops.aqlm_dequant(codes, codebooks, codebook_partition_sizes)
|
||||
|
||||
|
||||
# fp8
|
||||
def scaled_fp8_quant(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
||||
|
@ -8,7 +8,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm._C import ops
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
|
Loading…
x
Reference in New Issue
Block a user