[V1] LoRA - Add triton kernels for V1 (#13096)

Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
Varun Sundar Rabindranath 2025-03-10 17:27:53 -04:00 committed by GitHub
parent 0967110e42
commit 5ff0d32580
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 1165 additions and 191 deletions

View File

@ -23,6 +23,7 @@ from vllm.lora.ops.triton_ops.bgmv_shrink import bgmv_shrink
from vllm.lora.ops.triton_ops.sgmv_expand import sgmv_expand
from vllm.lora.ops.triton_ops.sgmv_shrink import sgmv_shrink
from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
from vllm.lora.ops.triton_ops.v1 import V1KernelMeta, v1_expand, v1_shrink
from vllm.utils import FlexibleArgumentParser
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
@ -171,6 +172,8 @@ class OpType(Enum):
SGMV_EXPAND = auto()
BGMV_EXPAND = auto()
BGMV_EXPAND_SLICE = auto()
V1_SHRINK = auto()
V1_EXPAND = auto()
@staticmethod
def from_str(s: str) -> "OpType":
@ -184,28 +187,43 @@ class OpType(Enum):
return OpType.BGMV_EXPAND
if s.lower() == "bgmv_expand_slice":
return OpType.BGMV_EXPAND_SLICE
if s.lower() == "v1_shrink":
return OpType.V1_SHRINK
if s.lower() == "v1_expand":
return OpType.V1_EXPAND
raise ValueError(f"Unrecognized str {s} to convert to OpType")
def is_shrink_fn(self) -> bool:
return self in [OpType.SGMV_SHRINK, OpType.BGMV_SHRINK]
return self in [
OpType.SGMV_SHRINK, OpType.BGMV_SHRINK, OpType.V1_SHRINK
]
def is_expand_fn(self) -> bool:
return self in [OpType.SGMV_EXPAND, OpType.BGMV_EXPAND]
return self in [
OpType.SGMV_EXPAND, OpType.BGMV_EXPAND, OpType.V1_EXPAND
]
def is_prefill_op(self) -> bool:
return self in [OpType.SGMV_SHRINK, OpType.SGMV_EXPAND]
return self in [
OpType.SGMV_SHRINK, OpType.SGMV_EXPAND, OpType.V1_SHRINK,
OpType.V1_EXPAND
]
def is_decode_op(self) -> bool:
return self in [
OpType.BGMV_SHRINK, OpType.BGMV_EXPAND, OpType.BGMV_EXPAND_SLICE
OpType.BGMV_SHRINK, OpType.BGMV_EXPAND, OpType.BGMV_EXPAND_SLICE,
OpType.V1_SHRINK, OpType.V1_EXPAND
]
def is_expand_slice_fn(self) -> bool:
return self in [OpType.BGMV_EXPAND_SLICE]
def num_slices(self) -> list[int]:
if self in [OpType.SGMV_EXPAND, OpType.SGMV_SHRINK]:
# SGMV kernels supports slices
if self in [
OpType.SGMV_EXPAND, OpType.SGMV_SHRINK, OpType.V1_SHRINK,
OpType.V1_EXPAND
]:
# SGMV kernels and v1 kernels supports slices
return [1, 2, 3]
if self in [OpType.BGMV_SHRINK, OpType.BGMV_EXPAND]:
return [1]
@ -250,11 +268,13 @@ class OpType(Enum):
m, k, n = self.mkn(batch_size, seq_length, hidden_size, lora_rank)
b_shape = (num_loras, n, k) # col-major
if self == OpType.SGMV_SHRINK:
# SGMV shrink supports num_slices inherently in the kernel
if self in [OpType.SGMV_SHRINK, OpType.V1_SHRINK]:
# SGMV shrink and V1 shrink kernels support num_slices inherently
# in the kernel.
return ((m, k), b_shape, (num_slices, m, n))
if self == OpType.SGMV_EXPAND:
# SGMV expand supports num_slices inherently in the kernel
if self in [OpType.SGMV_EXPAND, OpType.V1_EXPAND]:
# SGMV expand and V1 expand kernels support num_slices inherently
# in the kernel
return ((num_slices, m, k), b_shape, (m, n * num_slices))
if self == OpType.BGMV_SHRINK:
return ((m, k), b_shape, (m, n))
@ -281,25 +301,30 @@ class OpType(Enum):
return bgmv_expand
if self == OpType.BGMV_EXPAND_SLICE:
return emulate_bgmv_expand_slice
if self == OpType.V1_SHRINK:
return v1_shrink
if self == OpType.V1_EXPAND:
return v1_expand
raise ValueError(f"Unrecognized optype {self}")
def run_ref_group_gemm(self, output: torch.Tensor, input: torch.Tensor,
lora_weights: list[torch.Tensor],
**kwargs) -> Callable:
"""Each benchmark operation expected the input, lora_weights and outputs
"""Each benchmark operation expects the input, lora_weights and outputs
in a slightly different format. Refer to self.matmul_shapes().
run_ref_group_gemm accounts for those differences in executing a
reference group gemm for correctness testing.
"""
w_dtype = lora_weights[0].dtype
num_slices = len(lora_weights)
if self == OpType.SGMV_SHRINK:
if self in [OpType.SGMV_SHRINK, OpType.V1_SHRINK]:
for slice_idx in range(num_slices):
ref_group_gemm(ref_out=output[slice_idx, :],
input=input,
lora_weights=lora_weights[slice_idx],
**kwargs)
if self == OpType.SGMV_EXPAND:
elif self in [OpType.SGMV_EXPAND, OpType.V1_EXPAND]:
hidden_size = lora_weights[0].shape[1]
for slice_idx in range(num_slices):
slice_offset = slice_idx * hidden_size
@ -308,19 +333,19 @@ class OpType(Enum):
input=input[slice_idx].clone().to(dtype=w_dtype),
lora_weights=lora_weights[slice_idx],
**kwargs)
if self == OpType.BGMV_SHRINK:
elif self == OpType.BGMV_SHRINK:
assert num_slices == 1
ref_group_gemm(ref_out=output,
input=input,
lora_weights=lora_weights[0],
**kwargs)
if self == OpType.BGMV_EXPAND:
elif self == OpType.BGMV_EXPAND:
assert num_slices == 1
ref_group_gemm(ref_out=output,
input=input.clone().to(dtype=w_dtype),
lora_weights=lora_weights[0],
**kwargs)
if self == OpType.BGMV_EXPAND_SLICE:
elif self == OpType.BGMV_EXPAND_SLICE:
hidden_size = lora_weights[0].shape[1]
for slice_idx in range(num_slices):
slice_offset = slice_idx * hidden_size
@ -329,7 +354,8 @@ class OpType(Enum):
input=input[slice_idx].clone().to(dtype=w_dtype),
lora_weights=lora_weights[slice_idx],
**kwargs)
raise ValueError(f"Unrecognized optype {self}")
else:
raise ValueError(f"Unrecognized optype {self}")
@dataclass
@ -390,6 +416,8 @@ class BenchmarkTensors:
seq_start_loc: torch.Tensor
prompt_lora_mapping: torch.Tensor
token_lora_mapping: torch.Tensor
# v1 kernel metadata
v1_kernel_meta: Optional[V1KernelMeta] = None
def io_types(self) -> str:
return (f"{dtype_to_str(self.input.dtype)}x"
@ -432,10 +460,19 @@ class BenchmarkTensors:
total_tokens, ctx.batch_size, prompt_lora_indices_tensor,
seq_len_tensor, "cpu")
v1_kernel_meta = None
if op_type in [OpType.V1_SHRINK, OpType.V1_EXPAND]:
v1_kernel_meta = V1KernelMeta.make(
max_loras=ctx.num_loras,
max_num_tokens=token_lora_indices_tensor.size(0),
device="cpu")
v1_kernel_meta.prepare_tensors(
token_lora_mapping=token_lora_indices_tensor)
return BenchmarkTensors(input_tensor, lora_weights, output_tensor,
seq_len_tensor, seq_start_loc_tensor,
prompt_lora_indices_tensor,
token_lora_indices_tensor)
token_lora_indices_tensor, v1_kernel_meta)
def sanity_check(self) -> None:
"""
@ -468,6 +505,13 @@ class BenchmarkTensors:
for i in range(len(self.lora_weights_lst)):
self.lora_weights_lst[i] = to_device(self.lora_weights_lst[i])
# v1 meta
if self.v1_kernel_meta:
for field_name in V1KernelMeta.__dataclass_fields__:
field = getattr(self.v1_kernel_meta, field_name)
assert isinstance(field, torch.Tensor)
setattr(self.v1_kernel_meta, field_name, to_device(field))
def metadata(self) -> tuple[int, int, int]:
"""
Return num_seqs, num_tokens and max_seq_len
@ -667,6 +711,78 @@ class BenchmarkTensors:
})
return {'kwargs_list': kwargs_list}
def as_v1_shrink_kwargs(self) -> dict[str, Any]:
assert self.v1_kernel_meta is not None
self.sanity_check()
self.to_device(self.input.device)
_, num_tokens, _, num_slices = self.metadata()
# Sanity check matrix shapes.
i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[
0].shape, self.output.shape
# Expected input shape [num_tokens, hidden_size]
assert len(i_shape) == 2
assert i_shape[0] == num_tokens
hidden_size = i_shape[1]
# Expected lora weight shape [num_loras, lora_rank, hidden_size]
assert len(lw_shape) == 3
assert lw_shape[2] == hidden_size
lora_rank = lw_shape[1]
# Expected output shape [num_slices, num_tokens, lora_rank]
assert len(o_shape) == 3
assert o_shape == (num_slices, num_tokens, lora_rank)
return {
'inputs': self.input,
'lora_a_weights': self.lora_weights_lst,
'output_tensor': self.output,
'token_lora_mapping': self.v1_kernel_meta.token_lora_mapping,
'token_indices_sorted_by_lora_ids':
self.v1_kernel_meta.token_indices_sorted_by_lora_ids,
'num_tokens_per_lora': self.v1_kernel_meta.num_tokens_per_lora,
'lora_token_start_loc': self.v1_kernel_meta.lora_token_start_loc,
'lora_ids': self.v1_kernel_meta.active_lora_ids,
'scaling': 1.0,
}
def as_v1_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]:
assert self.v1_kernel_meta is not None
self.sanity_check()
self.to_device(self.input.device)
_, num_tokens, _, num_slices = self.metadata()
# Sanity check matrix shapes.
i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[
0].shape, self.output.shape
# Expected input shape : [num_slices, num_tokens, lora_rank]
assert len(i_shape) == 3
assert i_shape[0] == num_slices
assert i_shape[1] == num_tokens
lora_rank = i_shape[2]
# Expected lora weight shape : [num_lora, hidden_size, lora_rank]
assert len(lw_shape) == 3
assert lw_shape[2] == lora_rank
hidden_size = lw_shape[1]
# Expected output shape : [num_tokens, hidden_size * num_slices]
assert len(o_shape) == 2
assert o_shape == (num_tokens, hidden_size * num_slices)
return {
'inputs': self.input,
'lora_b_weights': self.lora_weights_lst,
'output_tensor': self.output,
'token_lora_mapping': self.v1_kernel_meta.token_lora_mapping,
'token_indices_sorted_by_lora_ids':
self.v1_kernel_meta.token_indices_sorted_by_lora_ids,
'num_tokens_per_lora': self.v1_kernel_meta.num_tokens_per_lora,
'lora_token_start_loc': self.v1_kernel_meta.lora_token_start_loc,
'lora_ids': self.v1_kernel_meta.active_lora_ids,
'offset_start': 0,
'add_inputs': add_inputs,
}
def bench_fn_kwargs(self,
op_type: OpType,
add_inputs: Optional[bool] = None) -> dict[str, Any]:
@ -685,6 +801,10 @@ class BenchmarkTensors:
return self.as_bgmv_expand_kwargs(add_inputs)
if op_type == OpType.BGMV_EXPAND_SLICE:
return self.as_bgmv_expand_slice_kwargs(add_inputs)
if op_type == OpType.V1_SHRINK:
return self.as_v1_shrink_kwargs()
if op_type == OpType.V1_EXPAND:
return self.as_v1_expand_kwargs(add_inputs)
raise ValueError(f"Unrecognized optype {self}")
def test_correctness(self, op_type: OpType,
@ -872,12 +992,9 @@ def run(args: argparse.Namespace, bench_ctxs: list[BenchmarkContext]):
timers = []
for bench_ctx in bench_ctxs:
for seq_len in args.seq_lengths:
bench_ops: list[OpType] = []
if seq_len == 1:
# bench all decode ops
bench_ops = [op for op in args.op_types if op.is_decode_op()]
else:
# bench all prefill ops
bench_ops: list[OpType] = args.op_types
if seq_len > 1:
# bench only prefill ops
bench_ops = [op for op in args.op_types if op.is_prefill_op()]
seq_len_timers = []

View File

@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
import importlib
import random
from copy import deepcopy
from dataclasses import dataclass
@ -63,6 +64,36 @@ DEVICES = ([
# stages, so we need to verify this. prefill stage(True) or decode stage(False)
STAGES = [True, False]
# With the inclusion of V1 tests (look at the run_with_both_engines_lora),
# the tests in this file run twice, once with the V0 engine and then with
# the V1 engine.
# The NUM_RANDOM_SEEDS value was set to 10 before. It is cut to half
# with the inclusion of V1 tests to maintain the CI test times.
NUM_RANDOM_SEEDS = 5
# The VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS value was set to
# 256 before. It is cut to half with the inclusion of V1 tests to maintain
# the CI test times.
VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS = 128
@pytest.fixture(autouse=True)
def v1(run_with_both_engines_lora):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
# Reload punica_gpu as the kernels used are tied to engine type.
from vllm.lora.punica_wrapper import punica_gpu
importlib.reload(punica_gpu)
# Release any memory we might be holding on to. CI runs OOMs otherwise.
from vllm.lora.ops.triton_ops.utils import (_LORA_A_PTR_DICT,
_LORA_B_PTR_DICT)
_LORA_B_PTR_DICT.clear()
_LORA_A_PTR_DICT.clear()
yield
def get_random_id_to_index(num_loras: int,
num_slots: int,
@ -226,7 +257,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
torch.set_default_device(device)
max_loras = 8
punica_wrapper = get_punica_wrapper(8192, 256, device)
punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
assert check_punica_wrapper(punica_wrapper)
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
@ -241,7 +272,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
return embedding, lora_embedding
for i in range(10):
for i in range(NUM_RANDOM_SEEDS):
set_random_seed(i)
id_to_index = get_random_id_to_index(num_loras, max_loras)
@ -329,7 +360,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
torch.set_default_device(device)
max_loras = 8
punica_wrapper = get_punica_wrapper(8192, 256, device)
punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
assert check_punica_wrapper(punica_wrapper)
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
@ -353,7 +384,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
return expanded_embedding, lora_embedding
for i in range(10):
for i in range(NUM_RANDOM_SEEDS):
set_random_seed(i)
id_to_index = get_random_id_to_index(num_loras, max_loras)
@ -468,7 +499,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
torch.set_default_device(device)
max_loras = 8
punica_wrapper = get_punica_wrapper(8192, 256, device)
punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
assert check_punica_wrapper(punica_wrapper)
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
@ -490,7 +521,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
return linear, logits_processor, lora_logits_processor
for i in range(10):
for i in range(NUM_RANDOM_SEEDS):
set_random_seed(i)
id_to_index = get_random_id_to_index(num_loras, max_loras)
@ -600,10 +631,10 @@ def test_linear_replicated(dist_init, num_loras, device, stage,
if current_platform.is_cuda_alike():
torch.cuda.set_device(device)
torch.set_default_device(device)
punica_wrapper = get_punica_wrapper(8192, 256, device)
assert check_punica_wrapper(punica_wrapper)
max_loras = 8
torch.set_default_device(device)
punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
assert check_punica_wrapper(punica_wrapper)
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
lora_dtype=torch.float16,
@ -627,7 +658,7 @@ def test_linear_replicated(dist_init, num_loras, device, stage,
assert lora_linear.lora_bias_stacked is None
return linear, lora_linear
for i in range(10):
for i in range(NUM_RANDOM_SEEDS):
set_random_seed(i)
id_to_index = get_random_id_to_index(num_loras, max_loras)
@ -716,10 +747,10 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
if current_platform.is_cuda_alike():
torch.cuda.set_device(device)
torch.set_default_device(device)
punica_wrapper = get_punica_wrapper(8192, 256, device)
assert check_punica_wrapper(punica_wrapper)
max_loras = 8
torch.set_default_device(device)
punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
assert check_punica_wrapper(punica_wrapper)
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
fully_sharded_loras=fully_shard,
@ -753,7 +784,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
assert lora_linear.lora_bias_stacked is None
return linear, lora_linear
for i in range(10):
for i in range(NUM_RANDOM_SEEDS):
set_random_seed(i)
id_to_index = get_random_id_to_index(num_loras, max_loras)
@ -842,10 +873,10 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
if current_platform.is_cuda_alike():
torch.cuda.set_device(device)
torch.set_default_device(device)
punica_wrapper = get_punica_wrapper(8192, 256, device)
assert check_punica_wrapper(punica_wrapper)
max_loras = 8
torch.set_default_device(device)
punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
assert check_punica_wrapper(punica_wrapper)
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
fully_sharded_loras=fully_shard,
@ -900,7 +931,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
assert lora_linear.lora_bias_stacked is None
return linear, lora_linear
for i in range(10):
for i in range(NUM_RANDOM_SEEDS):
set_random_seed(i)
id_to_index = get_random_id_to_index(num_loras, max_loras)
@ -1002,12 +1033,12 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
is_neox_style, rotary_dim, head_size,
seq_len) -> None:
dtype = torch.float16
max_loras = 8
seed = 0
current_platform.seed_everything(seed)
torch.set_default_device(device)
punica_wrapper = get_punica_wrapper(8192, 256, device)
punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
assert check_punica_wrapper(punica_wrapper)
max_loras = 8
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
long_lora_scaling_factors=scaling_factors,
@ -1083,7 +1114,8 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
@pytest.mark.parametrize("tp_size", [1, 2, 4, 8])
@pytest.mark.parametrize("seed", list(range(256)))
@pytest.mark.parametrize(
"seed", list(range(VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS)))
def test_vocab_parallel_embedding_indices(tp_size, seed):
random.seed(seed)
vocab_size = random.randint(4000, 64000)

View File

@ -5,10 +5,12 @@ import pytest
import torch
import vllm.lora.ops.triton_ops # noqa: F401
import vllm.lora.ops.triton_ops.v1 # noqa: F401
from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice,
bgmv_shrink, sgmv_expand,
sgmv_expand_slice, sgmv_shrink)
from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
from vllm.lora.ops.triton_ops.v1 import V1KernelMeta
from vllm.platforms import current_platform
from .utils import (PunicaTensors, assert_close, generate_data,
@ -91,12 +93,12 @@ def sgmv_expand_for_nslices(nslices: int, hidden_size: int,
_dict_lock = Lock()
def check_sgmv_shrink(batches: int, num_loras: int, rank: int,
hidden_size: int, nslices: int, dtype: torch.dtype,
device: str, seq_length: int, scaling: float):
def check_shrink_kernels(batches: int, num_loras: int, rank: int,
hidden_size: int, nslices: int, dtype: torch.dtype,
device: str, seq_length: int, scaling: float):
"""
Compare outputs of vllm.sgmv_shrink kernel against a reference
implementation.
Compare outputs of vllm.sgmv_shrink and vllm.v1_shrink kernel against a
reference implementation.
"""
data: PunicaTensors = generate_data_for_nslices(
batches,
@ -111,44 +113,63 @@ def check_sgmv_shrink(batches: int, num_loras: int, rank: int,
)
max_seq_length, token_nums = data.meta()
# Setup metadata information for SGMV and reference kernels
sgmv_meta_args = (data.b_seq_start_loc, data.seq_len_tensor,
data.prompt_lora_mapping, batches, max_seq_length,
token_nums)
# Setup metadata information for the V1 kernel.
v1_meta = V1KernelMeta.make(max_loras=num_loras,
max_num_tokens=token_nums,
device='cuda')
v1_meta.prepare_tensors(data.token_lora_mapping)
ref_out_tensor = data.ref_out_tensor
sgmv_out_tensor = data.our_out_tensor
v1_out_tensor = data.our_out_tensor.clone()
# Preventing cache error pointer.
with _dict_lock:
# SGMV shrink kernel
_LORA_A_PTR_DICT.clear()
torch.ops.vllm.sgmv_shrink(
data.inputs_tensor,
data.lora_weights,
data.our_out_tensor,
data.b_seq_start_loc,
data.seq_len_tensor,
data.prompt_lora_mapping,
batches,
max_seq_length,
token_nums,
sgmv_out_tensor,
*sgmv_meta_args,
scaling,
)
sgmv_shrink_for_nslices(
nslices,
# V1 shrink kernel
_LORA_A_PTR_DICT.clear()
torch.ops.vllm.v1_shrink(
data.inputs_tensor,
data.lora_weights,
data.ref_out_tensor,
data.b_seq_start_loc,
data.seq_len_tensor,
data.prompt_lora_mapping,
batches,
max_seq_length,
token_nums,
v1_out_tensor,
*v1_meta.meta_args(token_nums=token_nums),
scaling,
)
assert_close(data.our_out_tensor, data.ref_out_tensor)
# Reference
sgmv_shrink_for_nslices(
nslices,
data.inputs_tensor,
data.lora_weights,
ref_out_tensor,
*sgmv_meta_args,
scaling,
)
assert_close(sgmv_out_tensor, ref_out_tensor)
assert_close(v1_out_tensor, ref_out_tensor)
def check_sgmv_expand(batches: int, num_loras: int, rank: int,
hidden_size: int, nslices: int, dtype: torch.dtype,
device: str, seq_length: int, add_inputs: bool):
def check_expand_kernels(batches: int, num_loras: int, rank: int,
hidden_size: int, nslices: int, dtype: torch.dtype,
device: str, seq_length: int, add_inputs: bool):
"""
Compare outputs of vllm.sgmv_expand kernel against a reference
implementation.
Compare outputs of vllm.sgmv_expand and vllm.v1_expand kernels against a
reference implementation.
"""
data: PunicaTensors = generate_data_for_nslices(
batches,
@ -164,36 +185,54 @@ def check_sgmv_expand(batches: int, num_loras: int, rank: int,
max_seq_length, token_nums = data.meta()
# Setup metadata information for SGMV and reference kernels
sgmv_meta_args = (data.b_seq_start_loc, data.seq_len_tensor,
data.prompt_lora_mapping, batches, max_seq_length,
token_nums)
# Setup metadata information for the V1 kernel.
v1_meta = V1KernelMeta.make(max_loras=num_loras,
max_num_tokens=token_nums,
device='cuda')
v1_meta.prepare_tensors(data.token_lora_mapping)
# Setup output tensors
ref_out_tensor = data.ref_out_tensor
sgmv_out_tensor = data.our_out_tensor
v1_out_tensor = data.our_out_tensor.clone()
with _dict_lock:
# SGMV expand kernel
_LORA_B_PTR_DICT.clear()
torch.ops.vllm.sgmv_expand(
data.inputs_tensor,
data.lora_weights,
data.our_out_tensor,
data.b_seq_start_loc,
data.seq_len_tensor,
data.prompt_lora_mapping,
batches,
max_seq_length,
token_nums,
sgmv_out_tensor,
*sgmv_meta_args,
offset_start=0,
add_inputs=add_inputs,
)
# V1 expand kernel
_LORA_B_PTR_DICT.clear()
torch.ops.vllm.v1_expand(data.inputs_tensor,
data.lora_weights,
v1_out_tensor,
*v1_meta.meta_args(token_nums=token_nums),
offset_start=0,
add_inputs=add_inputs)
# Reference
sgmv_expand_for_nslices(nslices,
hidden_size,
data.inputs_tensor,
data.lora_weights,
data.ref_out_tensor,
data.b_seq_start_loc,
data.seq_len_tensor,
data.prompt_lora_mapping,
batches,
max_seq_length,
token_nums,
ref_out_tensor,
*sgmv_meta_args,
add_inputs=add_inputs)
assert_close(data.our_out_tensor, data.ref_out_tensor)
assert_close(sgmv_out_tensor, ref_out_tensor)
assert_close(v1_out_tensor, ref_out_tensor)
def check_bgmv_shrink(batches: int, num_loras: int, rank: int,
@ -439,7 +478,7 @@ SEED = [0]
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
def test_punica_sgmv(
def test_kernels(
batches: int,
num_loras: int,
rank: int,
@ -450,29 +489,32 @@ def test_punica_sgmv(
seed: int,
op_type: str,
):
"""
Tests SGMV and V1 kernels.
"""
torch.set_default_device(device)
current_platform.seed_everything(seed)
if op_type == "shrink":
check_sgmv_shrink(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
nslices=nslices,
dtype=dtype,
device=device,
seq_length=128,
scaling=0.5)
check_shrink_kernels(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
nslices=nslices,
dtype=dtype,
device=device,
seq_length=128,
scaling=0.5)
else:
check_sgmv_expand(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
nslices=nslices,
dtype=dtype,
device=device,
seq_length=128,
add_inputs=True)
check_expand_kernels(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
nslices=nslices,
dtype=dtype,
device=device,
seq_length=128,
add_inputs=True)
@pytest.mark.parametrize("batches", hs_test_params['batches'])
@ -484,7 +526,7 @@ def test_punica_sgmv(
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
def test_punica_sgmv_hidden_size(
def test_kernels_hidden_size(
batches: int,
num_loras: int,
rank: int,
@ -495,29 +537,32 @@ def test_punica_sgmv_hidden_size(
seed: int,
op_type: str,
):
"""
Tests SGMV and V1 kernels.
"""
torch.set_default_device(device)
current_platform.seed_everything(seed)
if op_type == "shrink":
check_sgmv_shrink(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
nslices=nslices,
dtype=dtype,
device=device,
seq_length=128,
scaling=0.5)
check_shrink_kernels(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
nslices=nslices,
dtype=dtype,
device=device,
seq_length=128,
scaling=0.5)
else:
check_sgmv_expand(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
nslices=nslices,
dtype=dtype,
device=device,
seq_length=128,
add_inputs=True)
check_expand_kernels(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
nslices=nslices,
dtype=dtype,
device=device,
seq_length=128,
add_inputs=True)
@pytest.mark.parametrize("batches", test_params['batches'])

View File

@ -326,9 +326,11 @@ class LoRAModelManager(AdapterModelManager):
self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots
self.vocab_size = vocab_size
self.long_lora_context: Optional[LongContextLoRAContext] = None
self.punica_wrapper = get_punica_wrapper(max_num_batched_tokens,
max_batches=self.max_num_seqs,
device=self.device)
self.punica_wrapper = get_punica_wrapper(
max_num_batched_tokens,
max_batches=self.max_num_seqs,
device=self.device,
max_loras=self.lora_config.max_loras)
# Scaling factor -> offset to the sin_cos_cache to it.
# Used for long context lora.
self.scaling_factor_to_offset: Dict[float, int] = {}

View File

@ -54,7 +54,7 @@ _LORA_A_PTR_DICT: Dict[Tuple[int, ...], Tuple[torch.tensor, ...]] = {}
_LORA_B_PTR_DICT: Dict[Tuple[int, ...], Tuple[torch.tensor, ...]] = {}
def _get_lora_a_ptr(lora_a_weights: List[torch.Tensor], device: str):
def _get_lora_a_ptr(lora_a_weights: List[torch.Tensor], device: torch.device):
"""
`_LORA_A_PTR_DICT` collects the required information during `profile_run`,
After this, it remains constant and subsequent usage is through LUT.
@ -100,7 +100,7 @@ def _get_lora_a_ptr(lora_a_weights: List[torch.Tensor], device: str):
def _get_lora_b_ptr(lora_weights: List[torch.Tensor], offset_start: int,
device: str):
device: torch.device):
"""
`_LORA_B_PTR_DICT` collects the required information during `profile_run`,
After this, it remains constant and subsequent usage is through LUT.

View File

@ -0,0 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
from vllm.lora.ops.triton_ops.v1.v1_expand import v1_expand
from vllm.lora.ops.triton_ops.v1.v1_kernel_metadata import V1KernelMeta
from vllm.lora.ops.triton_ops.v1.v1_shrink import v1_shrink
__all__ = [
"v1_expand",
"v1_shrink",
"V1KernelMeta",
]

View File

@ -0,0 +1,282 @@
# SPDX-License-Identifier: Apache-2.0
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
from typing import List
import torch
import triton
import triton.language as tl
from vllm.lora.ops.triton_ops.kernel_utils import do_expand_kernel
from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr
from vllm.utils import direct_register_custom_op
@triton.jit
def _v1_expand_kernel(
input_ptr,
lora_ptr,
out_ptr,
M,
N,
K,
token_indices_sorted_by_lora_ids,
num_tokens_per_lora,
lora_token_start_loc,
lora_ids,
slice_start_loc,
input_d0_stride,
input_d1_stride,
input_d2_stride, # 1
ls_d0_ptr,
ls_d1_ptr,
ls_d2_ptr, # 1
output_d0_stride,
output_d1_stride, # 1
output_hs_ptr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
EVEN_K: tl.constexpr,
ADD_INPUTS: tl.constexpr,
CAST_TYPE: tl.constexpr,
SLICE_NUM: tl.constexpr,
SAME_STRIDE: tl.constexpr):
cta_n_num = tl.cdiv(N, BLOCK_N)
cta_m_num = tl.cdiv(M, BLOCK_M)
pid_mn = tl.program_id(axis=0)
pid_m = pid_mn % cta_m_num
pid_n = (pid_mn // cta_m_num) % cta_n_num
slice_id = tl.program_id(axis=1)
lora_idx = tl.program_id(axis=2)
lora_id = tl.load(lora_ids + lora_idx)
if lora_id == -1:
# Early exit for the no-lora case.
return
lora_m_size = tl.load(num_tokens_per_lora + lora_idx)
cta_m_offset = pid_m * BLOCK_M
if cta_m_offset >= lora_m_size:
# Early exit CTA.
return
# When the output dimensions of each slice are the same,cur_n=N, otherwise
# cur_n=tl.load(output_hs_ptr + slice_id), this situation exists in GQA's
# qkv linear.
curr_N = N if SAME_STRIDE else tl.load(output_hs_ptr + slice_id)
if pid_n * BLOCK_N >= curr_N:
# Early exit CTA.
return
# num rows this CTA should process.
cta_m_len = min(BLOCK_M, lora_m_size - cta_m_offset)
# Identify all rows that this CTA should process.
lora_m_indices_start = tl.load(lora_token_start_loc + lora_idx)
cta_lora_seq_indices = (token_indices_sorted_by_lora_ids +
lora_m_indices_start + cta_m_offset)
# Load all relevant row indices.
offset_m = tl.arange(0, BLOCK_M) % cta_m_len
ram = tl.load(cta_lora_seq_indices + offset_m)
do_expand_kernel(
pid_n,
lora_id,
slice_id,
input_ptr,
lora_ptr,
out_ptr,
curr_N,
K,
cta_m_len,
ram, # array identifying the rows of Input ptr to operate on
slice_start_loc,
# input ptr strides
input_d0_stride,
input_d1_stride,
input_d2_stride,
# lora ptr strides
ls_d0_ptr,
ls_d1_ptr,
ls_d2_ptr,
# out ptr strides
output_d0_stride,
output_d1_stride,
# constants
BLOCK_M,
BLOCK_N,
BLOCK_K,
SAME_STRIDE,
SLICE_NUM,
EVEN_K,
CAST_TYPE,
ADD_INPUTS)
@torch.inference_mode()
def _v1_expand(
inputs: torch.Tensor, # shape [num_slices, num_tokens, lora_rank]
lora_b_weights: List[
torch.Tensor], # shape [num_lora, hidden_size, lora_rank]
output_tensor: torch.
Tensor, # shape [num_tokens, hidden_size * num_slices]
token_lora_mapping: torch.Tensor, # shape [num_tokens]
token_indices_sorted_by_lora_ids: torch.Tensor, # shape [num_tokens]
num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1]
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
lora_ids: torch.Tensor, # shape [max-loras + 1]
offset_start: int = 0,
add_inputs: bool = False,
) -> None:
"""
Args:
inputs (torch.Tensor): input tensor
lora_b_weights (List[torch.Tensor]): lora'b weight
output_tensor (torch.Tensor): output tensor
token_lora_mapping (torch.Tensor): A tensor mapping each input token
to the lora-id related to that token. A value of -1 indicates that
LoRA doesn't apply to that token.
token_indices_sorted_by_lora_ids (torch.Tensor): Row/Token indices from
the A matrix grouped by LoRA IDs.
num_tokens_per_lora (torch.Tensor): num_tokens_per_lora[i] is the number
of tokens that are to be processed by LoRA ID lora_ids[i]
lora_token_start_loc (torch.Tensor): A cumulative sum of
num_tokens_per_lora. lora_token_start_loc[0] is always 0 so that
lora_token_start_loc[i], along with num_tokens_per_lora[i]
identifies the the region in token_indices_sorted_by_lora_ids that
LoRA lora_ids[i] should process.
lora_ids (torch.Tensor): LoRA ids to process.
offset_start (int, optional): Offset start for output_tensor.
Defaults to 0.
add_inputs (bool, optional): Whether to add the input tensor to the
output tensor. Defaults to False.
"""
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
for weight in lora_b_weights:
assert weight.dtype in [torch.float16, torch.bfloat16]
assert inputs.size(0) == len(lora_b_weights)
assert output_tensor.is_contiguous()
# metadata sanity check.
assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size(
0)
assert lora_ids.size(0) == num_tokens_per_lora.size(0)
assert lora_token_start_loc.size(0) == lora_ids.size(0) + 1
(slice_start_tensor, lora_ptr_tensor, lora_strides_d0_tensor,
lora_strides_d1_tensor, lora_strides_d2_tensor, hidden_sizes_tensor,
same_stride, MAX_N) = _get_lora_b_ptr(lora_b_weights, offset_start,
inputs.device)
K = lora_b_weights[0].shape[-1] # K= rank
M = inputs.size(1)
ADD_INPUTS = add_inputs
MAX_LORAS = lora_ids.size(0)
CAST_TYPE = False
NUM_SLICES = len(lora_b_weights)
# Triton kernel configs.
BLOCK_M = 64
BLOCK_N = 128
BLOCK_K = 16
NUM_WARPS = 4
NUM_CTAS = 1
NUM_STAGES = 2
MAX_NREG = None
EVEN_K = K % BLOCK_K == 0 # type: ignore
if inputs.dtype == torch.float32 and lora_b_weights[0].dtype in [
torch.float16,
torch.bfloat16,
]:
CAST_TYPE = True
# TODO (varun): This grid formulation maximizes parallelization at the
# cost of wasteful thread block launch when only a few input tokens require
# LoRA. This might not be the best in all cases.
grid = (
triton.cdiv(M, BLOCK_M) * triton.cdiv(MAX_N, BLOCK_N),
NUM_SLICES,
# Each LoRA receives its own set of thread blocks for output
# computation. If some LoRA doesn't have any tokens to process, its
# thread blocks simply exit.
MAX_LORAS,
)
_v1_expand_kernel[grid](
inputs,
lora_ptr_tensor,
output_tensor,
M,
MAX_N,
K,
token_indices_sorted_by_lora_ids,
num_tokens_per_lora,
lora_token_start_loc,
lora_ids,
slice_start_tensor,
inputs.stride(0),
inputs.stride(1),
inputs.stride(2),
lora_strides_d0_tensor,
lora_strides_d1_tensor,
lora_strides_d2_tensor,
output_tensor.stride(0),
output_tensor.stride(1),
hidden_sizes_tensor,
BLOCK_M,
BLOCK_N,
BLOCK_K,
EVEN_K,
ADD_INPUTS,
CAST_TYPE,
NUM_SLICES,
same_stride,
num_warps=NUM_WARPS,
num_ctas=NUM_CTAS,
num_stages=NUM_STAGES,
maxnreg=MAX_NREG,
)
return
def _v1_expand_fake(
inputs: torch.Tensor,
lora_b_weights: List[torch.Tensor],
output_tensor: torch.Tensor,
token_lora_mapping: torch.Tensor,
token_indices_sorted_by_lora_ids: torch.Tensor,
num_tokens_per_lora: torch.Tensor,
lora_token_start_loc: torch.Tensor,
lora_ids: torch.Tensor,
offset_start: int = 0,
add_inputs: bool = False,
) -> None:
return
try:
direct_register_custom_op(
op_name="v1_expand",
op_func=_v1_expand,
mutates_args=["output_tensor"],
fake_impl=_v1_expand_fake,
)
v1_expand = torch.ops.vllm.v1_expand
except AttributeError:
v1_expand = _v1_expand

View File

@ -0,0 +1,117 @@
# SPDX-License-Identifier: Apache-2.0
"""
V1 LoRA kernels metadata preparation utilities.
"""
from dataclasses import dataclass
from typing import Tuple, Union
import torch
@dataclass
class V1KernelMeta:
token_lora_mapping: torch.Tensor
token_indices_sorted_by_lora_ids: torch.Tensor
active_lora_ids: torch.Tensor
num_tokens_per_lora: torch.Tensor
lora_token_start_loc: torch.Tensor
@staticmethod
def make(max_loras: int, max_num_tokens: int,
device: Union[torch.device, str]) -> "V1KernelMeta":
token_lora_mapping = torch.empty(max_num_tokens,
dtype=torch.int32,
device=device)
token_indices_sorted_by_lora_ids = torch.empty(max_num_tokens,
dtype=torch.int32,
device=device)
# +1 because "no-lora" is also a possibility
# example: let max_loras be 3, active_lora_ids of [-1, 0, 2, 1]
# is a possibility.
active_lora_ids = torch.empty(max_loras + 1,
dtype=torch.int32,
device=device)
# using running example, [3, 10, 5, 2] is a possibility.
num_tokens_per_lora = torch.zeros(max_loras + 1,
dtype=torch.int32,
device=device)
# +2 for this because, the first index is always 0.
# using running example, lora_token_start_loc
# is [0, 3, 13, 18, 20].
lora_token_start_loc = torch.zeros(max_loras + 2,
dtype=torch.int32,
device=device)
return V1KernelMeta(
token_lora_mapping=token_lora_mapping,
token_indices_sorted_by_lora_ids=token_indices_sorted_by_lora_ids,
active_lora_ids=active_lora_ids,
num_tokens_per_lora=num_tokens_per_lora,
lora_token_start_loc=lora_token_start_loc)
def _reset(self):
self.active_lora_ids.fill_(-1)
self.num_tokens_per_lora.fill_(0)
self.lora_token_start_loc.fill_(0)
def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None:
"""
Prepare kernel metadata tensors for the current forward pass.
Args:
token_lora_tensor (torch.Tensor): Tensor containing lora indices
for each input token.
"""
self._reset()
num_tokens = token_lora_mapping.size(0)
# copy token lora mapping
self.token_lora_mapping[:num_tokens].copy_(token_lora_mapping,
non_blocking=True)
# token_indices_sorted_by_lora_ids
_, token_indices_sorted_by_lora_ids = torch.sort(token_lora_mapping,
stable=True)
# start gpu transfer
self.token_indices_sorted_by_lora_ids[:num_tokens].copy_(
token_indices_sorted_by_lora_ids, non_blocking=True)
# active_lora_ids, num_tokens_per_lora
lora_ids, num_tokens_per_lora = torch.unique(token_lora_mapping,
sorted=False,
return_counts=True)
self.active_lora_ids[:lora_ids.size(0)].copy_(lora_ids,
non_blocking=True)
self.num_tokens_per_lora[:num_tokens_per_lora.size(0)].copy_(
num_tokens_per_lora, non_blocking=True)
# lora_token_start_loc
lora_token_start_loc = torch.cumsum(num_tokens_per_lora, dim=0)
self.lora_token_start_loc[1:1 + lora_token_start_loc.size(0)].copy_(
lora_token_start_loc, non_blocking=True)
def meta_args(
self, token_nums: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
torch.Tensor]:
"""
This function returns the kernel metadata required for the current
forward pass execution of the kernel. The function returns all the
metadata required by the kernel, in order, as a tuple, so it can be
unpacked directly during the v1_shrink/v1_expand function call.
Args:
token_nums (int): Number of input tokens in the current forward
pass.
"""
return (self.token_lora_mapping[:token_nums],
self.token_indices_sorted_by_lora_ids[:token_nums],
self.num_tokens_per_lora, self.lora_token_start_loc,
self.active_lora_ids)

View File

@ -0,0 +1,236 @@
# SPDX-License-Identifier: Apache-2.0
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
from typing import List
import torch
import triton
import triton.language as tl
from vllm.lora.ops.triton_ops.kernel_utils import do_shrink_kernel
from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr
from vllm.utils import direct_register_custom_op
@triton.jit
def _v1_shrink_kernel(input_ptr, lora_ptr, out_ptr, M, N, K,
token_indices_sorted_by_lora_ids, num_tokens_per_lora,
lora_token_start_loc, lora_ids, scaling, input_d0_stride,
input_d1_stride, lora_d0_stride, lora_d1_stride,
lora_d2_stride, output_d0_stride, output_d1_stride,
output_d2_stride, BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
EVEN_K: tl.constexpr, SPLIT_K: tl.constexpr,
SLICE_NUM: tl.constexpr):
cta_n_num = tl.cdiv(N, BLOCK_N)
cta_m_num = tl.cdiv(M, BLOCK_M)
pid_sk_m_n = tl.program_id(axis=0)
pid_sk = pid_sk_m_n % SPLIT_K
pid_m = (pid_sk_m_n // SPLIT_K) % cta_m_num
pid_n = pid_sk_m_n // (SPLIT_K * cta_m_num) % cta_n_num
slice_id = tl.program_id(axis=1)
lora_idx = tl.program_id(axis=2)
lora_id = tl.load(lora_ids + lora_idx)
if lora_id == -1:
# Early exit for the no-lora case.
return
lora_m_size = tl.load(num_tokens_per_lora + lora_idx)
cta_m_offset = pid_m * BLOCK_M
if cta_m_offset >= lora_m_size:
# Early exit CTA.
return
# num rows this CTA should process.
cta_m_len = min(BLOCK_M, lora_m_size - cta_m_offset)
# Identify all rows that this CTA should process.
lora_m_indices_start = tl.load(lora_token_start_loc + lora_idx)
cta_lora_seq_indices = (token_indices_sorted_by_lora_ids +
lora_m_indices_start + cta_m_offset)
# Load all relevant row indices.
offset_m = tl.arange(0, BLOCK_M) % cta_m_len
ram = tl.load(cta_lora_seq_indices + offset_m)
do_shrink_kernel(
pid_n,
pid_sk,
slice_id,
lora_id,
input_ptr,
lora_ptr,
out_ptr,
N,
K,
cta_m_len,
ram, # array identifying the rows of Input ptr to operate on
# input strides
input_d0_stride,
input_d1_stride,
# lora strides
lora_d0_stride,
lora_d1_stride,
lora_d2_stride,
# output strides
output_d0_stride,
output_d1_stride,
output_d2_stride,
scaling,
BLOCK_M,
BLOCK_N,
BLOCK_K,
EVEN_K,
SPLIT_K,
SLICE_NUM)
@torch.inference_mode()
def _v1_shrink(
inputs: torch.Tensor, # shape [num_tokens, hidden_size]
lora_a_weights: List[
torch.Tensor], # shape [num_loras, lora_rank, hidden_size]
output_tensor: torch.Tensor, # shape [num_slices, num_tokens, lora_rank]
token_lora_mapping: torch.Tensor, # shape [num_tokens]
token_indices_sorted_by_lora_ids: torch.Tensor, # shape [num_tokens]
num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1]
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
lora_ids: torch.Tensor, # shape [max-loras + 1]
scaling: float,
) -> None:
"""
Args:
inputs (torch.Tensor): Input tensor
lora_a_weights (List[torch.Tensor]): LoRA weights
output_tensor (torch.Tensor): output tensor
token_lora_mapping (torch.Tensor): A tensor mapping each input token
to the lora-id related to that token. A value of -1 indicates that
LoRA doesn't apply to that token.
token_indices_sorted_by_lora_ids (torch.Tensor): Row/Token indices from
the A matrix grouped by LoRA IDs.
num_tokens_per_lora (torch.Tensor): num_tokens_per_lora[i] is the number
of tokens that are to be processed by LoRA ID lora_ids[i]
lora_token_start_loc (torch.Tensor): A cumulative sum of
num_tokens_per_lora. lora_token_start_loc[0] is always 0 so that
lora_token_start_loc[i], along with num_tokens_per_lora[i]
identifies the region in token_indices_sorted_by_lora_ids that
LoRA lora_ids[i] should process.
lora_ids (torch.Tensor): LoRA ids to process.
scaling (float): Scaling factor.
"""
assert inputs.dtype == lora_a_weights[0].dtype
assert inputs.dtype in [torch.float16, torch.bfloat16]
for weight in lora_a_weights:
assert weight.dtype in [torch.float16, torch.bfloat16]
assert inputs.size(1) == lora_a_weights[0].size(-1)
assert inputs.is_contiguous()
assert output_tensor.is_contiguous()
# metadata sanity check
assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size(
0)
assert lora_ids.size(0) == num_tokens_per_lora.size(0)
assert lora_token_start_loc.size(0) == lora_ids.size(0) + 1
(lora_ptr_tensor, lora_strides_d0, lora_strides_d1,
lora_strides_d2) = _get_lora_a_ptr(lora_a_weights, inputs.device)
N, K = lora_a_weights[0].shape[-2:] # K=hidden_size,N=rank
M = inputs.size(0)
NUM_SLICES = len(lora_a_weights)
MAX_LORAS = lora_ids.size(0)
# Triton kernel configs
BLOCK_M = 32
BLOCK_N = 16
BLOCK_K = 256 if M < 128 else 32
SPLIT_K = 64 if M < 128 else 8
NUM_WARPS = 4
NUM_CTAS = 1
NUM_STAGES = 2
MAX_NREG = None
EVEN_K = K % (BLOCK_K * SPLIT_K) == 0 # type: ignore
# TODO (varun): This grid formulation maximizes parallelization at the
# cost of wasteful thread block launch when only few of the input tokens
# require LoRA. This might not be the best in all cases.
grid = (
SPLIT_K * triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),
NUM_SLICES,
# Each LoRA receives its own set of thread blocks for output
# computation. If some LoRA doesn't have any tokens to process, its
# thread blocks exit early.
MAX_LORAS,
)
_v1_shrink_kernel[grid](
inputs,
lora_ptr_tensor,
output_tensor,
M,
N,
K,
token_indices_sorted_by_lora_ids,
num_tokens_per_lora,
lora_token_start_loc,
lora_ids,
scaling,
inputs.stride(0),
inputs.stride(1),
lora_strides_d0,
lora_strides_d1,
lora_strides_d2,
output_tensor.stride(0),
output_tensor.stride(1),
output_tensor.stride(2),
BLOCK_M,
BLOCK_N,
BLOCK_K,
EVEN_K,
SPLIT_K,
NUM_SLICES,
num_warps=NUM_WARPS,
num_ctas=NUM_CTAS,
num_stages=NUM_STAGES,
maxnreg=MAX_NREG,
)
return
def _v1_shrink_fake(
inputs: torch.Tensor,
lora_a_weights: List[torch.Tensor],
output_tensor: torch.Tensor,
token_lora_mapping: torch.Tensor,
token_indices_sorted_by_lora_ids: torch.Tensor,
num_tokens_per_lora: torch.Tensor,
lora_token_start_loc: torch.Tensor,
lora_ids: torch.Tensor,
scaling: float,
) -> None:
return
try:
direct_register_custom_op(
op_name="v1_shrink",
op_func=_v1_shrink,
mutates_args=["output_tensor"],
fake_impl=_v1_shrink_fake,
)
v1_shrink = torch.ops.vllm.v1_shrink
except AttributeError:
v1_shrink = _v1_shrink

View File

@ -6,24 +6,83 @@ Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
from typing import Optional, Tuple, Union, final
from typing import TYPE_CHECKING, List, Optional, Tuple, Union, final
import torch
import vllm.envs as env
from vllm.lora.layers import LoRAMapping
from vllm.triton_utils import HAS_TRITON
if HAS_TRITON:
from vllm.lora.ops.triton_ops import bgmv_expand
from vllm.lora.ops.triton_ops import bgmv_expand_slice
from vllm.lora.ops.triton_ops import bgmv_shrink
from vllm.lora.ops.triton_ops import sgmv_expand
from vllm.lora.ops.triton_ops import sgmv_shrink
if env.VLLM_USE_V1:
from vllm.lora.ops.triton_ops.v1 import (V1KernelMeta, v1_expand,
v1_shrink)
else:
from vllm.lora.ops.triton_ops import bgmv_expand
from vllm.lora.ops.triton_ops import bgmv_expand_slice
from vllm.lora.ops.triton_ops import bgmv_shrink
from vllm.lora.ops.triton_ops import sgmv_expand
from vllm.lora.ops.triton_ops import sgmv_shrink
from .punica_base import PunicaWrapperBase
if TYPE_CHECKING:
# avoid circuit import
from vllm.lora.models import LongContextLoRAContext
class V1KernelMixin:
def _v1_make_metadata(self, max_loras: int, max_num_batched_tokens: int,
max_batches: int, device: Union[torch.device, str]):
self.token_mapping_v1_meta = V1KernelMeta.make(max_loras,
max_num_batched_tokens,
device=device)
self.prompt_mapping_v1_meta = V1KernelMeta.make(max_loras,
max_batches,
device=device)
def _v1_prepare_metadata_tensors(self, token_lora_indices: torch.Tensor,
sampler_indices: torch.Tensor):
self.token_mapping_v1_meta.prepare_tensors(token_lora_indices)
self.prompt_mapping_v1_meta.prepare_tensors(sampler_indices)
def _v1_apply_shrink(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: Tuple[torch.Tensor, ...],
scale: float,
):
v1_shrink(
x,
w_t_all,
y,
*self.token_mapping_v1_meta.meta_args(x.size(0)),
scale,
)
def _v1_apply_expand(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: Tuple[torch.Tensor, ...],
offset_start: int,
add_inputs: bool,
):
v1_expand(
x,
w_t_all,
y,
*self.token_mapping_v1_meta.meta_args(x.size(0)),
offset_start=offset_start,
add_inputs=add_inputs,
)
@final
class PunicaWrapperGPU(PunicaWrapperBase):
class PunicaWrapperGPU(PunicaWrapperBase, V1KernelMixin):
"""
PunicaWrapperGPU is designed to manage and provide metadata for the punica
kernel. The main function is to maintain the state information for
@ -35,6 +94,36 @@ class PunicaWrapperGPU(PunicaWrapperBase):
PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches,
device)
self.max_loras = kwargs['max_loras']
if env.VLLM_USE_V1:
self._v1_make_metadata(self.max_loras, max_num_batched_tokens,
max_batches, device)
def update_metadata(
self,
mapping: LoRAMapping,
lora_index_to_id: List[Optional[int]],
max_loras: int,
vocab_size: int,
extra_vocab_size: int,
long_lora_context: Optional["LongContextLoRAContext"] = None,
**kwargs):
if env.VLLM_USE_V1:
self.is_prefill = mapping.is_prefill
self._update_base_metadata(mapping, lora_index_to_id, max_loras,
vocab_size, extra_vocab_size,
long_lora_context)
self._v1_prepare_metadata_tensors(self.token_lora_indices,
self.sampler_indices)
else:
# Forward to base class update_metadata
PunicaWrapperBase.update_metadata(self, mapping, lora_index_to_id,
max_loras, vocab_size,
extra_vocab_size,
long_lora_context, **kwargs)
def _apply_shrink_prefill(
self,
y: torch.Tensor,
@ -66,7 +155,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
w_t_all: Tuple[torch.Tensor, ...],
offset_start: int,
add_inputs: bool,
):
@ -118,14 +207,21 @@ class PunicaWrapperGPU(PunicaWrapperBase):
x = x.view(-1, x.shape[-1])
if self.is_prefill:
# NOTE fused kernel
self._apply_shrink_prefill(y, x, lora_a_stacked, scale)
if env.VLLM_USE_V1:
self._v1_apply_shrink(y, x, lora_a_stacked, scale) # type: ignore
else:
# TODO fuse these kernels
for slice_idx in range(len(lora_a_stacked)):
self._apply_shrink_decode(y[slice_idx], x,
lora_a_stacked[slice_idx], scale)
if self.is_prefill:
# NOTE fused kernel
self._apply_shrink_prefill(
y, # type: ignore
x,
lora_a_stacked,
scale)
else:
# TODO fuse these kernels
for slice_idx in range(len(lora_a_stacked)):
self._apply_shrink_decode(y[slice_idx], x,
lora_a_stacked[slice_idx], scale)
def add_expand(self,
y: torch.Tensor,
@ -160,25 +256,38 @@ class PunicaWrapperGPU(PunicaWrapperBase):
if lora_bias_stacked is not None:
self._apply_bias(self.token_lora_indices, y, output_slices,
lora_bias_stacked)
if self.is_prefill:
# NOTE fused kernel
self._apply_expand_prefill(y,
x,
lora_b_stacked,
offset_start,
add_inputs=True)
if env.VLLM_USE_V1:
# TODO (varun): Profile with add_inputs = False. i.e. move the
# addition out of the kernel
self._v1_apply_expand(
y,
x, # type: ignore
lora_b_stacked,
offset_start,
add_inputs=True)
else:
# TODO fuse these kernels
for slice_idx in range(len(lora_b_stacked)):
self._apply_expand_decode(
if self.is_prefill:
# NOTE fused kernel
self._apply_expand_prefill(
y,
x[slice_idx],
lora_b_stacked[slice_idx],
x, # type: ignore
lora_b_stacked,
offset_start,
output_slices[slice_idx],
add_inputs=add_inputs,
)
offset_start += output_slices[slice_idx]
add_inputs=True)
else:
# TODO fuse these kernels
for slice_idx in range(len(lora_b_stacked)):
self._apply_expand_decode(
y,
x[slice_idx],
lora_b_stacked[slice_idx],
offset_start,
output_slices[slice_idx],
add_inputs=add_inputs,
)
offset_start += output_slices[slice_idx]
y = y.view_as(y_org)
def add_lora_embedding(self,
@ -200,18 +309,24 @@ class PunicaWrapperGPU(PunicaWrapperBase):
add_inputs (bool): Default to True.
"""
if self.is_prefill:
sgmv_expand(
x.unsqueeze(dim=0),
[lora_b_stacked],
y,
*self.prefill_metadata,
offset_start=0,
add_inputs=add_inputs,
)
if env.VLLM_USE_V1:
self._v1_apply_expand(y,
x.unsqueeze(dim=0), (lora_b_stacked, ),
offset_start=0,
add_inputs=add_inputs)
else:
bgmv_expand(x, lora_b_stacked, y, self.token_lora_indices,
add_inputs)
if self.is_prefill:
sgmv_expand(
x.unsqueeze(dim=0),
(lora_b_stacked, ),
y,
*self.prefill_metadata,
offset_start=0,
add_inputs=add_inputs,
)
else:
bgmv_expand(x, lora_b_stacked, y, self.token_lora_indices,
add_inputs)
def add_lora_linear(self,
y: torch.Tensor,
@ -257,19 +372,25 @@ class PunicaWrapperGPU(PunicaWrapperBase):
r = lora_b_stacked[0].size(-1)
# We set the buffer to be float32 by default ,refer to:
# https://github.com/triton-lang/triton/issues/1387
buffer = torch.zeros(
buffer = torch.zeros( # type: ignore
(len(output_slices), x.size(0), r),
dtype=torch.float32,
device=x.device,
)
self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs)
self.add_expand(y,
buffer,
lora_b_stacked,
None,
output_slices,
add_inputs=True,
**kwargs)
self.add_shrink(
buffer, # type: ignore
x,
lora_a_stacked,
scale,
**kwargs)
self.add_expand(
y,
buffer, # type: ignore
lora_b_stacked,
None,
output_slices,
add_inputs=True,
**kwargs)
def add_lora_logits(self,
y: torch.Tensor,
@ -305,11 +426,22 @@ class PunicaWrapperGPU(PunicaWrapperBase):
buffer = torch.zeros((x.size(0), r),
dtype=torch.float32,
device=x.device)
# LogitsProcessorWithLoRA always using bgmv.
bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale)
bgmv_expand(buffer,
lora_b_stacked,
y,
self.sampler_indices,
add_inputs=True)
if env.VLLM_USE_V1:
v1_shrink(x, [lora_a_stacked], buffer.unsqueeze(dim=0),
*self.prompt_mapping_v1_meta.meta_args(x.size(0)), scale)
v1_expand(buffer.unsqueeze(dim=0), [lora_b_stacked],
y,
*self.prompt_mapping_v1_meta.meta_args(buffer.size(0)),
add_inputs=True)
else:
# V0 LogitsProcessorWithLoRA always using bgmv.
bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale)
bgmv_expand(buffer,
lora_b_stacked,
y,
self.sampler_indices,
add_inputs=True)
y = y.view_as(y_org)

View File

@ -62,9 +62,9 @@ class LoRAModelRunnerMixin:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
# We dont make any distinction between prefills and decodes in the
# scheduler. To that effect, set is_prefill to True so we use the
# sgmv punica kernels always.
# Set is_prefill to True, so we always use the SGMV kernels.
# For cuda platforms, we have specialized triton kernels, and
# the cuda path ignores `is_prefill`.
lora_mapping = LoRAMapping(token_lora_mapping,
prompt_lora_mapping,
is_prefill=True)