# SPDX-License-Identifier: Apache-2.0 import torch from tests.kernels.utils import opcheck from vllm import _custom_ops as ops # noqa: F401 def test_aqlm_dequant_opcheck(): codes = torch.randint(-32768, 32767, (22016, 512, 1), device='cuda', dtype=torch.int16) codebooks = torch.rand((2, 65536, 1, 8), device='cuda', dtype=torch.float16) codebook_partition_sizes = [11008, 11008] opcheck(torch.ops._C.aqlm_dequant, (codes, codebooks, codebook_partition_sizes)) def test_aqlm_gemm_opcheck(): input = torch.rand((4, 4096), device='cuda', dtype=torch.float16) codes = torch.randint(-32768, 32767, (12288, 512, 1), device='cuda', dtype=torch.int16) codebooks = torch.rand((3, 65536, 1, 8), device='cuda', dtype=torch.float16) scales = torch.rand((12288, 1, 1, 1), device='cuda', dtype=torch.float16) codebook_partition_sizes = [4096, 4096, 4096] bias = None opcheck(torch.ops._C.aqlm_gemm, (input, codes, codebooks, scales, codebook_partition_sizes, None)) opcheck(torch.ops._C.aqlm_gemm, (input, codes, codebooks, scales, codebook_partition_sizes, bias))