[Kernel][LoRA]Punica prefill kernels fusion (#11234)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: Abatom <abzhonghua@gmail.com>
Co-authored-by: Zhonghua Deng <abatom@163.com>
This commit is contained in:
Jee Jee Li 2025-01-07 12:01:39 +08:00 committed by GitHub
parent 8ceffbf315
commit b278557935
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 710 additions and 767 deletions

View File

@ -242,7 +242,7 @@ steps:
source_file_dependencies: source_file_dependencies:
- vllm/lora - vllm/lora
- tests/lora - tests/lora
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py --ignore=lora/test_minicpmv_tp.py
parallelism: 4 parallelism: 4
- label: "PyTorch Fullgraph Smoke Test" # 9min - label: "PyTorch Fullgraph Smoke Test" # 9min
@ -535,6 +535,7 @@ steps:
# requires multi-GPU testing for validation. # requires multi-GPU testing for validation.
- pytest -v -s -x lora/test_chatglm3_tp.py - pytest -v -s -x lora/test_chatglm3_tp.py
- pytest -v -s -x lora/test_llama_tp.py - pytest -v -s -x lora/test_llama_tp.py
- pytest -v -s -x lora/test_minicpmv_tp.py
- label: Weight Loading Multiple GPU Test # 33min - label: Weight Loading Multiple GPU Test # 33min

View File

@ -1,77 +0,0 @@
from typing import List
import pytest
import vllm
from vllm.assets.image import ImageAsset
from vllm.lora.request import LoRARequest
from vllm.platforms import current_platform
MODEL_PATH = "openbmb/MiniCPM-Llama3-V-2_5"
PROMPT_TEMPLATE = (
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"
"(<image>./</image>)\nWhat is in the image?<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n")
IMAGE_ASSETS = [
ImageAsset("stop_sign"),
ImageAsset("cherry_blossom"),
]
# After fine-tuning with LoRA, all generated content should start begin `A`.
EXPECTED_OUTPUT = [
"A red and white stop sign with a Chinese archway in the background featuring red lanterns and gold accents.", # noqa: E501
"A pink cherry blossom tree with a blue sky in the background.",
]
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
sampling_params = vllm.SamplingParams(
temperature=0,
max_tokens=5,
stop_token_ids=[128001, 128009], # eos_id, eot_id
)
inputs = [{
"prompt": PROMPT_TEMPLATE,
"multi_modal_data": {
"image": asset.pil_image
},
} for asset in IMAGE_ASSETS]
outputs = llm.generate(
inputs,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None,
)
# Print the outputs.
generated_texts: List[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts
@pytest.mark.xfail(
current_platform.is_rocm(),
reason="MiniCPM-V dependency xformers incompatible with ROCm")
def test_minicpmv_lora(minicpmv_lora_files):
llm = vllm.LLM(
MODEL_PATH,
max_num_seqs=2,
enable_lora=True,
max_loras=4,
max_lora_rank=64,
trust_remote_code=True,
enable_chunked_prefill=True,
)
output1 = do_sample(llm, minicpmv_lora_files, lora_id=1)
for i in range(len(EXPECTED_OUTPUT)):
assert EXPECTED_OUTPUT[i].startswith(output1[i])
output2 = do_sample(llm, minicpmv_lora_files, lora_id=2)
for i in range(len(EXPECTED_OUTPUT)):
assert EXPECTED_OUTPUT[i].startswith(output2[i])

View File

@ -3,10 +3,10 @@ from typing import List
import pytest import pytest
import vllm import vllm
from tests.utils import fork_new_process_for_each_test
from vllm.assets.image import ImageAsset from vllm.assets.image import ImageAsset
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.platforms import current_platform
from ..utils import multi_gpu_test
MODEL_PATH = "openbmb/MiniCPM-Llama3-V-2_5" MODEL_PATH = "openbmb/MiniCPM-Llama3-V-2_5"
@ -17,13 +17,11 @@ PROMPT_TEMPLATE = (
IMAGE_ASSETS = [ IMAGE_ASSETS = [
ImageAsset("stop_sign"), ImageAsset("stop_sign"),
ImageAsset("cherry_blossom"),
] ]
# After fine-tuning with LoRA, all generated content should start begin `A`. # After fine-tuning with LoRA, all generated content should start begin `A`.
EXPECTED_OUTPUT = [ EXPECTED_OUTPUT = [
"A red and white stop sign with a Chinese archway in the background featuring red lanterns and gold accents.", # noqa: E501 "A red and white stop sign with a Chinese archway in the background featuring red lanterns and gold accents.", # noqa: E501
"A pink cherry blossom tree with a blue sky in the background.",
] ]
@ -50,37 +48,40 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
# Print the outputs. # Print the outputs.
generated_texts: List[str] = [] generated_texts: List[str] = []
for output in outputs: for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip() generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text) generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") print(f"Generated text: {generated_text!r}")
return generated_texts return generated_texts
@multi_gpu_test(num_gpus=2) @pytest.mark.xfail(
@pytest.mark.parametrize("fully_sharded", [True, False]) current_platform.is_rocm(),
def test_minicpmv_tp2(minicpmv_lora_files, fully_sharded): reason="MiniCPM-V dependency xformers incompatible with ROCm")
@fork_new_process_for_each_test
def test_minicpmv_lora(minicpmv_lora_files):
llm = vllm.LLM( llm = vllm.LLM(
MODEL_PATH, MODEL_PATH,
enable_lora=True,
max_num_seqs=2, max_num_seqs=2,
max_loras=4, enable_lora=True,
max_lora_rank=64, max_loras=2,
tensor_parallel_size=2, max_lora_rank=8,
enforce_eager=True,
trust_remote_code=True, trust_remote_code=True,
fully_sharded_loras=fully_sharded,
enable_chunked_prefill=True, enable_chunked_prefill=True,
) )
output1 = do_sample(llm, minicpmv_lora_files, lora_id=1)
output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1)
for i in range(len(EXPECTED_OUTPUT)): for i in range(len(EXPECTED_OUTPUT)):
assert EXPECTED_OUTPUT[i].startswith(output_tp[i]) assert EXPECTED_OUTPUT[i].startswith(output1[i])
output2 = do_sample(llm, minicpmv_lora_files, lora_id=2)
for i in range(len(EXPECTED_OUTPUT)):
assert EXPECTED_OUTPUT[i].startswith(output2[i])
@multi_gpu_test(num_gpus=4) @pytest.mark.xfail(
@pytest.mark.parametrize("fully_sharded", [True, False]) current_platform.is_rocm(),
def test_minicpmv_tp4(minicpmv_lora_files, fully_sharded): reason="MiniCPM-V dependency xformers incompatible with ROCm")
@fork_new_process_for_each_test
def test_minicpmv_tp4_wo_fully_sharded_loras(minicpmv_lora_files):
llm = vllm.LLM( llm = vllm.LLM(
MODEL_PATH, MODEL_PATH,
enable_lora=True, enable_lora=True,
@ -89,9 +90,33 @@ def test_minicpmv_tp4(minicpmv_lora_files, fully_sharded):
max_lora_rank=64, max_lora_rank=64,
tensor_parallel_size=4, tensor_parallel_size=4,
trust_remote_code=True, trust_remote_code=True,
fully_sharded_loras=fully_sharded, enforce_eager=True,
enable_chunked_prefill=True, enable_chunked_prefill=True,
) )
output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1) output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1)
for i in range(len(EXPECTED_OUTPUT)): for i in range(len(EXPECTED_OUTPUT)):
assert EXPECTED_OUTPUT[i].startswith(output_tp[i]) assert EXPECTED_OUTPUT[i].startswith(output_tp[i])
@pytest.mark.xfail(
current_platform.is_rocm(),
reason="MiniCPM-V dependency xformers incompatible with ROCm")
@fork_new_process_for_each_test
def test_minicpmv_tp4_fully_sharded_loras(minicpmv_lora_files):
llm = vllm.LLM(
MODEL_PATH,
enable_lora=True,
max_num_seqs=2,
max_loras=2,
max_lora_rank=8,
tensor_parallel_size=4,
trust_remote_code=True,
fully_sharded_loras=True,
enable_chunked_prefill=True,
)
output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1)
for i in range(len(EXPECTED_OUTPUT)):
assert EXPECTED_OUTPUT[i].startswith(output_tp[i])
output_tp = do_sample(llm, minicpmv_lora_files, lora_id=2)
for i in range(len(EXPECTED_OUTPUT)):
assert EXPECTED_OUTPUT[i].startswith(output_tp[i])

View File

@ -4,6 +4,8 @@ hidden_sizes included in the LoRA models currently supported by vLLM. It tests
whether the corresponding Triton kernel can run normally when tensor parallelism whether the corresponding Triton kernel can run normally when tensor parallelism
is set to [1, 2, 4, 8, 16, 32, 64]. is set to [1, 2, 4, 8, 16, 32, 64].
""" """
from threading import Lock
import pytest import pytest
import torch import torch
@ -11,12 +13,13 @@ from vllm.lora.ops.bgmv_expand import bgmv_expand
from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice
from vllm.lora.ops.bgmv_shrink import bgmv_shrink from vllm.lora.ops.bgmv_shrink import bgmv_shrink
from vllm.lora.ops.sgmv_expand import sgmv_expand from vllm.lora.ops.sgmv_expand import sgmv_expand
from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice
from vllm.lora.ops.sgmv_shrink import sgmv_shrink from vllm.lora.ops.sgmv_shrink import sgmv_shrink
from vllm.lora.ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
from vllm.platforms import current_platform from vllm.platforms import current_platform
from .utils import (generate_data, generate_data_for_expand_nslices, from .utils import (assert_close, generate_data,
ref_torch_groupgemm) generate_data_for_expand_nslices,
generate_data_for_nslices, ref_torch_groupgemm)
HIDDEN_SIZES = [ HIDDEN_SIZES = [
128, 128,
@ -112,14 +115,7 @@ SCALES = [0.5]
SEED = [0] SEED = [0]
CUDA_DEVICES = [f"cuda:{0}"] CUDA_DEVICES = [f"cuda:{0}"]
_dict_lock = Lock()
def assert_close(a, b):
rtol, atol = {
torch.float16: (6e-2, 6e-2),
torch.bfloat16: (6e-2, 6e-2),
torch.float32: (1e-2, 1e-2),
}[a.dtype]
torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
@pytest.mark.parametrize("batches", BATCHES) @pytest.mark.parametrize("batches", BATCHES)
@ -127,6 +123,7 @@ def assert_close(a, b):
@pytest.mark.parametrize("rank", MAX_RANKS) @pytest.mark.parametrize("rank", MAX_RANKS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("scaling", SCALES) @pytest.mark.parametrize("scaling", SCALES)
@pytest.mark.parametrize("nslices", [1, 2, 3])
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("op_type", ["shrink", "expand"]) @pytest.mark.parametrize("op_type", ["shrink", "expand"])
@pytest.mark.parametrize("seed", SEED) @pytest.mark.parametrize("seed", SEED)
@ -137,6 +134,7 @@ def test_punica_sgmv(
rank: int, rank: int,
hidden_size: int, hidden_size: int,
scaling: float, scaling: float,
nslices: int,
dtype: torch.dtype, dtype: torch.dtype,
op_type: str, op_type: str,
seed: int, seed: int,
@ -148,19 +146,20 @@ def test_punica_sgmv(
seq_length = 128 seq_length = 128
( (
inputs_tensor, inputs_tensor,
lora_weights, lora_weights_lst,
our_out_tensor, our_out_tensor,
ref_out_tensor, ref_out_tensor,
b_seq_start_loc, b_seq_start_loc,
lora_indices_tensor, lora_indices_tensor,
seq_len_tensor, seq_len_tensor,
indices, indices,
) = generate_data( ) = generate_data_for_nslices(
batches, batches,
hidden_size, hidden_size,
num_loras, num_loras,
rank, rank,
seq_length, seq_length,
nslices,
dtype, dtype,
op_type, op_type,
device, device,
@ -172,9 +171,12 @@ def test_punica_sgmv(
else: else:
max_seq_length = max_seq_length.item() max_seq_length = max_seq_length.item()
if op_type == "shrink": if op_type == "shrink":
# Preventing cache error pointer.
with _dict_lock:
_LORA_A_PTR_DICT.clear()
sgmv_shrink( sgmv_shrink(
inputs_tensor, inputs_tensor,
lora_weights, lora_weights_lst,
our_out_tensor, our_out_tensor,
b_seq_start_loc, b_seq_start_loc,
seq_len_tensor, seq_len_tensor,
@ -184,10 +186,23 @@ def test_punica_sgmv(
token_nums, token_nums,
scaling, scaling,
) )
for index in range(nslices):
ref_torch_groupgemm(
ref_out_tensor[index],
inputs_tensor,
lora_weights_lst[index],
lora_indices_tensor,
seq_len_tensor,
batches,
scaling,
op_type,
)
else: else:
with _dict_lock:
_LORA_B_PTR_DICT.clear()
sgmv_expand( sgmv_expand(
inputs_tensor, inputs_tensor,
lora_weights, lora_weights_lst,
our_out_tensor, our_out_tensor,
b_seq_start_loc, b_seq_start_loc,
seq_len_tensor, seq_len_tensor,
@ -195,20 +210,25 @@ def test_punica_sgmv(
batches, batches,
max_seq_length, max_seq_length,
token_nums, token_nums,
offset_start=0,
add_inputs=True, add_inputs=True,
) )
slice_offset = 0
for index in range(nslices):
lora_weights = lora_weights_lst[index]
ref_torch_groupgemm( ref_torch_groupgemm(
ref_out_tensor, ref_out_tensor[:, slice_offset:slice_offset + hidden_size],
inputs_tensor, inputs_tensor[index],
lora_weights, lora_weights,
lora_indices_tensor, lora_indices_tensor,
seq_len_tensor, seq_len_tensor,
batches, batches,
scaling if op_type == "shrink" else 1.0, 1.0,
op_type, op_type,
) )
if op_type == "shrink": slice_offset += hidden_size
ref_out_tensor = ref_out_tensor.to(torch.float32)
assert_close(our_out_tensor, ref_out_tensor) assert_close(our_out_tensor, ref_out_tensor)
@ -292,25 +312,22 @@ def test_punica_bgmv(
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("nslices", [2, 3]) @pytest.mark.parametrize("nslices", [2, 3])
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("op_type", ["sgmv", "bgmv"])
@pytest.mark.parametrize("seed", SEED) @pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
def test_punica_expand_nslices( def test_punica_bgmv_expand_nslices(
batches: int, batches: int,
num_loras: int, num_loras: int,
rank: int, rank: int,
hidden_size: int, hidden_size: int,
nslices: int, nslices: int,
dtype: torch.dtype, dtype: torch.dtype,
op_type: str,
seed: int, seed: int,
device: str, device: str,
): ):
torch.set_default_device(device) torch.set_default_device(device)
current_platform.seed_everything(seed) current_platform.seed_everything(seed)
seq_length = 128 if op_type == "sgmv" else 1 seq_length = 1
( (
inputs_tensor, inputs_tensor,
lora_weights_lst, lora_weights_lst,
@ -330,32 +347,9 @@ def test_punica_expand_nslices(
nslices, nslices,
device, device,
) )
max_seq_length = seq_len_tensor.max()
token_nums = seq_len_tensor.sum().item()
if isinstance(max_seq_length, tuple):
max_seq_length = max_seq_length[0].item()
else:
max_seq_length = max_seq_length.item()
slice_offset = 0 slice_offset = 0
for index in range(nslices): for index in range(nslices):
lora_weights = lora_weights_lst[index] lora_weights = lora_weights_lst[index]
if op_type == "sgmv":
sgmv_expand_slice(
inputs_tensor,
lora_weights,
our_outputs,
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
batches,
max_seq_length,
token_nums,
slice_offset,
hidden_size,
add_inputs=True,
)
else:
bgmv_expand_slice( bgmv_expand_slice(
inputs_tensor, inputs_tensor,
lora_weights, lora_weights,

View File

@ -3,6 +3,8 @@ This script is mainly used to test whether trtion kernels can run normally
under different conditions, including various batches, numbers of LoRA , and under different conditions, including various batches, numbers of LoRA , and
maximum ranks. maximum ranks.
""" """
from threading import Lock
import pytest import pytest
import torch import torch
@ -11,12 +13,13 @@ import vllm.lora.ops.bgmv_expand
import vllm.lora.ops.bgmv_expand_slice import vllm.lora.ops.bgmv_expand_slice
import vllm.lora.ops.bgmv_shrink import vllm.lora.ops.bgmv_shrink
import vllm.lora.ops.sgmv_expand import vllm.lora.ops.sgmv_expand
import vllm.lora.ops.sgmv_expand_slice
import vllm.lora.ops.sgmv_shrink # noqa: F401 import vllm.lora.ops.sgmv_shrink # noqa: F401
from vllm.lora.ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
from vllm.platforms import current_platform from vllm.platforms import current_platform
from .utils import (generate_data, generate_data_for_expand_nslices, from .utils import (assert_close, generate_data,
ref_torch_groupgemm) generate_data_for_expand_nslices,
generate_data_for_nslices, ref_torch_groupgemm)
HIDDEN_SIZES = [4097] HIDDEN_SIZES = [4097]
@ -28,31 +31,23 @@ SCALES = [0.5]
SEED = [0] SEED = [0]
CUDA_DEVICES = [f"cuda:{0}"] CUDA_DEVICES = [f"cuda:{0}"]
def assert_close(a, b):
rtol, atol = {
torch.float16: (6e-2, 6e-2),
torch.bfloat16: (6e-2, 6e-2),
torch.float32: (1e-2, 1e-2),
}[a.dtype]
torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
# Unlike test_punica_sizes.py, we directly utilize custom op for # Unlike test_punica_sizes.py, we directly utilize custom op for
# testing, which verifies the correct registration of these ops. # testing, which verifies the correct registration of these ops.
bgmv_expand = torch.ops.vllm.bgmv_expand bgmv_expand = torch.ops.vllm.bgmv_expand
bgmv_expand_slice = torch.ops.vllm.bgmv_expand_slice bgmv_expand_slice = torch.ops.vllm.bgmv_expand_slice
bgmv_shrink = torch.ops.vllm.bgmv_shrink bgmv_shrink = torch.ops.vllm.bgmv_shrink
sgmv_expand = torch.ops.vllm.sgmv_expand sgmv_expand = torch.ops.vllm.sgmv_expand
sgmv_expand_slice = torch.ops.vllm.sgmv_expand_slice
sgmv_shrink = torch.ops.vllm.sgmv_shrink sgmv_shrink = torch.ops.vllm.sgmv_shrink
_dict_lock = Lock()
@pytest.mark.parametrize("batches", BATCHES) @pytest.mark.parametrize("batches", BATCHES)
@pytest.mark.parametrize("num_loras", NUM_LORA) @pytest.mark.parametrize("num_loras", NUM_LORA)
@pytest.mark.parametrize("rank", MAX_RANKS) @pytest.mark.parametrize("rank", MAX_RANKS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("scaling", SCALES) @pytest.mark.parametrize("scaling", SCALES)
@pytest.mark.parametrize("nslices", [1, 2, 3])
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("op_type", ["shrink", "expand"]) @pytest.mark.parametrize("op_type", ["shrink", "expand"])
@pytest.mark.parametrize("seed", SEED) @pytest.mark.parametrize("seed", SEED)
@ -63,6 +58,7 @@ def test_punica_sgmv(
rank: int, rank: int,
hidden_size: int, hidden_size: int,
scaling: float, scaling: float,
nslices: int,
dtype: torch.dtype, dtype: torch.dtype,
op_type: str, op_type: str,
seed: int, seed: int,
@ -74,19 +70,20 @@ def test_punica_sgmv(
seq_length = 128 seq_length = 128
( (
inputs_tensor, inputs_tensor,
lora_weights, lora_weights_lst,
our_out_tensor, our_out_tensor,
ref_out_tensor, ref_out_tensor,
b_seq_start_loc, b_seq_start_loc,
lora_indices_tensor, lora_indices_tensor,
seq_len_tensor, seq_len_tensor,
indices, indices,
) = generate_data( ) = generate_data_for_nslices(
batches, batches,
hidden_size, hidden_size,
num_loras, num_loras,
rank, rank,
seq_length, seq_length,
nslices,
dtype, dtype,
op_type, op_type,
device, device,
@ -98,9 +95,12 @@ def test_punica_sgmv(
else: else:
max_seq_length = max_seq_length.item() max_seq_length = max_seq_length.item()
if op_type == "shrink": if op_type == "shrink":
# Preventing cache error pointer.
with _dict_lock:
_LORA_A_PTR_DICT.clear()
sgmv_shrink( sgmv_shrink(
inputs_tensor, inputs_tensor,
lora_weights, lora_weights_lst,
our_out_tensor, our_out_tensor,
b_seq_start_loc, b_seq_start_loc,
seq_len_tensor, seq_len_tensor,
@ -110,10 +110,23 @@ def test_punica_sgmv(
token_nums, token_nums,
scaling, scaling,
) )
for index in range(nslices):
ref_torch_groupgemm(
ref_out_tensor[index],
inputs_tensor,
lora_weights_lst[index],
lora_indices_tensor,
seq_len_tensor,
batches,
scaling,
op_type,
)
else: else:
with _dict_lock:
_LORA_B_PTR_DICT.clear()
sgmv_expand( sgmv_expand(
inputs_tensor, inputs_tensor,
lora_weights, lora_weights_lst,
our_out_tensor, our_out_tensor,
b_seq_start_loc, b_seq_start_loc,
seq_len_tensor, seq_len_tensor,
@ -121,20 +134,25 @@ def test_punica_sgmv(
batches, batches,
max_seq_length, max_seq_length,
token_nums, token_nums,
offset_start=0,
add_inputs=True, add_inputs=True,
) )
slice_offset = 0
for index in range(nslices):
lora_weights = lora_weights_lst[index]
ref_torch_groupgemm( ref_torch_groupgemm(
ref_out_tensor, ref_out_tensor[:, slice_offset:slice_offset + hidden_size],
inputs_tensor, inputs_tensor[index],
lora_weights, lora_weights,
lora_indices_tensor, lora_indices_tensor,
seq_len_tensor, seq_len_tensor,
batches, batches,
scaling if op_type == "shrink" else 1.0, 1.0,
op_type, op_type,
) )
if op_type == "shrink": slice_offset += hidden_size
ref_out_tensor = ref_out_tensor.to(torch.float32)
assert_close(our_out_tensor, ref_out_tensor) assert_close(our_out_tensor, ref_out_tensor)
@ -220,24 +238,22 @@ def test_punica_bgmv(
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("nslices", [2, 3]) @pytest.mark.parametrize("nslices", [2, 3])
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("op_type", ["sgmv", "bgmv"])
@pytest.mark.parametrize("seed", SEED) @pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
def test_punica_expand_nslices( def test_punica_bgmv_expand_nslices(
batches: int, batches: int,
num_loras: int, num_loras: int,
rank: int, rank: int,
hidden_size: int, hidden_size: int,
nslices: int, nslices: int,
dtype: torch.dtype, dtype: torch.dtype,
op_type: str,
seed: int, seed: int,
device: str, device: str,
): ):
torch.set_default_device(device) torch.set_default_device(device)
current_platform.seed_everything(seed) current_platform.seed_everything(seed)
seq_length = 128 if op_type == "sgmv" else 1 seq_length = 1
( (
inputs_tensor, inputs_tensor,
lora_weights_lst, lora_weights_lst,
@ -257,31 +273,9 @@ def test_punica_expand_nslices(
nslices, nslices,
device, device,
) )
max_seq_length = seq_len_tensor.max()
token_nums = seq_len_tensor.sum().item()
if isinstance(max_seq_length, tuple):
max_seq_length = max_seq_length[0].item()
else:
max_seq_length = max_seq_length.item()
slice_offset = 0 slice_offset = 0
for index in range(nslices): for index in range(nslices):
lora_weights = lora_weights_lst[index] lora_weights = lora_weights_lst[index]
if op_type == "sgmv":
sgmv_expand_slice(
inputs_tensor,
lora_weights,
our_outputs,
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
batches,
max_seq_length,
token_nums,
slice_offset,
hidden_size,
add_inputs=True,
)
else:
bgmv_expand_slice( bgmv_expand_slice(
inputs_tensor, inputs_tensor,
lora_weights, lora_weights,

View File

@ -18,11 +18,13 @@ class DummyLoRAManager:
def get_module_lora(self, module_name: str) -> LoRALayerWeights: def get_module_lora(self, module_name: str) -> LoRALayerWeights:
return self._loras[module_name] return self._loras[module_name]
def init_random_lora(self, def init_random_lora(
self,
module_name: str, module_name: str,
weight: torch.Tensor, weight: torch.Tensor,
rank: int = 8, rank: int = 8,
generate_embeddings_tensor: int = 0): generate_embeddings_tensor: int = 0,
):
lora = LoRALayerWeights( lora = LoRALayerWeights(
module_name, module_name,
rank=rank, rank=rank,
@ -35,21 +37,25 @@ class DummyLoRAManager:
device=self._device), device=self._device),
) )
if generate_embeddings_tensor: if generate_embeddings_tensor:
lora.embeddings_tensor = torch.rand(5, lora.embeddings_tensor = torch.rand(
5,
generate_embeddings_tensor, generate_embeddings_tensor,
dtype=weight.dtype, dtype=weight.dtype,
device=self._device) device=self._device,
)
self.set_module_lora(module_name, lora) self.set_module_lora(module_name, lora)
return lora return lora
def init_lora(self, def init_lora(
self,
module_name: str, module_name: str,
input_dim: int, input_dim: int,
output_dim: int, output_dim: int,
rank=8, rank=8,
noop=False, noop=False,
embeddings_tensor=None): embeddings_tensor=None,
):
lora = LoRALayerWeights( lora = LoRALayerWeights(
module_name, module_name,
rank=rank, rank=rank,
@ -125,8 +131,16 @@ def ref_torch_groupgemm(
return return
def generate_data(batches, hidden_size, lora_nums, max_rank, seq_length, dtype, def generate_data(
op_type, device): batches,
hidden_size,
lora_nums,
max_rank,
seq_length,
dtype,
op_type,
device,
):
seq_len_tensor = torch.randint(seq_length, seq_length + 1, seq_len_tensor = torch.randint(seq_length, seq_length + 1,
(batches, )).to(device) (batches, )).to(device)
b_seq_start_loc = torch.cumsum( b_seq_start_loc = torch.cumsum(
@ -187,8 +201,16 @@ def generate_data(batches, hidden_size, lora_nums, max_rank, seq_length, dtype,
) )
def generate_data_for_expand_nslices(batches, hidden_size, lora_nums, max_rank, def generate_data_for_expand_nslices(
seq_length, dtype, nslices, device): batches,
hidden_size,
lora_nums,
max_rank,
seq_length,
dtype,
nslices,
device,
):
seq_len_tensor = torch.randint(seq_length, seq_length + 1, seq_len_tensor = torch.randint(seq_length, seq_length + 1,
(batches, )).to(device) (batches, )).to(device)
b_seq_start_loc = torch.cumsum( b_seq_start_loc = torch.cumsum(
@ -221,7 +243,87 @@ def generate_data_for_expand_nslices(batches, hidden_size, lora_nums, max_rank,
for b_id in range(batches): for b_id in range(batches):
lora_index = lora_indices_tensor[b_id] lora_index = lora_indices_tensor[b_id]
indices[current_offset:current_offset + indices[current_offset:current_offset +
seq_len_tensor[b_id]] = lora_index.item() seq_len_tensor[b_id]] = (lora_index.item())
current_offset += seq_len_tensor[b_id].item()
lora_indices_tensor = lora_indices_tensor.to(device)
return (
inputs_tensor,
lora_weights_lst,
our_out_tensor,
ref_out_tensor,
b_seq_start_loc,
lora_indices_tensor,
seq_len_tensor,
indices,
)
def generate_data_for_nslices(
batches,
hidden_size,
lora_nums,
max_rank,
seq_length,
nslices,
dtype,
op_type,
device,
):
seq_len_tensor = torch.randint(seq_length, seq_length + 1,
(batches, )).to(device)
b_seq_start_loc = torch.cumsum(
torch.tensor([0] + seq_len_tensor[:-1].tolist(), dtype=torch.long),
dim=0,
).to(device)
total_tokens = seq_len_tensor.sum()
lora_weights_lst = []
if op_type == "shrink":
inputs_tensor = torch.rand((total_tokens, hidden_size),
dtype=dtype).to(device)
for _ in range(nslices):
if op_type == "shrink":
lora_weights_lst.append(
torch.rand(
(lora_nums, max_rank, hidden_size), # col-major
dtype=dtype,
).to(device))
# NOTE shrink kernel using torch.float32 as output type
# shrink op need atomic_add, so output is initinized by 0
our_out_tensor = torch.zeros(
(nslices, total_tokens, max_rank),
dtype=torch.float32,
).to(device)
else:
inputs_tensor = torch.rand(
(nslices, total_tokens, max_rank),
dtype=dtype,
).to(device)
for _ in range(nslices):
lora_weights_lst.append(
torch.rand(
(lora_nums, hidden_size, max_rank), # col-major
dtype=dtype,
).to(device))
# expand op needs to complete y+=a@lora_b, so output is
# initinized randomly
our_out_tensor = torch.rand((total_tokens, hidden_size * nslices),
dtype=dtype).to(device)
# Ensure the same input.
ref_out_tensor = our_out_tensor.clone()
lora_indices_tensor = torch.randint(0,
lora_nums - 1 if lora_nums > 1 else 1,
(batches, ))
indices = torch.zeros((total_tokens), dtype=torch.long).to(device)
current_offset = 0
for b_id in range(batches):
lora_index = lora_indices_tensor[b_id]
indices[current_offset:current_offset +
seq_len_tensor[b_id]] = (lora_index.item())
current_offset += seq_len_tensor[b_id].item() current_offset += seq_len_tensor[b_id].item()
lora_indices_tensor = lora_indices_tensor.to(device) lora_indices_tensor = lora_indices_tensor.to(device)

View File

@ -5,12 +5,16 @@ Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547 https://arxiv.org/abs/2310.18547
""" """
from typing import List
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
from .utils import _get_lora_b_ptr
@triton.jit @triton.jit
def _sgmv_expand_kernel( def _sgmv_expand_kernel(
@ -22,45 +26,84 @@ def _sgmv_expand_kernel(
b_seq_start_loc, b_seq_start_loc,
seq_lens, seq_lens,
lora_indices, lora_indices,
xm_stride, slice_start_loc,
xk_stride, # 1 input_d0_stride,
l0_stride, # hidden_size*max_rank input_d1_stride,
lora_k_stride, input_d2_stride, # 1
lora_n_stride, ls_d0_ptr,
cm_stride, ls_d1_ptr,
cn_stride, ls_d2_ptr, # 1
output_d0_stride,
output_d1_stride, # 1
output_hs_ptr,
BLOCK_M: tl.constexpr, BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr, BLOCK_K: tl.constexpr,
EVEN_K: tl.constexpr, EVEN_K: tl.constexpr,
ADD_INPUTS: tl.constexpr, ADD_INPUTS: tl.constexpr,
CAST_TYPE: tl.constexpr, CAST_TYPE: tl.constexpr,
): SLICE_NUM: tl.constexpr,
SAME_STRIDE: tl.constexpr):
""" """
The sgmv's expand triton kernel is based on GroupGEMM.
Similar to the 'sgmv_expand' operator, but with an added parameter
'slice_offset'. The reason for not reusing the 'sgmv_expand' operator
might be that in the future, we could implement a fusion operator to
achieve the current functionality instead of having to call it multiple
times.
""" """
pid = tl.program_id(axis=0) pid = tl.program_id(axis=0)
cur_batch = tl.program_id(axis=1) cur_batch = tl.program_id(axis=1)
slice_id = tl.program_id(axis=2)
cta_n_num = tl.cdiv(N, BLOCK_N) cta_n_num = tl.cdiv(N, BLOCK_N)
# When the output dimensions of each slice are the same,cur_n=N, otherwise
# cur_n=tl.load(output_hs_ptr + slice_id), this situation exists in GQA's
# qkv linear.
curr_N = N if SAME_STRIDE else tl.load(output_hs_ptr + slice_id)
pid_m = pid // cta_n_num pid_m = pid // cta_n_num
pid_n = pid % cta_n_num pid_n = pid % cta_n_num
M = tl.load(seq_lens + cur_batch) M = tl.load(seq_lens + cur_batch)
if pid_m * BLOCK_M > M: if pid_m * BLOCK_M > M:
return return
if pid_n * BLOCK_N > curr_N:
return
lora_index = tl.load(lora_indices + cur_batch) lora_index = tl.load(lora_indices + cur_batch)
if lora_index == -1: if lora_index == -1:
return return
cur_seq_start = tl.load(b_seq_start_loc + cur_batch) cur_seq_start = tl.load(b_seq_start_loc + cur_batch)
offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
offset_k = tl.arange(0, BLOCK_K) offset_k = tl.arange(0, BLOCK_K)
ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M) ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) rbn = tl.max_contiguous(tl.multiple_of(offset_n % curr_N, BLOCK_N),
BLOCK_N)
# ls_d*_ptr can be either an integer or a pointer
if SAME_STRIDE:
# integer
cur_lora_d0_stride = ls_d0_ptr
cur_lora_d1_stride = ls_d1_ptr
cur_lora_d2_stride = ls_d2_ptr
else:
# pointer
cur_lora_d0_stride = tl.load(ls_d0_ptr + slice_id)
cur_lora_d1_stride = tl.load(ls_d1_ptr + slice_id)
cur_lora_d2_stride = tl.load(ls_d2_ptr + slice_id)
if SLICE_NUM == 1:
cur_input_ptr = input_ptr
cur_lora_ptr = lora_ptr
a_ptr = (input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride + else:
offset_k[None, :] * xk_stride, ) cur_input_ptr = input_ptr + slice_id * input_d0_stride
b_ptr = (lora_ptr + l0_stride * lora_index + cur_lora_ptr = tl.load(lora_ptr + slice_id).to(
offset_k[:, None] * lora_n_stride + rbn[None, :] * lora_k_stride) tl.pointer_type(out_ptr.dtype.element_ty))
a_ptr = (cur_input_ptr + cur_seq_start * input_d1_stride +
ram[:, None] * input_d1_stride +
offset_k[None, :] * input_d2_stride, )
b_ptr = (cur_lora_ptr + cur_lora_d0_stride * lora_index +
offset_k[:, None] * cur_lora_d2_stride +
rbn[None, :] * cur_lora_d1_stride)
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(tl.cdiv(K, BLOCK_K)): for k in range(tl.cdiv(K, BLOCK_K)):
if EVEN_K: if EVEN_K:
@ -74,26 +117,30 @@ def _sgmv_expand_kernel(
mask=offset_k[:, None] < K - k * BLOCK_K, mask=offset_k[:, None] < K - k * BLOCK_K,
other=0) other=0)
if CAST_TYPE: if CAST_TYPE:
tiled_a = tiled_a.to(lora_ptr.dtype.element_ty) tiled_a = tiled_a.to(cur_lora_ptr.dtype.element_ty)
accumulator += tl.dot( accumulator += tl.dot(
tiled_a, tiled_a,
tiled_b, tiled_b,
) )
a_ptr += BLOCK_K * xk_stride a_ptr += BLOCK_K * input_d2_stride
b_ptr += BLOCK_K * lora_n_stride b_ptr += BLOCK_K * cur_lora_d2_stride
tiled_c = accumulator.to(lora_ptr.dtype.element_ty)
tiled_c = accumulator.to(cur_lora_ptr.dtype.element_ty)
if SLICE_NUM == 1:
cur_slice_start = slice_start_loc
else:
cur_slice_start = tl.load(slice_start_loc + slice_id)
offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + cur_slice_start
c_ptr = (out_ptr + offset_cm[:, None] * cm_stride + c_ptr = (out_ptr + offset_cm[:, None] * output_d0_stride +
offset_cn[None, :] * cn_stride) offset_cn[None, :] * output_d1_stride)
M = tl.load(seq_lens + cur_batch) M = tl.load(seq_lens + cur_batch)
c_mask = (offset_cm[:, None] < c_mask = (offset_cm[:, None] <
(cur_seq_start + M)) & (offset_cn[None, :] < N) (cur_seq_start + M)) & (offset_cn[None, :] <
(cur_slice_start + curr_N))
if ADD_INPUTS: if ADD_INPUTS:
# explicitly pass in other=None to tell triton that masked values tiled_out = tl.load(c_ptr, mask=c_mask)
# can be uninitialized. This is OK because the later tl.store operation
# uses the same mask, eliminating the risk of garbage values propagating
tiled_out = tl.load(c_ptr, mask=c_mask, other=None)
tiled_c += tiled_out tiled_c += tiled_out
tl.store(c_ptr, tiled_c, mask=c_mask) tl.store(c_ptr, tiled_c, mask=c_mask)
@ -101,7 +148,7 @@ def _sgmv_expand_kernel(
@torch.inference_mode() @torch.inference_mode()
def _sgmv_expand( def _sgmv_expand(
inputs: torch.Tensor, inputs: torch.Tensor,
lora_b_weights: torch.Tensor, lora_b_weights: List[torch.Tensor],
output_tensor: torch.Tensor, output_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor, b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor, seq_len_tensor: torch.Tensor,
@ -109,17 +156,18 @@ def _sgmv_expand(
batches: int, batches: int,
max_seq_length: int, max_seq_length: int,
token_nums: int, token_nums: int,
offset_start: int = 0,
add_inputs: bool = False, add_inputs: bool = False,
) -> None: ) -> None:
""" """
Args: Args:
inputs (torch.Tensor): input tensor inputs (torch.Tensor): input tensor
lora_b_weights (torch.Tensor): lora'a weight lora_b_weights (List[torch.Tensor]): lora'b weight
output_tensor (torch.Tensor): output tensor output_tensor (torch.Tensor): output tensor
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
sequence lengths of the sequences in the batch, used to index sequence lengths of the sequences in the batch, used to index
into sequence. E.g., if the sequence length is [4, 6], it is into sequence. E.g., if the sequence length is [4, 6], it is
[0, 4, 10]. [0, 4].
seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence
length of the sequences in the batch. length of the sequences in the batch.
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
@ -130,77 +178,80 @@ def _sgmv_expand(
batch. batch.
token_nums (int): The token numbers in the batch. Used to verify if the token_nums (int): The token numbers in the batch. Used to verify if the
token numbers in the inputs matches the one in the metadata. token numbers in the inputs matches the one in the metadata.
add_inputs (bool, optional): Defaults to False, adds the final lora offset_start (int, optional): Offset start for output_tensor.
results to the output. Defaults to 0.
add_inputs (bool, optional): Whether to add the input tensor to the
output tensor. Defaults to False.
""" """
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
assert lora_b_weights.dtype in [ for weight in lora_b_weights:
torch.float16, assert weight.dtype in [torch.float16, torch.bfloat16]
torch.bfloat16,
] assert inputs.size(1) == token_nums
assert inputs.size(0) == token_nums assert inputs.size(0) == len(lora_b_weights)
assert inputs.size(1) == lora_b_weights.size(-1)
assert b_seq_start_loc.size(0) == batches assert b_seq_start_loc.size(0) == batches
assert lora_indices_tensor.size(0) == batches assert lora_indices_tensor.size(0) == batches
assert inputs.is_contiguous()
assert output_tensor.is_contiguous() assert output_tensor.is_contiguous()
(slice_start_tensor, lora_ptr_tensor, lora_strides_d0_tensor,
if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank) lora_strides_d1_tensor, lora_strides_d2_tensor, hidden_sizes_tensor,
assert lora_b_weights.size(1) == 1 same_stride, MAX_N) = _get_lora_b_ptr(lora_b_weights, offset_start,
lora_b_weights = lora_b_weights.squeeze(dim=1) b_seq_start_loc.device)
else:
assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank)
assert lora_b_weights.is_contiguous()
# TODO tuning this config # TODO tuning this config
K = lora_b_weights[0].shape[-1] # K= rank
N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size BLOCK_M = 64
BLOCK_M = 32 BLOCK_N = 128
BLOCK_N = 32
BLOCK_K = 16 BLOCK_K = 16
EVEN_K = K % BLOCK_K == 0 EVEN_K = K % BLOCK_K == 0
ADD_INPUTS = add_inputs ADD_INPUTS = add_inputs
CAST_TYPE = False CAST_TYPE = False
if inputs.dtype == torch.float32 and lora_b_weights.dtype in [
if inputs.dtype == torch.float32 and lora_b_weights[0].dtype in [
torch.float16, torch.float16,
torch.bfloat16, torch.bfloat16,
]: ]:
CAST_TYPE = True CAST_TYPE = True
grid = ( grid = (
triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N), triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(MAX_N, BLOCK_N),
batches, batches,
len(lora_b_weights),
) )
_sgmv_expand_kernel[grid]( _sgmv_expand_kernel[grid](
inputs, inputs,
lora_b_weights, lora_ptr_tensor,
output_tensor, output_tensor,
N, MAX_N,
K, K,
b_seq_start_loc, b_seq_start_loc,
seq_len_tensor, seq_len_tensor,
lora_indices_tensor, lora_indices_tensor,
slice_start_tensor,
inputs.stride(0), inputs.stride(0),
inputs.stride(1), inputs.stride(1),
lora_b_weights.stride(0), inputs.stride(2),
lora_b_weights.stride(1), lora_strides_d0_tensor,
lora_b_weights.stride(2), lora_strides_d1_tensor,
lora_strides_d2_tensor,
output_tensor.stride(0), output_tensor.stride(0),
output_tensor.stride(1), output_tensor.stride(1),
hidden_sizes_tensor,
BLOCK_M, BLOCK_M,
BLOCK_N, BLOCK_N,
BLOCK_K, BLOCK_K,
EVEN_K, EVEN_K,
ADD_INPUTS, ADD_INPUTS,
CAST_TYPE, CAST_TYPE,
len(lora_b_weights),
same_stride,
) )
return return
def sgmv_expand_fake( def _sgmv_expand_fake(
inputs: torch.Tensor, inputs: torch.Tensor,
lora_b_weights: torch.Tensor, lora_b_weights: List[torch.Tensor],
output_tensor: torch.Tensor, output_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor, b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor, seq_len_tensor: torch.Tensor,
@ -208,18 +259,18 @@ def sgmv_expand_fake(
batches: int, batches: int,
max_seq_length: int, max_seq_length: int,
token_nums: int, token_nums: int,
offset_start: int = 0,
add_inputs: bool = False, add_inputs: bool = False,
) -> None: ) -> None:
return return
try: try:
direct_register_custom_op( direct_register_custom_op(
op_name="sgmv_expand", op_name="sgmv_expand",
op_func=_sgmv_expand, op_func=_sgmv_expand,
mutates_args=["output_tensor"], mutates_args=["output_tensor"],
fake_impl=sgmv_expand_fake, fake_impl=_sgmv_expand_fake,
) )
sgmv_expand = torch.ops.vllm.sgmv_expand sgmv_expand = torch.ops.vllm.sgmv_expand

View File

@ -1,241 +0,0 @@
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
import torch
import triton
import triton.language as tl
from vllm.utils import direct_register_custom_op
@triton.jit
def _sgmv_expand_slice_kernel(
input_ptr,
lora_ptr,
out_ptr,
N,
K,
b_seq_start_loc,
seq_lens,
lora_indices,
xm_stride,
xk_stride, # 1
l0_stride, # hidden_size*max_rank
lora_k_stride,
lora_n_stride,
cm_stride,
cn_stride,
slice_offset,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
EVEN_K: tl.constexpr,
ADD_INPUTS: tl.constexpr,
CAST_TYPE: tl.constexpr,
):
"""
Similar to the 'sgmv_expand' operator, but with an added parameter
'slice_offset'. The reason for not reusing the 'sgmv_expand' operator
might be that in the future, we could implement a fusion operator to
achieve the current functionality instead of having to call it multiple
times.
"""
pid = tl.program_id(axis=0)
cur_batch = tl.program_id(axis=1)
cta_n_num = tl.cdiv(N, BLOCK_N)
pid_m = pid // cta_n_num
pid_n = pid % cta_n_num
M = tl.load(seq_lens + cur_batch)
if pid_m * BLOCK_M > M:
return
lora_index = tl.load(lora_indices + cur_batch)
if lora_index == -1:
return
cur_seq_start = tl.load(b_seq_start_loc + cur_batch)
offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
offset_k = tl.arange(0, BLOCK_K)
ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N)
a_ptr = (input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride +
offset_k[None, :] * xk_stride, )
b_ptr = (lora_ptr + l0_stride * lora_index +
offset_k[:, None] * lora_n_stride + rbn[None, :] * lora_k_stride)
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(tl.cdiv(K, BLOCK_K)):
if EVEN_K:
tiled_a = tl.load(a_ptr)
tiled_b = tl.load(b_ptr)
else:
tiled_a = tl.load(a_ptr,
mask=offset_k[None, :] < K - k * BLOCK_K,
other=0)
tiled_b = tl.load(b_ptr,
mask=offset_k[:, None] < K - k * BLOCK_K,
other=0)
if CAST_TYPE:
tiled_a = tiled_a.to(lora_ptr.dtype.element_ty)
accumulator += tl.dot(
tiled_a,
tiled_b,
)
a_ptr += BLOCK_K * xk_stride
b_ptr += BLOCK_K * lora_n_stride
tiled_c = accumulator.to(lora_ptr.dtype.element_ty)
offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + slice_offset
c_ptr = (out_ptr + offset_cm[:, None] * cm_stride +
offset_cn[None, :] * cn_stride)
M = tl.load(seq_lens + cur_batch)
c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & (offset_cn[None, :] <
(slice_offset + N))
if ADD_INPUTS:
# explicitly pass in other=None to tell triton that masked values
# can be uninitialized. This is OK because the later tl.store operation
# uses the same mask, eliminating the risk of garbage values propagating
tiled_out = tl.load(c_ptr, mask=c_mask, other=None)
tiled_c += tiled_out
tl.store(c_ptr, tiled_c, mask=c_mask)
@torch.inference_mode()
def _sgmv_expand_slice(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
slice_offset: int,
slice_size: int,
add_inputs: bool = False,
) -> None:
"""_summary_
Args:
inputs (torch.Tensor): input tensor
lora_b_weights (torch.Tensor): lora'a weight
output_tensor (torch.Tensor): output tensor
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
sequence lengths of the sequences in the batch, used to index
into sequence. E.g., if the sequence length is [4, 6], it is
[0, 4, 10].
seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence
length of the sequences in the batch
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch. An index of -1 means no lora should be
applied.
batches (int): batch size
max_seq_length (int): The max sequence lengths of the sequences
in the batch
token_nums (int): The token numbers in the batch. Used to verify if the
token numbers in the inputs matches the one in the metadata.
slice_offset (int): output_tensor's offset
slice_size (int): current output_tensor's size
add_inputs (bool, optional): Defaults to False, adds the final lora
results to the output.
"""
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
assert lora_b_weights.dtype in [
torch.float16,
torch.bfloat16,
]
assert inputs.size(0) == token_nums
assert inputs.size(1) == lora_b_weights.size(-1)
assert b_seq_start_loc.size(0) == batches
assert lora_indices_tensor.size(0) == batches
assert slice_size == lora_b_weights.size(-2)
assert inputs.is_contiguous()
assert output_tensor.is_contiguous()
if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank)
assert lora_b_weights.size(1) == 1
lora_b_weights = lora_b_weights.squeeze(dim=1)
else:
assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank)
assert lora_b_weights.is_contiguous()
# TODO tuning this config
N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size
BLOCK_M = 32
BLOCK_N = 32
BLOCK_K = 16
EVEN_K = K % BLOCK_K == 0
ADD_INPUTS = add_inputs
CAST_TYPE = False
if inputs.dtype == torch.float32 and lora_b_weights.dtype in [
torch.float16,
torch.bfloat16,
]:
CAST_TYPE = True
grid = (
triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N),
batches,
)
_sgmv_expand_slice_kernel[grid](
inputs,
lora_b_weights,
output_tensor,
N,
K,
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
inputs.stride(0),
inputs.stride(1),
lora_b_weights.stride(0),
lora_b_weights.stride(1),
lora_b_weights.stride(2),
output_tensor.stride(0),
output_tensor.stride(1),
slice_offset,
BLOCK_M,
BLOCK_N,
BLOCK_K,
EVEN_K,
ADD_INPUTS,
CAST_TYPE,
)
return
def sgmv_expand_slice_fake(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
slice_offset: int,
slice_size: int,
add_inputs: bool = False,
) -> None:
return
try:
direct_register_custom_op(
op_name="sgmv_expand_slice",
op_func=_sgmv_expand_slice,
mutates_args=["output_tensor"],
fake_impl=sgmv_expand_slice_fake,
)
sgmv_expand_slice = torch.ops.vllm.sgmv_expand_slice
except AttributeError:
sgmv_expand_slice = _sgmv_expand_slice

View File

@ -5,17 +5,21 @@ Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547 https://arxiv.org/abs/2310.18547
""" """
from typing import List
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
from .utils import _get_lora_a_ptr
@triton.jit @triton.jit
def _sgmv_shrink_kernel( def _sgmv_shrink_kernel(
input_ptr, input_ptr,
lora_ptr, lora_ptr, #1-3
out_ptr, out_ptr,
N, N,
K, K,
@ -23,30 +27,38 @@ def _sgmv_shrink_kernel(
seq_lens, seq_lens,
lora_indices, lora_indices,
scaling, scaling,
xm_stride, # hidden_size input_d0_stride,
xk_stride, # 1 input_d1_stride, # 1
l0_stride, # hidden_size*max_rank lora_d0_stride,
lora_k_stride, lora_d1_stride,
lora_n_stride, lora_d2_stride, # 1
cm_stride, output_d0_stride,
cn_stride, output_d1_stride,
output_d2_stride, # 1
BLOCK_M: tl.constexpr, BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr, BLOCK_K: tl.constexpr,
EVEN_K: tl.constexpr, EVEN_K: tl.constexpr,
SPLIT_K: tl.constexpr, SPLIT_K: tl.constexpr,
): SLICE_NUM: tl.constexpr):
""" """
The sgmv's shrink triton kernel is based on GroupGEMM+SPLIT-K. The sgmv's shrink triton kernel is based on GroupGEMM+SPLIT-K.
The GEMM of Multi-LoRA can be considered as GroupGEMM. Additionally, The GEMM of Multi-LoRA can be considered as GroupGEMM. Additionally,
introducing SPLIT-K can improve performance introducing SPLIT-K can improve performance
""" """
pid = tl.program_id(axis=0) pid = tl.program_id(axis=0)
pid_sk = tl.program_id(axis=1) pid_mix = tl.program_id(axis=1)
cur_batch = tl.program_id(axis=2) cur_batch = tl.program_id(axis=2)
cta_n_num = tl.cdiv(N, BLOCK_N) cta_n_num = tl.cdiv(N, BLOCK_N)
pid_m = pid // cta_n_num pid_m = pid // cta_n_num
pid_n = pid % cta_n_num pid_n = pid % cta_n_num
if SLICE_NUM == 1:
slice_id: tl.constexpr = 0
pid_sk = tl.program_id(axis=1)
else:
pid_mix = tl.program_id(axis=1)
slice_id = pid_mix // SPLIT_K
pid_sk = pid_mix % SPLIT_K
M = tl.load(seq_lens + cur_batch) M = tl.load(seq_lens + cur_batch)
if pid_m * BLOCK_M > M: if pid_m * BLOCK_M > M:
@ -61,11 +73,22 @@ def _sgmv_shrink_kernel(
ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M) ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N)
# input ptr
a_ptr = (input_ptr + cur_seq_start * input_d0_stride +
ram[:, None] * input_d0_stride +
offset_k[None, :] * input_d1_stride)
a_ptr = (input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride + if SLICE_NUM == 1:
offset_k[None, :] * xk_stride) # current lora ptr
b_ptr = (lora_ptr + l0_stride * lora_index + rbn[None, :] * lora_k_stride + cur_lora_ptr = lora_ptr
offset_k[:, None] * lora_n_stride) else:
# current lora ptr
cur_lora_ptr = tl.load(lora_ptr + slice_id).to(
tl.pointer_type(input_ptr.dtype.element_ty))
b_ptr = (cur_lora_ptr + lora_d0_stride * lora_index +
rbn[None, :] * lora_d1_stride +
offset_k[:, None] * lora_d2_stride)
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
@ -82,13 +105,15 @@ def _sgmv_shrink_kernel(
other=0.0) other=0.0)
accumulator += tl.dot(tiled_a, tiled_b) accumulator += tl.dot(tiled_a, tiled_b)
a_ptr += BLOCK_K * SPLIT_K * xk_stride a_ptr += BLOCK_K * SPLIT_K * input_d1_stride
b_ptr += BLOCK_K * SPLIT_K * lora_n_stride b_ptr += BLOCK_K * SPLIT_K * lora_d2_stride
offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
c_ptr = (out_ptr + offset_cm[:, None] * cm_stride + cur_out_ptr = (out_ptr if SLICE_NUM == 1 else out_ptr +
offset_cn[None, :] * cn_stride) slice_id * output_d0_stride)
c_ptr = cur_out_ptr + offset_cm[:, None] * output_d1_stride + offset_cn[
None, :] * output_d2_stride
c_mask = (offset_cm[:, None] < c_mask = (offset_cm[:, None] <
(cur_seq_start + M)) & (offset_cn[None, :] < N) (cur_seq_start + M)) & (offset_cn[None, :] < N)
accumulator *= scaling accumulator *= scaling
@ -102,7 +127,7 @@ def _sgmv_shrink_kernel(
@torch.inference_mode() @torch.inference_mode()
def _sgmv_shrink( def _sgmv_shrink(
inputs: torch.Tensor, inputs: torch.Tensor,
lora_a_weights: torch.Tensor, lora_a_weights: List[torch.Tensor],
output_tensor: torch.Tensor, output_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor, b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor, seq_len_tensor: torch.Tensor,
@ -113,10 +138,9 @@ def _sgmv_shrink(
scaling: float, scaling: float,
) -> None: ) -> None:
""" """
Args: Args:
inputs (torch.Tensor): input tensor inputs (torch.Tensor): input tensor
lora_a_weights (torch.Tensor): lora'a weight lora_a_weights (List[torch.Tensor]): lora'a weight
output_tensor (torch.Tensor): output tensor output_tensor (torch.Tensor): output tensor
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
sequence lengths of the sequences in the batch, used to index sequence lengths of the sequences in the batch, used to index
@ -134,27 +158,21 @@ def _sgmv_shrink(
token numbers in the inputs matches the one in the metadata. token numbers in the inputs matches the one in the metadata.
scaling (float): Scaling factor. scaling (float): Scaling factor.
""" """
assert inputs.dtype == lora_a_weights.dtype assert inputs.dtype == lora_a_weights[0].dtype
assert inputs.dtype in [torch.float16, torch.bfloat16] assert inputs.dtype in [torch.float16, torch.bfloat16]
assert lora_a_weights.dtype in [ for weight in lora_a_weights:
torch.float16, assert weight.dtype in [torch.float16, torch.bfloat16]
torch.bfloat16,
]
assert inputs.size(0) == token_nums assert inputs.size(0) == token_nums
assert inputs.size(1) == lora_a_weights.size(-1) assert inputs.size(1) == lora_a_weights[0].size(-1)
assert b_seq_start_loc.size(0) == batches assert b_seq_start_loc.size(0) == batches
assert lora_indices_tensor.size(0) == batches assert lora_indices_tensor.size(0) == batches
assert inputs.is_contiguous() assert inputs.is_contiguous()
if lora_a_weights.ndim == 4: # shape:(lora_num,1,rank, size)
assert lora_a_weights.size(1) == 1
lora_a_weights = lora_a_weights.squeeze(dim=1)
else:
assert lora_a_weights.ndim == 3 # shape:(lora_num,rank, size)
assert lora_a_weights.is_contiguous()
assert output_tensor.is_contiguous() assert output_tensor.is_contiguous()
(lora_ptr_tensor, lora_strides_d0, lora_strides_d1,
lora_strides_d2) = _get_lora_a_ptr(lora_a_weights, b_seq_start_loc.device)
# TODO tuning this config # TODO tuning this config
N, K = lora_a_weights.shape[-2:] # K=hidden_size,N=rank N, K = lora_a_weights[0].shape[-2:] # K=hidden_size,N=rank
BLOCK_M = 32 BLOCK_M = 32
BLOCK_N = 16 BLOCK_N = 16
BLOCK_K = 32 BLOCK_K = 32
@ -162,13 +180,12 @@ def _sgmv_shrink(
EVEN_K = K % (BLOCK_K * SPLIT_K) == 0 EVEN_K = K % (BLOCK_K * SPLIT_K) == 0
grid = ( grid = (
triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N), triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N),
SPLIT_K, SPLIT_K * len(lora_a_weights),
batches, batches,
) )
_sgmv_shrink_kernel[grid]( _sgmv_shrink_kernel[grid](
inputs, inputs,
lora_a_weights, lora_ptr_tensor,
output_tensor, output_tensor,
N, N,
K, K,
@ -178,23 +195,25 @@ def _sgmv_shrink(
scaling, scaling,
inputs.stride(0), inputs.stride(0),
inputs.stride(1), inputs.stride(1),
lora_a_weights.stride(0), lora_strides_d0,
lora_a_weights.stride(1), lora_strides_d1,
lora_a_weights.stride(2), lora_strides_d2,
output_tensor.stride(0), output_tensor.stride(0),
output_tensor.stride(1), output_tensor.stride(1),
output_tensor.stride(2),
BLOCK_M, BLOCK_M,
BLOCK_N, BLOCK_N,
BLOCK_K, BLOCK_K,
EVEN_K, EVEN_K,
SPLIT_K, SPLIT_K,
len(lora_a_weights),
) )
return return
def sgmv_shrink_fake( def sgmv_shrink_fake(
inputs: torch.Tensor, inputs: torch.Tensor,
lora_a_weights: torch.Tensor, lora_a_weights: List[torch.Tensor],
output_tensor: torch.Tensor, output_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor, b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor, seq_len_tensor: torch.Tensor,

View File

@ -1,5 +1,7 @@
import functools import functools
from typing import Dict from typing import Dict, List, Tuple
import torch
@functools.lru_cache @functools.lru_cache
@ -44,3 +46,120 @@ def get_lora_op_configs(op_type: str, batch: int,
if not config: if not config:
config = _get_default_config(op_type, batch, hidden_size) config = _get_default_config(op_type, batch, hidden_size)
return config return config
_LORA_A_PTR_DICT: Dict[Tuple[int, ...], Tuple[torch.tensor, ...]] = {}
_LORA_B_PTR_DICT: Dict[Tuple[int, ...], Tuple[torch.tensor, ...]] = {}
def _get_lora_a_ptr(lora_a_weights: List[torch.Tensor], device: str):
"""
`_LORA_A_PTR_DICT` collects the required information during `profile_run`,
After this, it remains constant and subsequent usage is through LUT.
Refer to:
https://github.com/triton-lang/triton/blob/release/3.1.x/python/tutorials/08-grouped-gemm.py
"""
key = tuple(lora_weight.data_ptr() for lora_weight in lora_a_weights)
if values := _LORA_A_PTR_DICT.get(key):
return values
lora_strides_d0 = []
lora_strides_d1 = []
lora_strides_d2 = []
tensor_ptrs = []
for lora_a_weight in lora_a_weights:
if lora_a_weight.ndim == 4: # shape:(lora_num,1,size,rank)
assert lora_a_weight.size(1) == 1
lora_a_weight = lora_a_weight.squeeze(dim=1)
else:
assert lora_a_weight.ndim == 3 # shape:(lora_num,size,rank)
assert lora_a_weight.is_contiguous()
tensor_ptrs.append(lora_a_weight.data_ptr())
lora_strides_d0.append(lora_a_weight.stride(0))
lora_strides_d1.append(lora_a_weight.stride(1))
lora_strides_d2.append(lora_a_weight.stride(2))
if len(lora_a_weights) > 1:
lora_ptr_tensor = torch.tensor(tensor_ptrs, device=device)
else:
lora_ptr_tensor = lora_a_weights[0]
if (len(set(lora_strides_d0)) > 1 or len(set(lora_strides_d1)) > 1
or len(set(lora_strides_d2)) > 1):
raise ValueError("All LoRA weights must have the same stride.")
_LORA_A_PTR_DICT[key] = (
lora_ptr_tensor,
lora_strides_d0[0],
lora_strides_d1[0],
lora_strides_d2[0],
)
return _LORA_A_PTR_DICT.get(key)
def _get_lora_b_ptr(lora_weights: List[torch.Tensor], offset_start: int,
device: str):
"""
`_LORA_B_PTR_DICT` collects the required information during `profile_run`,
After this, it remains constant and subsequent usage is through LUT.
Refer to:
https://github.com/triton-lang/triton/blob/release/3.1.x/python/tutorials/08-grouped-gemm.py
"""
key = tuple(lora_weight.data_ptr() for lora_weight in lora_weights)
if values := _LORA_B_PTR_DICT.get(key):
return values
slice_offset_lst = []
tensor_ptrs = []
lora_strides_d0 = []
lora_strides_d1 = []
lora_strides_d2 = []
hidden_sizes = []
slice_offset = offset_start
for lora_b_weight in lora_weights:
if lora_b_weight.ndim == 4: # shape:(lora_num,1,size,rank)
assert lora_b_weight.size(1) == 1
lora_b_weight = lora_b_weight.squeeze(dim=1)
else:
assert lora_b_weight.ndim == 3 # shape:(lora_num,size,rank)
assert lora_b_weight.is_contiguous()
tensor_ptrs.append(lora_b_weight.data_ptr())
lora_strides_d0.append(lora_b_weight.stride(0))
lora_strides_d1.append(lora_b_weight.stride(1))
lora_strides_d2.append(lora_b_weight.stride(2))
slice_offset_lst.append(slice_offset)
slice_offset += lora_b_weight.size(1)
hidden_sizes.append(lora_b_weight.size(1))
if len(lora_weights) > 1:
# note these are device tensors
lora_ptr_tensor = torch.tensor(tensor_ptrs, device=device)
slice_start_tensor = torch.tensor(slice_offset_lst, device=device)
else:
slice_start_tensor = slice_offset_lst[0]
lora_ptr_tensor = lora_b_weight[0]
# If each lora has the same stride, there's no need to use a
# tensor for storage.
if (len(set(lora_strides_d0)) == 1 and len(set(lora_strides_d1)) == 1 and
len(set(lora_strides_d2)) == 1) and len(set(hidden_sizes)) == 1:
lora_strides_d0_tensor = lora_strides_d0[0]
lora_strides_d1_tensor = lora_strides_d1[0]
lora_strides_d2_tensor = lora_strides_d2[0]
hidden_sizes_tensor = hidden_sizes[0]
same_stride = True
else:
lora_strides_d0_tensor = torch.tensor(lora_strides_d0, device=device)
lora_strides_d1_tensor = torch.tensor(lora_strides_d1, device=device)
lora_strides_d2_tensor = torch.tensor(lora_strides_d2, device=device)
hidden_sizes_tensor = torch.tensor(hidden_sizes, device=device)
same_stride = False
# MAX_N is the maximum hidden size among all the lora_b weights
MAX_N = max(hidden_sizes)
_LORA_B_PTR_DICT[key] = (slice_start_tensor, lora_ptr_tensor,
lora_strides_d0_tensor, lora_strides_d1_tensor,
lora_strides_d2_tensor, hidden_sizes_tensor,
same_stride, MAX_N)
return _LORA_B_PTR_DICT.get(key)

View File

@ -5,7 +5,7 @@ Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547 https://arxiv.org/abs/2310.18547
""" """
from typing import Callable, Optional, Tuple, Union, final from typing import Optional, Tuple, Union, final
import torch import torch
@ -16,7 +16,6 @@ if HAS_TRITON:
from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice
from vllm.lora.ops.bgmv_shrink import bgmv_shrink from vllm.lora.ops.bgmv_shrink import bgmv_shrink
from vllm.lora.ops.sgmv_expand import sgmv_expand from vllm.lora.ops.sgmv_expand import sgmv_expand
from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice
from vllm.lora.ops.sgmv_shrink import sgmv_shrink from vllm.lora.ops.sgmv_shrink import sgmv_shrink
from .punica_base import PunicaWrapperBase from .punica_base import PunicaWrapperBase
@ -35,11 +34,11 @@ class PunicaWrapperGPU(PunicaWrapperBase):
PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches,
device) device)
def _shrink_prefill( def _apply_shrink_prefill(
self, self,
y: torch.Tensor, y: torch.Tensor,
x: torch.Tensor, x: torch.Tensor,
w_t_all: torch.Tensor, w_t_all: Tuple[torch.Tensor, ...],
scale: float, scale: float,
): ):
#No LoRA request, so return directly #No LoRA request, so return directly
@ -53,7 +52,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
scale, scale,
) )
def _shrink_decode( def _apply_shrink_decode(
self, self,
y: torch.Tensor, y: torch.Tensor,
x: torch.Tensor, x: torch.Tensor,
@ -62,56 +61,28 @@ class PunicaWrapperGPU(PunicaWrapperBase):
): ):
bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale) bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale)
def _expand_prefill( def _apply_expand_prefill(
self, self,
y: torch.Tensor, y: torch.Tensor,
x: torch.Tensor, x: torch.Tensor,
w_t_all: torch.Tensor, w_t_all: torch.Tensor,
offset_start: int,
add_inputs: bool, add_inputs: bool,
): ):
#No LoRA request, so return directly #No LoRA request, so return directly
if self.no_lora: if self.no_lora:
return return
sgmv_expand( sgmv_expand(
x, x,
w_t_all, w_t_all,
y, y,
*self.prefill_metadata, *self.prefill_metadata,
add_inputs, offset_start=offset_start,
add_inputs=add_inputs,
) )
def _expand_decode( def _apply_expand_decode(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
add_inputs: bool,
):
bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_inputs)
def _expand_slice_prefill(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
y_offset: Optional[int],
y_slice_size: Optional[int],
add_inputs: bool,
):
#No LoRA request, so return directly
if self.no_lora:
return
sgmv_expand_slice(
x,
w_t_all,
y,
*self.prefill_metadata,
y_offset,
y_slice_size,
add_inputs,
)
def _expand_slice_decode(
self, self,
y: torch.Tensor, y: torch.Tensor,
x: torch.Tensor, x: torch.Tensor,
@ -123,43 +94,6 @@ class PunicaWrapperGPU(PunicaWrapperBase):
bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset,
y_slice_size, add_inputs) y_slice_size, add_inputs)
def _apply_expand(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
y_offset: Optional[int],
y_slice_size: Optional[int],
add_inputs: bool = True,
):
"""
Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all`
computation, which is suitable for the
GEMM of lora'b.
"""
expand_slice_fun: Callable = (self._expand_slice_prefill
if self.is_prefill else
self._expand_slice_decode)
expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_inputs)
def _apply_shrink(self, y: torch.Tensor, x: torch.Tensor,
w_t_all: torch.Tensor, scale: float):
"""
Perform the ` y+=x@w_t_all` computation, which is suitable for the
GEMM of lora'a.
When `is_prefill is` true, it indicates that it is currently the
prefill stage, and the `_shrink_prefill` function should be called.
Otherwise, it is the decode stage, and the _shrink_decode function
should be called.
"""
y_org = y
y = y.view(-1, y.shape[-1])
shrink_fun: Callable = (self._shrink_prefill
if self.is_prefill else self._shrink_decode)
shrink_fun(y, x, w_t_all, scale)
y = y.view_as(y_org)
def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor],
x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...],
scale: float, **kwargs): scale: float, **kwargs):
@ -182,10 +116,15 @@ class PunicaWrapperGPU(PunicaWrapperBase):
""" """
x = x.view(-1, x.shape[-1]) x = x.view(-1, x.shape[-1])
if self.is_prefill:
# NOTE fused kernel
self._apply_shrink_prefill(y, x, lora_a_stacked, scale)
else:
# TODO fuse these kernels # TODO fuse these kernels
for slice_idx in range(len(lora_a_stacked)): for slice_idx in range(len(lora_a_stacked)):
self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx], self._apply_shrink_decode(y[slice_idx], x,
scale) lora_a_stacked[slice_idx], scale)
def add_expand(self, def add_expand(self,
y: torch.Tensor, y: torch.Tensor,
@ -217,20 +156,28 @@ class PunicaWrapperGPU(PunicaWrapperBase):
""" """
y_org = y y_org = y
y = y.view(-1, y.shape[-1]) y = y.view(-1, y.shape[-1])
offset_left = offset_start
if lora_bias_stacked is not None: if lora_bias_stacked is not None:
self._apply_bias(self.token_lora_indices, y, output_slices, self._apply_bias(self.token_lora_indices, y, output_slices,
lora_bias_stacked) lora_bias_stacked)
if self.is_prefill:
# NOTE fused kernel
self._apply_expand_prefill(y,
x,
lora_b_stacked,
offset_start,
add_inputs=True)
else:
# TODO fuse these kernels
for slice_idx in range(len(lora_b_stacked)): for slice_idx in range(len(lora_b_stacked)):
self._apply_expand( self._apply_expand_decode(
y, y,
x[slice_idx], x[slice_idx],
lora_b_stacked[slice_idx], lora_b_stacked[slice_idx],
offset_left, offset_start,
output_slices[slice_idx], output_slices[slice_idx],
add_inputs=add_inputs, add_inputs=add_inputs,
) )
offset_left += output_slices[slice_idx] offset_start += output_slices[slice_idx]
y = y.view_as(y_org) y = y.view_as(y_org)
def add_lora_embedding(self, def add_lora_embedding(self,
@ -252,10 +199,18 @@ class PunicaWrapperGPU(PunicaWrapperBase):
add_inputs (bool): Default to True. add_inputs (bool): Default to True.
""" """
# Embedding layer only need expand op if self.is_prefill:
expand_fun: Callable = (self._expand_prefill sgmv_expand(
if self.is_prefill else self._expand_decode) x.unsqueeze(dim=0),
expand_fun(y, x, lora_b_stacked, add_inputs) [lora_b_stacked],
y,
*self.prefill_metadata,
offset_start=0,
add_inputs=add_inputs,
)
else:
bgmv_expand(x, lora_b_stacked, y, self.token_lora_indices,
add_inputs)
def add_lora_linear(self, def add_lora_linear(self,
y: torch.Tensor, y: torch.Tensor,
@ -301,10 +256,11 @@ class PunicaWrapperGPU(PunicaWrapperBase):
r = lora_b_stacked[0].size(-1) r = lora_b_stacked[0].size(-1)
# We set the buffer to be float32 by default ,refer to: # We set the buffer to be float32 by default ,refer to:
# https://github.com/triton-lang/triton/issues/1387 # https://github.com/triton-lang/triton/issues/1387
buffer = tuple( buffer = torch.zeros(
torch.zeros( (len(output_slices), x.size(0), r),
(x.size(0), r), dtype=torch.float32, device=x.device) dtype=torch.float32,
for _ in range(len(output_slices))) device=x.device,
)
self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs) self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs)
self.add_expand(y, self.add_expand(y,
buffer, buffer,