[Kernel][LoRA]Punica prefill kernels fusion (#11234)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Signed-off-by: Abatom <abzhonghua@gmail.com> Co-authored-by: Zhonghua Deng <abatom@163.com>
This commit is contained in:
parent
8ceffbf315
commit
b278557935
@ -242,7 +242,7 @@ steps:
|
|||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/lora
|
- vllm/lora
|
||||||
- tests/lora
|
- tests/lora
|
||||||
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py
|
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py --ignore=lora/test_minicpmv_tp.py
|
||||||
parallelism: 4
|
parallelism: 4
|
||||||
|
|
||||||
- label: "PyTorch Fullgraph Smoke Test" # 9min
|
- label: "PyTorch Fullgraph Smoke Test" # 9min
|
||||||
@ -535,6 +535,7 @@ steps:
|
|||||||
# requires multi-GPU testing for validation.
|
# requires multi-GPU testing for validation.
|
||||||
- pytest -v -s -x lora/test_chatglm3_tp.py
|
- pytest -v -s -x lora/test_chatglm3_tp.py
|
||||||
- pytest -v -s -x lora/test_llama_tp.py
|
- pytest -v -s -x lora/test_llama_tp.py
|
||||||
|
- pytest -v -s -x lora/test_minicpmv_tp.py
|
||||||
|
|
||||||
|
|
||||||
- label: Weight Loading Multiple GPU Test # 33min
|
- label: Weight Loading Multiple GPU Test # 33min
|
||||||
|
@ -1,77 +0,0 @@
|
|||||||
from typing import List
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
import vllm
|
|
||||||
from vllm.assets.image import ImageAsset
|
|
||||||
from vllm.lora.request import LoRARequest
|
|
||||||
from vllm.platforms import current_platform
|
|
||||||
|
|
||||||
MODEL_PATH = "openbmb/MiniCPM-Llama3-V-2_5"
|
|
||||||
|
|
||||||
PROMPT_TEMPLATE = (
|
|
||||||
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"
|
|
||||||
"(<image>./</image>)\nWhat is in the image?<|eot_id|>"
|
|
||||||
"<|start_header_id|>assistant<|end_header_id|>\n\n")
|
|
||||||
|
|
||||||
IMAGE_ASSETS = [
|
|
||||||
ImageAsset("stop_sign"),
|
|
||||||
ImageAsset("cherry_blossom"),
|
|
||||||
]
|
|
||||||
|
|
||||||
# After fine-tuning with LoRA, all generated content should start begin `A`.
|
|
||||||
EXPECTED_OUTPUT = [
|
|
||||||
"A red and white stop sign with a Chinese archway in the background featuring red lanterns and gold accents.", # noqa: E501
|
|
||||||
"A pink cherry blossom tree with a blue sky in the background.",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
|
|
||||||
sampling_params = vllm.SamplingParams(
|
|
||||||
temperature=0,
|
|
||||||
max_tokens=5,
|
|
||||||
stop_token_ids=[128001, 128009], # eos_id, eot_id
|
|
||||||
)
|
|
||||||
|
|
||||||
inputs = [{
|
|
||||||
"prompt": PROMPT_TEMPLATE,
|
|
||||||
"multi_modal_data": {
|
|
||||||
"image": asset.pil_image
|
|
||||||
},
|
|
||||||
} for asset in IMAGE_ASSETS]
|
|
||||||
|
|
||||||
outputs = llm.generate(
|
|
||||||
inputs,
|
|
||||||
sampling_params,
|
|
||||||
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
|
|
||||||
if lora_id else None,
|
|
||||||
)
|
|
||||||
# Print the outputs.
|
|
||||||
generated_texts: List[str] = []
|
|
||||||
for output in outputs:
|
|
||||||
prompt = output.prompt
|
|
||||||
generated_text = output.outputs[0].text.strip()
|
|
||||||
generated_texts.append(generated_text)
|
|
||||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
|
||||||
return generated_texts
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.xfail(
|
|
||||||
current_platform.is_rocm(),
|
|
||||||
reason="MiniCPM-V dependency xformers incompatible with ROCm")
|
|
||||||
def test_minicpmv_lora(minicpmv_lora_files):
|
|
||||||
llm = vllm.LLM(
|
|
||||||
MODEL_PATH,
|
|
||||||
max_num_seqs=2,
|
|
||||||
enable_lora=True,
|
|
||||||
max_loras=4,
|
|
||||||
max_lora_rank=64,
|
|
||||||
trust_remote_code=True,
|
|
||||||
enable_chunked_prefill=True,
|
|
||||||
)
|
|
||||||
output1 = do_sample(llm, minicpmv_lora_files, lora_id=1)
|
|
||||||
for i in range(len(EXPECTED_OUTPUT)):
|
|
||||||
assert EXPECTED_OUTPUT[i].startswith(output1[i])
|
|
||||||
output2 = do_sample(llm, minicpmv_lora_files, lora_id=2)
|
|
||||||
for i in range(len(EXPECTED_OUTPUT)):
|
|
||||||
assert EXPECTED_OUTPUT[i].startswith(output2[i])
|
|
@ -3,10 +3,10 @@ from typing import List
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import vllm
|
import vllm
|
||||||
|
from tests.utils import fork_new_process_for_each_test
|
||||||
from vllm.assets.image import ImageAsset
|
from vllm.assets.image import ImageAsset
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from ..utils import multi_gpu_test
|
|
||||||
|
|
||||||
MODEL_PATH = "openbmb/MiniCPM-Llama3-V-2_5"
|
MODEL_PATH = "openbmb/MiniCPM-Llama3-V-2_5"
|
||||||
|
|
||||||
@ -17,13 +17,11 @@ PROMPT_TEMPLATE = (
|
|||||||
|
|
||||||
IMAGE_ASSETS = [
|
IMAGE_ASSETS = [
|
||||||
ImageAsset("stop_sign"),
|
ImageAsset("stop_sign"),
|
||||||
ImageAsset("cherry_blossom"),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
# After fine-tuning with LoRA, all generated content should start begin `A`.
|
# After fine-tuning with LoRA, all generated content should start begin `A`.
|
||||||
EXPECTED_OUTPUT = [
|
EXPECTED_OUTPUT = [
|
||||||
"A red and white stop sign with a Chinese archway in the background featuring red lanterns and gold accents.", # noqa: E501
|
"A red and white stop sign with a Chinese archway in the background featuring red lanterns and gold accents.", # noqa: E501
|
||||||
"A pink cherry blossom tree with a blue sky in the background.",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -50,37 +48,40 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
|
|||||||
# Print the outputs.
|
# Print the outputs.
|
||||||
generated_texts: List[str] = []
|
generated_texts: List[str] = []
|
||||||
for output in outputs:
|
for output in outputs:
|
||||||
prompt = output.prompt
|
|
||||||
generated_text = output.outputs[0].text.strip()
|
generated_text = output.outputs[0].text.strip()
|
||||||
generated_texts.append(generated_text)
|
generated_texts.append(generated_text)
|
||||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
print(f"Generated text: {generated_text!r}")
|
||||||
return generated_texts
|
return generated_texts
|
||||||
|
|
||||||
|
|
||||||
@multi_gpu_test(num_gpus=2)
|
@pytest.mark.xfail(
|
||||||
@pytest.mark.parametrize("fully_sharded", [True, False])
|
current_platform.is_rocm(),
|
||||||
def test_minicpmv_tp2(minicpmv_lora_files, fully_sharded):
|
reason="MiniCPM-V dependency xformers incompatible with ROCm")
|
||||||
|
@fork_new_process_for_each_test
|
||||||
|
def test_minicpmv_lora(minicpmv_lora_files):
|
||||||
llm = vllm.LLM(
|
llm = vllm.LLM(
|
||||||
MODEL_PATH,
|
MODEL_PATH,
|
||||||
enable_lora=True,
|
|
||||||
max_num_seqs=2,
|
max_num_seqs=2,
|
||||||
max_loras=4,
|
enable_lora=True,
|
||||||
max_lora_rank=64,
|
max_loras=2,
|
||||||
tensor_parallel_size=2,
|
max_lora_rank=8,
|
||||||
|
enforce_eager=True,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
fully_sharded_loras=fully_sharded,
|
|
||||||
enable_chunked_prefill=True,
|
enable_chunked_prefill=True,
|
||||||
)
|
)
|
||||||
|
output1 = do_sample(llm, minicpmv_lora_files, lora_id=1)
|
||||||
output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1)
|
|
||||||
|
|
||||||
for i in range(len(EXPECTED_OUTPUT)):
|
for i in range(len(EXPECTED_OUTPUT)):
|
||||||
assert EXPECTED_OUTPUT[i].startswith(output_tp[i])
|
assert EXPECTED_OUTPUT[i].startswith(output1[i])
|
||||||
|
output2 = do_sample(llm, minicpmv_lora_files, lora_id=2)
|
||||||
|
for i in range(len(EXPECTED_OUTPUT)):
|
||||||
|
assert EXPECTED_OUTPUT[i].startswith(output2[i])
|
||||||
|
|
||||||
|
|
||||||
@multi_gpu_test(num_gpus=4)
|
@pytest.mark.xfail(
|
||||||
@pytest.mark.parametrize("fully_sharded", [True, False])
|
current_platform.is_rocm(),
|
||||||
def test_minicpmv_tp4(minicpmv_lora_files, fully_sharded):
|
reason="MiniCPM-V dependency xformers incompatible with ROCm")
|
||||||
|
@fork_new_process_for_each_test
|
||||||
|
def test_minicpmv_tp4_wo_fully_sharded_loras(minicpmv_lora_files):
|
||||||
llm = vllm.LLM(
|
llm = vllm.LLM(
|
||||||
MODEL_PATH,
|
MODEL_PATH,
|
||||||
enable_lora=True,
|
enable_lora=True,
|
||||||
@ -89,9 +90,33 @@ def test_minicpmv_tp4(minicpmv_lora_files, fully_sharded):
|
|||||||
max_lora_rank=64,
|
max_lora_rank=64,
|
||||||
tensor_parallel_size=4,
|
tensor_parallel_size=4,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
fully_sharded_loras=fully_sharded,
|
enforce_eager=True,
|
||||||
enable_chunked_prefill=True,
|
enable_chunked_prefill=True,
|
||||||
)
|
)
|
||||||
output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1)
|
output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1)
|
||||||
for i in range(len(EXPECTED_OUTPUT)):
|
for i in range(len(EXPECTED_OUTPUT)):
|
||||||
assert EXPECTED_OUTPUT[i].startswith(output_tp[i])
|
assert EXPECTED_OUTPUT[i].startswith(output_tp[i])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.xfail(
|
||||||
|
current_platform.is_rocm(),
|
||||||
|
reason="MiniCPM-V dependency xformers incompatible with ROCm")
|
||||||
|
@fork_new_process_for_each_test
|
||||||
|
def test_minicpmv_tp4_fully_sharded_loras(minicpmv_lora_files):
|
||||||
|
llm = vllm.LLM(
|
||||||
|
MODEL_PATH,
|
||||||
|
enable_lora=True,
|
||||||
|
max_num_seqs=2,
|
||||||
|
max_loras=2,
|
||||||
|
max_lora_rank=8,
|
||||||
|
tensor_parallel_size=4,
|
||||||
|
trust_remote_code=True,
|
||||||
|
fully_sharded_loras=True,
|
||||||
|
enable_chunked_prefill=True,
|
||||||
|
)
|
||||||
|
output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1)
|
||||||
|
for i in range(len(EXPECTED_OUTPUT)):
|
||||||
|
assert EXPECTED_OUTPUT[i].startswith(output_tp[i])
|
||||||
|
output_tp = do_sample(llm, minicpmv_lora_files, lora_id=2)
|
||||||
|
for i in range(len(EXPECTED_OUTPUT)):
|
||||||
|
assert EXPECTED_OUTPUT[i].startswith(output_tp[i])
|
||||||
|
@ -4,6 +4,8 @@ hidden_sizes included in the LoRA models currently supported by vLLM. It tests
|
|||||||
whether the corresponding Triton kernel can run normally when tensor parallelism
|
whether the corresponding Triton kernel can run normally when tensor parallelism
|
||||||
is set to [1, 2, 4, 8, 16, 32, 64].
|
is set to [1, 2, 4, 8, 16, 32, 64].
|
||||||
"""
|
"""
|
||||||
|
from threading import Lock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -11,12 +13,13 @@ from vllm.lora.ops.bgmv_expand import bgmv_expand
|
|||||||
from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice
|
from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice
|
||||||
from vllm.lora.ops.bgmv_shrink import bgmv_shrink
|
from vllm.lora.ops.bgmv_shrink import bgmv_shrink
|
||||||
from vllm.lora.ops.sgmv_expand import sgmv_expand
|
from vllm.lora.ops.sgmv_expand import sgmv_expand
|
||||||
from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice
|
|
||||||
from vllm.lora.ops.sgmv_shrink import sgmv_shrink
|
from vllm.lora.ops.sgmv_shrink import sgmv_shrink
|
||||||
|
from vllm.lora.ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
from .utils import (generate_data, generate_data_for_expand_nslices,
|
from .utils import (assert_close, generate_data,
|
||||||
ref_torch_groupgemm)
|
generate_data_for_expand_nslices,
|
||||||
|
generate_data_for_nslices, ref_torch_groupgemm)
|
||||||
|
|
||||||
HIDDEN_SIZES = [
|
HIDDEN_SIZES = [
|
||||||
128,
|
128,
|
||||||
@ -112,14 +115,7 @@ SCALES = [0.5]
|
|||||||
SEED = [0]
|
SEED = [0]
|
||||||
CUDA_DEVICES = [f"cuda:{0}"]
|
CUDA_DEVICES = [f"cuda:{0}"]
|
||||||
|
|
||||||
|
_dict_lock = Lock()
|
||||||
def assert_close(a, b):
|
|
||||||
rtol, atol = {
|
|
||||||
torch.float16: (6e-2, 6e-2),
|
|
||||||
torch.bfloat16: (6e-2, 6e-2),
|
|
||||||
torch.float32: (1e-2, 1e-2),
|
|
||||||
}[a.dtype]
|
|
||||||
torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("batches", BATCHES)
|
@pytest.mark.parametrize("batches", BATCHES)
|
||||||
@ -127,6 +123,7 @@ def assert_close(a, b):
|
|||||||
@pytest.mark.parametrize("rank", MAX_RANKS)
|
@pytest.mark.parametrize("rank", MAX_RANKS)
|
||||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||||
@pytest.mark.parametrize("scaling", SCALES)
|
@pytest.mark.parametrize("scaling", SCALES)
|
||||||
|
@pytest.mark.parametrize("nslices", [1, 2, 3])
|
||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
|
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
|
||||||
@pytest.mark.parametrize("seed", SEED)
|
@pytest.mark.parametrize("seed", SEED)
|
||||||
@ -137,6 +134,7 @@ def test_punica_sgmv(
|
|||||||
rank: int,
|
rank: int,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
scaling: float,
|
scaling: float,
|
||||||
|
nslices: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
op_type: str,
|
op_type: str,
|
||||||
seed: int,
|
seed: int,
|
||||||
@ -148,19 +146,20 @@ def test_punica_sgmv(
|
|||||||
seq_length = 128
|
seq_length = 128
|
||||||
(
|
(
|
||||||
inputs_tensor,
|
inputs_tensor,
|
||||||
lora_weights,
|
lora_weights_lst,
|
||||||
our_out_tensor,
|
our_out_tensor,
|
||||||
ref_out_tensor,
|
ref_out_tensor,
|
||||||
b_seq_start_loc,
|
b_seq_start_loc,
|
||||||
lora_indices_tensor,
|
lora_indices_tensor,
|
||||||
seq_len_tensor,
|
seq_len_tensor,
|
||||||
indices,
|
indices,
|
||||||
) = generate_data(
|
) = generate_data_for_nslices(
|
||||||
batches,
|
batches,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
num_loras,
|
num_loras,
|
||||||
rank,
|
rank,
|
||||||
seq_length,
|
seq_length,
|
||||||
|
nslices,
|
||||||
dtype,
|
dtype,
|
||||||
op_type,
|
op_type,
|
||||||
device,
|
device,
|
||||||
@ -172,43 +171,64 @@ def test_punica_sgmv(
|
|||||||
else:
|
else:
|
||||||
max_seq_length = max_seq_length.item()
|
max_seq_length = max_seq_length.item()
|
||||||
if op_type == "shrink":
|
if op_type == "shrink":
|
||||||
sgmv_shrink(
|
# Preventing cache error pointer.
|
||||||
inputs_tensor,
|
with _dict_lock:
|
||||||
lora_weights,
|
_LORA_A_PTR_DICT.clear()
|
||||||
our_out_tensor,
|
sgmv_shrink(
|
||||||
b_seq_start_loc,
|
inputs_tensor,
|
||||||
seq_len_tensor,
|
lora_weights_lst,
|
||||||
lora_indices_tensor,
|
our_out_tensor,
|
||||||
batches,
|
b_seq_start_loc,
|
||||||
max_seq_length,
|
seq_len_tensor,
|
||||||
token_nums,
|
lora_indices_tensor,
|
||||||
scaling,
|
batches,
|
||||||
)
|
max_seq_length,
|
||||||
|
token_nums,
|
||||||
|
scaling,
|
||||||
|
)
|
||||||
|
for index in range(nslices):
|
||||||
|
ref_torch_groupgemm(
|
||||||
|
ref_out_tensor[index],
|
||||||
|
inputs_tensor,
|
||||||
|
lora_weights_lst[index],
|
||||||
|
lora_indices_tensor,
|
||||||
|
seq_len_tensor,
|
||||||
|
batches,
|
||||||
|
scaling,
|
||||||
|
op_type,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
sgmv_expand(
|
with _dict_lock:
|
||||||
inputs_tensor,
|
_LORA_B_PTR_DICT.clear()
|
||||||
lora_weights,
|
sgmv_expand(
|
||||||
our_out_tensor,
|
inputs_tensor,
|
||||||
b_seq_start_loc,
|
lora_weights_lst,
|
||||||
seq_len_tensor,
|
our_out_tensor,
|
||||||
lora_indices_tensor,
|
b_seq_start_loc,
|
||||||
batches,
|
seq_len_tensor,
|
||||||
max_seq_length,
|
lora_indices_tensor,
|
||||||
token_nums,
|
batches,
|
||||||
add_inputs=True,
|
max_seq_length,
|
||||||
)
|
token_nums,
|
||||||
ref_torch_groupgemm(
|
offset_start=0,
|
||||||
ref_out_tensor,
|
add_inputs=True,
|
||||||
inputs_tensor,
|
)
|
||||||
lora_weights,
|
|
||||||
lora_indices_tensor,
|
slice_offset = 0
|
||||||
seq_len_tensor,
|
for index in range(nslices):
|
||||||
batches,
|
lora_weights = lora_weights_lst[index]
|
||||||
scaling if op_type == "shrink" else 1.0,
|
ref_torch_groupgemm(
|
||||||
op_type,
|
ref_out_tensor[:, slice_offset:slice_offset + hidden_size],
|
||||||
)
|
inputs_tensor[index],
|
||||||
if op_type == "shrink":
|
lora_weights,
|
||||||
ref_out_tensor = ref_out_tensor.to(torch.float32)
|
lora_indices_tensor,
|
||||||
|
seq_len_tensor,
|
||||||
|
batches,
|
||||||
|
1.0,
|
||||||
|
op_type,
|
||||||
|
)
|
||||||
|
slice_offset += hidden_size
|
||||||
|
|
||||||
assert_close(our_out_tensor, ref_out_tensor)
|
assert_close(our_out_tensor, ref_out_tensor)
|
||||||
|
|
||||||
|
|
||||||
@ -292,25 +312,22 @@ def test_punica_bgmv(
|
|||||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||||
@pytest.mark.parametrize("nslices", [2, 3])
|
@pytest.mark.parametrize("nslices", [2, 3])
|
||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
@pytest.mark.parametrize("op_type", ["sgmv", "bgmv"])
|
|
||||||
@pytest.mark.parametrize("seed", SEED)
|
@pytest.mark.parametrize("seed", SEED)
|
||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
def test_punica_expand_nslices(
|
def test_punica_bgmv_expand_nslices(
|
||||||
batches: int,
|
batches: int,
|
||||||
num_loras: int,
|
num_loras: int,
|
||||||
rank: int,
|
rank: int,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
nslices: int,
|
nslices: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
op_type: str,
|
|
||||||
seed: int,
|
seed: int,
|
||||||
device: str,
|
device: str,
|
||||||
):
|
):
|
||||||
|
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
current_platform.seed_everything(seed)
|
current_platform.seed_everything(seed)
|
||||||
|
|
||||||
seq_length = 128 if op_type == "sgmv" else 1
|
seq_length = 1
|
||||||
(
|
(
|
||||||
inputs_tensor,
|
inputs_tensor,
|
||||||
lora_weights_lst,
|
lora_weights_lst,
|
||||||
@ -330,41 +347,18 @@ def test_punica_expand_nslices(
|
|||||||
nslices,
|
nslices,
|
||||||
device,
|
device,
|
||||||
)
|
)
|
||||||
max_seq_length = seq_len_tensor.max()
|
|
||||||
token_nums = seq_len_tensor.sum().item()
|
|
||||||
if isinstance(max_seq_length, tuple):
|
|
||||||
max_seq_length = max_seq_length[0].item()
|
|
||||||
else:
|
|
||||||
max_seq_length = max_seq_length.item()
|
|
||||||
slice_offset = 0
|
slice_offset = 0
|
||||||
for index in range(nslices):
|
for index in range(nslices):
|
||||||
lora_weights = lora_weights_lst[index]
|
lora_weights = lora_weights_lst[index]
|
||||||
if op_type == "sgmv":
|
bgmv_expand_slice(
|
||||||
sgmv_expand_slice(
|
inputs_tensor,
|
||||||
inputs_tensor,
|
lora_weights,
|
||||||
lora_weights,
|
our_outputs,
|
||||||
our_outputs,
|
indices,
|
||||||
b_seq_start_loc,
|
slice_offset,
|
||||||
seq_len_tensor,
|
slice_size=hidden_size,
|
||||||
lora_indices_tensor,
|
add_inputs=True,
|
||||||
batches,
|
)
|
||||||
max_seq_length,
|
|
||||||
token_nums,
|
|
||||||
slice_offset,
|
|
||||||
hidden_size,
|
|
||||||
add_inputs=True,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
|
|
||||||
bgmv_expand_slice(
|
|
||||||
inputs_tensor,
|
|
||||||
lora_weights,
|
|
||||||
our_outputs,
|
|
||||||
indices,
|
|
||||||
slice_offset,
|
|
||||||
slice_size=hidden_size,
|
|
||||||
add_inputs=True,
|
|
||||||
)
|
|
||||||
ref_torch_groupgemm(
|
ref_torch_groupgemm(
|
||||||
ref_outputs[:, slice_offset:slice_offset + hidden_size],
|
ref_outputs[:, slice_offset:slice_offset + hidden_size],
|
||||||
inputs_tensor,
|
inputs_tensor,
|
||||||
|
@ -3,6 +3,8 @@ This script is mainly used to test whether trtion kernels can run normally
|
|||||||
under different conditions, including various batches, numbers of LoRA , and
|
under different conditions, including various batches, numbers of LoRA , and
|
||||||
maximum ranks.
|
maximum ranks.
|
||||||
"""
|
"""
|
||||||
|
from threading import Lock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -11,12 +13,13 @@ import vllm.lora.ops.bgmv_expand
|
|||||||
import vllm.lora.ops.bgmv_expand_slice
|
import vllm.lora.ops.bgmv_expand_slice
|
||||||
import vllm.lora.ops.bgmv_shrink
|
import vllm.lora.ops.bgmv_shrink
|
||||||
import vllm.lora.ops.sgmv_expand
|
import vllm.lora.ops.sgmv_expand
|
||||||
import vllm.lora.ops.sgmv_expand_slice
|
|
||||||
import vllm.lora.ops.sgmv_shrink # noqa: F401
|
import vllm.lora.ops.sgmv_shrink # noqa: F401
|
||||||
|
from vllm.lora.ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
from .utils import (generate_data, generate_data_for_expand_nslices,
|
from .utils import (assert_close, generate_data,
|
||||||
ref_torch_groupgemm)
|
generate_data_for_expand_nslices,
|
||||||
|
generate_data_for_nslices, ref_torch_groupgemm)
|
||||||
|
|
||||||
HIDDEN_SIZES = [4097]
|
HIDDEN_SIZES = [4097]
|
||||||
|
|
||||||
@ -28,31 +31,23 @@ SCALES = [0.5]
|
|||||||
SEED = [0]
|
SEED = [0]
|
||||||
CUDA_DEVICES = [f"cuda:{0}"]
|
CUDA_DEVICES = [f"cuda:{0}"]
|
||||||
|
|
||||||
|
|
||||||
def assert_close(a, b):
|
|
||||||
rtol, atol = {
|
|
||||||
torch.float16: (6e-2, 6e-2),
|
|
||||||
torch.bfloat16: (6e-2, 6e-2),
|
|
||||||
torch.float32: (1e-2, 1e-2),
|
|
||||||
}[a.dtype]
|
|
||||||
torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
|
|
||||||
|
|
||||||
|
|
||||||
# Unlike test_punica_sizes.py, we directly utilize custom op for
|
# Unlike test_punica_sizes.py, we directly utilize custom op for
|
||||||
# testing, which verifies the correct registration of these ops.
|
# testing, which verifies the correct registration of these ops.
|
||||||
bgmv_expand = torch.ops.vllm.bgmv_expand
|
bgmv_expand = torch.ops.vllm.bgmv_expand
|
||||||
bgmv_expand_slice = torch.ops.vllm.bgmv_expand_slice
|
bgmv_expand_slice = torch.ops.vllm.bgmv_expand_slice
|
||||||
bgmv_shrink = torch.ops.vllm.bgmv_shrink
|
bgmv_shrink = torch.ops.vllm.bgmv_shrink
|
||||||
sgmv_expand = torch.ops.vllm.sgmv_expand
|
sgmv_expand = torch.ops.vllm.sgmv_expand
|
||||||
sgmv_expand_slice = torch.ops.vllm.sgmv_expand_slice
|
|
||||||
sgmv_shrink = torch.ops.vllm.sgmv_shrink
|
sgmv_shrink = torch.ops.vllm.sgmv_shrink
|
||||||
|
|
||||||
|
_dict_lock = Lock()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("batches", BATCHES)
|
@pytest.mark.parametrize("batches", BATCHES)
|
||||||
@pytest.mark.parametrize("num_loras", NUM_LORA)
|
@pytest.mark.parametrize("num_loras", NUM_LORA)
|
||||||
@pytest.mark.parametrize("rank", MAX_RANKS)
|
@pytest.mark.parametrize("rank", MAX_RANKS)
|
||||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||||
@pytest.mark.parametrize("scaling", SCALES)
|
@pytest.mark.parametrize("scaling", SCALES)
|
||||||
|
@pytest.mark.parametrize("nslices", [1, 2, 3])
|
||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
|
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
|
||||||
@pytest.mark.parametrize("seed", SEED)
|
@pytest.mark.parametrize("seed", SEED)
|
||||||
@ -63,6 +58,7 @@ def test_punica_sgmv(
|
|||||||
rank: int,
|
rank: int,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
scaling: float,
|
scaling: float,
|
||||||
|
nslices: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
op_type: str,
|
op_type: str,
|
||||||
seed: int,
|
seed: int,
|
||||||
@ -74,19 +70,20 @@ def test_punica_sgmv(
|
|||||||
seq_length = 128
|
seq_length = 128
|
||||||
(
|
(
|
||||||
inputs_tensor,
|
inputs_tensor,
|
||||||
lora_weights,
|
lora_weights_lst,
|
||||||
our_out_tensor,
|
our_out_tensor,
|
||||||
ref_out_tensor,
|
ref_out_tensor,
|
||||||
b_seq_start_loc,
|
b_seq_start_loc,
|
||||||
lora_indices_tensor,
|
lora_indices_tensor,
|
||||||
seq_len_tensor,
|
seq_len_tensor,
|
||||||
indices,
|
indices,
|
||||||
) = generate_data(
|
) = generate_data_for_nslices(
|
||||||
batches,
|
batches,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
num_loras,
|
num_loras,
|
||||||
rank,
|
rank,
|
||||||
seq_length,
|
seq_length,
|
||||||
|
nslices,
|
||||||
dtype,
|
dtype,
|
||||||
op_type,
|
op_type,
|
||||||
device,
|
device,
|
||||||
@ -98,43 +95,64 @@ def test_punica_sgmv(
|
|||||||
else:
|
else:
|
||||||
max_seq_length = max_seq_length.item()
|
max_seq_length = max_seq_length.item()
|
||||||
if op_type == "shrink":
|
if op_type == "shrink":
|
||||||
sgmv_shrink(
|
# Preventing cache error pointer.
|
||||||
inputs_tensor,
|
with _dict_lock:
|
||||||
lora_weights,
|
_LORA_A_PTR_DICT.clear()
|
||||||
our_out_tensor,
|
sgmv_shrink(
|
||||||
b_seq_start_loc,
|
inputs_tensor,
|
||||||
seq_len_tensor,
|
lora_weights_lst,
|
||||||
lora_indices_tensor,
|
our_out_tensor,
|
||||||
batches,
|
b_seq_start_loc,
|
||||||
max_seq_length,
|
seq_len_tensor,
|
||||||
token_nums,
|
lora_indices_tensor,
|
||||||
scaling,
|
batches,
|
||||||
)
|
max_seq_length,
|
||||||
|
token_nums,
|
||||||
|
scaling,
|
||||||
|
)
|
||||||
|
for index in range(nslices):
|
||||||
|
ref_torch_groupgemm(
|
||||||
|
ref_out_tensor[index],
|
||||||
|
inputs_tensor,
|
||||||
|
lora_weights_lst[index],
|
||||||
|
lora_indices_tensor,
|
||||||
|
seq_len_tensor,
|
||||||
|
batches,
|
||||||
|
scaling,
|
||||||
|
op_type,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
sgmv_expand(
|
with _dict_lock:
|
||||||
inputs_tensor,
|
_LORA_B_PTR_DICT.clear()
|
||||||
lora_weights,
|
sgmv_expand(
|
||||||
our_out_tensor,
|
inputs_tensor,
|
||||||
b_seq_start_loc,
|
lora_weights_lst,
|
||||||
seq_len_tensor,
|
our_out_tensor,
|
||||||
lora_indices_tensor,
|
b_seq_start_loc,
|
||||||
batches,
|
seq_len_tensor,
|
||||||
max_seq_length,
|
lora_indices_tensor,
|
||||||
token_nums,
|
batches,
|
||||||
add_inputs=True,
|
max_seq_length,
|
||||||
)
|
token_nums,
|
||||||
ref_torch_groupgemm(
|
offset_start=0,
|
||||||
ref_out_tensor,
|
add_inputs=True,
|
||||||
inputs_tensor,
|
)
|
||||||
lora_weights,
|
|
||||||
lora_indices_tensor,
|
slice_offset = 0
|
||||||
seq_len_tensor,
|
for index in range(nslices):
|
||||||
batches,
|
lora_weights = lora_weights_lst[index]
|
||||||
scaling if op_type == "shrink" else 1.0,
|
ref_torch_groupgemm(
|
||||||
op_type,
|
ref_out_tensor[:, slice_offset:slice_offset + hidden_size],
|
||||||
)
|
inputs_tensor[index],
|
||||||
if op_type == "shrink":
|
lora_weights,
|
||||||
ref_out_tensor = ref_out_tensor.to(torch.float32)
|
lora_indices_tensor,
|
||||||
|
seq_len_tensor,
|
||||||
|
batches,
|
||||||
|
1.0,
|
||||||
|
op_type,
|
||||||
|
)
|
||||||
|
slice_offset += hidden_size
|
||||||
|
|
||||||
assert_close(our_out_tensor, ref_out_tensor)
|
assert_close(our_out_tensor, ref_out_tensor)
|
||||||
|
|
||||||
|
|
||||||
@ -220,24 +238,22 @@ def test_punica_bgmv(
|
|||||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||||
@pytest.mark.parametrize("nslices", [2, 3])
|
@pytest.mark.parametrize("nslices", [2, 3])
|
||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
@pytest.mark.parametrize("op_type", ["sgmv", "bgmv"])
|
|
||||||
@pytest.mark.parametrize("seed", SEED)
|
@pytest.mark.parametrize("seed", SEED)
|
||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
def test_punica_expand_nslices(
|
def test_punica_bgmv_expand_nslices(
|
||||||
batches: int,
|
batches: int,
|
||||||
num_loras: int,
|
num_loras: int,
|
||||||
rank: int,
|
rank: int,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
nslices: int,
|
nslices: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
op_type: str,
|
|
||||||
seed: int,
|
seed: int,
|
||||||
device: str,
|
device: str,
|
||||||
):
|
):
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
current_platform.seed_everything(seed)
|
current_platform.seed_everything(seed)
|
||||||
|
|
||||||
seq_length = 128 if op_type == "sgmv" else 1
|
seq_length = 1
|
||||||
(
|
(
|
||||||
inputs_tensor,
|
inputs_tensor,
|
||||||
lora_weights_lst,
|
lora_weights_lst,
|
||||||
@ -257,40 +273,18 @@ def test_punica_expand_nslices(
|
|||||||
nslices,
|
nslices,
|
||||||
device,
|
device,
|
||||||
)
|
)
|
||||||
max_seq_length = seq_len_tensor.max()
|
|
||||||
token_nums = seq_len_tensor.sum().item()
|
|
||||||
if isinstance(max_seq_length, tuple):
|
|
||||||
max_seq_length = max_seq_length[0].item()
|
|
||||||
else:
|
|
||||||
max_seq_length = max_seq_length.item()
|
|
||||||
slice_offset = 0
|
slice_offset = 0
|
||||||
for index in range(nslices):
|
for index in range(nslices):
|
||||||
lora_weights = lora_weights_lst[index]
|
lora_weights = lora_weights_lst[index]
|
||||||
if op_type == "sgmv":
|
bgmv_expand_slice(
|
||||||
sgmv_expand_slice(
|
inputs_tensor,
|
||||||
inputs_tensor,
|
lora_weights,
|
||||||
lora_weights,
|
our_outputs,
|
||||||
our_outputs,
|
indices,
|
||||||
b_seq_start_loc,
|
slice_offset,
|
||||||
seq_len_tensor,
|
slice_size=hidden_size,
|
||||||
lora_indices_tensor,
|
add_inputs=True,
|
||||||
batches,
|
)
|
||||||
max_seq_length,
|
|
||||||
token_nums,
|
|
||||||
slice_offset,
|
|
||||||
hidden_size,
|
|
||||||
add_inputs=True,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
bgmv_expand_slice(
|
|
||||||
inputs_tensor,
|
|
||||||
lora_weights,
|
|
||||||
our_outputs,
|
|
||||||
indices,
|
|
||||||
slice_offset,
|
|
||||||
slice_size=hidden_size,
|
|
||||||
add_inputs=True,
|
|
||||||
)
|
|
||||||
ref_torch_groupgemm(
|
ref_torch_groupgemm(
|
||||||
ref_outputs[:, slice_offset:slice_offset + hidden_size],
|
ref_outputs[:, slice_offset:slice_offset + hidden_size],
|
||||||
inputs_tensor,
|
inputs_tensor,
|
||||||
|
@ -18,11 +18,13 @@ class DummyLoRAManager:
|
|||||||
def get_module_lora(self, module_name: str) -> LoRALayerWeights:
|
def get_module_lora(self, module_name: str) -> LoRALayerWeights:
|
||||||
return self._loras[module_name]
|
return self._loras[module_name]
|
||||||
|
|
||||||
def init_random_lora(self,
|
def init_random_lora(
|
||||||
module_name: str,
|
self,
|
||||||
weight: torch.Tensor,
|
module_name: str,
|
||||||
rank: int = 8,
|
weight: torch.Tensor,
|
||||||
generate_embeddings_tensor: int = 0):
|
rank: int = 8,
|
||||||
|
generate_embeddings_tensor: int = 0,
|
||||||
|
):
|
||||||
lora = LoRALayerWeights(
|
lora = LoRALayerWeights(
|
||||||
module_name,
|
module_name,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
@ -35,21 +37,25 @@ class DummyLoRAManager:
|
|||||||
device=self._device),
|
device=self._device),
|
||||||
)
|
)
|
||||||
if generate_embeddings_tensor:
|
if generate_embeddings_tensor:
|
||||||
lora.embeddings_tensor = torch.rand(5,
|
lora.embeddings_tensor = torch.rand(
|
||||||
generate_embeddings_tensor,
|
5,
|
||||||
dtype=weight.dtype,
|
generate_embeddings_tensor,
|
||||||
device=self._device)
|
dtype=weight.dtype,
|
||||||
|
device=self._device,
|
||||||
|
)
|
||||||
self.set_module_lora(module_name, lora)
|
self.set_module_lora(module_name, lora)
|
||||||
|
|
||||||
return lora
|
return lora
|
||||||
|
|
||||||
def init_lora(self,
|
def init_lora(
|
||||||
module_name: str,
|
self,
|
||||||
input_dim: int,
|
module_name: str,
|
||||||
output_dim: int,
|
input_dim: int,
|
||||||
rank=8,
|
output_dim: int,
|
||||||
noop=False,
|
rank=8,
|
||||||
embeddings_tensor=None):
|
noop=False,
|
||||||
|
embeddings_tensor=None,
|
||||||
|
):
|
||||||
lora = LoRALayerWeights(
|
lora = LoRALayerWeights(
|
||||||
module_name,
|
module_name,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
@ -125,8 +131,16 @@ def ref_torch_groupgemm(
|
|||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
def generate_data(batches, hidden_size, lora_nums, max_rank, seq_length, dtype,
|
def generate_data(
|
||||||
op_type, device):
|
batches,
|
||||||
|
hidden_size,
|
||||||
|
lora_nums,
|
||||||
|
max_rank,
|
||||||
|
seq_length,
|
||||||
|
dtype,
|
||||||
|
op_type,
|
||||||
|
device,
|
||||||
|
):
|
||||||
seq_len_tensor = torch.randint(seq_length, seq_length + 1,
|
seq_len_tensor = torch.randint(seq_length, seq_length + 1,
|
||||||
(batches, )).to(device)
|
(batches, )).to(device)
|
||||||
b_seq_start_loc = torch.cumsum(
|
b_seq_start_loc = torch.cumsum(
|
||||||
@ -187,8 +201,16 @@ def generate_data(batches, hidden_size, lora_nums, max_rank, seq_length, dtype,
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def generate_data_for_expand_nslices(batches, hidden_size, lora_nums, max_rank,
|
def generate_data_for_expand_nslices(
|
||||||
seq_length, dtype, nslices, device):
|
batches,
|
||||||
|
hidden_size,
|
||||||
|
lora_nums,
|
||||||
|
max_rank,
|
||||||
|
seq_length,
|
||||||
|
dtype,
|
||||||
|
nslices,
|
||||||
|
device,
|
||||||
|
):
|
||||||
seq_len_tensor = torch.randint(seq_length, seq_length + 1,
|
seq_len_tensor = torch.randint(seq_length, seq_length + 1,
|
||||||
(batches, )).to(device)
|
(batches, )).to(device)
|
||||||
b_seq_start_loc = torch.cumsum(
|
b_seq_start_loc = torch.cumsum(
|
||||||
@ -221,7 +243,87 @@ def generate_data_for_expand_nslices(batches, hidden_size, lora_nums, max_rank,
|
|||||||
for b_id in range(batches):
|
for b_id in range(batches):
|
||||||
lora_index = lora_indices_tensor[b_id]
|
lora_index = lora_indices_tensor[b_id]
|
||||||
indices[current_offset:current_offset +
|
indices[current_offset:current_offset +
|
||||||
seq_len_tensor[b_id]] = lora_index.item()
|
seq_len_tensor[b_id]] = (lora_index.item())
|
||||||
|
current_offset += seq_len_tensor[b_id].item()
|
||||||
|
|
||||||
|
lora_indices_tensor = lora_indices_tensor.to(device)
|
||||||
|
return (
|
||||||
|
inputs_tensor,
|
||||||
|
lora_weights_lst,
|
||||||
|
our_out_tensor,
|
||||||
|
ref_out_tensor,
|
||||||
|
b_seq_start_loc,
|
||||||
|
lora_indices_tensor,
|
||||||
|
seq_len_tensor,
|
||||||
|
indices,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_data_for_nslices(
|
||||||
|
batches,
|
||||||
|
hidden_size,
|
||||||
|
lora_nums,
|
||||||
|
max_rank,
|
||||||
|
seq_length,
|
||||||
|
nslices,
|
||||||
|
dtype,
|
||||||
|
op_type,
|
||||||
|
device,
|
||||||
|
):
|
||||||
|
seq_len_tensor = torch.randint(seq_length, seq_length + 1,
|
||||||
|
(batches, )).to(device)
|
||||||
|
b_seq_start_loc = torch.cumsum(
|
||||||
|
torch.tensor([0] + seq_len_tensor[:-1].tolist(), dtype=torch.long),
|
||||||
|
dim=0,
|
||||||
|
).to(device)
|
||||||
|
total_tokens = seq_len_tensor.sum()
|
||||||
|
|
||||||
|
lora_weights_lst = []
|
||||||
|
if op_type == "shrink":
|
||||||
|
|
||||||
|
inputs_tensor = torch.rand((total_tokens, hidden_size),
|
||||||
|
dtype=dtype).to(device)
|
||||||
|
|
||||||
|
for _ in range(nslices):
|
||||||
|
if op_type == "shrink":
|
||||||
|
lora_weights_lst.append(
|
||||||
|
torch.rand(
|
||||||
|
(lora_nums, max_rank, hidden_size), # col-major
|
||||||
|
dtype=dtype,
|
||||||
|
).to(device))
|
||||||
|
# NOTE shrink kernel using torch.float32 as output type
|
||||||
|
# shrink op need atomic_add, so output is initinized by 0
|
||||||
|
our_out_tensor = torch.zeros(
|
||||||
|
(nslices, total_tokens, max_rank),
|
||||||
|
dtype=torch.float32,
|
||||||
|
).to(device)
|
||||||
|
else:
|
||||||
|
inputs_tensor = torch.rand(
|
||||||
|
(nslices, total_tokens, max_rank),
|
||||||
|
dtype=dtype,
|
||||||
|
).to(device)
|
||||||
|
for _ in range(nslices):
|
||||||
|
lora_weights_lst.append(
|
||||||
|
torch.rand(
|
||||||
|
(lora_nums, hidden_size, max_rank), # col-major
|
||||||
|
dtype=dtype,
|
||||||
|
).to(device))
|
||||||
|
# expand op needs to complete y+=a@lora_b, so output is
|
||||||
|
# initinized randomly
|
||||||
|
our_out_tensor = torch.rand((total_tokens, hidden_size * nslices),
|
||||||
|
dtype=dtype).to(device)
|
||||||
|
|
||||||
|
# Ensure the same input.
|
||||||
|
ref_out_tensor = our_out_tensor.clone()
|
||||||
|
lora_indices_tensor = torch.randint(0,
|
||||||
|
lora_nums - 1 if lora_nums > 1 else 1,
|
||||||
|
(batches, ))
|
||||||
|
indices = torch.zeros((total_tokens), dtype=torch.long).to(device)
|
||||||
|
current_offset = 0
|
||||||
|
for b_id in range(batches):
|
||||||
|
lora_index = lora_indices_tensor[b_id]
|
||||||
|
indices[current_offset:current_offset +
|
||||||
|
seq_len_tensor[b_id]] = (lora_index.item())
|
||||||
current_offset += seq_len_tensor[b_id].item()
|
current_offset += seq_len_tensor[b_id].item()
|
||||||
|
|
||||||
lora_indices_tensor = lora_indices_tensor.to(device)
|
lora_indices_tensor = lora_indices_tensor.to(device)
|
||||||
|
@ -1,66 +1,109 @@
|
|||||||
"""
|
"""
|
||||||
Based on:
|
Based on:
|
||||||
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
|
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
|
||||||
Punica: Multi-Tenant LoRA Serving.
|
Punica: Multi-Tenant LoRA Serving.
|
||||||
https://arxiv.org/abs/2310.18547
|
https://arxiv.org/abs/2310.18547
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import direct_register_custom_op
|
||||||
|
|
||||||
|
from .utils import _get_lora_b_ptr
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _sgmv_expand_kernel(
|
def _sgmv_expand_kernel(
|
||||||
input_ptr,
|
input_ptr,
|
||||||
lora_ptr,
|
lora_ptr,
|
||||||
out_ptr,
|
out_ptr,
|
||||||
N,
|
N,
|
||||||
K,
|
K,
|
||||||
b_seq_start_loc,
|
b_seq_start_loc,
|
||||||
seq_lens,
|
seq_lens,
|
||||||
lora_indices,
|
lora_indices,
|
||||||
xm_stride,
|
slice_start_loc,
|
||||||
xk_stride, # 1
|
input_d0_stride,
|
||||||
l0_stride, # hidden_size*max_rank
|
input_d1_stride,
|
||||||
lora_k_stride,
|
input_d2_stride, # 1
|
||||||
lora_n_stride,
|
ls_d0_ptr,
|
||||||
cm_stride,
|
ls_d1_ptr,
|
||||||
cn_stride,
|
ls_d2_ptr, # 1
|
||||||
BLOCK_M: tl.constexpr,
|
output_d0_stride,
|
||||||
BLOCK_N: tl.constexpr,
|
output_d1_stride, # 1
|
||||||
BLOCK_K: tl.constexpr,
|
output_hs_ptr,
|
||||||
EVEN_K: tl.constexpr,
|
BLOCK_M: tl.constexpr,
|
||||||
ADD_INPUTS: tl.constexpr,
|
BLOCK_N: tl.constexpr,
|
||||||
CAST_TYPE: tl.constexpr,
|
BLOCK_K: tl.constexpr,
|
||||||
):
|
EVEN_K: tl.constexpr,
|
||||||
|
ADD_INPUTS: tl.constexpr,
|
||||||
|
CAST_TYPE: tl.constexpr,
|
||||||
|
SLICE_NUM: tl.constexpr,
|
||||||
|
SAME_STRIDE: tl.constexpr):
|
||||||
"""
|
"""
|
||||||
The sgmv's expand triton kernel is based on GroupGEMM.
|
|
||||||
|
Similar to the 'sgmv_expand' operator, but with an added parameter
|
||||||
|
'slice_offset'. The reason for not reusing the 'sgmv_expand' operator
|
||||||
|
might be that in the future, we could implement a fusion operator to
|
||||||
|
achieve the current functionality instead of having to call it multiple
|
||||||
|
times.
|
||||||
"""
|
"""
|
||||||
pid = tl.program_id(axis=0)
|
pid = tl.program_id(axis=0)
|
||||||
cur_batch = tl.program_id(axis=1)
|
cur_batch = tl.program_id(axis=1)
|
||||||
|
slice_id = tl.program_id(axis=2)
|
||||||
cta_n_num = tl.cdiv(N, BLOCK_N)
|
cta_n_num = tl.cdiv(N, BLOCK_N)
|
||||||
|
# When the output dimensions of each slice are the same,cur_n=N, otherwise
|
||||||
|
# cur_n=tl.load(output_hs_ptr + slice_id), this situation exists in GQA's
|
||||||
|
# qkv linear.
|
||||||
|
curr_N = N if SAME_STRIDE else tl.load(output_hs_ptr + slice_id)
|
||||||
pid_m = pid // cta_n_num
|
pid_m = pid // cta_n_num
|
||||||
pid_n = pid % cta_n_num
|
pid_n = pid % cta_n_num
|
||||||
M = tl.load(seq_lens + cur_batch)
|
M = tl.load(seq_lens + cur_batch)
|
||||||
if pid_m * BLOCK_M > M:
|
if pid_m * BLOCK_M > M:
|
||||||
return
|
return
|
||||||
|
if pid_n * BLOCK_N > curr_N:
|
||||||
|
return
|
||||||
lora_index = tl.load(lora_indices + cur_batch)
|
lora_index = tl.load(lora_indices + cur_batch)
|
||||||
if lora_index == -1:
|
if lora_index == -1:
|
||||||
return
|
return
|
||||||
|
|
||||||
cur_seq_start = tl.load(b_seq_start_loc + cur_batch)
|
cur_seq_start = tl.load(b_seq_start_loc + cur_batch)
|
||||||
offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
|
offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
|
||||||
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
||||||
offset_k = tl.arange(0, BLOCK_K)
|
offset_k = tl.arange(0, BLOCK_K)
|
||||||
ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M)
|
ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M)
|
||||||
rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N)
|
rbn = tl.max_contiguous(tl.multiple_of(offset_n % curr_N, BLOCK_N),
|
||||||
|
BLOCK_N)
|
||||||
|
# ls_d*_ptr can be either an integer or a pointer
|
||||||
|
if SAME_STRIDE:
|
||||||
|
# integer
|
||||||
|
cur_lora_d0_stride = ls_d0_ptr
|
||||||
|
cur_lora_d1_stride = ls_d1_ptr
|
||||||
|
cur_lora_d2_stride = ls_d2_ptr
|
||||||
|
else:
|
||||||
|
# pointer
|
||||||
|
cur_lora_d0_stride = tl.load(ls_d0_ptr + slice_id)
|
||||||
|
cur_lora_d1_stride = tl.load(ls_d1_ptr + slice_id)
|
||||||
|
cur_lora_d2_stride = tl.load(ls_d2_ptr + slice_id)
|
||||||
|
if SLICE_NUM == 1:
|
||||||
|
cur_input_ptr = input_ptr
|
||||||
|
cur_lora_ptr = lora_ptr
|
||||||
|
|
||||||
a_ptr = (input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride +
|
else:
|
||||||
offset_k[None, :] * xk_stride, )
|
cur_input_ptr = input_ptr + slice_id * input_d0_stride
|
||||||
b_ptr = (lora_ptr + l0_stride * lora_index +
|
cur_lora_ptr = tl.load(lora_ptr + slice_id).to(
|
||||||
offset_k[:, None] * lora_n_stride + rbn[None, :] * lora_k_stride)
|
tl.pointer_type(out_ptr.dtype.element_ty))
|
||||||
|
|
||||||
|
a_ptr = (cur_input_ptr + cur_seq_start * input_d1_stride +
|
||||||
|
ram[:, None] * input_d1_stride +
|
||||||
|
offset_k[None, :] * input_d2_stride, )
|
||||||
|
b_ptr = (cur_lora_ptr + cur_lora_d0_stride * lora_index +
|
||||||
|
offset_k[:, None] * cur_lora_d2_stride +
|
||||||
|
rbn[None, :] * cur_lora_d1_stride)
|
||||||
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||||
for k in range(tl.cdiv(K, BLOCK_K)):
|
for k in range(tl.cdiv(K, BLOCK_K)):
|
||||||
if EVEN_K:
|
if EVEN_K:
|
||||||
@ -74,26 +117,30 @@ def _sgmv_expand_kernel(
|
|||||||
mask=offset_k[:, None] < K - k * BLOCK_K,
|
mask=offset_k[:, None] < K - k * BLOCK_K,
|
||||||
other=0)
|
other=0)
|
||||||
if CAST_TYPE:
|
if CAST_TYPE:
|
||||||
tiled_a = tiled_a.to(lora_ptr.dtype.element_ty)
|
tiled_a = tiled_a.to(cur_lora_ptr.dtype.element_ty)
|
||||||
accumulator += tl.dot(
|
accumulator += tl.dot(
|
||||||
tiled_a,
|
tiled_a,
|
||||||
tiled_b,
|
tiled_b,
|
||||||
)
|
)
|
||||||
a_ptr += BLOCK_K * xk_stride
|
a_ptr += BLOCK_K * input_d2_stride
|
||||||
b_ptr += BLOCK_K * lora_n_stride
|
b_ptr += BLOCK_K * cur_lora_d2_stride
|
||||||
tiled_c = accumulator.to(lora_ptr.dtype.element_ty)
|
|
||||||
|
tiled_c = accumulator.to(cur_lora_ptr.dtype.element_ty)
|
||||||
|
if SLICE_NUM == 1:
|
||||||
|
cur_slice_start = slice_start_loc
|
||||||
|
else:
|
||||||
|
cur_slice_start = tl.load(slice_start_loc + slice_id)
|
||||||
|
|
||||||
offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
|
offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
|
||||||
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + cur_slice_start
|
||||||
c_ptr = (out_ptr + offset_cm[:, None] * cm_stride +
|
c_ptr = (out_ptr + offset_cm[:, None] * output_d0_stride +
|
||||||
offset_cn[None, :] * cn_stride)
|
offset_cn[None, :] * output_d1_stride)
|
||||||
M = tl.load(seq_lens + cur_batch)
|
M = tl.load(seq_lens + cur_batch)
|
||||||
c_mask = (offset_cm[:, None] <
|
c_mask = (offset_cm[:, None] <
|
||||||
(cur_seq_start + M)) & (offset_cn[None, :] < N)
|
(cur_seq_start + M)) & (offset_cn[None, :] <
|
||||||
|
(cur_slice_start + curr_N))
|
||||||
if ADD_INPUTS:
|
if ADD_INPUTS:
|
||||||
# explicitly pass in other=None to tell triton that masked values
|
tiled_out = tl.load(c_ptr, mask=c_mask)
|
||||||
# can be uninitialized. This is OK because the later tl.store operation
|
|
||||||
# uses the same mask, eliminating the risk of garbage values propagating
|
|
||||||
tiled_out = tl.load(c_ptr, mask=c_mask, other=None)
|
|
||||||
tiled_c += tiled_out
|
tiled_c += tiled_out
|
||||||
tl.store(c_ptr, tiled_c, mask=c_mask)
|
tl.store(c_ptr, tiled_c, mask=c_mask)
|
||||||
|
|
||||||
@ -101,7 +148,7 @@ def _sgmv_expand_kernel(
|
|||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def _sgmv_expand(
|
def _sgmv_expand(
|
||||||
inputs: torch.Tensor,
|
inputs: torch.Tensor,
|
||||||
lora_b_weights: torch.Tensor,
|
lora_b_weights: List[torch.Tensor],
|
||||||
output_tensor: torch.Tensor,
|
output_tensor: torch.Tensor,
|
||||||
b_seq_start_loc: torch.Tensor,
|
b_seq_start_loc: torch.Tensor,
|
||||||
seq_len_tensor: torch.Tensor,
|
seq_len_tensor: torch.Tensor,
|
||||||
@ -109,17 +156,18 @@ def _sgmv_expand(
|
|||||||
batches: int,
|
batches: int,
|
||||||
max_seq_length: int,
|
max_seq_length: int,
|
||||||
token_nums: int,
|
token_nums: int,
|
||||||
|
offset_start: int = 0,
|
||||||
add_inputs: bool = False,
|
add_inputs: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
inputs (torch.Tensor): input tensor
|
inputs (torch.Tensor): input tensor
|
||||||
lora_b_weights (torch.Tensor): lora'a weight
|
lora_b_weights (List[torch.Tensor]): lora'b weight
|
||||||
output_tensor (torch.Tensor): output tensor
|
output_tensor (torch.Tensor): output tensor
|
||||||
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
|
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
|
||||||
sequence lengths of the sequences in the batch, used to index
|
sequence lengths of the sequences in the batch, used to index
|
||||||
into sequence. E.g., if the sequence length is [4, 6], it is
|
into sequence. E.g., if the sequence length is [4, 6], it is
|
||||||
[0, 4, 10].
|
[0, 4].
|
||||||
seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence
|
seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence
|
||||||
length of the sequences in the batch.
|
length of the sequences in the batch.
|
||||||
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
|
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
|
||||||
@ -130,77 +178,80 @@ def _sgmv_expand(
|
|||||||
batch.
|
batch.
|
||||||
token_nums (int): The token numbers in the batch. Used to verify if the
|
token_nums (int): The token numbers in the batch. Used to verify if the
|
||||||
token numbers in the inputs matches the one in the metadata.
|
token numbers in the inputs matches the one in the metadata.
|
||||||
add_inputs (bool, optional): Defaults to False, adds the final lora
|
offset_start (int, optional): Offset start for output_tensor.
|
||||||
results to the output.
|
Defaults to 0.
|
||||||
|
add_inputs (bool, optional): Whether to add the input tensor to the
|
||||||
|
output tensor. Defaults to False.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
|
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
|
||||||
assert lora_b_weights.dtype in [
|
for weight in lora_b_weights:
|
||||||
torch.float16,
|
assert weight.dtype in [torch.float16, torch.bfloat16]
|
||||||
torch.bfloat16,
|
|
||||||
]
|
assert inputs.size(1) == token_nums
|
||||||
assert inputs.size(0) == token_nums
|
assert inputs.size(0) == len(lora_b_weights)
|
||||||
assert inputs.size(1) == lora_b_weights.size(-1)
|
|
||||||
assert b_seq_start_loc.size(0) == batches
|
assert b_seq_start_loc.size(0) == batches
|
||||||
assert lora_indices_tensor.size(0) == batches
|
assert lora_indices_tensor.size(0) == batches
|
||||||
assert inputs.is_contiguous()
|
|
||||||
assert output_tensor.is_contiguous()
|
assert output_tensor.is_contiguous()
|
||||||
|
(slice_start_tensor, lora_ptr_tensor, lora_strides_d0_tensor,
|
||||||
if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank)
|
lora_strides_d1_tensor, lora_strides_d2_tensor, hidden_sizes_tensor,
|
||||||
assert lora_b_weights.size(1) == 1
|
same_stride, MAX_N) = _get_lora_b_ptr(lora_b_weights, offset_start,
|
||||||
lora_b_weights = lora_b_weights.squeeze(dim=1)
|
b_seq_start_loc.device)
|
||||||
else:
|
|
||||||
assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank)
|
|
||||||
|
|
||||||
assert lora_b_weights.is_contiguous()
|
|
||||||
|
|
||||||
# TODO tuning this config
|
# TODO tuning this config
|
||||||
|
K = lora_b_weights[0].shape[-1] # K= rank
|
||||||
|
|
||||||
N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size
|
BLOCK_M = 64
|
||||||
BLOCK_M = 32
|
BLOCK_N = 128
|
||||||
BLOCK_N = 32
|
|
||||||
BLOCK_K = 16
|
BLOCK_K = 16
|
||||||
EVEN_K = K % BLOCK_K == 0
|
EVEN_K = K % BLOCK_K == 0
|
||||||
ADD_INPUTS = add_inputs
|
ADD_INPUTS = add_inputs
|
||||||
CAST_TYPE = False
|
CAST_TYPE = False
|
||||||
if inputs.dtype == torch.float32 and lora_b_weights.dtype in [
|
|
||||||
|
if inputs.dtype == torch.float32 and lora_b_weights[0].dtype in [
|
||||||
torch.float16,
|
torch.float16,
|
||||||
torch.bfloat16,
|
torch.bfloat16,
|
||||||
]:
|
]:
|
||||||
CAST_TYPE = True
|
CAST_TYPE = True
|
||||||
grid = (
|
grid = (
|
||||||
triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N),
|
triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(MAX_N, BLOCK_N),
|
||||||
batches,
|
batches,
|
||||||
|
len(lora_b_weights),
|
||||||
)
|
)
|
||||||
_sgmv_expand_kernel[grid](
|
_sgmv_expand_kernel[grid](
|
||||||
inputs,
|
inputs,
|
||||||
lora_b_weights,
|
lora_ptr_tensor,
|
||||||
output_tensor,
|
output_tensor,
|
||||||
N,
|
MAX_N,
|
||||||
K,
|
K,
|
||||||
b_seq_start_loc,
|
b_seq_start_loc,
|
||||||
seq_len_tensor,
|
seq_len_tensor,
|
||||||
lora_indices_tensor,
|
lora_indices_tensor,
|
||||||
|
slice_start_tensor,
|
||||||
inputs.stride(0),
|
inputs.stride(0),
|
||||||
inputs.stride(1),
|
inputs.stride(1),
|
||||||
lora_b_weights.stride(0),
|
inputs.stride(2),
|
||||||
lora_b_weights.stride(1),
|
lora_strides_d0_tensor,
|
||||||
lora_b_weights.stride(2),
|
lora_strides_d1_tensor,
|
||||||
|
lora_strides_d2_tensor,
|
||||||
output_tensor.stride(0),
|
output_tensor.stride(0),
|
||||||
output_tensor.stride(1),
|
output_tensor.stride(1),
|
||||||
|
hidden_sizes_tensor,
|
||||||
BLOCK_M,
|
BLOCK_M,
|
||||||
BLOCK_N,
|
BLOCK_N,
|
||||||
BLOCK_K,
|
BLOCK_K,
|
||||||
EVEN_K,
|
EVEN_K,
|
||||||
ADD_INPUTS,
|
ADD_INPUTS,
|
||||||
CAST_TYPE,
|
CAST_TYPE,
|
||||||
|
len(lora_b_weights),
|
||||||
|
same_stride,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
def sgmv_expand_fake(
|
def _sgmv_expand_fake(
|
||||||
inputs: torch.Tensor,
|
inputs: torch.Tensor,
|
||||||
lora_b_weights: torch.Tensor,
|
lora_b_weights: List[torch.Tensor],
|
||||||
output_tensor: torch.Tensor,
|
output_tensor: torch.Tensor,
|
||||||
b_seq_start_loc: torch.Tensor,
|
b_seq_start_loc: torch.Tensor,
|
||||||
seq_len_tensor: torch.Tensor,
|
seq_len_tensor: torch.Tensor,
|
||||||
@ -208,18 +259,18 @@ def sgmv_expand_fake(
|
|||||||
batches: int,
|
batches: int,
|
||||||
max_seq_length: int,
|
max_seq_length: int,
|
||||||
token_nums: int,
|
token_nums: int,
|
||||||
|
offset_start: int = 0,
|
||||||
add_inputs: bool = False,
|
add_inputs: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
direct_register_custom_op(
|
direct_register_custom_op(
|
||||||
op_name="sgmv_expand",
|
op_name="sgmv_expand",
|
||||||
op_func=_sgmv_expand,
|
op_func=_sgmv_expand,
|
||||||
mutates_args=["output_tensor"],
|
mutates_args=["output_tensor"],
|
||||||
fake_impl=sgmv_expand_fake,
|
fake_impl=_sgmv_expand_fake,
|
||||||
)
|
)
|
||||||
sgmv_expand = torch.ops.vllm.sgmv_expand
|
sgmv_expand = torch.ops.vllm.sgmv_expand
|
||||||
|
|
||||||
|
@ -1,241 +0,0 @@
|
|||||||
"""
|
|
||||||
Based on:
|
|
||||||
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
|
|
||||||
Punica: Multi-Tenant LoRA Serving.
|
|
||||||
https://arxiv.org/abs/2310.18547
|
|
||||||
"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import triton
|
|
||||||
import triton.language as tl
|
|
||||||
|
|
||||||
from vllm.utils import direct_register_custom_op
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def _sgmv_expand_slice_kernel(
|
|
||||||
input_ptr,
|
|
||||||
lora_ptr,
|
|
||||||
out_ptr,
|
|
||||||
N,
|
|
||||||
K,
|
|
||||||
b_seq_start_loc,
|
|
||||||
seq_lens,
|
|
||||||
lora_indices,
|
|
||||||
xm_stride,
|
|
||||||
xk_stride, # 1
|
|
||||||
l0_stride, # hidden_size*max_rank
|
|
||||||
lora_k_stride,
|
|
||||||
lora_n_stride,
|
|
||||||
cm_stride,
|
|
||||||
cn_stride,
|
|
||||||
slice_offset,
|
|
||||||
BLOCK_M: tl.constexpr,
|
|
||||||
BLOCK_N: tl.constexpr,
|
|
||||||
BLOCK_K: tl.constexpr,
|
|
||||||
EVEN_K: tl.constexpr,
|
|
||||||
ADD_INPUTS: tl.constexpr,
|
|
||||||
CAST_TYPE: tl.constexpr,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
|
|
||||||
Similar to the 'sgmv_expand' operator, but with an added parameter
|
|
||||||
'slice_offset'. The reason for not reusing the 'sgmv_expand' operator
|
|
||||||
might be that in the future, we could implement a fusion operator to
|
|
||||||
achieve the current functionality instead of having to call it multiple
|
|
||||||
times.
|
|
||||||
"""
|
|
||||||
pid = tl.program_id(axis=0)
|
|
||||||
cur_batch = tl.program_id(axis=1)
|
|
||||||
cta_n_num = tl.cdiv(N, BLOCK_N)
|
|
||||||
pid_m = pid // cta_n_num
|
|
||||||
pid_n = pid % cta_n_num
|
|
||||||
M = tl.load(seq_lens + cur_batch)
|
|
||||||
if pid_m * BLOCK_M > M:
|
|
||||||
return
|
|
||||||
lora_index = tl.load(lora_indices + cur_batch)
|
|
||||||
if lora_index == -1:
|
|
||||||
return
|
|
||||||
cur_seq_start = tl.load(b_seq_start_loc + cur_batch)
|
|
||||||
offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
|
|
||||||
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
|
||||||
offset_k = tl.arange(0, BLOCK_K)
|
|
||||||
ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M)
|
|
||||||
rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N)
|
|
||||||
|
|
||||||
a_ptr = (input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride +
|
|
||||||
offset_k[None, :] * xk_stride, )
|
|
||||||
b_ptr = (lora_ptr + l0_stride * lora_index +
|
|
||||||
offset_k[:, None] * lora_n_stride + rbn[None, :] * lora_k_stride)
|
|
||||||
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
|
||||||
for k in range(tl.cdiv(K, BLOCK_K)):
|
|
||||||
if EVEN_K:
|
|
||||||
tiled_a = tl.load(a_ptr)
|
|
||||||
tiled_b = tl.load(b_ptr)
|
|
||||||
else:
|
|
||||||
tiled_a = tl.load(a_ptr,
|
|
||||||
mask=offset_k[None, :] < K - k * BLOCK_K,
|
|
||||||
other=0)
|
|
||||||
tiled_b = tl.load(b_ptr,
|
|
||||||
mask=offset_k[:, None] < K - k * BLOCK_K,
|
|
||||||
other=0)
|
|
||||||
if CAST_TYPE:
|
|
||||||
tiled_a = tiled_a.to(lora_ptr.dtype.element_ty)
|
|
||||||
accumulator += tl.dot(
|
|
||||||
tiled_a,
|
|
||||||
tiled_b,
|
|
||||||
)
|
|
||||||
a_ptr += BLOCK_K * xk_stride
|
|
||||||
b_ptr += BLOCK_K * lora_n_stride
|
|
||||||
tiled_c = accumulator.to(lora_ptr.dtype.element_ty)
|
|
||||||
offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
|
|
||||||
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + slice_offset
|
|
||||||
c_ptr = (out_ptr + offset_cm[:, None] * cm_stride +
|
|
||||||
offset_cn[None, :] * cn_stride)
|
|
||||||
M = tl.load(seq_lens + cur_batch)
|
|
||||||
c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & (offset_cn[None, :] <
|
|
||||||
(slice_offset + N))
|
|
||||||
if ADD_INPUTS:
|
|
||||||
# explicitly pass in other=None to tell triton that masked values
|
|
||||||
# can be uninitialized. This is OK because the later tl.store operation
|
|
||||||
# uses the same mask, eliminating the risk of garbage values propagating
|
|
||||||
tiled_out = tl.load(c_ptr, mask=c_mask, other=None)
|
|
||||||
tiled_c += tiled_out
|
|
||||||
tl.store(c_ptr, tiled_c, mask=c_mask)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def _sgmv_expand_slice(
|
|
||||||
inputs: torch.Tensor,
|
|
||||||
lora_b_weights: torch.Tensor,
|
|
||||||
output_tensor: torch.Tensor,
|
|
||||||
b_seq_start_loc: torch.Tensor,
|
|
||||||
seq_len_tensor: torch.Tensor,
|
|
||||||
lora_indices_tensor: torch.Tensor,
|
|
||||||
batches: int,
|
|
||||||
max_seq_length: int,
|
|
||||||
token_nums: int,
|
|
||||||
slice_offset: int,
|
|
||||||
slice_size: int,
|
|
||||||
add_inputs: bool = False,
|
|
||||||
) -> None:
|
|
||||||
"""_summary_
|
|
||||||
|
|
||||||
Args:
|
|
||||||
inputs (torch.Tensor): input tensor
|
|
||||||
lora_b_weights (torch.Tensor): lora'a weight
|
|
||||||
output_tensor (torch.Tensor): output tensor
|
|
||||||
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
|
|
||||||
sequence lengths of the sequences in the batch, used to index
|
|
||||||
into sequence. E.g., if the sequence length is [4, 6], it is
|
|
||||||
[0, 4, 10].
|
|
||||||
seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence
|
|
||||||
length of the sequences in the batch
|
|
||||||
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
|
|
||||||
corresponding to each batch. An index of -1 means no lora should be
|
|
||||||
applied.
|
|
||||||
batches (int): batch size
|
|
||||||
max_seq_length (int): The max sequence lengths of the sequences
|
|
||||||
in the batch
|
|
||||||
token_nums (int): The token numbers in the batch. Used to verify if the
|
|
||||||
token numbers in the inputs matches the one in the metadata.
|
|
||||||
slice_offset (int): output_tensor's offset
|
|
||||||
slice_size (int): current output_tensor's size
|
|
||||||
add_inputs (bool, optional): Defaults to False, adds the final lora
|
|
||||||
results to the output.
|
|
||||||
"""
|
|
||||||
|
|
||||||
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
|
|
||||||
assert lora_b_weights.dtype in [
|
|
||||||
torch.float16,
|
|
||||||
torch.bfloat16,
|
|
||||||
]
|
|
||||||
assert inputs.size(0) == token_nums
|
|
||||||
assert inputs.size(1) == lora_b_weights.size(-1)
|
|
||||||
assert b_seq_start_loc.size(0) == batches
|
|
||||||
assert lora_indices_tensor.size(0) == batches
|
|
||||||
assert slice_size == lora_b_weights.size(-2)
|
|
||||||
assert inputs.is_contiguous()
|
|
||||||
assert output_tensor.is_contiguous()
|
|
||||||
|
|
||||||
if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank)
|
|
||||||
assert lora_b_weights.size(1) == 1
|
|
||||||
lora_b_weights = lora_b_weights.squeeze(dim=1)
|
|
||||||
else:
|
|
||||||
assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank)
|
|
||||||
|
|
||||||
assert lora_b_weights.is_contiguous()
|
|
||||||
|
|
||||||
# TODO tuning this config
|
|
||||||
N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size
|
|
||||||
|
|
||||||
BLOCK_M = 32
|
|
||||||
BLOCK_N = 32
|
|
||||||
BLOCK_K = 16
|
|
||||||
EVEN_K = K % BLOCK_K == 0
|
|
||||||
ADD_INPUTS = add_inputs
|
|
||||||
CAST_TYPE = False
|
|
||||||
if inputs.dtype == torch.float32 and lora_b_weights.dtype in [
|
|
||||||
torch.float16,
|
|
||||||
torch.bfloat16,
|
|
||||||
]:
|
|
||||||
CAST_TYPE = True
|
|
||||||
grid = (
|
|
||||||
triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N),
|
|
||||||
batches,
|
|
||||||
)
|
|
||||||
_sgmv_expand_slice_kernel[grid](
|
|
||||||
inputs,
|
|
||||||
lora_b_weights,
|
|
||||||
output_tensor,
|
|
||||||
N,
|
|
||||||
K,
|
|
||||||
b_seq_start_loc,
|
|
||||||
seq_len_tensor,
|
|
||||||
lora_indices_tensor,
|
|
||||||
inputs.stride(0),
|
|
||||||
inputs.stride(1),
|
|
||||||
lora_b_weights.stride(0),
|
|
||||||
lora_b_weights.stride(1),
|
|
||||||
lora_b_weights.stride(2),
|
|
||||||
output_tensor.stride(0),
|
|
||||||
output_tensor.stride(1),
|
|
||||||
slice_offset,
|
|
||||||
BLOCK_M,
|
|
||||||
BLOCK_N,
|
|
||||||
BLOCK_K,
|
|
||||||
EVEN_K,
|
|
||||||
ADD_INPUTS,
|
|
||||||
CAST_TYPE,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
def sgmv_expand_slice_fake(
|
|
||||||
inputs: torch.Tensor,
|
|
||||||
lora_b_weights: torch.Tensor,
|
|
||||||
output_tensor: torch.Tensor,
|
|
||||||
b_seq_start_loc: torch.Tensor,
|
|
||||||
seq_len_tensor: torch.Tensor,
|
|
||||||
lora_indices_tensor: torch.Tensor,
|
|
||||||
batches: int,
|
|
||||||
max_seq_length: int,
|
|
||||||
token_nums: int,
|
|
||||||
slice_offset: int,
|
|
||||||
slice_size: int,
|
|
||||||
add_inputs: bool = False,
|
|
||||||
) -> None:
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
|
||||||
direct_register_custom_op(
|
|
||||||
op_name="sgmv_expand_slice",
|
|
||||||
op_func=_sgmv_expand_slice,
|
|
||||||
mutates_args=["output_tensor"],
|
|
||||||
fake_impl=sgmv_expand_slice_fake,
|
|
||||||
)
|
|
||||||
sgmv_expand_slice = torch.ops.vllm.sgmv_expand_slice
|
|
||||||
|
|
||||||
except AttributeError:
|
|
||||||
sgmv_expand_slice = _sgmv_expand_slice
|
|
@ -5,48 +5,60 @@ Punica: Multi-Tenant LoRA Serving.
|
|||||||
https://arxiv.org/abs/2310.18547
|
https://arxiv.org/abs/2310.18547
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import direct_register_custom_op
|
||||||
|
|
||||||
|
from .utils import _get_lora_a_ptr
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _sgmv_shrink_kernel(
|
def _sgmv_shrink_kernel(
|
||||||
input_ptr,
|
input_ptr,
|
||||||
lora_ptr,
|
lora_ptr, #1-3
|
||||||
out_ptr,
|
out_ptr,
|
||||||
N,
|
N,
|
||||||
K,
|
K,
|
||||||
b_seq_start_loc,
|
b_seq_start_loc,
|
||||||
seq_lens,
|
seq_lens,
|
||||||
lora_indices,
|
lora_indices,
|
||||||
scaling,
|
scaling,
|
||||||
xm_stride, # hidden_size
|
input_d0_stride,
|
||||||
xk_stride, # 1
|
input_d1_stride, # 1
|
||||||
l0_stride, # hidden_size*max_rank
|
lora_d0_stride,
|
||||||
lora_k_stride,
|
lora_d1_stride,
|
||||||
lora_n_stride,
|
lora_d2_stride, # 1
|
||||||
cm_stride,
|
output_d0_stride,
|
||||||
cn_stride,
|
output_d1_stride,
|
||||||
BLOCK_M: tl.constexpr,
|
output_d2_stride, # 1
|
||||||
BLOCK_N: tl.constexpr,
|
BLOCK_M: tl.constexpr,
|
||||||
BLOCK_K: tl.constexpr,
|
BLOCK_N: tl.constexpr,
|
||||||
EVEN_K: tl.constexpr,
|
BLOCK_K: tl.constexpr,
|
||||||
SPLIT_K: tl.constexpr,
|
EVEN_K: tl.constexpr,
|
||||||
):
|
SPLIT_K: tl.constexpr,
|
||||||
|
SLICE_NUM: tl.constexpr):
|
||||||
"""
|
"""
|
||||||
The sgmv's shrink triton kernel is based on GroupGEMM+SPLIT-K.
|
The sgmv's shrink triton kernel is based on GroupGEMM+SPLIT-K.
|
||||||
The GEMM of Multi-LoRA can be considered as GroupGEMM. Additionally,
|
The GEMM of Multi-LoRA can be considered as GroupGEMM. Additionally,
|
||||||
introducing SPLIT-K can improve performance
|
introducing SPLIT-K can improve performance
|
||||||
"""
|
"""
|
||||||
pid = tl.program_id(axis=0)
|
pid = tl.program_id(axis=0)
|
||||||
pid_sk = tl.program_id(axis=1)
|
pid_mix = tl.program_id(axis=1)
|
||||||
cur_batch = tl.program_id(axis=2)
|
cur_batch = tl.program_id(axis=2)
|
||||||
cta_n_num = tl.cdiv(N, BLOCK_N)
|
cta_n_num = tl.cdiv(N, BLOCK_N)
|
||||||
pid_m = pid // cta_n_num
|
pid_m = pid // cta_n_num
|
||||||
pid_n = pid % cta_n_num
|
pid_n = pid % cta_n_num
|
||||||
|
if SLICE_NUM == 1:
|
||||||
|
slice_id: tl.constexpr = 0
|
||||||
|
pid_sk = tl.program_id(axis=1)
|
||||||
|
else:
|
||||||
|
pid_mix = tl.program_id(axis=1)
|
||||||
|
slice_id = pid_mix // SPLIT_K
|
||||||
|
pid_sk = pid_mix % SPLIT_K
|
||||||
|
|
||||||
M = tl.load(seq_lens + cur_batch)
|
M = tl.load(seq_lens + cur_batch)
|
||||||
if pid_m * BLOCK_M > M:
|
if pid_m * BLOCK_M > M:
|
||||||
@ -61,11 +73,22 @@ def _sgmv_shrink_kernel(
|
|||||||
|
|
||||||
ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M)
|
ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M)
|
||||||
rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N)
|
rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N)
|
||||||
|
# input ptr
|
||||||
|
a_ptr = (input_ptr + cur_seq_start * input_d0_stride +
|
||||||
|
ram[:, None] * input_d0_stride +
|
||||||
|
offset_k[None, :] * input_d1_stride)
|
||||||
|
|
||||||
a_ptr = (input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride +
|
if SLICE_NUM == 1:
|
||||||
offset_k[None, :] * xk_stride)
|
# current lora ptr
|
||||||
b_ptr = (lora_ptr + l0_stride * lora_index + rbn[None, :] * lora_k_stride +
|
cur_lora_ptr = lora_ptr
|
||||||
offset_k[:, None] * lora_n_stride)
|
else:
|
||||||
|
# current lora ptr
|
||||||
|
cur_lora_ptr = tl.load(lora_ptr + slice_id).to(
|
||||||
|
tl.pointer_type(input_ptr.dtype.element_ty))
|
||||||
|
|
||||||
|
b_ptr = (cur_lora_ptr + lora_d0_stride * lora_index +
|
||||||
|
rbn[None, :] * lora_d1_stride +
|
||||||
|
offset_k[:, None] * lora_d2_stride)
|
||||||
|
|
||||||
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||||
for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
|
for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
|
||||||
@ -82,13 +105,15 @@ def _sgmv_shrink_kernel(
|
|||||||
other=0.0)
|
other=0.0)
|
||||||
accumulator += tl.dot(tiled_a, tiled_b)
|
accumulator += tl.dot(tiled_a, tiled_b)
|
||||||
|
|
||||||
a_ptr += BLOCK_K * SPLIT_K * xk_stride
|
a_ptr += BLOCK_K * SPLIT_K * input_d1_stride
|
||||||
b_ptr += BLOCK_K * SPLIT_K * lora_n_stride
|
b_ptr += BLOCK_K * SPLIT_K * lora_d2_stride
|
||||||
offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
|
offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
|
||||||
|
|
||||||
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
||||||
c_ptr = (out_ptr + offset_cm[:, None] * cm_stride +
|
cur_out_ptr = (out_ptr if SLICE_NUM == 1 else out_ptr +
|
||||||
offset_cn[None, :] * cn_stride)
|
slice_id * output_d0_stride)
|
||||||
|
c_ptr = cur_out_ptr + offset_cm[:, None] * output_d1_stride + offset_cn[
|
||||||
|
None, :] * output_d2_stride
|
||||||
c_mask = (offset_cm[:, None] <
|
c_mask = (offset_cm[:, None] <
|
||||||
(cur_seq_start + M)) & (offset_cn[None, :] < N)
|
(cur_seq_start + M)) & (offset_cn[None, :] < N)
|
||||||
accumulator *= scaling
|
accumulator *= scaling
|
||||||
@ -102,7 +127,7 @@ def _sgmv_shrink_kernel(
|
|||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def _sgmv_shrink(
|
def _sgmv_shrink(
|
||||||
inputs: torch.Tensor,
|
inputs: torch.Tensor,
|
||||||
lora_a_weights: torch.Tensor,
|
lora_a_weights: List[torch.Tensor],
|
||||||
output_tensor: torch.Tensor,
|
output_tensor: torch.Tensor,
|
||||||
b_seq_start_loc: torch.Tensor,
|
b_seq_start_loc: torch.Tensor,
|
||||||
seq_len_tensor: torch.Tensor,
|
seq_len_tensor: torch.Tensor,
|
||||||
@ -113,10 +138,9 @@ def _sgmv_shrink(
|
|||||||
scaling: float,
|
scaling: float,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
inputs (torch.Tensor): input tensor
|
inputs (torch.Tensor): input tensor
|
||||||
lora_a_weights (torch.Tensor): lora'a weight
|
lora_a_weights (List[torch.Tensor]): lora'a weight
|
||||||
output_tensor (torch.Tensor): output tensor
|
output_tensor (torch.Tensor): output tensor
|
||||||
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
|
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
|
||||||
sequence lengths of the sequences in the batch, used to index
|
sequence lengths of the sequences in the batch, used to index
|
||||||
@ -134,27 +158,21 @@ def _sgmv_shrink(
|
|||||||
token numbers in the inputs matches the one in the metadata.
|
token numbers in the inputs matches the one in the metadata.
|
||||||
scaling (float): Scaling factor.
|
scaling (float): Scaling factor.
|
||||||
"""
|
"""
|
||||||
assert inputs.dtype == lora_a_weights.dtype
|
assert inputs.dtype == lora_a_weights[0].dtype
|
||||||
assert inputs.dtype in [torch.float16, torch.bfloat16]
|
assert inputs.dtype in [torch.float16, torch.bfloat16]
|
||||||
assert lora_a_weights.dtype in [
|
for weight in lora_a_weights:
|
||||||
torch.float16,
|
assert weight.dtype in [torch.float16, torch.bfloat16]
|
||||||
torch.bfloat16,
|
|
||||||
]
|
|
||||||
assert inputs.size(0) == token_nums
|
assert inputs.size(0) == token_nums
|
||||||
assert inputs.size(1) == lora_a_weights.size(-1)
|
assert inputs.size(1) == lora_a_weights[0].size(-1)
|
||||||
assert b_seq_start_loc.size(0) == batches
|
assert b_seq_start_loc.size(0) == batches
|
||||||
assert lora_indices_tensor.size(0) == batches
|
assert lora_indices_tensor.size(0) == batches
|
||||||
assert inputs.is_contiguous()
|
assert inputs.is_contiguous()
|
||||||
|
|
||||||
if lora_a_weights.ndim == 4: # shape:(lora_num,1,rank, size)
|
|
||||||
assert lora_a_weights.size(1) == 1
|
|
||||||
lora_a_weights = lora_a_weights.squeeze(dim=1)
|
|
||||||
else:
|
|
||||||
assert lora_a_weights.ndim == 3 # shape:(lora_num,rank, size)
|
|
||||||
assert lora_a_weights.is_contiguous()
|
|
||||||
assert output_tensor.is_contiguous()
|
assert output_tensor.is_contiguous()
|
||||||
|
(lora_ptr_tensor, lora_strides_d0, lora_strides_d1,
|
||||||
|
lora_strides_d2) = _get_lora_a_ptr(lora_a_weights, b_seq_start_loc.device)
|
||||||
# TODO tuning this config
|
# TODO tuning this config
|
||||||
N, K = lora_a_weights.shape[-2:] # K=hidden_size,N=rank
|
N, K = lora_a_weights[0].shape[-2:] # K=hidden_size,N=rank
|
||||||
BLOCK_M = 32
|
BLOCK_M = 32
|
||||||
BLOCK_N = 16
|
BLOCK_N = 16
|
||||||
BLOCK_K = 32
|
BLOCK_K = 32
|
||||||
@ -162,13 +180,12 @@ def _sgmv_shrink(
|
|||||||
EVEN_K = K % (BLOCK_K * SPLIT_K) == 0
|
EVEN_K = K % (BLOCK_K * SPLIT_K) == 0
|
||||||
grid = (
|
grid = (
|
||||||
triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N),
|
triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N),
|
||||||
SPLIT_K,
|
SPLIT_K * len(lora_a_weights),
|
||||||
batches,
|
batches,
|
||||||
)
|
)
|
||||||
|
|
||||||
_sgmv_shrink_kernel[grid](
|
_sgmv_shrink_kernel[grid](
|
||||||
inputs,
|
inputs,
|
||||||
lora_a_weights,
|
lora_ptr_tensor,
|
||||||
output_tensor,
|
output_tensor,
|
||||||
N,
|
N,
|
||||||
K,
|
K,
|
||||||
@ -178,23 +195,25 @@ def _sgmv_shrink(
|
|||||||
scaling,
|
scaling,
|
||||||
inputs.stride(0),
|
inputs.stride(0),
|
||||||
inputs.stride(1),
|
inputs.stride(1),
|
||||||
lora_a_weights.stride(0),
|
lora_strides_d0,
|
||||||
lora_a_weights.stride(1),
|
lora_strides_d1,
|
||||||
lora_a_weights.stride(2),
|
lora_strides_d2,
|
||||||
output_tensor.stride(0),
|
output_tensor.stride(0),
|
||||||
output_tensor.stride(1),
|
output_tensor.stride(1),
|
||||||
|
output_tensor.stride(2),
|
||||||
BLOCK_M,
|
BLOCK_M,
|
||||||
BLOCK_N,
|
BLOCK_N,
|
||||||
BLOCK_K,
|
BLOCK_K,
|
||||||
EVEN_K,
|
EVEN_K,
|
||||||
SPLIT_K,
|
SPLIT_K,
|
||||||
|
len(lora_a_weights),
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
def sgmv_shrink_fake(
|
def sgmv_shrink_fake(
|
||||||
inputs: torch.Tensor,
|
inputs: torch.Tensor,
|
||||||
lora_a_weights: torch.Tensor,
|
lora_a_weights: List[torch.Tensor],
|
||||||
output_tensor: torch.Tensor,
|
output_tensor: torch.Tensor,
|
||||||
b_seq_start_loc: torch.Tensor,
|
b_seq_start_loc: torch.Tensor,
|
||||||
seq_len_tensor: torch.Tensor,
|
seq_len_tensor: torch.Tensor,
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
import functools
|
import functools
|
||||||
from typing import Dict
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
@functools.lru_cache
|
@functools.lru_cache
|
||||||
@ -44,3 +46,120 @@ def get_lora_op_configs(op_type: str, batch: int,
|
|||||||
if not config:
|
if not config:
|
||||||
config = _get_default_config(op_type, batch, hidden_size)
|
config = _get_default_config(op_type, batch, hidden_size)
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
_LORA_A_PTR_DICT: Dict[Tuple[int, ...], Tuple[torch.tensor, ...]] = {}
|
||||||
|
_LORA_B_PTR_DICT: Dict[Tuple[int, ...], Tuple[torch.tensor, ...]] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_lora_a_ptr(lora_a_weights: List[torch.Tensor], device: str):
|
||||||
|
"""
|
||||||
|
`_LORA_A_PTR_DICT` collects the required information during `profile_run`,
|
||||||
|
After this, it remains constant and subsequent usage is through LUT.
|
||||||
|
Refer to:
|
||||||
|
https://github.com/triton-lang/triton/blob/release/3.1.x/python/tutorials/08-grouped-gemm.py
|
||||||
|
"""
|
||||||
|
key = tuple(lora_weight.data_ptr() for lora_weight in lora_a_weights)
|
||||||
|
|
||||||
|
if values := _LORA_A_PTR_DICT.get(key):
|
||||||
|
return values
|
||||||
|
|
||||||
|
lora_strides_d0 = []
|
||||||
|
lora_strides_d1 = []
|
||||||
|
lora_strides_d2 = []
|
||||||
|
tensor_ptrs = []
|
||||||
|
for lora_a_weight in lora_a_weights:
|
||||||
|
if lora_a_weight.ndim == 4: # shape:(lora_num,1,size,rank)
|
||||||
|
assert lora_a_weight.size(1) == 1
|
||||||
|
lora_a_weight = lora_a_weight.squeeze(dim=1)
|
||||||
|
else:
|
||||||
|
assert lora_a_weight.ndim == 3 # shape:(lora_num,size,rank)
|
||||||
|
assert lora_a_weight.is_contiguous()
|
||||||
|
tensor_ptrs.append(lora_a_weight.data_ptr())
|
||||||
|
lora_strides_d0.append(lora_a_weight.stride(0))
|
||||||
|
lora_strides_d1.append(lora_a_weight.stride(1))
|
||||||
|
lora_strides_d2.append(lora_a_weight.stride(2))
|
||||||
|
if len(lora_a_weights) > 1:
|
||||||
|
lora_ptr_tensor = torch.tensor(tensor_ptrs, device=device)
|
||||||
|
else:
|
||||||
|
lora_ptr_tensor = lora_a_weights[0]
|
||||||
|
|
||||||
|
if (len(set(lora_strides_d0)) > 1 or len(set(lora_strides_d1)) > 1
|
||||||
|
or len(set(lora_strides_d2)) > 1):
|
||||||
|
raise ValueError("All LoRA weights must have the same stride.")
|
||||||
|
|
||||||
|
_LORA_A_PTR_DICT[key] = (
|
||||||
|
lora_ptr_tensor,
|
||||||
|
lora_strides_d0[0],
|
||||||
|
lora_strides_d1[0],
|
||||||
|
lora_strides_d2[0],
|
||||||
|
)
|
||||||
|
return _LORA_A_PTR_DICT.get(key)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_lora_b_ptr(lora_weights: List[torch.Tensor], offset_start: int,
|
||||||
|
device: str):
|
||||||
|
"""
|
||||||
|
`_LORA_B_PTR_DICT` collects the required information during `profile_run`,
|
||||||
|
After this, it remains constant and subsequent usage is through LUT.
|
||||||
|
Refer to:
|
||||||
|
https://github.com/triton-lang/triton/blob/release/3.1.x/python/tutorials/08-grouped-gemm.py
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
key = tuple(lora_weight.data_ptr() for lora_weight in lora_weights)
|
||||||
|
if values := _LORA_B_PTR_DICT.get(key):
|
||||||
|
return values
|
||||||
|
slice_offset_lst = []
|
||||||
|
tensor_ptrs = []
|
||||||
|
lora_strides_d0 = []
|
||||||
|
lora_strides_d1 = []
|
||||||
|
lora_strides_d2 = []
|
||||||
|
hidden_sizes = []
|
||||||
|
slice_offset = offset_start
|
||||||
|
for lora_b_weight in lora_weights:
|
||||||
|
if lora_b_weight.ndim == 4: # shape:(lora_num,1,size,rank)
|
||||||
|
assert lora_b_weight.size(1) == 1
|
||||||
|
lora_b_weight = lora_b_weight.squeeze(dim=1)
|
||||||
|
else:
|
||||||
|
assert lora_b_weight.ndim == 3 # shape:(lora_num,size,rank)
|
||||||
|
assert lora_b_weight.is_contiguous()
|
||||||
|
tensor_ptrs.append(lora_b_weight.data_ptr())
|
||||||
|
lora_strides_d0.append(lora_b_weight.stride(0))
|
||||||
|
lora_strides_d1.append(lora_b_weight.stride(1))
|
||||||
|
lora_strides_d2.append(lora_b_weight.stride(2))
|
||||||
|
slice_offset_lst.append(slice_offset)
|
||||||
|
slice_offset += lora_b_weight.size(1)
|
||||||
|
hidden_sizes.append(lora_b_weight.size(1))
|
||||||
|
|
||||||
|
if len(lora_weights) > 1:
|
||||||
|
# note these are device tensors
|
||||||
|
lora_ptr_tensor = torch.tensor(tensor_ptrs, device=device)
|
||||||
|
slice_start_tensor = torch.tensor(slice_offset_lst, device=device)
|
||||||
|
else:
|
||||||
|
slice_start_tensor = slice_offset_lst[0]
|
||||||
|
lora_ptr_tensor = lora_b_weight[0]
|
||||||
|
|
||||||
|
# If each lora has the same stride, there's no need to use a
|
||||||
|
# tensor for storage.
|
||||||
|
if (len(set(lora_strides_d0)) == 1 and len(set(lora_strides_d1)) == 1 and
|
||||||
|
len(set(lora_strides_d2)) == 1) and len(set(hidden_sizes)) == 1:
|
||||||
|
lora_strides_d0_tensor = lora_strides_d0[0]
|
||||||
|
lora_strides_d1_tensor = lora_strides_d1[0]
|
||||||
|
lora_strides_d2_tensor = lora_strides_d2[0]
|
||||||
|
hidden_sizes_tensor = hidden_sizes[0]
|
||||||
|
same_stride = True
|
||||||
|
|
||||||
|
else:
|
||||||
|
lora_strides_d0_tensor = torch.tensor(lora_strides_d0, device=device)
|
||||||
|
lora_strides_d1_tensor = torch.tensor(lora_strides_d1, device=device)
|
||||||
|
lora_strides_d2_tensor = torch.tensor(lora_strides_d2, device=device)
|
||||||
|
hidden_sizes_tensor = torch.tensor(hidden_sizes, device=device)
|
||||||
|
same_stride = False
|
||||||
|
# MAX_N is the maximum hidden size among all the lora_b weights
|
||||||
|
MAX_N = max(hidden_sizes)
|
||||||
|
_LORA_B_PTR_DICT[key] = (slice_start_tensor, lora_ptr_tensor,
|
||||||
|
lora_strides_d0_tensor, lora_strides_d1_tensor,
|
||||||
|
lora_strides_d2_tensor, hidden_sizes_tensor,
|
||||||
|
same_stride, MAX_N)
|
||||||
|
return _LORA_B_PTR_DICT.get(key)
|
||||||
|
@ -5,7 +5,7 @@ Punica: Multi-Tenant LoRA Serving.
|
|||||||
https://arxiv.org/abs/2310.18547
|
https://arxiv.org/abs/2310.18547
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Callable, Optional, Tuple, Union, final
|
from typing import Optional, Tuple, Union, final
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -16,7 +16,6 @@ if HAS_TRITON:
|
|||||||
from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice
|
from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice
|
||||||
from vllm.lora.ops.bgmv_shrink import bgmv_shrink
|
from vllm.lora.ops.bgmv_shrink import bgmv_shrink
|
||||||
from vllm.lora.ops.sgmv_expand import sgmv_expand
|
from vllm.lora.ops.sgmv_expand import sgmv_expand
|
||||||
from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice
|
|
||||||
from vllm.lora.ops.sgmv_shrink import sgmv_shrink
|
from vllm.lora.ops.sgmv_shrink import sgmv_shrink
|
||||||
|
|
||||||
from .punica_base import PunicaWrapperBase
|
from .punica_base import PunicaWrapperBase
|
||||||
@ -35,11 +34,11 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
|||||||
PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches,
|
PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches,
|
||||||
device)
|
device)
|
||||||
|
|
||||||
def _shrink_prefill(
|
def _apply_shrink_prefill(
|
||||||
self,
|
self,
|
||||||
y: torch.Tensor,
|
y: torch.Tensor,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
w_t_all: torch.Tensor,
|
w_t_all: Tuple[torch.Tensor, ...],
|
||||||
scale: float,
|
scale: float,
|
||||||
):
|
):
|
||||||
#No LoRA request, so return directly
|
#No LoRA request, so return directly
|
||||||
@ -53,7 +52,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
|||||||
scale,
|
scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _shrink_decode(
|
def _apply_shrink_decode(
|
||||||
self,
|
self,
|
||||||
y: torch.Tensor,
|
y: torch.Tensor,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
@ -62,56 +61,28 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
|||||||
):
|
):
|
||||||
bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale)
|
bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale)
|
||||||
|
|
||||||
def _expand_prefill(
|
def _apply_expand_prefill(
|
||||||
self,
|
self,
|
||||||
y: torch.Tensor,
|
y: torch.Tensor,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
w_t_all: torch.Tensor,
|
w_t_all: torch.Tensor,
|
||||||
|
offset_start: int,
|
||||||
add_inputs: bool,
|
add_inputs: bool,
|
||||||
):
|
):
|
||||||
#No LoRA request, so return directly
|
#No LoRA request, so return directly
|
||||||
if self.no_lora:
|
if self.no_lora:
|
||||||
return
|
return
|
||||||
|
|
||||||
sgmv_expand(
|
sgmv_expand(
|
||||||
x,
|
x,
|
||||||
w_t_all,
|
w_t_all,
|
||||||
y,
|
y,
|
||||||
*self.prefill_metadata,
|
*self.prefill_metadata,
|
||||||
add_inputs,
|
offset_start=offset_start,
|
||||||
|
add_inputs=add_inputs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _expand_decode(
|
def _apply_expand_decode(
|
||||||
self,
|
|
||||||
y: torch.Tensor,
|
|
||||||
x: torch.Tensor,
|
|
||||||
w_t_all: torch.Tensor,
|
|
||||||
add_inputs: bool,
|
|
||||||
):
|
|
||||||
bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_inputs)
|
|
||||||
|
|
||||||
def _expand_slice_prefill(
|
|
||||||
self,
|
|
||||||
y: torch.Tensor,
|
|
||||||
x: torch.Tensor,
|
|
||||||
w_t_all: torch.Tensor,
|
|
||||||
y_offset: Optional[int],
|
|
||||||
y_slice_size: Optional[int],
|
|
||||||
add_inputs: bool,
|
|
||||||
):
|
|
||||||
#No LoRA request, so return directly
|
|
||||||
if self.no_lora:
|
|
||||||
return
|
|
||||||
sgmv_expand_slice(
|
|
||||||
x,
|
|
||||||
w_t_all,
|
|
||||||
y,
|
|
||||||
*self.prefill_metadata,
|
|
||||||
y_offset,
|
|
||||||
y_slice_size,
|
|
||||||
add_inputs,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _expand_slice_decode(
|
|
||||||
self,
|
self,
|
||||||
y: torch.Tensor,
|
y: torch.Tensor,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
@ -123,43 +94,6 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
|||||||
bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset,
|
bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset,
|
||||||
y_slice_size, add_inputs)
|
y_slice_size, add_inputs)
|
||||||
|
|
||||||
def _apply_expand(
|
|
||||||
self,
|
|
||||||
y: torch.Tensor,
|
|
||||||
x: torch.Tensor,
|
|
||||||
w_t_all: torch.Tensor,
|
|
||||||
y_offset: Optional[int],
|
|
||||||
y_slice_size: Optional[int],
|
|
||||||
add_inputs: bool = True,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all`
|
|
||||||
computation, which is suitable for the
|
|
||||||
GEMM of lora'b.
|
|
||||||
"""
|
|
||||||
|
|
||||||
expand_slice_fun: Callable = (self._expand_slice_prefill
|
|
||||||
if self.is_prefill else
|
|
||||||
self._expand_slice_decode)
|
|
||||||
expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_inputs)
|
|
||||||
|
|
||||||
def _apply_shrink(self, y: torch.Tensor, x: torch.Tensor,
|
|
||||||
w_t_all: torch.Tensor, scale: float):
|
|
||||||
"""
|
|
||||||
Perform the ` y+=x@w_t_all` computation, which is suitable for the
|
|
||||||
GEMM of lora'a.
|
|
||||||
When `is_prefill is` true, it indicates that it is currently the
|
|
||||||
prefill stage, and the `_shrink_prefill` function should be called.
|
|
||||||
Otherwise, it is the decode stage, and the _shrink_decode function
|
|
||||||
should be called.
|
|
||||||
"""
|
|
||||||
y_org = y
|
|
||||||
y = y.view(-1, y.shape[-1])
|
|
||||||
shrink_fun: Callable = (self._shrink_prefill
|
|
||||||
if self.is_prefill else self._shrink_decode)
|
|
||||||
shrink_fun(y, x, w_t_all, scale)
|
|
||||||
y = y.view_as(y_org)
|
|
||||||
|
|
||||||
def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor],
|
def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor],
|
||||||
x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...],
|
x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...],
|
||||||
scale: float, **kwargs):
|
scale: float, **kwargs):
|
||||||
@ -182,10 +116,15 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
x = x.view(-1, x.shape[-1])
|
x = x.view(-1, x.shape[-1])
|
||||||
# TODO fuse these kernels
|
|
||||||
for slice_idx in range(len(lora_a_stacked)):
|
if self.is_prefill:
|
||||||
self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx],
|
# NOTE fused kernel
|
||||||
scale)
|
self._apply_shrink_prefill(y, x, lora_a_stacked, scale)
|
||||||
|
else:
|
||||||
|
# TODO fuse these kernels
|
||||||
|
for slice_idx in range(len(lora_a_stacked)):
|
||||||
|
self._apply_shrink_decode(y[slice_idx], x,
|
||||||
|
lora_a_stacked[slice_idx], scale)
|
||||||
|
|
||||||
def add_expand(self,
|
def add_expand(self,
|
||||||
y: torch.Tensor,
|
y: torch.Tensor,
|
||||||
@ -217,20 +156,28 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
|||||||
"""
|
"""
|
||||||
y_org = y
|
y_org = y
|
||||||
y = y.view(-1, y.shape[-1])
|
y = y.view(-1, y.shape[-1])
|
||||||
offset_left = offset_start
|
|
||||||
if lora_bias_stacked is not None:
|
if lora_bias_stacked is not None:
|
||||||
self._apply_bias(self.token_lora_indices, y, output_slices,
|
self._apply_bias(self.token_lora_indices, y, output_slices,
|
||||||
lora_bias_stacked)
|
lora_bias_stacked)
|
||||||
for slice_idx in range(len(lora_b_stacked)):
|
if self.is_prefill:
|
||||||
self._apply_expand(
|
# NOTE fused kernel
|
||||||
y,
|
self._apply_expand_prefill(y,
|
||||||
x[slice_idx],
|
x,
|
||||||
lora_b_stacked[slice_idx],
|
lora_b_stacked,
|
||||||
offset_left,
|
offset_start,
|
||||||
output_slices[slice_idx],
|
add_inputs=True)
|
||||||
add_inputs=add_inputs,
|
else:
|
||||||
)
|
# TODO fuse these kernels
|
||||||
offset_left += output_slices[slice_idx]
|
for slice_idx in range(len(lora_b_stacked)):
|
||||||
|
self._apply_expand_decode(
|
||||||
|
y,
|
||||||
|
x[slice_idx],
|
||||||
|
lora_b_stacked[slice_idx],
|
||||||
|
offset_start,
|
||||||
|
output_slices[slice_idx],
|
||||||
|
add_inputs=add_inputs,
|
||||||
|
)
|
||||||
|
offset_start += output_slices[slice_idx]
|
||||||
y = y.view_as(y_org)
|
y = y.view_as(y_org)
|
||||||
|
|
||||||
def add_lora_embedding(self,
|
def add_lora_embedding(self,
|
||||||
@ -252,10 +199,18 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
|||||||
add_inputs (bool): Default to True.
|
add_inputs (bool): Default to True.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Embedding layer only need expand op
|
if self.is_prefill:
|
||||||
expand_fun: Callable = (self._expand_prefill
|
sgmv_expand(
|
||||||
if self.is_prefill else self._expand_decode)
|
x.unsqueeze(dim=0),
|
||||||
expand_fun(y, x, lora_b_stacked, add_inputs)
|
[lora_b_stacked],
|
||||||
|
y,
|
||||||
|
*self.prefill_metadata,
|
||||||
|
offset_start=0,
|
||||||
|
add_inputs=add_inputs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
bgmv_expand(x, lora_b_stacked, y, self.token_lora_indices,
|
||||||
|
add_inputs)
|
||||||
|
|
||||||
def add_lora_linear(self,
|
def add_lora_linear(self,
|
||||||
y: torch.Tensor,
|
y: torch.Tensor,
|
||||||
@ -301,10 +256,11 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
|||||||
r = lora_b_stacked[0].size(-1)
|
r = lora_b_stacked[0].size(-1)
|
||||||
# We set the buffer to be float32 by default ,refer to:
|
# We set the buffer to be float32 by default ,refer to:
|
||||||
# https://github.com/triton-lang/triton/issues/1387
|
# https://github.com/triton-lang/triton/issues/1387
|
||||||
buffer = tuple(
|
buffer = torch.zeros(
|
||||||
torch.zeros(
|
(len(output_slices), x.size(0), r),
|
||||||
(x.size(0), r), dtype=torch.float32, device=x.device)
|
dtype=torch.float32,
|
||||||
for _ in range(len(output_slices)))
|
device=x.device,
|
||||||
|
)
|
||||||
self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs)
|
self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs)
|
||||||
self.add_expand(y,
|
self.add_expand(y,
|
||||||
buffer,
|
buffer,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user