vllm/tests/kernels/test_permute_cols.py
Lucas Wilkinson 86e9c8df29
[Kernel] (2/N) Machete - Integrate into CompressedTensorsWNA16 and GPTQMarlin (#7701)
Co-authored-by: mgoin <michael@neuralmagic.com>
Co-authored-by: Divakar Verma <137818590+divakar-amd@users.noreply.github.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
2024-09-23 13:46:26 -04:00

15 lines
518 B
Python

import pytest
import torch
from tests.kernels.utils import opcheck
from vllm._custom_ops import permute_cols
@pytest.mark.parametrize('shape', [(1, 512), (544, 4096), (67, 8192)])
@pytest.mark.parametrize('dtype', [torch.bfloat16, torch.float16])
def test_permute_cols(shape, dtype):
x = torch.randn(shape, dtype=dtype).cuda()
perm = torch.randperm(x.shape[1]).to(torch.int).cuda()
opcheck(torch.ops._C.permute_cols, (x, perm))
y = permute_cols(x, perm)
torch.testing.assert_close(y, x[:, perm])