vllm/tests/lora/test_punica.py
Antoni Baum 9b945daaf1
[Experimental] Add multi-LoRA support (#1804)
Co-authored-by: Chen Shen <scv119@gmail.com>
Co-authored-by: Shreyas Krishnaswamy <shrekris@anyscale.com>
Co-authored-by: Avnish Narayan <avnish@anyscale.com>
2024-01-23 15:26:37 -08:00

176 lines
5.8 KiB
Python

# Based on code from https://github.com/punica-ai/punica
import pytest
import torch
import vllm.lora.punica as punica
def assert_close(a, b):
rtol, atol = {
torch.float16: (5e-3, 5e-3),
torch.bfloat16: (3e-2, 2e-2),
torch.float32: (None, None),
}[a.dtype]
torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
def _lora_ref_impl(
y_final: torch.Tensor,
x: torch.Tensor,
wa_T_all: torch.Tensor,
wb_T_all: torch.Tensor,
indicies: torch.LongTensor,
layer_idx: int,
scale: float,
):
y_stage_1 = torch.empty(
(x.size(0), wa_T_all.size(-2)),
dtype=torch.float32,
device=x.device,
)
bs = x.shape[0]
s = torch.tensor(scale, dtype=torch.float32, device=x.device)
for i, lora_idx in zip(range(bs), indicies.cpu().tolist()):
xi = x[i].unsqueeze(0).to(torch.float32)
wa = wa_T_all[lora_idx, layer_idx].transpose(-1, -2).to(torch.float32)
wb = wb_T_all[lora_idx, layer_idx].transpose(-1, -2).to(torch.float32)
tmp = xi @ wa
y_stage_1[i] = tmp.squeeze(0)
y_final[i] += (tmp @ wb).squeeze(0) * s
return y_final, y_stage_1
H1 = H2 = [
128, 256, 512, 1024, 1280, 2048, 2560, 2752, 3072, 3456, 3584, 4096, 5120,
5504, 5632, 6912, 7168, 8192, 9216, 10240, 11008, 13824, 14336, 32000,
32256, 32512, 32768, 33024
]
SEED = [0xabcdabcd987]
@pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"])
@pytest.mark.parametrize("h1", H1)
@pytest.mark.parametrize("h2", H2)
@pytest.mark.parametrize("seed", SEED)
@torch.inference_mode()
def test_lora_correctness(dtype_str, h1, h2, seed):
torch.manual_seed(seed)
num_loras = 4
num_layers = 1
r = 8
bs = 32
scale = 0.123
dtype = getattr(torch, dtype_str)
device = torch.device("cuda")
wa_T_all = torch.randn(num_loras,
num_layers,
r,
h1,
dtype=dtype,
device=device)
wb_T_all = torch.randn(num_loras,
num_layers,
h2,
r,
dtype=dtype,
device=device)
indices = torch.randint(num_loras, (bs, ), dtype=torch.long, device=device)
for layer_idx in range(num_layers):
x = torch.randn(bs, h1, dtype=dtype, device=device)
y = torch.randn(bs, h2, dtype=dtype, device=device)
y_ref = y.clone()
_lora_ref_impl(y_ref, x, wa_T_all, wb_T_all, indices, layer_idx, scale)
y_our = y.clone()
punica.add_lora(y_our, x, wa_T_all, wb_T_all, indices, layer_idx,
scale)
assert_close(y_ref, y_our)
@pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"])
@pytest.mark.parametrize("h1", H1)
@pytest.mark.parametrize("h2", H2)
@pytest.mark.parametrize("seed", SEED)
@torch.inference_mode()
def test_lora_correctness_slice(dtype_str, h1, h2, seed):
if h2 % 3 != 0 or h2 // 3 not in H1:
pytest.skip("h2 must be divisible by 3 and in supported shapes")
torch.manual_seed(seed)
num_loras = 4
num_layers = 1
r = 8
bs = 32
scale = 0.123
dtype = getattr(torch, dtype_str)
device = torch.device("cuda")
wa_T_all_0 = torch.randn(num_loras,
num_layers,
r,
h1,
dtype=dtype,
device=device)
wa_T_all_1 = torch.randn(num_loras,
num_layers,
r,
h1,
dtype=dtype,
device=device)
wa_T_all_2 = torch.randn(num_loras,
num_layers,
r,
h1,
dtype=dtype,
device=device)
wb_T_all_0 = torch.randn(num_loras,
num_layers,
h2 // 3,
r,
dtype=dtype,
device=device)
wb_T_all_1 = torch.randn(num_loras,
num_layers,
h2 // 3,
r,
dtype=dtype,
device=device)
wb_T_all_2 = torch.randn(num_loras,
num_layers,
h2 // 3,
r,
dtype=dtype,
device=device)
indices = torch.randint(num_loras, (bs, ), dtype=torch.long, device=device)
for layer_idx in range(num_layers):
x = torch.randn(bs, h1, dtype=dtype, device=device)
y = torch.randn(bs, h2, dtype=dtype, device=device)
s = h2 // 3
y_ref = y.clone()
_lora_ref_impl(y_ref[:, :s], x, wa_T_all_0, wb_T_all_0, indices,
layer_idx, scale)
_lora_ref_impl(y_ref[:, s:s * 2], x, wa_T_all_1, wb_T_all_1, indices,
layer_idx, scale)
_lora_ref_impl(y_ref[:, s * 2:], x, wa_T_all_2, wb_T_all_2, indices,
layer_idx, scale)
y_our = y.clone()
punica.add_lora_slice(y_our, x, wa_T_all_0, wb_T_all_0, indices,
layer_idx, scale, 0, s)
punica.add_lora_slice(y_our, x, wa_T_all_1, wb_T_all_1, indices,
layer_idx, scale, s, s)
punica.add_lora_slice(y_our, x, wa_T_all_2, wb_T_all_2, indices,
layer_idx, scale, s * 2, s)
assert_close(y_ref[:, :s], y_our[:, :s])
assert_close(y_ref[:, s:s * 2], y_our[:, s:s * 2])
assert_close(y_ref[:, s * 2:], y_our[:, s * 2:])