[Core]refactor aqlm quant ops (#4351)
This commit is contained in:
parent
bd7a8eef25
commit
f4bc4de1b1
@ -6,7 +6,7 @@ from typing import Optional
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from vllm._C import ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.model_executor.layers.quantization.aqlm import (
|
from vllm.model_executor.layers.quantization.aqlm import (
|
||||||
dequantize_weight, generic_dequantize_gemm, get_int_dtype,
|
dequantize_weight, generic_dequantize_gemm, get_int_dtype,
|
||||||
optimized_dequantize_gemm)
|
optimized_dequantize_gemm)
|
||||||
|
@ -153,6 +153,20 @@ def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
|
|||||||
size_n, size_k)
|
size_n, size_k)
|
||||||
|
|
||||||
|
|
||||||
|
# aqlm
|
||||||
|
def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor,
|
||||||
|
codebooks: torch.Tensor, scales: torch.Tensor,
|
||||||
|
codebook_partition_sizes: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||||
|
return vllm_ops.aqlm_gemm(input, codes, codebooks, scales,
|
||||||
|
codebook_partition_sizes, bias)
|
||||||
|
|
||||||
|
|
||||||
|
def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor,
|
||||||
|
codebook_partition_sizes: torch.Tensor) -> torch.Tensor:
|
||||||
|
return vllm_ops.aqlm_dequant(codes, codebooks, codebook_partition_sizes)
|
||||||
|
|
||||||
|
|
||||||
# fp8
|
# fp8
|
||||||
def scaled_fp8_quant(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
def scaled_fp8_quant(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
||||||
|
@ -8,7 +8,7 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from vllm._C import ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||||
set_weight_attrs)
|
set_weight_attrs)
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
|
Loading…
x
Reference in New Issue
Block a user