[Kernel] support non-zero cuda devices in punica kernels (#3636)
This commit is contained in:
parent
0dc72273b8
commit
566b57c5c4
@ -1,7 +1,7 @@
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cstdint>
|
||||
|
||||
#include "bgmv/bgmv_config.h"
|
||||
@ -91,6 +91,7 @@ void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
|
||||
CHECK_EQ(w.size(2), h_out);
|
||||
CHECK_EQ(indicies.size(0), x.size(0));
|
||||
CHECK_EQ(y.size(0), x.size(0));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||
bool ok = false;
|
||||
if (h_in < 65536 && h_out < 65536) {
|
||||
// TODO: See if we can get rid of this massive nested switch
|
||||
@ -322,6 +323,7 @@ void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
|
||||
CHECK_EQ(w.size(2), h_out);
|
||||
CHECK_EQ(indicies.size(0), x.size(0));
|
||||
CHECK_EQ(y.size(0), x.size(0));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||
bool ok = false;
|
||||
if (h_in < 65536 && h_out < 65536) {
|
||||
// TODO: See if we can get rid of this massive nested switch
|
||||
|
@ -49,14 +49,18 @@ H1 = H2 = [
|
||||
32768, 33024
|
||||
]
|
||||
SEED = [0xabcdabcd987]
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"])
|
||||
@pytest.mark.parametrize("h1", H1)
|
||||
@pytest.mark.parametrize("h2", H2)
|
||||
@pytest.mark.parametrize("seed", SEED)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_lora_correctness(dtype_str, h1, h2, seed):
|
||||
def test_lora_correctness(dtype_str, h1, h2, seed, device):
|
||||
torch.manual_seed(seed)
|
||||
num_loras = 4
|
||||
num_layers = 1
|
||||
@ -64,25 +68,15 @@ def test_lora_correctness(dtype_str, h1, h2, seed):
|
||||
bs = 32
|
||||
scale = 0.123
|
||||
dtype = getattr(torch, dtype_str)
|
||||
device = torch.device("cuda")
|
||||
torch.set_default_device(device)
|
||||
|
||||
wa_T_all = torch.randn(num_loras,
|
||||
num_layers,
|
||||
r,
|
||||
h1,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
wb_T_all = torch.randn(num_loras,
|
||||
num_layers,
|
||||
h2,
|
||||
r,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
indices = torch.randint(num_loras, (bs, ), dtype=torch.long, device=device)
|
||||
wa_T_all = torch.randn(num_loras, num_layers, r, h1, dtype=dtype)
|
||||
wb_T_all = torch.randn(num_loras, num_layers, h2, r, dtype=dtype)
|
||||
indices = torch.randint(num_loras, (bs, ), dtype=torch.long)
|
||||
|
||||
for layer_idx in range(num_layers):
|
||||
x = torch.randn(bs, h1, dtype=dtype, device=device)
|
||||
y = torch.randn(bs, h2, dtype=dtype, device=device)
|
||||
x = torch.randn(bs, h1, dtype=dtype)
|
||||
y = torch.randn(bs, h2, dtype=dtype)
|
||||
|
||||
y_ref = y.clone()
|
||||
_lora_ref_impl(y_ref, x, wa_T_all, wb_T_all, indices, layer_idx, scale)
|
||||
@ -98,8 +92,9 @@ def test_lora_correctness(dtype_str, h1, h2, seed):
|
||||
@pytest.mark.parametrize("h1", H1)
|
||||
@pytest.mark.parametrize("h2", H2)
|
||||
@pytest.mark.parametrize("seed", SEED)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_lora_correctness_slice(dtype_str, h1, h2, seed):
|
||||
def test_lora_correctness_slice(dtype_str, h1, h2, seed, device):
|
||||
if h2 % 3 != 0 or h2 // 3 not in H1:
|
||||
pytest.skip("h2 must be divisible by 3 and in supported shapes")
|
||||
torch.manual_seed(seed)
|
||||
@ -109,50 +104,20 @@ def test_lora_correctness_slice(dtype_str, h1, h2, seed):
|
||||
bs = 32
|
||||
scale = 0.123
|
||||
dtype = getattr(torch, dtype_str)
|
||||
device = torch.device("cuda")
|
||||
torch.set_default_device(device)
|
||||
|
||||
wa_T_all_0 = torch.randn(num_loras,
|
||||
num_layers,
|
||||
r,
|
||||
h1,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
wa_T_all_1 = torch.randn(num_loras,
|
||||
num_layers,
|
||||
r,
|
||||
h1,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
wa_T_all_2 = torch.randn(num_loras,
|
||||
num_layers,
|
||||
r,
|
||||
h1,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
wb_T_all_0 = torch.randn(num_loras,
|
||||
num_layers,
|
||||
h2 // 3,
|
||||
r,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
wb_T_all_1 = torch.randn(num_loras,
|
||||
num_layers,
|
||||
h2 // 3,
|
||||
r,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
wb_T_all_2 = torch.randn(num_loras,
|
||||
num_layers,
|
||||
h2 // 3,
|
||||
r,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
wa_T_all_0 = torch.randn(num_loras, num_layers, r, h1, dtype=dtype)
|
||||
wa_T_all_1 = torch.randn(num_loras, num_layers, r, h1, dtype=dtype)
|
||||
wa_T_all_2 = torch.randn(num_loras, num_layers, r, h1, dtype=dtype)
|
||||
wb_T_all_0 = torch.randn(num_loras, num_layers, h2 // 3, r, dtype=dtype)
|
||||
wb_T_all_1 = torch.randn(num_loras, num_layers, h2 // 3, r, dtype=dtype)
|
||||
wb_T_all_2 = torch.randn(num_loras, num_layers, h2 // 3, r, dtype=dtype)
|
||||
|
||||
indices = torch.randint(num_loras, (bs, ), dtype=torch.long, device=device)
|
||||
indices = torch.randint(num_loras, (bs, ), dtype=torch.long)
|
||||
|
||||
for layer_idx in range(num_layers):
|
||||
x = torch.randn(bs, h1, dtype=dtype, device=device)
|
||||
y = torch.randn(bs, h2, dtype=dtype, device=device)
|
||||
x = torch.randn(bs, h1, dtype=dtype)
|
||||
y = torch.randn(bs, h2, dtype=dtype)
|
||||
s = h2 // 3
|
||||
|
||||
y_ref = y.clone()
|
||||
|
Loading…
x
Reference in New Issue
Block a user