vllm/tests/lora/test_punica_ops.py
2025-03-02 17:34:51 -08:00

652 lines
19 KiB
Python

# SPDX-License-Identifier: Apache-2.0
from threading import Lock
import pytest
import torch
import vllm.lora.ops.triton_ops # noqa: F401
from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice,
bgmv_shrink, sgmv_expand,
sgmv_expand_slice, sgmv_shrink)
from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
from vllm.platforms import current_platform
from .utils import (PunicaTensors, assert_close, generate_data,
generate_data_for_expand_nslices,
generate_data_for_nslices)
# Utility shrink and expand operations used as reference implementations.
def sgmv_shrink_for_nslices(
nslices: int, inputs_tensor: torch.Tensor,
lora_weights_lst: list[torch.Tensor], out_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor, seq_len_tensor: torch.Tensor,
prompt_lora_mapping: torch.Tensor, batches: int, max_seq_length: int,
num_tokens: int, scaling: float):
"""
Wrapper around sgmv_shrink that handles any nslices.
"""
for index in range(nslices):
sgmv_shrink(
inputs_tensor,
lora_weights_lst[index],
out_tensor[index],
b_seq_start_loc,
seq_len_tensor,
prompt_lora_mapping,
batches,
max_seq_length,
num_tokens,
scaling,
)
def sgmv_expand_for_nslices(nslices: int, hidden_size: int,
inputs_tensor: torch.Tensor,
lora_weights_lst: list[torch.Tensor],
out_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
prompt_lora_mapping: torch.Tensor, batches: int,
max_seq_length: int, num_tokens: int,
add_inputs: bool) -> None:
"""
Wrapper around sgmv_expand that handles any nslices.
"""
if nslices == 1:
# Verify the torch's sgmv_expand op
sgmv_expand(
inputs_tensor[0],
lora_weights_lst[0],
out_tensor,
b_seq_start_loc,
seq_len_tensor,
prompt_lora_mapping,
batches,
max_seq_length,
num_tokens,
add_inputs=add_inputs,
)
else:
slice_offset = 0
for index in range(nslices):
lora_weights = lora_weights_lst[index]
sgmv_expand_slice(
inputs_tensor[index],
lora_weights,
out_tensor,
b_seq_start_loc,
seq_len_tensor,
prompt_lora_mapping,
batches,
max_seq_length,
num_tokens,
slice_offset,
hidden_size,
add_inputs=add_inputs,
)
slice_offset += hidden_size
_dict_lock = Lock()
def check_sgmv_shrink(batches: int, num_loras: int, rank: int,
hidden_size: int, nslices: int, dtype: torch.dtype,
device: str, seq_length: int, scaling: float):
"""
Compare outputs of vllm.sgmv_shrink kernel against a reference
implementation.
"""
data: PunicaTensors = generate_data_for_nslices(
batches,
hidden_size,
num_loras,
rank,
seq_length,
nslices,
dtype,
"shrink",
device,
)
max_seq_length, token_nums = data.meta()
# Preventing cache error pointer.
with _dict_lock:
_LORA_A_PTR_DICT.clear()
torch.ops.vllm.sgmv_shrink(
data.inputs_tensor,
data.lora_weights,
data.our_out_tensor,
data.b_seq_start_loc,
data.seq_len_tensor,
data.prompt_lora_mapping,
batches,
max_seq_length,
token_nums,
scaling,
)
sgmv_shrink_for_nslices(
nslices,
data.inputs_tensor,
data.lora_weights,
data.ref_out_tensor,
data.b_seq_start_loc,
data.seq_len_tensor,
data.prompt_lora_mapping,
batches,
max_seq_length,
token_nums,
scaling,
)
assert_close(data.our_out_tensor, data.ref_out_tensor)
def check_sgmv_expand(batches: int, num_loras: int, rank: int,
hidden_size: int, nslices: int, dtype: torch.dtype,
device: str, seq_length: int, add_inputs: bool):
"""
Compare outputs of vllm.sgmv_expand kernel against a reference
implementation.
"""
data: PunicaTensors = generate_data_for_nslices(
batches,
hidden_size,
num_loras,
rank,
seq_length,
nslices,
dtype,
"expand",
device,
)
max_seq_length, token_nums = data.meta()
with _dict_lock:
_LORA_B_PTR_DICT.clear()
torch.ops.vllm.sgmv_expand(
data.inputs_tensor,
data.lora_weights,
data.our_out_tensor,
data.b_seq_start_loc,
data.seq_len_tensor,
data.prompt_lora_mapping,
batches,
max_seq_length,
token_nums,
offset_start=0,
add_inputs=add_inputs,
)
sgmv_expand_for_nslices(nslices,
hidden_size,
data.inputs_tensor,
data.lora_weights,
data.ref_out_tensor,
data.b_seq_start_loc,
data.seq_len_tensor,
data.prompt_lora_mapping,
batches,
max_seq_length,
token_nums,
add_inputs=add_inputs)
assert_close(data.our_out_tensor, data.ref_out_tensor)
def check_bgmv_shrink(batches: int, num_loras: int, rank: int,
hidden_size: int, dtype: torch.dtype, device: str,
scaling: float):
"""
Compare vllm.bgmv_shrink against a reference implementation.
"""
seq_length = 1
data: PunicaTensors = generate_data(
batches,
hidden_size,
num_loras,
rank,
seq_length,
dtype,
"shrink",
device,
)
torch.ops.vllm.bgmv_shrink(
data.inputs_tensor,
data.lora_weights,
data.our_out_tensor,
data.token_lora_mapping,
scaling,
)
bgmv_shrink(
data.inputs_tensor,
data.lora_weights,
data.ref_out_tensor,
data.token_lora_mapping,
scaling,
)
data.ref_out_tensor = data.ref_out_tensor.to(torch.float32)
assert_close(data.our_out_tensor, data.ref_out_tensor)
def check_bgmv_expand(batches: int, num_loras: int, rank: int,
hidden_size: int, dtype: torch.dtype, device: str,
add_inputs: bool):
"""
Compare vllm.bgmv_expand against a reference implementation.
"""
seq_length = 1
data: PunicaTensors = generate_data(
batches,
hidden_size,
num_loras,
rank,
seq_length,
dtype,
"expand",
device,
)
torch.ops.vllm.bgmv_expand(
data.inputs_tensor,
data.lora_weights,
data.our_out_tensor,
data.token_lora_mapping,
add_inputs=add_inputs,
)
bgmv_expand(
data.inputs_tensor,
data.lora_weights,
data.ref_out_tensor,
data.token_lora_mapping,
add_inputs=add_inputs,
)
assert_close(data.our_out_tensor, data.ref_out_tensor)
def check_bgmv_expand_slice(batches: int, num_loras: int, rank: int,
hidden_size: int, nslices: int, dtype: torch.dtype,
device: str, add_inputs: bool):
"""
Compare vllm.bgmv_expand_slice against a reference implementation.
"""
seq_length = 1
data: PunicaTensors = generate_data_for_expand_nslices(
batches,
hidden_size,
num_loras,
rank,
seq_length,
dtype,
nslices,
device,
)
slice_offset = 0
for index in range(nslices):
torch.ops.vllm.bgmv_expand_slice(
data.inputs_tensor,
data.lora_weights[index],
data.our_out_tensor,
data.token_lora_mapping,
slice_offset,
slice_size=hidden_size,
add_inputs=add_inputs,
)
bgmv_expand_slice(
data.inputs_tensor,
data.lora_weights[index],
data.ref_out_tensor,
data.token_lora_mapping,
slice_offset,
slice_size=hidden_size,
add_inputs=add_inputs,
)
slice_offset += hidden_size
assert_close(data.our_out_tensor, data.ref_out_tensor)
# Tests
# We test the punica kernels along 2 verticals mainly.
# 1. Variations in hidden_dim size
# 2. Variations in all other parameters like (batch_size, max_rank, num_loras
# etc.)
# We have collected the hidden_sizes included in the LoRA models
# currently supported by vLLM. It tests whether the corresponding Triton
# kernel can run normally when tensor parallelism is set to
# [1, 2, 4, 8, 16, 32, 64].
HIDDEN_SIZES = [
128,
256,
512,
896,
1024,
1152,
1216,
1280,
1536,
1664,
2048,
2240,
2304,
2368,
2432,
2560,
2752,
3072,
3328,
3456,
3584,
3712,
4096,
4480,
4608,
4736,
4864,
5120,
5504,
5632,
5888,
6144,
6400,
6848,
6912,
7168,
7424,
8192,
8960,
9216,
9472,
10240,
11008,
11264,
13824,
14336,
14784,
14848,
15360,
18944,
22016,
22528,
24576,
27392,
27648,
29568,
29696,
32000,
32256,
32512,
32768,
33024,
36864,
43264,
49152,
49408,
60544,
60672,
64000,
64256,
102400,
102656,
128000,
128256,
]
#The size of TP
divisibility = [1, 2, 8, 16, 64]
all_hidden_size = []
for div in divisibility:
for hidden_size in HIDDEN_SIZES:
all_hidden_size.append(hidden_size // div)
HIDDEN_SIZES = list(set(all_hidden_size))
# Test params that focuses on hidden_size variation.
hs_test_params = {
"hidden_sizes": HIDDEN_SIZES,
"batches": [4],
"num_loras": [4],
"max_ranks": [32],
}
# General tests params that tests for variations in all dimensions
# except hidden_size.
test_params = {
"hidden_sizes": [2049],
"batches": [1, 4, 16, 32],
"num_loras": [1, 8, 32, 128],
"max_ranks": [1, 4, 8, 16, 32, 64, 128, 256],
}
DTYPES = [torch.float16, torch.bfloat16]
DEVICES = [f"cuda:{0}"]
SEED = [0]
@pytest.mark.parametrize("batches", test_params['batches'])
@pytest.mark.parametrize("num_loras", test_params['num_loras'])
@pytest.mark.parametrize("rank", test_params['max_ranks'])
@pytest.mark.parametrize("hidden_size", test_params['hidden_sizes'])
@pytest.mark.parametrize("nslices", [1, 2, 3])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
def test_punica_sgmv(
batches: int,
num_loras: int,
rank: int,
hidden_size: int,
nslices: int,
dtype: torch.dtype,
device: str,
seed: int,
op_type: str,
):
torch.set_default_device(device)
current_platform.seed_everything(seed)
if op_type == "shrink":
check_sgmv_shrink(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
nslices=nslices,
dtype=dtype,
device=device,
seq_length=128,
scaling=0.5)
else:
check_sgmv_expand(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
nslices=nslices,
dtype=dtype,
device=device,
seq_length=128,
add_inputs=True)
@pytest.mark.parametrize("batches", hs_test_params['batches'])
@pytest.mark.parametrize("num_loras", hs_test_params['num_loras'])
@pytest.mark.parametrize("rank", hs_test_params['max_ranks'])
@pytest.mark.parametrize("hidden_size", hs_test_params['hidden_sizes'])
@pytest.mark.parametrize("nslices", [1, 2, 3])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
def test_punica_sgmv_hidden_size(
batches: int,
num_loras: int,
rank: int,
hidden_size: int,
nslices: int,
dtype: torch.dtype,
device: str,
seed: int,
op_type: str,
):
torch.set_default_device(device)
current_platform.seed_everything(seed)
if op_type == "shrink":
check_sgmv_shrink(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
nslices=nslices,
dtype=dtype,
device=device,
seq_length=128,
scaling=0.5)
else:
check_sgmv_expand(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
nslices=nslices,
dtype=dtype,
device=device,
seq_length=128,
add_inputs=True)
@pytest.mark.parametrize("batches", test_params['batches'])
@pytest.mark.parametrize("num_loras", test_params['num_loras'])
@pytest.mark.parametrize("rank", test_params['max_ranks'])
@pytest.mark.parametrize("hidden_size", test_params['hidden_sizes'])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
def test_punica_bgmv(
batches: int,
num_loras: int,
rank: int,
hidden_size: int,
dtype: torch.dtype,
device: str,
seed: int,
op_type: str,
):
torch.set_default_device(device)
current_platform.seed_everything(seed)
if op_type == "shrink":
check_bgmv_shrink(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
dtype=dtype,
device=device,
scaling=0.5)
else:
check_bgmv_expand(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
dtype=dtype,
device=device,
add_inputs=True)
@pytest.mark.parametrize("batches", hs_test_params['batches'])
@pytest.mark.parametrize("num_loras", hs_test_params['num_loras'])
@pytest.mark.parametrize("rank", hs_test_params['max_ranks'])
@pytest.mark.parametrize("hidden_size", hs_test_params['hidden_sizes'])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
def test_punica_bgmv_hidden_size(
batches: int,
num_loras: int,
rank: int,
hidden_size: int,
dtype: torch.dtype,
device: str,
seed: int,
op_type: str,
):
torch.set_default_device(device)
current_platform.seed_everything(seed)
if op_type == "shrink":
check_bgmv_shrink(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
dtype=dtype,
device=device,
scaling=0.5)
else:
check_bgmv_expand(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
dtype=dtype,
device=device,
add_inputs=True)
@pytest.mark.parametrize("batches", test_params['batches'])
@pytest.mark.parametrize("num_loras", test_params['num_loras'])
@pytest.mark.parametrize("rank", test_params['max_ranks'])
@pytest.mark.parametrize("hidden_size", test_params['hidden_sizes'])
@pytest.mark.parametrize("nslices", [2, 3])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("seed", SEED)
def test_punica_bgmv_expand_nslices(batches: int, num_loras: int, rank: int,
hidden_size: int, nslices: int,
dtype: torch.dtype, device: str,
seed: int):
torch.set_default_device(device)
current_platform.seed_everything(seed)
check_bgmv_expand_slice(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
nslices=nslices,
dtype=dtype,
device=device,
add_inputs=True)
@pytest.mark.parametrize("batches", hs_test_params['batches'])
@pytest.mark.parametrize("num_loras", hs_test_params['num_loras'])
@pytest.mark.parametrize("rank", hs_test_params['max_ranks'])
@pytest.mark.parametrize("hidden_size", hs_test_params['hidden_sizes'])
@pytest.mark.parametrize("nslices", [2, 3])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("seed", SEED)
def test_punica_bgmv_expand_nslices_hidden_size(batches: int, num_loras: int,
rank: int, hidden_size: int,
nslices: int,
dtype: torch.dtype,
device: str, seed: int):
torch.set_default_device(device)
current_platform.seed_everything(seed)
check_bgmv_expand_slice(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
nslices=nslices,
dtype=dtype,
device=device,
add_inputs=True)