[Kernel] support non-zero cuda devices in punica kernels (#3636)

This commit is contained in:
Jee Li 2024-03-27 08:37:42 +08:00 committed by GitHub
parent 0dc72273b8
commit 566b57c5c4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 26 additions and 59 deletions

View File

@ -1,7 +1,7 @@
#include <cuda_bf16.h> #include <cuda_bf16.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cstdint> #include <cstdint>
#include "bgmv/bgmv_config.h" #include "bgmv/bgmv_config.h"
@ -91,6 +91,7 @@ void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
CHECK_EQ(w.size(2), h_out); CHECK_EQ(w.size(2), h_out);
CHECK_EQ(indicies.size(0), x.size(0)); CHECK_EQ(indicies.size(0), x.size(0));
CHECK_EQ(y.size(0), x.size(0)); CHECK_EQ(y.size(0), x.size(0));
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
bool ok = false; bool ok = false;
if (h_in < 65536 && h_out < 65536) { if (h_in < 65536 && h_out < 65536) {
// TODO: See if we can get rid of this massive nested switch // TODO: See if we can get rid of this massive nested switch
@ -322,6 +323,7 @@ void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
CHECK_EQ(w.size(2), h_out); CHECK_EQ(w.size(2), h_out);
CHECK_EQ(indicies.size(0), x.size(0)); CHECK_EQ(indicies.size(0), x.size(0));
CHECK_EQ(y.size(0), x.size(0)); CHECK_EQ(y.size(0), x.size(0));
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
bool ok = false; bool ok = false;
if (h_in < 65536 && h_out < 65536) { if (h_in < 65536 && h_out < 65536) {
// TODO: See if we can get rid of this massive nested switch // TODO: See if we can get rid of this massive nested switch

View File

@ -49,14 +49,18 @@ H1 = H2 = [
32768, 33024 32768, 33024
] ]
SEED = [0xabcdabcd987] SEED = [0xabcdabcd987]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
@pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"]) @pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"])
@pytest.mark.parametrize("h1", H1) @pytest.mark.parametrize("h1", H1)
@pytest.mark.parametrize("h2", H2) @pytest.mark.parametrize("h2", H2)
@pytest.mark.parametrize("seed", SEED) @pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode() @torch.inference_mode()
def test_lora_correctness(dtype_str, h1, h2, seed): def test_lora_correctness(dtype_str, h1, h2, seed, device):
torch.manual_seed(seed) torch.manual_seed(seed)
num_loras = 4 num_loras = 4
num_layers = 1 num_layers = 1
@ -64,25 +68,15 @@ def test_lora_correctness(dtype_str, h1, h2, seed):
bs = 32 bs = 32
scale = 0.123 scale = 0.123
dtype = getattr(torch, dtype_str) dtype = getattr(torch, dtype_str)
device = torch.device("cuda") torch.set_default_device(device)
wa_T_all = torch.randn(num_loras, wa_T_all = torch.randn(num_loras, num_layers, r, h1, dtype=dtype)
num_layers, wb_T_all = torch.randn(num_loras, num_layers, h2, r, dtype=dtype)
r, indices = torch.randint(num_loras, (bs, ), dtype=torch.long)
h1,
dtype=dtype,
device=device)
wb_T_all = torch.randn(num_loras,
num_layers,
h2,
r,
dtype=dtype,
device=device)
indices = torch.randint(num_loras, (bs, ), dtype=torch.long, device=device)
for layer_idx in range(num_layers): for layer_idx in range(num_layers):
x = torch.randn(bs, h1, dtype=dtype, device=device) x = torch.randn(bs, h1, dtype=dtype)
y = torch.randn(bs, h2, dtype=dtype, device=device) y = torch.randn(bs, h2, dtype=dtype)
y_ref = y.clone() y_ref = y.clone()
_lora_ref_impl(y_ref, x, wa_T_all, wb_T_all, indices, layer_idx, scale) _lora_ref_impl(y_ref, x, wa_T_all, wb_T_all, indices, layer_idx, scale)
@ -98,8 +92,9 @@ def test_lora_correctness(dtype_str, h1, h2, seed):
@pytest.mark.parametrize("h1", H1) @pytest.mark.parametrize("h1", H1)
@pytest.mark.parametrize("h2", H2) @pytest.mark.parametrize("h2", H2)
@pytest.mark.parametrize("seed", SEED) @pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode() @torch.inference_mode()
def test_lora_correctness_slice(dtype_str, h1, h2, seed): def test_lora_correctness_slice(dtype_str, h1, h2, seed, device):
if h2 % 3 != 0 or h2 // 3 not in H1: if h2 % 3 != 0 or h2 // 3 not in H1:
pytest.skip("h2 must be divisible by 3 and in supported shapes") pytest.skip("h2 must be divisible by 3 and in supported shapes")
torch.manual_seed(seed) torch.manual_seed(seed)
@ -109,50 +104,20 @@ def test_lora_correctness_slice(dtype_str, h1, h2, seed):
bs = 32 bs = 32
scale = 0.123 scale = 0.123
dtype = getattr(torch, dtype_str) dtype = getattr(torch, dtype_str)
device = torch.device("cuda") torch.set_default_device(device)
wa_T_all_0 = torch.randn(num_loras, wa_T_all_0 = torch.randn(num_loras, num_layers, r, h1, dtype=dtype)
num_layers, wa_T_all_1 = torch.randn(num_loras, num_layers, r, h1, dtype=dtype)
r, wa_T_all_2 = torch.randn(num_loras, num_layers, r, h1, dtype=dtype)
h1, wb_T_all_0 = torch.randn(num_loras, num_layers, h2 // 3, r, dtype=dtype)
dtype=dtype, wb_T_all_1 = torch.randn(num_loras, num_layers, h2 // 3, r, dtype=dtype)
device=device) wb_T_all_2 = torch.randn(num_loras, num_layers, h2 // 3, r, dtype=dtype)
wa_T_all_1 = torch.randn(num_loras,
num_layers,
r,
h1,
dtype=dtype,
device=device)
wa_T_all_2 = torch.randn(num_loras,
num_layers,
r,
h1,
dtype=dtype,
device=device)
wb_T_all_0 = torch.randn(num_loras,
num_layers,
h2 // 3,
r,
dtype=dtype,
device=device)
wb_T_all_1 = torch.randn(num_loras,
num_layers,
h2 // 3,
r,
dtype=dtype,
device=device)
wb_T_all_2 = torch.randn(num_loras,
num_layers,
h2 // 3,
r,
dtype=dtype,
device=device)
indices = torch.randint(num_loras, (bs, ), dtype=torch.long, device=device) indices = torch.randint(num_loras, (bs, ), dtype=torch.long)
for layer_idx in range(num_layers): for layer_idx in range(num_layers):
x = torch.randn(bs, h1, dtype=dtype, device=device) x = torch.randn(bs, h1, dtype=dtype)
y = torch.randn(bs, h2, dtype=dtype, device=device) y = torch.randn(bs, h2, dtype=dtype)
s = h2 // 3 s = h2 // 3
y_ref = y.clone() y_ref = y.clone()